In [1]:
import argparse
import json
import gc
from copy import deepcopy

import numpy as np
import transformers
from transformers import BertModel, BertTokenizer
import torch
from torch.optim import Adam
from torch.utils.data import TensorDataset, DataLoader, RandomSampler, SequentialSampler
from datasets import load_dataset, load_metric


In [27]:
def get_label_lists(datasets, args):
    """"""
    label_lists = []
    for task in args.tasks:
        is_regression = task == "stsb"
        if is_regression:
            label_lists.append([None])
        else:
            label_lists.append(datasets[task].features["label"].names)
    return label_lists

def get_num_labels(label_lists):
    """ Get a list of number of labels for the tasks """
    return [len(label_list) for label_list in label_lists]

def preprocess(datasets, tokenizer, args):
    
    def preprocess_function(examples):
        inputs = (
            (examples[sentence1_key],) if sentence2_key is None else (examples[sentence1_key], examples[sentence2_key])
        )
        result = tokenizer(*inputs, padding=args.padding, max_length=args.max_length, truncation=True)
        return result

    for task, dataset in datasets.items():
        sentence1_key, sentence2_key = task_to_keys[task]
        datasets[task] = datasets[task].map(preprocess_function, batched=True)
    return datasets

def get_split_datasets(datasets, split="train", seed=None):
    if split == "train":
        split_datasets = {task:dataset[split].shuffle(seed=seed) for task, dataset in datasets.items()}
    else:
        split_datasets = {task:dataset[split] for task, dataset in datasets.items()}
    return split_datasets

def support_query_split(datasets, args):
    support_datasets = {}
    query_datasets   = {}
    for task, dataset in datasets.items():
        support_query_split    = dataset.train_test_split(test_size=args.query_size)
        support_datasets[task] = support_query_split["train"]
        query_datasets[task]   = support_query_split["test"]
    return support_datasets, query_datasets

def get_dataloaders(datasets, split, args):
    dataloaders = []
    for task, dataset in datasets.items():
        num_rows = dataset.num_rows if args.num_rows == -1 else args.num_rows
        all_input_ids      = np.zeros([num_rows, args.max_length])
        all_attention_mask = np.zeros([num_rows, args.max_length])
        all_token_type_ids = np.zeros([num_rows, args.max_length])
        for i in range(num_rows):
            features = dataset[i]
            curr_len = len(features["attention_mask"])
            all_input_ids[i,:curr_len]      = features["input_ids"]
            all_attention_mask[i,:curr_len] = features["attention_mask"]
            all_token_type_ids[i,:curr_len] = features["token_type_ids"]
        all_input_ids      = torch.tensor(all_input_ids, dtype=torch.long)
        all_attention_mask = torch.tensor(all_attention_mask, dtype=torch.long)
        all_token_type_ids = torch.tensor(all_token_type_ids, dtype=torch.long)
        all_label          = torch.tensor(dataset[:num_rows]["label"], dtype=torch.long)

        data = TensorDataset(all_input_ids, all_attention_mask, all_token_type_ids, all_label)
        if split in ["train", "support"]:
            sampler    = RandomSampler(data)
            dataloader = DataLoader(data, sampler=sampler, batch_size=args.train_batch_size)
        else:
            sampler    = SequentialSampler(data)
            dataloader = DataLoader(data, sampler=sampler, batch_size=args.eval_batch_size)
        dataloaders.append(dataloader)
    return dataloaders

In [None]:
class TrainingArgs:
    def __init__(self):
        #self.tasks = ["sst2", "qqp", "mnli", "qnli"]
        self.tasks = ["rte", "cola"]
        self.task_shared = True
        self.padding = True
        self.max_length = 512
        self.do_lower_case = True
        self.seed = 1123
        self.query_size = 0.2
        self.num_rows = -1

        # BERT hyperparameters
        self.input_dim = 768

        # MAML hyperparameters
        self.num_update_steps = 1   #5
        self.num_sample_tasks = 2   #8
        self.outer_learning_rate = 5e-5
        self.inner_learning_rate = 1e-3

        # train/eval hyperparameters
        self.num_train_epochs = 1   # 5
        self.train_batch_size = 8
        self.eval_batch_size = 8

args = TrainingArgs()

task_to_keys = {
    "cola": ("sentence", None),
    "mnli": ("premise", "hypothesis"),
    "mrpc": ("sentence1", "sentence2"),
    "qnli": ("question", "sentence"),
    "qqp" : ("question1", "question2"),
    "rte" : ("sentence1", "sentence2"),
    "sst2": ("sentence", None),
    "stsb": ("sentence1", "sentence2"),
    "wnli": ("sentence1", "sentence2"),
}

"""
task_cluster_dict = {
    "mrpc": 0,
    "cola": 1,
    "mnli": 0,
    "sst2": 1,
    "rte" : 0,
    "wnli": 0,
    "qqp" : 0,
    "qnli": 2,
    "stsb": 3
}
task_clusters = [task_cluster_dict[task] for task in tasks] if args.task_shared else None
"""

print("Loading datasets.")
train_datasets      = {task:load_dataset("glue", task, split="train") for task in args.tasks}
label_lists   = get_label_lists(datasets, args)
num_labels    = get_num_labels(label_lists)

print("Preprocessing datasets.")
tokenizer = BertTokenizer.from_pretrained("bert-base-uncased", do_lower_case=args.do_lower_case)
train_datasets  = preprocess(train_datasets, tokenizer, args)

print("Retrieving training set.")
#train_datasets = get_split_datasets(datasets, "train", seed=args.seed)
support_datasets, query_datasets = support_query_split(train_datasets, args)
support_dataloaders = get_dataloaders(support_datasets, "support", args)
query_dataloaders   = get_dataloaders(query_datasets, "query", args)

"""
print("Retrieving evaluation set.")
eval_datasets = get_split_datasets(datasets, "validation")
eval_dataloaders = get_dataloaders(eval_datasets, "validation", args)
"""

In [20]:
import torch
import torch.nn as nn
from torch.nn import CrossEntropyLoss, MSELoss



class Classifier(nn.Module):
    def __init__(self, embedder, input_dim, n_classes, dropout=0.2):
        super(Classifier, self).__init__()
        self.n_classes = n_classes
        self.embedder = embedder
        self.dropout = nn.Dropout(dropout)
        self.classifier = nn.Linear(input_dim, n_classes)
    
    def forward(self, input_ids, attention_mask, token_type_ids, labels):
        outputs = self.embedder(input_ids, attention_mask, token_type_ids)
        pooled_output = outputs[1]
        pooled_output = self.dropout(pooled_output)
        logits = self.classifier(pooled_output)
        
        if self.n_classes == 1:
            loss_function = MSELoss()
            loss = loss_function(logits.view(-1), labels.view(-1))
        else:
            loss_function = CrossEntropyLoss()
            loss = loss_function(logits.view(-1, self.n_classes), labels.view(-1))
        return logits, loss

In [24]:
def get_train_steps(dataloaders):
    return [len(dataloader.dataset) // (train_batch_size*(num_update_steps+1)) for dataloader in dataloaders]


In [5]:
model = BertModel.from_pretrained("bert-base-uncased")
outer_optimizer = Adam(model.parameters(), lr=outer_learning_rate)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

train_steps_per_task = get_train_steps(dataloaders)

classifiers = [Classifier(model, input_dim, num_labels[task_id]) for task_id in range(len(tasks))]

for epoch_id in range(num_train_epochs):
    print("Start Epoch {}".format(epoch_id))
    model.train()
    sum_gradients = []
    
    # Get sample tasks based on Probability Proportional to Size (PPS)
    sample_task_ids = []
    for task_id in range(len(tasks)):
        sample_task_ids += [task_id] * train_steps_per_task[task_id]
    sample_task_ids = np.random.choice(sample_task_ids, len(sample_task_ids), replace = False)
    
    for sample_task_id, task_id in enumerate(sample_task_ids):
        print(sample_task_id)
        classifier = classifiers[task_id]
        classifier.embedder = model
        inner_optimizer = Adam(classifier.parameters(), lr=inner_learning_rate)
        classifier.train()
        
        # Inner updates with support sets
        for step_id in range(num_update_steps):
            all_loss = []
            for inner_step, batch in enumerate(support_dataloaders[task_id]):
                print(step_id, inner_step)
                input_ids, attention_mask, token_type_ids, labels = tuple(t.to(device) for t in batch)
                outputs = classifier(input_ids, attention_mask, token_type_ids, labels = labels)
                loss = outputs[1]
                loss.backward()
                inner_optimizer.step()
                inner_optimizer.zero_grad()
                all_loss.append(loss.item())
        print("| inner_loss {:8.6f} |".format(np.mean(all_loss)))
        
        # Outer update with query set
        query_batch = iter(query_dataloaders[task_id]).next()
        q_input_ids, q_attention_mask, q_token_type_ids, q_labels = tuple(t.to(device) for t in query_batch)
        q_outputs = classifier(q_input_ids, q_attention_mask, q_token_type_ids, labels=q_labels)
        
        # Compute the cumulative gradients of original BERT parameters
        q_loss = q_outputs[1]
        q_loss.backward()
        classifier.to(torch.device("cpu"))
        for i, (name, params) in enumerate(classifier.named_parameters()):
            if name.startswith("embedder"):
                if sample_task_id == 0:
                    sum_gradients.append(deepcopy(params.grad))
                else:
                    sum_gradients[i] += deepcopy(params.grad)
        
    # Update BERT parameters after sampling num_sample_tasks
    if sample_task_id % num_sample_tasks == (num_sample_tasks-1):
        # Compute average gradient across tasks
        for i in range(len(sum_gradients)):
            sum_gradients[i] = sum_gradients[i] / num_sample_tasks

        # Assign gradients for original BERT model and Update weights
        for i, params in enumerate(model.parameters()):
            params.grad = sum_gradients[i]

        outer_optimizer.step()
        outer_optimizer.zero_grad()
        
    #gc.collect()
                
    

Start Epoch 0
0
0 0
0 1
| inner_loss 0.559568 |
1
0 0
0 1
| inner_loss 0.811201 |


In [7]:
model.eval()
model.save_pretrained("../checkpoints/metabert_maml")

In [22]:
import sys
def sizeof_fmt(num, suffix='B'):
    ''' by Fred Cirera,  https://stackoverflow.com/a/1094933/1870254, modified'''
    for unit in ['','Ki','Mi','Gi','Ti','Pi','Ei','Zi']:
        if abs(num) < 1024.0:
            return "%3.1f %s%s" % (num, unit, suffix)
        num /= 1024.0
    return "%.1f %s%s" % (num, 'Yi', suffix)

for name, size in sorted(((name, sys.getsizeof(value)) for name, value in locals().items()),
                         key= lambda x: -x[1])[:20]:
    print("{:>30}: {:>8}".format(name, sizeof_fmt(size)))

                           _i2:  3.4 KiB
                          _i12:  3.2 KiB
                           _i3:  3.2 KiB
                          _i10:  3.2 KiB
                           _i4:  3.2 KiB
                           _i8:  3.2 KiB
                           _i6:  3.2 KiB
                          _i18:  2.2 KiB
                           _i5:  2.2 KiB
                           _i7:  2.2 KiB
                           _i9:  2.2 KiB
                          _i11:  2.2 KiB
                          _i13:  2.2 KiB
                 BertTokenizer:  2.0 KiB
                    DataLoader:  1.4 KiB
                  TrainingArgs:  1.4 KiB
                     BertModel:  1.0 KiB
                          Adam:  1.0 KiB
                 TensorDataset:  1.0 KiB
                 RandomSampler:  1.0 KiB


In [23]:
for var, obj in list(locals().items()):
    print(var, sys.getsizeof(obj))

__name__ 57
__doc__ 298
__package__ 16
__loader__ 16
__spec__ 16
__builtin__ 88
__builtins__ 88
_ih 272
_oh 248
_dh 80
In 272
Out 248
get_ipython 72
exit 64
quit 64
_ 215
__ 264
___ 248
_i 587
_ii 192
_iii 1050
_i1 380
argparse 88
json 88
gc 88
deepcopy 144
np 88
transformers 88
BertModel 1064
BertTokenizer 2008
torch 88
Adam 1064
TensorDataset 1064
DataLoader 1472
RandomSampler 1064
SequentialSampler 1064
load_dataset 144
load_metric 144
_i2 3499
_i3 3315
_i4 3307
get_label_lists 144
get_num_labels 144
preprocess 144
get_split_datasets 144
support_query_split 144
get_dataloaders 144
_i5 2291
TrainingArgs 1472
args 64
task_to_keys 376
datasets 248
label_lists 104
_i6 3292
_i7 2291
num_labels 104
tokenizer 64
_i8 3296
_i9 2291
train_datasets 248
support_datasets 248
query_datasets 248
_i10 3311
_i11 2291
_i12 3321
_i13 2291
support_dataloaders 104
query_dataloaders 104
_i14 57
_14 248
_i15 132
_i16 220
task 53
dataset 264
_i17 65
_17 264
_i18 2299
_18 215
_i19 587
sys 88
sizeof_fmt 144
