In [1]:
import os, random, argparse, sys, pickle, time
import torch
from transformers import AutoConfig, AutoTokenizer, AdamW, get_linear_schedule_with_warmup
from torch.utils.data import DataLoader, SequentialSampler
from torch.utils.data.distributed import DistributedSampler
from tqdm import tqdm, trange
import numpy as np
import pandas as pd
os.environ["TOKENIZERS_PARALLELISM"] = "false"
from models.modelings_gpt2 import *
from logic_data.constants import *
from datasets import Dataset 
from torch.utils.data import DataLoader

def set_seed(seed: int):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)

def count_parameters(model):
    return sum(p.numel() for p in model.parameters() if p.requires_grad)

import logging
logging.basicConfig(level = logging.INFO)

In [2]:
class LogicSolverTrainer(object):
    def __init__(
        self, model,
        is_master,
        device,
        logger,
        lr=5e-5,
        apex_enable=False,
        n_gpu=1,
        early_stopping=5,
        do_statistic=False,
        is_wandb=False,
        model_name="",
    ):
        self.model = model
        self.is_master = is_master
        self.logger = logger
        self.is_wandb = is_wandb
        self.model_name = model_name
        
        self.device = device
        self.lr = lr
        self.n_gpu = n_gpu
    
        self.early_stopping = early_stopping
    
    def train(
        self, train_dataloader, dev_dataloader,
        optimizer, scheduler, output_dir,
        log_step, valid_steps, epochs, 
        gradient_accumulation_steps,
    ):
        self.model.train()
        train_iterator = trange(
            0, int(epochs), desc="Epoch"
        )
        total_step = 0
        total_log_step = 0
        best_eval_acc = -1
        for epoch in train_iterator:
            epoch_iterator = tqdm(train_dataloader, desc=f"Epoch: {epoch}", position=0, leave=True)
            for step, inputs in enumerate(epoch_iterator):
                for k, v in inputs.items():
                    if v is not None and isinstance(v, torch.Tensor):
                        inputs[k] = v.to(self.device)
                outputs = self.model(**inputs)
                loss = outputs.loss.mean() if self.n_gpu > 1 else outputs.loss
                
                actual_test_labels = inputs['labels'][:, -3]
                pred_test_labels = torch.argmax(outputs.logits[:, -4], dim=-1)
                correct_labels = (actual_test_labels==pred_test_labels)
                
                step_accuracy = correct_labels.sum() / correct_labels.shape[0]
                step_accuracy = step_accuracy.tolist()

                if total_step % log_step == 0 and self.is_wandb:
                    wandb.log(
                        {
                            "train/loss": loss.item(),
                            "train/step_accuracy": step_accuracy
                        },
                        step=total_log_step
                    )
                    
                    if total_step % valid_steps == 0:
                        total_count = 0
                        correct_count = 0
                        self.model.eval()
                        for step, inputs in enumerate(dev_dataloader):
                            for k, v in inputs.items():
                                if v is not None and isinstance(v, torch.Tensor):
                                    inputs[k] = v.to(self.device)
                            outputs = model(**inputs)

                            actual_test_labels = inputs['labels'][:, -3]
                            pred_test_labels = torch.argmax(outputs.logits[:, -4], dim=-1)
                            correct_labels = (actual_test_labels==pred_test_labels)

                            total_count += len(correct_labels)
                            correct_count += correct_labels.sum().tolist()

                        current_acc = round(correct_count/total_count, 2)
                        wandb.log(
                            {
                                "eval/accuracy": current_acc
                            },
                            step=total_log_step
                        )
                        if current_acc > best_eval_acc:
                            best_eval_acc = current_acc
                            if self.is_master:
                                if self.n_gpu > 1:
                                    self.model.module.save_pretrained(os.path.join(output_dir, 'model-best'))
                                else:
                                    self.model.save_pretrained(os.path.join(output_dir, 'model-best'))
                        self.model.train()
                        
                    
                    total_log_step += 1
                loss_str = round(loss.item(), 2)
                epoch_iterator.set_postfix({'loss': loss_str})
                
                if gradient_accumulation_steps > 1:
                    loss = loss / gradient_accumulation_steps
                
                if total_step % gradient_accumulation_steps == 0:
                    loss.backward()
                    optimizer.step()
                    scheduler.step()
                    self.model.zero_grad()
                    
                total_step += 1
                
        logging.info("Training is finished ...") 
        if self.is_master:
            if self.n_gpu > 1:
                self.model.module.save_pretrained(os.path.join(output_dir, 'model-last'))
            else:
                self.model.save_pretrained(os.path.join(output_dir, 'model-last'))

In [3]:
if __name__ == '__main__':
    is_notebook = False
    try:
        cmd = argparse.ArgumentParser('The testing components of')
        cmd.add_argument('--gpu', default=-1, type=int, help='use id of gpu, -1 if cpu.')
        cmd.add_argument('--train_batch_size', default=128, type=int, help='training batch size')
        cmd.add_argument('--eval_batch_size', default=128, type=int, help='training batch size')
        cmd.add_argument('--lr', default=0.01, type=float, help='learning rate')
        cmd.add_argument('--data_path', required=True, type=str, help='path to the training corpus')
        cmd.add_argument(
            '--encoder_config_path', 
            type=str, help='path to the encoder config'
        )
        cmd.add_argument(
            '--decoder_config_path', 
            type=str, help='path to the decoder config'
        )
        cmd.add_argument('--max_seq_len', default=512, type=int)
        cmd.add_argument('--seed', default=42, type=int)
        cmd.add_argument('--gradient_accumulation_steps', default=1, type=int)
        cmd.add_argument('--output_dir', required=True, type=str, help='save dir')
        cmd.add_argument('--local_rank', default=-1, type=int, help='multi gpu training')
        cmd.add_argument('--epochs', default=10, type=int, help='training epochs')
        cmd.add_argument('--model_path', type=str, required=False, default=None)
        cmd.add_argument('--warm_up', type=float, default=0.1)
        cmd.add_argument('--is_wandb', default=False, action='store_true')
        cmd.add_argument('--log_step', default=10, type=int)
        cmd.add_argument('--valid_steps', default=500, type=int)
        cmd.add_argument('--early_stopping', default=5, type=int)
        cmd.add_argument('--device', default="cuda", type=str, help='')
        cmd.add_argument('--do_train', default=False, action='store_true')
        cmd.add_argument('--do_eval', default=False, action='store_true')
        cmd.add_argument('--do_test', default=False, action='store_true')
        
        cmd.add_argument('--n_training_program', default=5, type=int)
        cmd.add_argument('--n_fewshot', default=6, type=int)
        
        args = cmd.parse_args(sys.argv[1:])
    except:
        is_notebook = True
        parser = argparse.ArgumentParser()
        args = parser.parse_args([])
        args.gpu = 1
        args.train_batch_size = 64
        args.eval_batch_size = 64
        args.gradient_accumulation_steps = 1
        args.lr = 1e-4
        args.data_path = "./logic_data/"
        args.encoder_config_path = None
        args.decoder_config_path = None
        args.max_seq_len = 512
        args.seed = 42
        args.output_dir = "./results_notebook/"
        args.epochs = 200
        args.warm_up = 0.1
        args.is_wandb = True
        args.log_step = 10
        args.valid_steps = 100 # -1 not do training eval!
        args.early_stopping = 999 # large == never early stop!
        args.device = "cuda:0"
        args.do_train = True
        args.do_eval = True
        args.do_test = True
        args.model_path = None
        
        args.n_training_program = 7
        args.n_fewshot = 6
        
        # args.model_path = "./results_notebook/logic_pipeline.model.gpt2.n_rule.11.n_shot.6.seed.42/model-last/"
        print("Using in a notebook env.")

Using in a notebook env.


usage: The testing components of [-h] [--gpu GPU]
                                 [--train_batch_size TRAIN_BATCH_SIZE]
                                 [--eval_batch_size EVAL_BATCH_SIZE] [--lr LR]
                                 --data_path DATA_PATH
                                 [--encoder_config_path ENCODER_CONFIG_PATH]
                                 [--decoder_config_path DECODER_CONFIG_PATH]
                                 [--max_seq_len MAX_SEQ_LEN] [--seed SEED]
                                 [--gradient_accumulation_steps GRADIENT_ACCUMULATION_STEPS]
                                 --output_dir OUTPUT_DIR
                                 [--local_rank LOCAL_RANK] [--epochs EPOCHS]
                                 [--model_path MODEL_PATH] [--warm_up WARM_UP]
                                 [--is_wandb] [--log_step LOG_STEP]
                                 [--valid_steps VALID_STEPS]
                                 [--early_stopping EARLY_STOPPING]
              

In [None]:
model_name = "gpt2"
run_name = f"logic_pipeline.model.{model_name}.n_rule.{args.n_training_program}.n_shot.{args.n_fewshot}.seed.{args.seed}"
logger = logging.getLogger()

# Dataloader
train_data = pickle.load(open(os.path.join(args.data_path, f"train_data.n_rule.{args.n_training_program}.n_shot.{args.n_fewshot}.pkl"), 'rb'))
dev_data = pickle.load(open(os.path.join(args.data_path, f"dev_data.n_rule.{args.n_training_program}.n_shot.{args.n_fewshot}.pkl"), 'rb'))
test_data = pickle.load(open(os.path.join(args.data_path, f"test_data.n_rule.{args.n_training_program}.n_shot.{args.n_fewshot}.pkl"), 'rb'))

train_dataset = Dataset.from_dict(
    {"input_ids": train_data["input_ids"], "labels": train_data["output_ids"]}
).with_format("torch")
train_dataloader = DataLoader(train_dataset, batch_size=args.train_batch_size)

dev_dataset = Dataset.from_dict(
    {"input_ids": dev_data["input_ids"], "labels": dev_data["output_ids"]}
).with_format("torch")
dev_dataloader = DataLoader(dev_dataset, batch_size=args.eval_batch_size)

test_dataset = Dataset.from_dict(
    {"input_ids": test_data["input_ids"], "labels": test_data["output_ids"]}
).with_format("torch")
test_dataloader = DataLoader(test_dataset, batch_size=args.eval_batch_size)


# Model
torch.cuda.empty_cache()

configuration = GPT2Config.from_pretrained(os.path.join(args.data_path, "decoder_config.json"))
model = CustomizedGPT2LMHeadModel(configuration)

if args.model_path is not None:
    logging.info("Loading pretrained model.")
    raw_weights = torch.load(os.path.join(args.model_path, 'pytorch_model.bin'))
    model.load_state_dict(raw_weights)
    
set_seed(args.seed)
device = torch.device(args.device)
if "cuda:" not in args.device:
    n_gpu = torch.cuda.device_count()
    logging.info(f'__Number CUDA Devices: {n_gpu}')
else:
    n_gpu = 1
    logging.info(f'__Number CUDA Devices: {n_gpu}')

if n_gpu > 1:
    model = torch.nn.DataParallel(model)
_ = model.to(device)

t_total = int(len(train_dataloader) * args.epochs)

warm_up_steps = args.warm_up * t_total
optimizer = torch.optim.AdamW(
    model.parameters(), lr=args.lr
)
scheduler = get_linear_schedule_with_warmup(optimizer, num_warmup_steps=warm_up_steps,
                                            num_training_steps=t_total)
is_master = True                                    
if not os.path.exists(args.output_dir) and is_master:
    os.mkdir(args.output_dir)

os.environ["WANDB_PROJECT"] = f"ToM-DAS"

output_dir = os.path.join(args.output_dir, run_name)
if args.do_train and args.is_wandb:
    import wandb
    run = wandb.init(
        project="ToM-DAS-GPT2", 
        entity="wuzhengx",
        name=run_name,
    )
    wandb.config.update(args)
if not os.path.exists(args.output_dir) and is_master:
    os.mkdir(args.output_dir)
    
trainer = LogicSolverTrainer(
    model, device=device, 
    logger=logger,
    is_master=is_master, 
    n_gpu=n_gpu,
    is_wandb=args.is_wandb, 
    model_name=model_name,
)
num_params = count_parameters(model)
logging.info(f'Number of {model_name} model params: {num_params}')

# Train
if args.do_train:
    logging.info(f"OUTPUT DIR: {output_dir}")
    trainer.train(
        train_dataloader, dev_dataloader,
        optimizer, scheduler, 
        log_step=args.log_step, valid_steps=args.valid_steps,
        output_dir=output_dir, epochs=args.epochs, 
        gradient_accumulation_steps=args.gradient_accumulation_steps,
    )

INFO:root:__Number CUDA Devices: 1
[34m[1mwandb[0m: Currently logged in as: [33mwuzhengx[0m. Use [1m`wandb login --relogin`[0m to force relogin


INFO:root:Number of gpt2 model params: 18367800
INFO:root:OUTPUT DIR: ./results_notebook/logic_pipeline.model.gpt2.n_rule.7.n_shot.6.seed.42
Epoch: 0: 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 1563/1563 [01:34<00:00, 16.46it/s, loss=4.29]
Epoch: 1: 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 1563/1563 [01:31<00:00, 17.03it/s, loss=0.75]
Epoch: 2: 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 1563/1563 [01:31<00:00, 17.11it/s, loss=0.54]
Epoch: 3: 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 1563/1563 [01:30<00:00, 17.21it/s, loss=0.52]
Epoch: 4: 100%|████████████████████████████████████████████████

Epoch: 40: 100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 1563/1563 [01:30<00:00, 17.23it/s, loss=0.48]
Epoch: 41: 100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 1563/1563 [01:31<00:00, 17.15it/s, loss=0.48]
Epoch: 42: 100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 1563/1563 [01:31<00:00, 17.12it/s, loss=0.48]
Epoch: 43: 100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 1563/1563 [01:31<00:00, 17.18it/s, loss=0.48]
Epoch: 44: 100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 1563/1563 [01:31<00:00, 17.11it/s, loss=0.48]
Epoch

Epoch: 81: 100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 1563/1563 [01:30<00:00, 17.24it/s, loss=0.41]
Epoch: 82: 100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 1563/1563 [01:31<00:00, 17.13it/s, loss=0.38]
Epoch: 83: 100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 1563/1563 [01:30<00:00, 17.27it/s, loss=0.38]
Epoch: 84:  77%|████████████████████████████████████████████████████████████████████████████████████████████████████████▌                              | 1210/1563 [01:11<00:48,  7.30it/s, loss=0.49]

In [None]:
if args.is_wandb:
    wandb.finish()

In [None]:
# Dev
if args.do_eval:
    total_count = 0
    correct_count = 0
    if args.do_eval:
        trainer.model.eval()
        epoch_iterator = tqdm(dev_dataloader, desc="Iteration", position=0, leave=True)
        for step, inputs in enumerate(epoch_iterator):
            for k, v in inputs.items():
                if v is not None and isinstance(v, torch.Tensor):
                    inputs[k] = v.to(device)
            outputs = model(**inputs)

            actual_test_labels = inputs['labels'][:, -3]
            pred_test_labels = torch.argmax(outputs.logits[:, -4], dim=-1)
            correct_labels = (actual_test_labels==pred_test_labels)

            total_count += len(correct_labels)
            correct_count += correct_labels.sum().tolist()

            current_acc = round(correct_count/total_count, 2)
            epoch_iterator.set_postfix({'acc': current_acc})

In [None]:
# Test
if args.do_test:
    total_count = 0
    correct_count = 0
    if args.do_test:
        trainer.model.eval()
        epoch_iterator = tqdm(test_dataloader, desc="Iteration", position=0, leave=True)
        for step, inputs in enumerate(epoch_iterator):
            for k, v in inputs.items():
                if v is not None and isinstance(v, torch.Tensor):
                    inputs[k] = v.to(device)
            outputs = model(**inputs)

            actual_test_labels = inputs['labels'][:, -3]
            pred_test_labels = torch.argmax(outputs.logits[:, -4], dim=-1)
            correct_labels = (actual_test_labels==pred_test_labels)

            total_count += len(correct_labels)
            correct_count += correct_labels.sum().tolist()

            current_acc = round(correct_count/total_count, 2)
            epoch_iterator.set_postfix({'acc': current_acc})