In [1]:
import dataclasses
import logging
import os
import sys
sys.path.append("..")
import math
from dataclasses import dataclass, field
from typing import Callable, Dict, List, Optional

import numpy as np
import torch
from torch.utils.data.sampler import RandomSampler
from torch.utils.data.dataloader import DataLoader

from transformers import (AutoConfig, AutoModelForSequenceClassification,
                          AutoTokenizer, EvalPrediction, GlueDataset, DefaultDataCollator) 
from transformers import AdamW, get_linear_schedule_with_warmup
from transformers import GlueDataTrainingArguments as DataTrainingArguments
from transformers import (HfArgumentParser, Trainer, TrainingArguments,
                          glue_compute_metrics, glue_output_modes,
                          glue_tasks_num_labels, set_seed)

from transformers import glue_convert_examples_to_features as convert_examples_to_features
from transformers.data.processors.glue import MrpcProcessor, ColaProcessor, MnliProcessor, Sst2Processor, RteProcessor, WnliProcessor, QqpProcessor, QnliProcessor, StsbProcessor


from tqdm import tqdm, trange

In [2]:
device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')

In [3]:
logger = logging.getLogger(__name__)


@dataclass
class ModelArguments:
    """
    Arguments pertaining to which model/config/tokenizer we are going to fine-tune from.
    """

    model_name_or_path: str = field(
        metadata={"help": "Path to pretrained model or model identifier from huggingface.co/models"}
    )
    config_name: Optional[str] = field(
        default=None, metadata={"help": "Pretrained config name or path if not the same as model_name"}
    )
    tokenizer_name: Optional[str] = field(
        default=None, metadata={"help": "Pretrained tokenizer name or path if not the same as model_name"}
    )
    cache_dir: Optional[str] = field(
        default=None, metadata={"help": "Where do you want to store the pretrained models downloaded from s3"}
    )

In [175]:
model_args = ModelArguments(model_name_or_path = 'albert-base-v2')
data_args = DataTrainingArguments(task_name = 'MNLI', data_dir = '/home/nlp/data/glue_data/MNLI')
training_args = TrainingArguments(output_dir = '/home/nlp/experiments/meta/',
                                 do_eval = True,
                                 per_device_train_batch_size=64)


if (
    os.path.exists(training_args.output_dir)
    and os.listdir(training_args.output_dir)
    and training_args.do_train
    and not training_args.overwrite_output_dir
):
    raise ValueError(
        f"Output directory ({training_args.output_dir}) already exists and is not empty. Use --overwrite_output_dir to overcome."
    )

# Set seed
set_seed(training_args.seed)

try:
    num_labels = glue_tasks_num_labels[data_args.task_name]
    output_mode = glue_output_modes[data_args.task_name]
except KeyError:
    raise ValueError("Task not found: %s" % (data_args.task_name))

In [176]:
training_args

TrainingArguments(output_dir='/home/nlp/experiments/meta/', overwrite_output_dir=False, do_train=False, do_eval=True, do_predict=False, evaluate_during_training=False, per_device_train_batch_size=64, per_device_eval_batch_size=8, per_gpu_train_batch_size=None, per_gpu_eval_batch_size=None, gradient_accumulation_steps=1, learning_rate=5e-05, weight_decay=0.0, adam_epsilon=1e-08, max_grad_norm=1.0, num_train_epochs=3.0, max_steps=-1, warmup_steps=0, logging_dir=None, logging_first_step=False, logging_steps=500, save_steps=500, save_total_limit=None, no_cuda=False, seed=42, fp16=False, fp16_opt_level='O1', local_rank=-1, tpu_num_cores=None, tpu_metrics_debug=False)

In [177]:
config = AutoConfig.from_pretrained(
    model_args.config_name if model_args.config_name else model_args.model_name_or_path,
    num_labels=num_labels,
    finetuning_task=data_args.task_name,
    cache_dir=model_args.cache_dir,
)
tokenizer = AutoTokenizer.from_pretrained(
    model_args.tokenizer_name if model_args.tokenizer_name else model_args.model_name_or_path,
    cache_dir=model_args.cache_dir,
)
model = AutoModelForSequenceClassification.from_pretrained(
    model_args.model_name_or_path,
    from_tf=bool(".ckpt" in model_args.model_name_or_path),
    config=config,
    cache_dir=model_args.cache_dir,
)


In [9]:
def build_compute_metrics_fn(task_name: str) -> Callable[[EvalPrediction], Dict]:
    def compute_metrics_fn(p: EvalPrediction) -> Dict:
        if output_mode == "classification":
            preds = np.argmax(p.predictions, axis=1)
        elif output_mode == "regression":
            preds = np.squeeze(p.predictions)
        return glue_compute_metrics(data_args.task_name, preds, p.label_ids)

    return compute_metrics_fn

In [10]:
TrainingArguments

transformers.training_args.TrainingArguments

In [11]:
@dataclass
class Arguments(TrainingArguments):
    tasks = ['mrpc', 'cola', 'mnli', 'sst-2', 'rte', 'qqp', 'qnli', 'sts-b']
    target_task: str = field(default = 'mrpc')
    task_shared: bool = field(default = True)

    reload_model = None
    num_update_steps: int = field(default = 5)
    num_sample_tasks: int = field(default = 5)
    inner_learning_rate: float = field(default = 1e-3)
    glue_dir : str = field(default = None)
    max_len: int = field(default = 80)
    output_dir: str
    eval_steps = 10

In [12]:
args = Arguments(glue_dir='/home/nlp/data/glue_data', output_dir='/home/nlp/experiments/meta')

In [13]:
processor_dict = {
          'mrpc': MrpcProcessor,
          'cola': ColaProcessor,
          'mnli': MnliProcessor,
          'sst-2': Sst2Processor,
          'rte': RteProcessor,
          'wnli': WnliProcessor,
          'qqp': QqpProcessor,
          'qnli': QnliProcessor,
          'sts-b': StsbProcessor
        }
processors = [processor_dict[task]() for task in args.tasks]

In [14]:
GLUE_PATH = os.path.join('home', 'nlp', 'data', 'glue_data')
dataset_dict = {
          'mrpc': args.glue_dir+'/MRPC',
          'cola': args.glue_dir+'/CoLA',
          'mnli': args.glue_dir+'/MNLI',
          'sst-2': args.glue_dir+'/SST-2',
          'rte':  args.glue_dir+'/RTE',
          'wnli': args.glue_dir+'/WNLI',
          'qqp':  args.glue_dir+'/QQP',
          'qnli': args.glue_dir+'/QNLI',
          'sts-b': args.glue_dir+'/STS-B'
        }
data_dirs =  [dataset_dict[task] for task in args.tasks]

In [15]:
for i, task in enumerate(args.tasks):
        if task == args.target_task:
            target_task_id = i
            break

task_cluster_dict = {
      'mrpc': 0,
      'cola': 1,
      'mnli': 0,
      'sst-2': 1,
      'rte': 0,
      'wnli': 0,
      'qqp': 0,
      'qnli': 2,
      'sts-b': 3
    }
task_clusters = [task_cluster_dict[task] for task in args.tasks] if args.task_shared else None

In [16]:
label_lists = [processor.get_labels() for processor in processors]

In [17]:
task_clusters

[0, 1, 0, 1, 0, 0, 2, 3]

In [18]:
label_lists

[['0', '1'],
 ['0', '1'],
 ['contradiction', 'entailment', 'neutral'],
 ['0', '1'],
 ['entailment', 'not_entailment'],
 ['0', '1'],
 ['entailment', 'not_entailment'],
 [None]]

In [19]:
if not args.task_shared:
    num_labels = [len(label_list) for label_list in label_lists]
else:
    cluster_num_labels = {0: 3, 1: 2, 2: 2, 3: 1}
    num_labels = [cluster_num_labels[task_cluster] for task_cluster in task_clusters]

In [20]:
args.tasks

['mrpc', 'cola', 'mnli', 'sst-2', 'rte', 'qqp', 'qnli', 'sts-b']

In [169]:
model.zero_grad()

In [178]:
torch.nn.utils.clip_grad_norm_(model.parameters(), args.max_grad_norm)

RuntimeError: stack expects a non-empty TensorList

In [166]:
args.gradient_accumulation_steps

1

In [21]:
num_labels

[3, 2, 3, 2, 3, 3, 2, 1]

In [22]:
 glue_output_modes[args.tasks[5]]

'classification'

In [160]:
args.max_grad_norm

1.0

In [23]:
data_dirs

['/home/nlp/data/glue_data/MRPC',
 '/home/nlp/data/glue_data/CoLA',
 '/home/nlp/data/glue_data/MNLI',
 '/home/nlp/data/glue_data/SST-2',
 '/home/nlp/data/glue_data/RTE',
 '/home/nlp/data/glue_data/QQP',
 '/home/nlp/data/glue_data/QNLI',
 '/home/nlp/data/glue_data/STS-B']

In [24]:
train_dataset_list, eval_dataset_list = [], []
for task, data_dir in tqdm(zip(args.tasks, data_dirs)):
    data_args.task_name = task
    data_args.data_dir = data_dir
    train_dataset_list.append(GlueDataset(data_args, tokenizer))
    eval_dataset_list.append(GlueDataset(data_args, tokenizer, mode = "dev"))

8it [00:24,  3.11s/it]


In [25]:
train_sampler_list = []
for dataset in train_dataset_list:
    train_sampler_list.append(RandomSampler(dataset))

In [26]:
train_dataloader_list, eval_dataloader_list = [], []
data_collator = DefaultDataCollator()

for train_dataset, eval_dataset, sampler in \
    tqdm(zip(train_dataset_list, eval_dataset_list, train_sampler_list)):
    
    train_dataloader_list.append(DataLoader(train_dataset,
            batch_size=training_args.train_batch_size,
            sampler=sampler,
            collate_fn=data_collator.collate_batch,
            drop_last=True))
    
    eval_dataloader_list.append(DataLoader(eval_dataset,
            batch_size=training_args.train_batch_size,
            sampler=sampler,
            collate_fn=data_collator.collate_batch,
            drop_last=True))

8it [00:00, 9906.83it/s]


In [27]:
train_examples = [processor.get_train_examples(data_dir) for processor, data_dir in tqdm(zip(processors, data_dirs))]

8it [00:10,  1.28s/it]


In [28]:
train_steps_per_task = [ math.floor((len(train_example)/training_args.per_device_train_batch_size)/(args.num_update_steps+1)) for train_example in train_examples]
total_steps = sum(train_steps_per_task) * training_args.num_train_epochs
print(f'Total steps: {total_steps}')

Total steps: 7401.0


In [37]:
label_lists

[['0', '1'],
 ['0', '1'],
 ['contradiction', 'entailment', 'neutral'],
 ['0', '1'],
 ['entailment', 'not_entailment'],
 ['0', '1'],
 ['entailment', 'not_entailment'],
 [None]]

In [38]:
args.tasks

['mrpc', 'cola', 'mnli', 'sst-2', 'rte', 'qqp', 'qnli', 'sts-b']

In [29]:
train_steps_per_task

[9, 22, 1022, 175, 6, 947, 272, 14]

In [39]:
import pandas as pd

In [77]:
metrics =  ['eval_loss', 'eval_acc', 'eval_f1', 'eval_acc_and_f1']
columns = args.tasks

In [79]:
columns

['mrpc', 'cola', 'mnli', 'sst-2', 'rte', 'qqp', 'qnli', 'sts-b']

In [147]:
df = pd.DataFrame(columns=columns, index=metrics)

In [150]:
for i in range(len(columns)):
    for j in range(len(metrics)):
        df[columns[i]][metrics[j]] = []

In [153]:
df[args.tasks[0]]['eval_loss'].append(2)

In [159]:
args.gradient_accumulation_steps

1

In [154]:
df

Unnamed: 0,mrpc,cola,mnli,sst-2,rte,qqp,qnli,sts-b
eval_loss,"[2, 2]",[],[],[],[],[],[],[]
eval_acc,[],[],[],[],[],[],[],[]
eval_f1,[],[],[],[],[],[],[],[]
eval_acc_and_f1,[],[],[],[],[],[],[],[]


In [30]:
t_total = int(len(train_dataloader_list) // training_args.gradient_accumulation_steps * training_args.num_train_epochs)
num_train_epochs = training_args.num_train_epochs

In [32]:
train_dataloaders_iters = [iter(train_dataloader) for train_dataloader in train_dataloader_list]

extra_ids = []
for t_id in range(len(args.tasks)):
    extra_ids += [t_id] * train_steps_per_task[t_id]  #math.ceil(len(train_examples[t_id]))
extra_ids = np.random.choice(extra_ids, len(extra_ids), replace=False) 

In [34]:
len(extra_ids)

2467

In [None]:
from transformers.trainer_utils import PREFIX_CHECKPOINT_DIR

In [31]:
 def empty_memory():
    import gc
    gc.collect()
    torch.cuda.empty_cache()

In [26]:
class MetaTrainer(Trainer):
    def __init__(self, 
                 model, 
                 args, 
                 train_dataloader_list,
                 eval_dataloader_list,
                 compute_metrics,
                 prediction_loss_only = False):
        
        self.model = model.to(args.device)
        self.args = args
        self.compute_metrics_list = [compute_metrics(task) for task in self.args.tasks]
        self.train_dataloader_list = train_dataloader_list
        self.eval_dataloader_list = eval_dataloader_list
        self.data_collator = DefaultDataCollator()
        self.prediction_loss_only = prediction_loss_only
        self.eval_results = {}
        self._setup_wandb()
        
    def update_model_params(self, model, fast_params):
        for idx, param in enumerate(model.parameters()):
            param.data = fast_params[idx]
        return model

    def train(self):
        model = self.model
        
        self.global_step = 0
        nb_tr_steps = 0
        tr_loss = 0
        optimizer, scheduler = self.get_optimizers(int(len(self.train_dataloader_list[0]) \
        // self.args.gradient_accumulation_steps * self.args.num_train_epochs))
        
        model.zero_grad()
        
        for _ in trange(int(training_args.num_train_epochs), desc='Epoch'):
            model.train()
            tr_loss = 0
            nb_tr_examples = 0 

            train_dataloaders_iters = [iter(train_dataloader) for train_dataloader in train_dataloader_list]

            extra_ids = []
            for t_id in range(len(args.tasks)):
                extra_ids += [t_id] * train_steps_per_task[t_id]  #math.ceil(len(train_examples[t_id]))
            extra_ids = np.random.choice(extra_ids, len(extra_ids), replace=False) 

            meta_loss = 0

            for i, task_id in tqdm(enumerate(extra_ids), desc='Task IDs'):
                for update_step in range(args.num_update_steps+1):

                    try:
                        inputs = next(train_dataloaders_iters[task_id])
                    except StopIteration:
                        break

                    for k, v in inputs.items():
                        inputs[k] = v.to(args.device)
                        if not isinstance(inputs['labels'], torch.cuda.LongTensor):
                            inputs['labels'] = inputs['labels'].long()

                    if update_step == args.num_update_steps:
                        if update_step == 0:
                            raise ValueError('update_step cannot be 0!')

                        for param, f_param in zip(model.parameters(), fast_params):
                            if not param.requires_grad:
                                continue
                            cur_grad = (param - f_param)/update_step/args.inner_learning_rate
                            if param.grad is None:
                                param.grad = torch.zeros(cur_grad.size()).cuda()
                                param.grad.add_(cur_grad/inputs['labels'].size(0))

                    elif update_step == 0:
                        loss = model(**inputs)[0]
                        grad = torch.autograd.grad(loss, model.parameters(), allow_unused=True)
                        fast_params = list(map(lambda p: p[1] - args.inner_learning_rate * p[0] if p[0] is not None else p[1], zip(grad, model.parameters())))

                    elif update_step < args.num_update_steps:
                        model = self.update_model_params(model, fast_params)
                        loss = model(**inputs)[0]
                        grad = torch.autograd.grad(loss, fast_params, allow_unused=True)
                        fast_params = list(map(lambda p: p[1] - args.inner_learning_rate * p[0] if p[0] is not None else p[1], zip(grad, fast_params)))

                if i % args.num_sample_tasks == (args.num_sample_tasks-1):
                    optimizer.step()
                    optimizer.zero_grad()
                    meta_loss = 0

                self.global_step += 1
        
                if self.global_step % args.eval_steps == 0:
                    for idx, eval_dataloader in enumerate(self.eval_dataloader_list):
                        self.compute_metrics = self.compute_metrics_list[idx]
                        result = self.evaluate(eval_dataloader.dataset)
                        
                        for key, value in result.items():
                            logger.info("%s  %s = %s",args.tasks[idx], key, value)
                        

In [27]:
trainer = MetaTrainer(model, args, train_dataloader_list,
                     eval_dataloader_list, build_compute_metrics_fn)

In [28]:
eval_dataloader_list.pop()

<torch.utils.data.dataloader.DataLoader at 0x7ff7c1958f10>

In [29]:
trainer.train()

Epoch:   0%|          | 0/3 [00:00<?, ?it/s]
Task IDs: 0it [00:00, ?it/s][A
Task IDs: 1it [00:03,  3.75s/it][A
Task IDs: 2it [00:06,  3.31s/it][A
Task IDs: 3it [00:08,  3.00s/it][A
Task IDs: 4it [00:10,  2.79s/it][A
Task IDs: 5it [00:12,  2.64s/it][A
Task IDs: 6it [00:15,  2.55s/it][A
Task IDs: 7it [00:17,  2.47s/it][A
Task IDs: 8it [00:19,  2.42s/it][A
Task IDs: 9it [00:22,  2.39s/it][A

HBox(children=(FloatProgress(value=0.0, description='Evaluation', max=26.0, style=ProgressStyle(description_wi…




{"eval_loss": 1.1688626500276418, "eval_pearson": -0.07593542849008117, "eval_spearmanr": -0.07026298108615246, "eval_corr": -0.07309920478811682, "step": 10}


HBox(children=(FloatProgress(value=0.0, description='Evaluation', max=66.0, style=ProgressStyle(description_wi…


{"eval_loss": 1.7732348694945828, "eval_pearson": 0.047892375448013634, "eval_spearmanr": 0.04789237544801361, "eval_corr": 0.04789237544801363, "step": 10}


HBox(children=(FloatProgress(value=0.0, description='Evaluation', max=614.0, style=ProgressStyle(description_w…


{"eval_loss": 1.125631565185634, "eval_pearson": 0.03112561928614942, "eval_spearmanr": 0.031110041573893667, "eval_corr": 0.031117830430021545, "step": 10}


HBox(children=(FloatProgress(value=0.0, description='Evaluation', max=55.0, style=ProgressStyle(description_wi…


{"eval_loss": 1.5897113214839589, "eval_pearson": -0.03872658021023936, "eval_spearmanr": -0.03872658021023939, "eval_corr": -0.03872658021023938, "step": 10}


HBox(children=(FloatProgress(value=0.0, description='Evaluation', max=18.0, style=ProgressStyle(description_wi…


{"eval_loss": 1.3225060535801783, "eval_pearson": -0.07009389681708172, "eval_spearmanr": -0.07272434521096278, "eval_corr": -0.07140912101402225, "step": 10}


HBox(children=(FloatProgress(value=0.0, description='Evaluation', max=2527.0, style=ProgressStyle(description_…


{"eval_loss": 0.9290406081956698, "eval_pearson": 0.17356125955788465, "eval_spearmanr": 0.1917711813336911, "eval_corr": 0.18266622044578787, "step": 10}


HBox(children=(FloatProgress(value=0.0, description='Evaluation', max=342.0, style=ProgressStyle(description_w…


Task IDs: 10it [02:16, 36.08s/it][A


{"eval_loss": 1.299326989734382, "eval_pearson": -0.1945533855892897, "eval_spearmanr": -0.1945780773062058, "eval_corr": -0.19456573144774775, "step": 10}



Task IDs: 11it [02:19, 25.97s/it][A
Task IDs: 12it [02:21, 18.89s/it][A
Task IDs: 13it [02:23, 13.94s/it][A
Task IDs: 14it [02:26, 10.47s/it][A
Task IDs: 15it [02:28,  8.05s/it][A
Task IDs: 16it [02:31,  6.35s/it][A
Task IDs: 17it [02:33,  5.16s/it][A
Task IDs: 18it [02:35,  4.32s/it][A
Task IDs: 19it [02:38,  3.74s/it][A

HBox(children=(FloatProgress(value=0.0, description='Evaluation', max=26.0, style=ProgressStyle(description_wi…


{"eval_loss": 0.754741173524123, "eval_pearson": 0.05367693570305286, "eval_spearmanr": 0.05343511206110535, "eval_corr": 0.053556023882079105, "step": 20}


HBox(children=(FloatProgress(value=0.0, description='Evaluation', max=66.0, style=ProgressStyle(description_wi…


{"eval_loss": 1.2134370496778777, "eval_pearson": -0.0027618219290598003, "eval_spearmanr": 0.0009026285023506604, "eval_corr": -0.00092959671335457, "step": 20}


HBox(children=(FloatProgress(value=0.0, description='Evaluation', max=614.0, style=ProgressStyle(description_w…


{"eval_loss": 1.2278344829230043, "eval_pearson": -0.01632747615200033, "eval_spearmanr": -0.0161722522877514, "eval_corr": -0.016249864219875863, "step": 20}


HBox(children=(FloatProgress(value=0.0, description='Evaluation', max=55.0, style=ProgressStyle(description_wi…


{"eval_loss": 1.1502182028510355, "eval_pearson": 0.005972333635109359, "eval_spearmanr": 0.009997267686546754, "eval_corr": 0.007984800660828056, "step": 20}


HBox(children=(FloatProgress(value=0.0, description='Evaluation', max=18.0, style=ProgressStyle(description_wi…


{"eval_loss": 0.8318254517184364, "eval_pearson": 0.0029989208045044153, "eval_spearmanr": 0.0029989208045043867, "eval_corr": 0.002998920804504401, "step": 20}


HBox(children=(FloatProgress(value=0.0, description='Evaluation', max=2527.0, style=ProgressStyle(description_…

Task IDs: 19it [03:38, 11.49s/it]
Epoch:   0%|          | 0/3 [03:38<?, ?it/s]







KeyboardInterrupt: 

In [35]:
data_args

GlueDataTrainingArguments(task_name='sts-b', data_dir='/home/nlp/data/glue_data/STS-B', max_seq_length=128, overwrite_cache=False)

In [36]:
label_lists

[['0', '1'],
 ['0', '1'],
 ['contradiction', 'entailment', 'neutral'],
 ['0', '1'],
 ['entailment', 'not_entailment'],
 ['0', '1'],
 ['entailment', 'not_entailment'],
 [None]]

In [None]:
# del trainer
# import gc
# gc.collect()
# torch.cuda.empty_cache()