In [8]:
import argparse
import json
import gc

import numpy as np
import transformers
from transformers import BertModel, BertTokenizer
import torch
from torch.optim import Adam
from torch.utils.data import TensorDataset, DataLoader, RandomSampler, SequentialSampler
from datasets import load_dataset, load_metric


In [2]:
def parse_arguments():
    parser = argparse.ArgumentParser()

    #parser.add_argument("--exp_name", type=str, default="exp1")
    parser.add_argument("--tasks", type=list, \
                        default=["mrpc", "cola", "mnli", "sst2", "rte", "qqp", "qnli", "stsb"])
    parser.add_argument("--task_shared", action="store_true")
    parser.add_argument("--seed", type=int, default=1123)
    args = parser.parse_args()

    print("input args:\n", json.dumps(vars(args), indent=4, separators=(",", ":")))
    return args

def get_label_lists(datasets, tasks):
    """"""
    label_lists = []
    for task in tasks:
        is_regression = task == "stsb"
        if is_regression:
            label_lists.append([None])
        else:
            label_lists.append(datasets[task]["train"].features["label"].names)
    return label_lists

def get_num_labels(label_lists, task_clusters, task_shared):
    """ Get a list of number of labels for the tasks """
    if task_shared:
        num_labels = [len(label_list) for label_list in label_lists]
    else:
        cluster_num_labels = {0:3, 1:2, 2:2, 3:1}
        num_labels = [cluster_num_labels[task_cluster] for task_cluster in task_clusters]
    return num_labels

def preprocess(datasets, tokenizer):
    
    def preprocess_function(examples):
        args = (
            (examples[sentence1_key],) if sentence2_key is None else (examples[sentence1_key], examples[sentence2_key])
        )
        result = tokenizer(*args, padding=padding, max_length=max_length, truncation=True)
        return result

    for task, dataset in datasets.items():
        sentence1_key, sentence2_key = task_to_keys[task]
        datasets[task] = datasets[task].map(preprocess_function, batched=True)
    return datasets

def get_split_datasets(datasets, split="train", seed=None):
    if split == "train":
        split_datasets = {task:dataset[split].shuffle(seed=seed) for task, dataset in datasets.items()}
    else:
        split_datasets = {task:dataset[split] for task, dataset in datasets.items()}
    return split_datasets

def support_query_split(datasets):
    support_datasets = {}
    query_datasets   = {}
    for task, dataset in datasets.items():
        support_query_split    = dataset.train_test_split(test_size=query_size)
        support_datasets[task] = support_query_split["train"]
        query_datasets[task]   = support_query_split["test"]
    return support_datasets, query_datasets

def get_dataloaders(datasets, split="train"):
    dataloaders = []
    for task, dataset in datasets.items():
        all_input_ids      = np.zeros([dataset.num_rows, max_length])
        all_attention_mask = np.zeros([dataset.num_rows, max_length])
        all_token_type_ids = np.zeros([dataset.num_rows, max_length])
        for i in range(dataset.num_rows):
            features = dataset[i]
            curr_len = len(features["attention_mask"])
            all_input_ids[i,:curr_len]      = features["input_ids"]
            all_attention_mask[i,:curr_len] = features["attention_mask"]
            all_token_type_ids[i,:curr_len] = features["token_type_ids"]
        all_input_ids      = torch.tensor(all_input_ids, dtype=torch.long)
        all_attention_mask = torch.tensor(all_attention_mask, dtype=torch.long)
        all_token_type_ids = torch.tensor(all_token_type_ids, dtype=torch.long)
        all_label          = torch.tensor(dataset[:]["label"], dtype=torch.long)
        
        data = TensorDataset(all_input_ids, all_attention_mask, all_token_type_ids, all_label)
        if split in ["train", "support"]:
            sampler    = RandomSampler(data)
            dataloader = DataLoader(data, sampler=sampler, batch_size=train_batch_size)
        else:
            sampler    = SequentialSampler(data)
            dataloader = DataLoader(data, sampler=sampler, batch_size=eval_batch_size)
        dataloaders.append(dataloader)
    return dataloaders

In [4]:
#tasks = ["sst2", "qqp", "mnli", "qnli"]
tasks = ["rte", "cola"]
task_shared = True
padding = True
max_length = 512
do_lower_case = True
seed = 1123
query_size = 0.2

# BERT hyperparameters
input_dim = 768

# MAML hyperparameters
num_update_steps = 5
num_sample_tasks = 8
outer_learning_rate = 5e-5
inner_learning_rate = 1e-3

# train/eval hyperparameters
num_train_epochs = 1
train_batch_size = 32
eval_batch_size = 32

task_to_keys = {
    "cola": ("sentence", None),
    "mnli": ("premise", "hypothesis"),
    "mrpc": ("sentence1", "sentence2"),
    "qnli": ("question", "sentence"),
    "qqp" : ("question1", "question2"),
    "rte" : ("sentence1", "sentence2"),
    "sst2": ("sentence", None),
    "stsb": ("sentence1", "sentence2"),
    "wnli": ("sentence1", "sentence2"),
}

task_cluster_dict = {
    "mrpc": 0,
    "cola": 1,
    "mnli": 0,
    "sst2": 1,
    "rte" : 0,
    "wnli": 0,
    "qqp" : 0,
    "qnli": 2,
    "stsb": 3
}
task_clusters = [task_cluster_dict[task] for task in tasks] if task_shared else None

print("Loading datasets.")
datasets      = {task:load_dataset("glue", task) for task in tasks}
label_lists   = get_label_lists(datasets, tasks)
num_labels    = get_num_labels(label_lists, task_clusters, task_shared)

print("Preprocessing datasets.")
tokenizer = BertTokenizer.from_pretrained("bert-base-uncased", do_lower_case=do_lower_case)
datasets  = preprocess(datasets, tokenizer)

print("Retrieving training set.")
train_datasets = get_split_datasets(datasets, "train", seed=seed)
support_datasets, query_datasets = support_query_split(train_datasets)
support_dataloaders = get_dataloaders(support_datasets, "support")
query_dataloaders   = get_dataloaders(query_datasets, "query")
print("Retrieving evaluation set.")
eval_datasets = get_split_datasets(datasets, "validation")
eval_dataloaders = get_dataloaders(eval_datasets, "validation")

Loading datasets.


Reusing dataset glue (/Users/tttyuntian/.cache/huggingface/datasets/glue/rte/1.0.0/7c99657241149a24692c402a5c3f34d4c9f1df5ac2e4c3759fadea38f6cb29c4)
Reusing dataset glue (/Users/tttyuntian/.cache/huggingface/datasets/glue/cola/1.0.0/7c99657241149a24692c402a5c3f34d4c9f1df5ac2e4c3759fadea38f6cb29c4)


Preprocessing datasets.


Loading cached processed dataset at /Users/tttyuntian/.cache/huggingface/datasets/glue/rte/1.0.0/7c99657241149a24692c402a5c3f34d4c9f1df5ac2e4c3759fadea38f6cb29c4/cache-8efda5aff4c45ac0.arrow
Loading cached processed dataset at /Users/tttyuntian/.cache/huggingface/datasets/glue/rte/1.0.0/7c99657241149a24692c402a5c3f34d4c9f1df5ac2e4c3759fadea38f6cb29c4/cache-80d279cd989e639f.arrow
Loading cached processed dataset at /Users/tttyuntian/.cache/huggingface/datasets/glue/rte/1.0.0/7c99657241149a24692c402a5c3f34d4c9f1df5ac2e4c3759fadea38f6cb29c4/cache-6e17d613e5b9f8b1.arrow
Loading cached processed dataset at /Users/tttyuntian/.cache/huggingface/datasets/glue/cola/1.0.0/7c99657241149a24692c402a5c3f34d4c9f1df5ac2e4c3759fadea38f6cb29c4/cache-fd85b48d6ce18d90.arrow
Loading cached processed dataset at /Users/tttyuntian/.cache/huggingface/datasets/glue/cola/1.0.0/7c99657241149a24692c402a5c3f34d4c9f1df5ac2e4c3759fadea38f6cb29c4/cache-44dc8d48946779cb.arrow
Loading cached processed dataset at /Users/

Retrieving training set.
Retrieving evaluation set.


In [5]:
import torch
import torch.nn as nn
from torch.nn import CrossEntropyLoss, MSELoss



class Classifier(nn.Module):
    def __init__(self, embedder, input_dim, n_classes, dropout=0.2):
        super(Classifier, self).__init__()
        self.n_classes = n_classes
        self.embedder = embedder
        self.dropout = nn.Dropout(dropout)
        self.classifier = nn.Linear(input_dim, n_classes)
    
    def forward(self, input_ids, attention_mask, token_type_ids, labels):
        outputs = self.embedder(input_ids, attention_mask, token_type_ids)
        pooled_output = outputs[1]
        pooled_output = self.dropout(pooled_output)
        logits = self.classifier(pooled_output)
        
        if self.n_classes == 1:
            loss_function = MSELoss()
            loss = loss_function(logits.view(-1), labels.view(-1))
        else:
            loss_function = CrossEntropyLoss()
            loss = loss_function(logits.view(-1, self.n_classes), labels.view(-1))
        return logits, loss

In [None]:
model = BertModel.from_pretrained("bert-base-uncased")
outer_optimizer = Adam(model.parameters(), lr=outer_learning_rate)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

train_steps_per_task = [dataset.num_rows // (train_batch_size*(num_update_steps+1)) for _, dataset in train_datasets.items()]

sum_gradients = []
classifiers = [Classifier(model, input_dim, num_labels[task_id]) for task_id in range(len(tasks))]

for epoch_id in range(num_train_epochs):
    print("Start Epoch {}".format(epoch_id))
    model.train()
    
    # Get sample tasks based on Probability Proportional to Size (PPS)
    sample_task_ids = []
    for task_id in range(len(tasks)):
        sample_task_ids += [task_id] * train_steps_per_task[task_id]
    sample_task_ids = np.random.choice(sample_task_ids, len(sample_task_ids), replace = False)
    
    for sample_task_id, task_id in enumerate(sample_task_ids):
        classifier = classifiers[task_id]
        classifier.embedder = model
        inner_optimizer = Adam(classifier.parameters(), lr=inner_learning_rate)
        classifier.train()
        
        # Inner updates with support sets
        for step_id in range(num_update_steps):
            all_loss = []
            for inner_step, batch in enumerate(support_dataloaders[task_id]):
                print(step_id, inner_step)
                input_ids, attention_mask, token_type_ids, labels = tuple(t.to(device) for t in batch)
                outputs = classifier(input_ids, attention_mask, token_type_ids, labels = labels)
                loss = outputs[1]
                loss.backward()
                inner_optimizer.step()
                inner_optimizer.zero_grad()
                all_loss.append(loss.item())
        print("| inner_loss {:8.6f} |".format(np.mean(all_loss)))
        
        # Outer update with query set
        query_batch = iter(query_dataloader[task_id]).next()
        q_input_ids, q_attention_mask, q_token_type_ids, q_labels = tuple(t.to(device) for t in query_batch)
        q_outputs = classifier(q_input_ids, q_attention_mask, q_token_type_ids, labels=q_labels)
        
        # Compute the cumulative gradients of original BERT parameters
        q_loss = q_outputs[1]
        q_loss.backward()
        classifier.to(torch.device("cpu"))
        for i, (name, params) in enumerate(classifier.namsed_parameters()):
            if name.startswith("embedder"):
                if sample_task_id == 0:
                    sum_gradients.append(deepcopy(params.grad))
                else:
                    sum_gradients[i] += deepcopy(params.grad)
        
        # Update BERT parameters after sampling num_sample_tasks
        if sample_task_id % num_sample_tasks == (num_sample_tasks-1):
            # Compute average gradient across tasks
            for i in range(len(sum_gradients)):
                sum_gradients[i] = sum_gradients[i] / num_sample_tasks
            
            # Assign gradients for original BERT model and Update weights
            for i, params in enumerate(model.parameters()):
                params.grad = sum_gradients[i]
            
            outer_optimizer.step()
            outer_optimizer.zero_grad()
        
        del sum_gradients
        gc.collect()
                
    