In [1]:
import os, random, argparse, sys, pickle, time
import torch
from transformers import AutoConfig, AutoTokenizer, AdamW, get_linear_schedule_with_warmup
from torch.utils.data import DataLoader, SequentialSampler
from torch.utils.data.distributed import DistributedSampler
from tqdm import tqdm, trange
import numpy as np
import pandas as pd
os.environ["TOKENIZERS_PARALLELISM"] = "false"
from models.modelings_alignable_gpt2 import *
from logic_data.constants import *
from datasets import Dataset 
from torch.utils.data import DataLoader
import gc

def set_seed(seed: int):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)

def count_parameters(model):
    return sum(p.numel() for p in model.parameters() if p.requires_grad)

import logging
logging.basicConfig(level = logging.INFO)

In [2]:
class LogicSolverAligner(object):
    def __init__(
        self, model,
        is_master,
        device,
        logger,
        lr=5e-5,
        apex_enable=False,
        n_gpu=1,
        early_stopping=5,
        do_statistic=False,
        is_wandb=False,
        model_name="",
        intervention_config=None
    ):
        self.model = model
        self.is_master = is_master
        self.logger = logger
        self.is_wandb = is_wandb
        self.model_name = model_name
        
        self.device = device
        self.lr = lr
        self.n_gpu = n_gpu
    
        self.early_stopping = early_stopping
    
        self.intervention_config = intervention_config
        self.preload_intervention_corr = None
        # this is to make things a little faster.
        if len(list(self.intervention_config.keys())) == 1:
            self.preload_intervention_corr = self.intervention_config[
                list(self.intervention_config.keys())[0]
            ]
            self.preload_intervention_corr = torch.tensor(self.preload_intervention_corr).long()
    
    def train(
        self, train_dataloader, dev_dataloader,
        optimizer, scheduler, output_dir,
        log_step, valid_steps, epochs, 
        gradient_accumulation_steps,
    ):
        # okay, have to honest, not sure whether we do train mode align or eval align;
        # i guess it is good to try both, but ... only trying train here and move on.
        self.model.train()
        train_iterator = trange(
            0, int(epochs), desc="Epoch"
        )
        total_step = 0
        total_log_step = 0
        best_eval_acc = -1
        for epoch in train_iterator:
            epoch_iterator = tqdm(train_dataloader, desc=f"Epoch: {epoch}", position=0, leave=True)
            for step, inputs in enumerate(epoch_iterator):
                for k, v in inputs.items():
                    if v is not None and isinstance(v, torch.Tensor):
                        inputs[k] = v.to(self.device)
                        
                if self.preload_intervention_corr is not None:
                    intervention_corr = self.preload_intervention_corr.expand(
                        inputs['input_ids'].shape[0],-1
                    ).to(self.device)
                else:
                    assert False # not implemented
                
                # aligning forward!
                source_hidden_states = self.model(
                   input_ids=inputs['source_input_ids']
                ).rotated_hidden_states
                outputs = self.model(
                    input_ids=inputs['input_ids'],
                    source_hidden_states=source_hidden_states,
                    intervention_corr=intervention_corr,
                    labels=inputs['counterfactual_labels']
                )
                loss = outputs.loss.mean() if self.n_gpu > 1 else outputs.loss
                
                actual_test_labels = inputs['counterfactual_labels'][:, -3]
                pred_test_labels = torch.argmax(outputs.logits[:, -4], dim=-1)
                correct_labels = (actual_test_labels==pred_test_labels)
                
                step_accuracy = correct_labels.sum() / correct_labels.shape[0]
                step_accuracy = step_accuracy.tolist()

                if total_step % log_step == 0 and self.is_wandb:
                    wandb.log(
                        {
                            "train/loss": loss.item(),
                            "train/step_accuracy": step_accuracy
                        },
                        step=total_log_step
                    )
                    
                    if total_step % valid_steps == 0:
                        total_count = 0
                        correct_count = 0
                        self.model.eval()
                        for step, inputs in enumerate(dev_dataloader):
                            for k, v in inputs.items():
                                if v is not None and isinstance(v, torch.Tensor):
                                    inputs[k] = v.to(self.device)
                            if self.preload_intervention_corr is not None:
                                intervention_corr = self.preload_intervention_corr.expand(
                                    inputs['input_ids'].shape[0],-1
                                ).to(self.device)
                            else:
                                assert False # not implemented

                            # aligning forward!
                            source_hidden_states = self.model(
                               input_ids=inputs['source_input_ids']
                            ).rotated_hidden_states
                            outputs = self.model(
                                input_ids=inputs['input_ids'],
                                source_hidden_states=source_hidden_states,
                                intervention_corr=intervention_corr,
                                labels=inputs['counterfactual_labels']
                            )

                            actual_test_labels = inputs['counterfactual_labels'][:, -3]
                            pred_test_labels = torch.argmax(outputs.logits[:, -4], dim=-1)
                            correct_labels = (actual_test_labels==pred_test_labels)

                            total_count += len(correct_labels)
                            correct_count += correct_labels.sum().tolist()

                        current_acc = round(correct_count/total_count, 2)
                        wandb.log(
                            {
                                "eval/accuracy": current_acc
                            },
                            step=total_log_step
                        )
                        if current_acc > best_eval_acc:
                            best_eval_acc = current_acc
                            if self.is_master:
                                if self.n_gpu > 1:
                                    self.model.module.save_pretrained(os.path.join(output_dir, 'model-best'))
                                else:
                                    self.model.save_pretrained(os.path.join(output_dir, 'model-best'))
                        self.model.train()
                        
                    
                    total_log_step += 1
                loss_str = round(loss.item(), 2)
                epoch_iterator.set_postfix({'loss': loss_str})
                
                if gradient_accumulation_steps > 1:
                    loss = loss / gradient_accumulation_steps
                
                if total_step % gradient_accumulation_steps == 0:
                    loss.backward()
                    optimizer.step()
                    scheduler.step()
                    self.model.zero_grad()
                    
                total_step += 1
                
        logging.info("Training is finished ...") 
        if self.is_master:
            if self.n_gpu > 1:
                self.model.module.save_pretrained(os.path.join(output_dir, 'model-last'))
            else:
                self.model.save_pretrained(os.path.join(output_dir, 'model-last'))

In [68]:
if __name__ == '__main__':
    is_notebook = False
    try:
        cmd = argparse.ArgumentParser('The testing components of')
        cmd.add_argument('--gpu', default=-1, type=int, help='use id of gpu, -1 if cpu.')
        cmd.add_argument('--train_batch_size', default=128, type=int, help='training batch size')
        cmd.add_argument('--eval_batch_size', default=128, type=int, help='training batch size')
        cmd.add_argument('--lr', default=0.01, type=float, help='learning rate')
        cmd.add_argument('--data_path', required=True, type=str, help='path to the training corpus')
        cmd.add_argument('--train_data_path', required=True, type=str, help='path to the training corpus')
        cmd.add_argument('--test_data_path', required=True, type=str, help='path to the training corpus')
        cmd.add_argument(
            '--encoder_config_path', 
            type=str, help='path to the encoder config'
        )
        cmd.add_argument(
            '--decoder_config_path', 
            type=str, help='path to the decoder config'
        )
        cmd.add_argument('--max_seq_len', default=512, type=int)
        cmd.add_argument('--seed', default=42, type=int)
        cmd.add_argument('--gradient_accumulation_steps', default=1, type=int)
        cmd.add_argument('--output_dir', required=True, type=str, help='save dir')
        cmd.add_argument('--local_rank', default=-1, type=int, help='multi gpu training')
        cmd.add_argument('--epochs', default=10, type=int, help='training epochs')
        cmd.add_argument('--model_path', type=str, required=False, default=None)
        cmd.add_argument('--warm_up', type=float, default=0.1)
        cmd.add_argument('--is_wandb', default=False, action='store_true')
        cmd.add_argument('--log_step', default=10, type=int)
        cmd.add_argument('--valid_steps', default=500, type=int)
        cmd.add_argument('--early_stopping', default=5, type=int)
        cmd.add_argument('--device', default="cuda", type=str, help='')
        cmd.add_argument('--do_prealign_eval', default=False, action='store_true')
        cmd.add_argument('--do_align', default=False, action='store_true')
        cmd.add_argument('--do_eval', default=False, action='store_true')
        cmd.add_argument('--do_test', default=False, action='store_true')
        
        cmd.add_argument('--n_training_program', default=5, type=int)
        cmd.add_argument('--n_fewshot', default=6, type=int)
        cmd.add_argument('--aligning_layer_n', default=0, type=int)
        
        args = cmd.parse_args(sys.argv[1:])
    except:
        is_notebook = True
        parser = argparse.ArgumentParser()
        args = parser.parse_args([])
        args.gpu = 1
        args.train_batch_size = 64
        args.eval_batch_size = 64
        args.gradient_accumulation_steps = 2
        args.lr = 1e-3
        args.data_path = "./logic_data"
        args.train_data_path = \
            ""
        args.test_data_path = \
            ""
        args.encoder_config_path = None
        args.decoder_config_path = None
        args.max_seq_len = 512
        args.seed = 42
        args.output_dir = "./results_notebook/"
        args.epochs = 10
        args.warm_up = 0.1
        args.is_wandb = False
        args.log_step = 10
        args.valid_steps = 100 # -1 not do training eval!
        args.early_stopping = 999 # large == never early stop!
        args.device = "cuda:0"
        args.do_prealign_eval = True # do it once at least!
        args.do_align = True
        args.do_eval = True
        args.do_test = True
        args.model_path = "./results_notebook/logic_pipeline.model.gpt2.n_rule.7.n_shot.14.seed.42/model-last/"
        # args.model_path = None
        
        # alignment search setting
        args.aligning_layer_n = 0
        args.aligning_basis_n = 600
        args.aligning_var_n = 1
        
        print("Using in a notebook env.")

Using in a notebook env.


usage: The testing components of [-h] [--gpu GPU]
                                 [--train_batch_size TRAIN_BATCH_SIZE]
                                 [--eval_batch_size EVAL_BATCH_SIZE] [--lr LR]
                                 --data_path DATA_PATH --train_data_path
                                 TRAIN_DATA_PATH --test_data_path
                                 TEST_DATA_PATH
                                 [--encoder_config_path ENCODER_CONFIG_PATH]
                                 [--decoder_config_path DECODER_CONFIG_PATH]
                                 [--max_seq_len MAX_SEQ_LEN] [--seed SEED]
                                 [--gradient_accumulation_steps GRADIENT_ACCUMULATION_STEPS]
                                 --output_dir OUTPUT_DIR
                                 [--local_rank LOCAL_RANK] [--epochs EPOCHS]
                                 [--model_path MODEL_PATH] [--warm_up WARM_UP]
                                 [--is_wandb] [--log_step LOG_STEP]
          

In [77]:
args.train_data_path = \
f"./logic_data/left_aligment_test_data.l1.unrolling.{s}.clauses.(c==b)and(c!=a).pkl"
args.test_data_path = \
f"./logic_data/left_aligment_test_data.l1.unrolling.{s}.clauses.(c==b)and(c!=a).pkl"
if args.model_path is None:
    model_name = "logic_pipeline.model.gpt2.n_rule.7.n_shot.14.seed.42"
else:
    model_name = args.model_path.strip("/").split("/")[-2]
align_dataname = args.train_data_path.split("/")[-1].split(".pkl")[0]
run_name = f"{model_name}.data.{align_dataname}.seed.{args.seed}"
logger = logging.getLogger()

# Dataloader
train_data = pickle.load(open(args.train_data_path, 'rb'))
test_data = pickle.load(open(args.test_data_path, 'rb'))

train_dataset = Dataset.from_dict(
    {
        "input_ids": train_data["base_input_ids"], 
        "labels": train_data["base_output_ids"],
        "source_input_ids": train_data["source_input_ids"], 
        "counterfactual_labels": train_data["counterfacut_output_ids"],
        "intervention_ids": train_data["intervention_ids"],
    }
).with_format("torch")
train_dataloader = DataLoader(train_dataset, batch_size=args.train_batch_size)

test_dataset = Dataset.from_dict(
    {
        "input_ids": test_data["base_input_ids"], 
        "labels": test_data["base_output_ids"],
        "source_input_ids": test_data["source_input_ids"], 
        "counterfactual_labels": test_data["counterfacut_output_ids"],
        "intervention_ids": test_data["intervention_ids"],
    }
).with_format("torch")
test_dataloader = DataLoader(test_dataset, batch_size=args.eval_batch_size)

# Model
torch.cuda.empty_cache()

configuration = GPT2Config.from_pretrained(os.path.join(args.data_path, "decoder_config.json"))

if "logic" in model_name:
    arity = 3
idx = 0
for token in model_name.split("."):
    if token == "n_shot":
        break
    idx += 1
n_shot = model_name.split(".")[idx+1]

if "unrolling" in args.train_data_path:
    idx = 0
    for token in args.train_data_path.split("."):
        if token == "unrolling":
            break
        idx += 1
    n_shot = args.train_data_path.split(".")[idx+1]
    # reset to unrolling position

start_idx = 1 + (arity + 4) * int(n_shot)
end_idx = start_idx + (arity+1)

alignment_config = {
    "layer" : args.aligning_layer_n,
    "token_range" : [start_idx, end_idx] # this is kind of fixed?
}
if args.aligning_var_n == 1:
    intervention_config = {
        0: [[0, args.aligning_basis_n]]
    }
elif args.aligning_var_n == 2:
    pass
logging.info(f"intervention_config = {intervention_config}")
logging.info(f"alignment_config = {alignment_config}")

model = AlignableGPT2LMHeadModel(configuration, alignment_config=alignment_config)
if args.model_path is not None:
    logging.info("Loading pretrained model.")
    raw_weights = torch.load(os.path.join(args.model_path, 'pytorch_model.bin'))
    model.load_state_dict(raw_weights, strict=False)

# we need to set off gradients!
for name, param in model.named_parameters():
    if "rotate_layer" not in name:
        param.requires_grad = False

set_seed(args.seed)
device = torch.device(args.device)
if "cuda:" not in args.device:
    n_gpu = torch.cuda.device_count()
    logging.info(f'__Number CUDA Devices: {n_gpu}')
else:
    n_gpu = 1
    logging.info(f'__Number CUDA Devices: {n_gpu}')

if n_gpu > 1:
    model = torch.nn.DataParallel(model)
_ = model.to(device)

t_total = int(len(train_dataloader) * args.epochs)

warm_up_steps = args.warm_up * t_total
optimizer = torch.optim.AdamW(
    model.parameters(), lr=args.lr
)
scheduler = get_linear_schedule_with_warmup(optimizer, num_warmup_steps=warm_up_steps,
                                            num_training_steps=t_total)
is_master = True                                    
if not os.path.exists(args.output_dir) and is_master:
    os.mkdir(args.output_dir)
os.environ["WANDB_PROJECT"] = f"ToM-DAS"

output_dir = os.path.join(args.output_dir, run_name)
if args.do_align and args.is_wandb:
    import wandb
    run = wandb.init(
        project="ToM-DAS-GPT2", 
        entity="wuzhengx",
        name=run_name,
    )
    wandb.config.update(args)
if not os.path.exists(args.output_dir) and is_master:
    os.mkdir(args.output_dir)

if args.do_prealign_eval:
    # before doing alignment, we need to check factual performance on the dataset.
    total_count = 0
    correct_count = 0
    if args.do_eval:
        _ = model.eval()
        epoch_iterator = tqdm(test_dataloader, desc="Iteration", position=0, leave=True)
        for step, inputs in enumerate(epoch_iterator):
            input_ids = inputs['input_ids'].to(device)
            labels = inputs['labels'].to(device)
            outputs = model(input_ids=input_ids)
            actual_test_labels = labels[:, -3]
            pred_test_labels = torch.argmax(outputs.logits[:, -4], dim=-1)
            correct_labels = (actual_test_labels==pred_test_labels)

            total_count += len(correct_labels)
            correct_count += correct_labels.sum().tolist()

            current_acc = round(correct_count/total_count, 2)
            epoch_iterator.set_postfix({'acc': current_acc})


INFO:root:intervention_config = {0: [[0, 600]]}
INFO:root:alignment_config = {'layer': 0, 'token_range': [1, 5]}
INFO:root:Loading pretrained model.
INFO:root:__Number CUDA Devices: 1
Iteration: 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 16/16 [00:00<00:00, 27.29it/s, acc=0.79]


In [10]:
aligner = LogicSolverAligner(
    model, device=device, 
    logger=logger,
    is_master=is_master, 
    n_gpu=n_gpu,
    is_wandb=args.is_wandb, 
    model_name=model_name,
    intervention_config=intervention_config
)
num_params = count_parameters(model)
logging.info(f'Number of {model_name} model params: {num_params}')

INFO:root:Number of logic_pipeline.model.gpt2.n_rule.7.n_shot.14.seed.42 model params: 1440000


In [6]:
# Train
if args.do_align:
    logging.info(f"OUTPUT DIR: {output_dir}")
    aligner.train(
        train_dataloader, test_dataloader,
        optimizer, scheduler, 
        log_step=args.log_step, valid_steps=args.valid_steps,
        output_dir=output_dir, epochs=args.epochs, 
        gradient_accumulation_steps=args.gradient_accumulation_steps,
    )

INFO:root:OUTPUT DIR: ./results_notebook/logic_pipeline.model.gpt2.n_rule.7.n_shot.14.seed.42.data.left_aligment_train_data.l1.clauses.(c==a)or(b!=a).seed.42
Epoch: 0:  84%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████                       | 262/313 [00:36<00:07,  7.25it/s, loss=0]
Epoch:   0%|                                                                                                                                                                   | 0/10 [00:36<?, ?it/s]


KeyboardInterrupt: 

In [7]:
if args.is_wandb:
    wandb.finish()

In [7]:
# Test
if args.do_test: 
    total_count = 0
    correct_count = 0
    if args.do_test:
        aligner.model.eval()
        epoch_iterator = tqdm(test_dataloader, desc="Iteration", position=0, leave=True)
        for step, inputs in enumerate(epoch_iterator):
            for k, v in inputs.items():
                if v is not None and isinstance(v, torch.Tensor):
                    inputs[k] = v.to(device)
            if aligner.preload_intervention_corr is not None:
                intervention_corr = aligner.preload_intervention_corr.expand(
                    inputs['input_ids'].shape[0],-1
                ).to(device)
            else:
                assert False # not implemented

            # aligning forward!
            source_hidden_states = aligner.model(
               input_ids=inputs['source_input_ids']
            ).rotated_hidden_states
            outputs = aligner.model(
                input_ids=inputs['input_ids'],
                source_hidden_states=source_hidden_states,
                intervention_corr=intervention_corr,
                labels=inputs['counterfactual_labels']
            )

            actual_test_labels = inputs['counterfactual_labels'][:, -3]
            pred_test_labels = torch.argmax(outputs.logits[:, -4], dim=-1)
            correct_labels = (actual_test_labels==pred_test_labels)

            total_count += len(correct_labels)
            correct_count += correct_labels.sum().tolist()

            current_acc = round(correct_count/total_count, 2)
            epoch_iterator.set_postfix({'acc': current_acc})

Iteration: 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 16/16 [00:01<00:00, 10.23it/s, acc=0.68]


### Zero-shot representation transfer


In [11]:
args.test_data_path = "./logic_data/left_aligment_train_data.l3.clauses.(c==a)or(b!=a)+(c==a)and(a==b).pkl"
test_data = pickle.load(open(args.test_data_path, 'rb'))

test_dataset = Dataset.from_dict(
    {
        "input_ids": test_data["base_input_ids"], 
        "labels": test_data["base_output_ids"],
        "source_input_ids": test_data["source_input_ids"], 
        "counterfactual_labels": test_data["counterfacut_output_ids"],
        "intervention_ids": test_data["intervention_ids"],
    }
).with_format("torch")
test_dataloader = DataLoader(test_dataset, batch_size=args.eval_batch_size)

# Model
torch.cuda.empty_cache()

total_count = 0
correct_count = 0
if args.do_test:
    aligner.model.eval()
    epoch_iterator = tqdm(test_dataloader, desc="Iteration", position=0, leave=True)
    for step, inputs in enumerate(epoch_iterator):
        for k, v in inputs.items():
            if v is not None and isinstance(v, torch.Tensor):
                inputs[k] = v.to(device)
        if aligner.preload_intervention_corr is not None:
            intervention_corr = aligner.preload_intervention_corr.expand(
                inputs['input_ids'].shape[0],-1
            ).to(device)
        else:
            assert False # not implemented

        # aligning forward!
        source_hidden_states = aligner.model(
           input_ids=inputs['source_input_ids']
        ).rotated_hidden_states
        outputs = aligner.model(
            input_ids=inputs['input_ids'],
            source_hidden_states=source_hidden_states,
            intervention_corr=intervention_corr,
            labels=inputs['counterfactual_labels']
        )

        actual_test_labels = inputs['counterfactual_labels'][:, -3]
        pred_test_labels = torch.argmax(outputs.logits[:, -4], dim=-1)
        correct_labels = (actual_test_labels==pred_test_labels)

        total_count += len(correct_labels)
        correct_count += correct_labels.sum().tolist()

        current_acc = round(correct_count/total_count, 2)
        epoch_iterator.set_postfix({'acc': current_acc})

Iteration: 100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 313/313 [00:28<00:00, 11.13it/s, acc=0.71]
