In [None]:
from utils.train_utils import *
from logic_data.utils import *

import seaborn as sns
import matplotlib.pyplot as plt

In [None]:
if __name__ == '__main__':
    is_notebook = False
    try:
        cmd = argparse.ArgumentParser('The testing components of')
        cmd.add_argument('--gpu', default=-1, type=int, help='use id of gpu, -1 if cpu.')
        cmd.add_argument('--train_batch_size', default=128, type=int, help='training batch size')
        cmd.add_argument('--eval_batch_size', default=128, type=int, help='training batch size')
        cmd.add_argument('--lr', default=0.01, type=float, help='learning rate')
        cmd.add_argument('--data_path', required=True, type=str, help='path to the training corpus')
        cmd.add_argument('--train_data_path', required=True, type=str, help='path to the training corpus')
        cmd.add_argument('--test_data_path', required=True, type=str, help='path to the training corpus')
        cmd.add_argument(
            '--encoder_config_path', 
            type=str, help='path to the encoder config'
        )
        cmd.add_argument(
            '--decoder_config_path', 
            type=str, help='path to the decoder config'
        )
        cmd.add_argument('--max_seq_len', default=512, type=int)
        cmd.add_argument('--seed', default=42, type=int)
        cmd.add_argument('--gradient_accumulation_steps', default=1, type=int)
        cmd.add_argument('--output_dir', required=True, type=str, help='save dir')
        cmd.add_argument('--local_rank', default=-1, type=int, help='multi gpu training')
        cmd.add_argument('--epochs', default=10, type=int, help='training epochs')
        cmd.add_argument('--model_path', type=str, required=False, default=None)
        cmd.add_argument('--warm_up', type=float, default=0.1)
        cmd.add_argument('--is_wandb', default=False, action='store_true')
        cmd.add_argument('--log_step', default=10, type=int)
        cmd.add_argument('--valid_steps', default=500, type=int)
        cmd.add_argument('--early_stopping', default=5, type=int)
        cmd.add_argument('--device', default="cuda", type=str, help='')
        cmd.add_argument('--do_prealign_eval', default=False, action='store_true')
        cmd.add_argument('--do_align', default=False, action='store_true')
        cmd.add_argument('--do_eval', default=False, action='store_true')
        cmd.add_argument('--do_test', default=False, action='store_true')
        
        cmd.add_argument('--n_training_program', default=5, type=int)
        cmd.add_argument('--n_fewshot', default=6, type=int)
        cmd.add_argument('--aligning_layer_n', default=0, type=int)
        
        args = cmd.parse_args(sys.argv[1:])
    except:
        is_notebook = True
        parser = argparse.ArgumentParser()
        args = parser.parse_args([])
        args.gpu = 1
        args.eval_batch_size = 64
        args.gradient_accumulation_steps = 2
        args.data_path = "./logic_data"
        args.encoder_config_path = None
        args.decoder_config_path = None
        args.max_seq_len = 512
        args.output_dir = "./results_notebook/"
        args.epochs = 10
        args.warm_up = 0.1
        args.is_wandb = False
        args.log_step = 10
        args.valid_steps = 100 # -1 not do training eval!
        args.early_stopping = 999 # large == never early stop!
        args.device = "cuda:0"
        args.do_prealign_eval = True # do it once at least!
        args.do_align = True
        args.do_eval = True
        args.do_test = True
        # args.model_path = None
        
        # alignment search setting
        args.aligning_layer_n = 0
        args.aligning_basis_n = 600
        args.aligning_var_n = 1
        
        print("Using in a notebook env.")

In [None]:
results = []

In [None]:
source_training_clauses = [
    '( c != a ) or ( a == b )',
    '( c != a ) and ( a == b )',
    '( c != a ) or ( a != b )',
    '( c != a ) and ( b != c )',
    '( c != a ) and ( a != b )',
    '( c == a ) or ( a != b )',
    '( c == a ) or ( b != c )',
    '( c == a ) or ( b == c )',
    '( b != a ) and ( b != c )',
    '( c == a ) and ( b != c )',
    '( c == a ) or ( a == b )',
]

target_training_clauses = [
    '( c != a ) or ( a == b )',
    '( c != a ) and ( a == b )',
    '( c != a ) or ( a != b )',
    '( c != a ) and ( b != c )',
    '( c != a ) and ( a != b )',
    '( c == a ) or ( a != b )',
    '( c == a ) or ( b != c )',
    '( c == a ) or ( b == c )',
    '( b != a ) and ( b != c )',
    '( c == a ) and ( b != c )',
    '( c == a ) or ( a == b )',
]
align_type = "left_aligment"
level = "l2"
shared_train = True if level == "l1" else False
n_fewshot = 10
n_examples = n_fewshot + 1
n_testing_examples = 1000
control = True

for seed in [66,]: # [42, 66, 77, 88, 99]
    for source_clauses in source_training_clauses:
        for clauses in target_training_clauses:
            args.seed = seed
            set_seed(args.seed)
            source_clauses_no_space = source_clauses.replace(" ", "")
            clauses_no_space = clauses.replace(" ", "")
            print(f"source clauses: {source_clauses}; transferring clauses: {clauses}")
            print("seed: ", seed)
            # finding model files, it needs to exist on the disk.
            if control:
                model_name = f"logic_pipeline.model.gpt2.n_rule.11.n_shot.{n_fewshot}.seed.{seed}/model-last"
            else:
                model_name = f"logic_pipeline.model.gpt2.n_rule.11.n_shot.{n_fewshot}.seed.{seed}.clauses.{source_clauses_no_space}.align.{align_type}.seed.{seed}/model-last"
            args.model_path = os.path.join(args.output_dir, model_name)
            logger = logging.getLogger()
            aligment_sampled_data = left_aligment_sampler(
                clauses, n_testing_examples, n_examples = n_examples, shared_train=shared_train
            )
            test_data = {
                "base_input_ids" : aligment_sampled_data[0],
                "base_output_ids" : aligment_sampled_data[1],
                "source_input_ids" : aligment_sampled_data[3],
                "source_output_ids" : aligment_sampled_data[4],
                "counterfacut_output_ids": aligment_sampled_data[6],
                "clauses" : aligment_sampled_data[2],
                "intervention_ids": [0 for i in range(len(aligment_sampled_data[0]))]
            }
            test_dataset = Dataset.from_dict(
                {
                    "input_ids": test_data["base_input_ids"], 
                    "labels": test_data["base_output_ids"],
                    "source_input_ids": test_data["source_input_ids"], 
                    "counterfactual_labels": test_data["counterfacut_output_ids"],
                    "intervention_ids": test_data["intervention_ids"],
                }
            ).with_format("torch")
            test_dataloader = DataLoader(test_dataset, batch_size=args.eval_batch_size)

            # Model
            torch.cuda.empty_cache()

            configuration = GPT2Config.from_pretrained(os.path.join(args.data_path, "decoder_config.json"))

            if "logic" in model_name:
                arity = 3
            start_idx = 1 + (arity + 4) * int(n_fewshot)
            end_idx = start_idx + (arity+1)
            alignment_config = {
                "layer" : args.aligning_layer_n,
                "token_range" : [start_idx, end_idx] # this is kind of fixed?
            }
            if args.aligning_var_n == 1:
                intervention_config = {
                    0: [[0, args.aligning_basis_n]]
                }
            elif args.aligning_var_n == 2:
                pass
            logging.info(f"intervention_config = {intervention_config}")
            logging.info(f"alignment_config = {alignment_config}")

            model = AlignableGPT2LMHeadModel(configuration, alignment_config=alignment_config)
            if args.model_path is not None:
                logging.info("Loading pretrained model.")
                raw_weights = torch.load(os.path.join(args.model_path, 'pytorch_model.bin'))
                model.load_state_dict(raw_weights, strict=False)

            # we need to set off gradients!
            for name, param in model.named_parameters():
                param.requires_grad = False

            device = torch.device(args.device)
            if "cuda:" not in args.device:
                n_gpu = torch.cuda.device_count()
                logging.info(f'__Number CUDA Devices: {n_gpu}')
            else:
                n_gpu = 1
                logging.info(f'__Number CUDA Devices: {n_gpu}')

            if n_gpu > 1:
                model = torch.nn.DataParallel(model)
            _ = model.to(device)

            aligner = LogicSolverAligner(
                model, device=device, 
                logger=logger,
                is_master=True, 
                n_gpu=n_gpu,
                is_wandb=args.is_wandb, 
                model_name=model_name,
                intervention_config=intervention_config
            )
            num_params = count_parameters(model)
            logging.info(f'Number of {model_name} model params: {num_params}')

            total_count = 0
            correct_count = 0
            _ = model.eval()
            epoch_iterator = tqdm(test_dataloader, desc="Iteration", position=0, leave=True)
            for step, inputs in enumerate(epoch_iterator):
                input_ids = inputs['input_ids'].to(device)
                labels = inputs['labels'].to(device)
                outputs = model(input_ids=input_ids)
                actual_test_labels = labels[:, -3]
                pred_test_labels = torch.argmax(outputs.logits[:, -4], dim=-1)
                correct_labels = (actual_test_labels==pred_test_labels)

                total_count += len(correct_labels)
                correct_count += correct_labels.sum().tolist()

                current_acc = round(correct_count/total_count, 2)
                epoch_iterator.set_postfix({'acc': current_acc})
            task_acc = round(correct_count/total_count, 2)

            total_count = 0
            correct_count = 0
            aligner.model.eval()
            epoch_iterator = tqdm(test_dataloader, desc="Iteration", position=0, leave=True)
            for step, inputs in enumerate(epoch_iterator):
                for k, v in inputs.items():
                    if v is not None and isinstance(v, torch.Tensor):
                        inputs[k] = v.to(device)
                if aligner.preload_intervention_corr is not None:
                    intervention_corr = aligner.preload_intervention_corr.expand(
                        inputs['input_ids'].shape[0],-1
                    ).to(device)
                else:
                    assert False # not implemented

                # aligning forward!
                source_hidden_states = aligner.model(
                   input_ids=inputs['source_input_ids']
                ).rotated_hidden_states
                outputs = aligner.model(
                    input_ids=inputs['input_ids'],
                    source_hidden_states=source_hidden_states,
                    intervention_corr=intervention_corr,
                    labels=inputs['counterfactual_labels']
                )

                actual_test_labels = inputs['counterfactual_labels'][:, -3]
                pred_test_labels = torch.argmax(outputs.logits[:, -4], dim=-1)
                correct_labels = (actual_test_labels==pred_test_labels)

                total_count += len(correct_labels)
                correct_count += correct_labels.sum().tolist()

                current_acc = round(correct_count/total_count, 2)
                epoch_iterator.set_postfix({'acc': current_acc})
            iia_acc = round(correct_count/total_count, 2)

            results += [[source_clauses, clauses, seed, task_acc, iia_acc]]


In [None]:
result_df = pd.DataFrame(results, columns = [
    'src_clauses', 'tgt_clauses', 'seed', 'task_acc', 'iit_acc'
])

In [None]:
heatmap_df = result_df.pivot(index='src_clauses', columns='tgt_clauses', values='iit_acc')

In [None]:
sns.heatmap(heatmap_df, annot=True, fmt="g", cmap='viridis', vmax=1.0, vmin=0.0)
plt.xticks(rotation=70)
plt.show()