In [1]:
cd ..

/data/2022F/CS330/project


In [2]:
%load_ext autoreload
%autoreload 2

In [3]:
import itertools

import matplotlib.pyplot as plt

import torch
from torch import nn
import torch.nn.functional as F
from torch.utils.data import DataLoader
from tqdm import tqdm

import model
import utils
import train

In [4]:
import pickle
from pathlib import Path

pathCache = Path('cache/experiment-lang')
pathCache.mkdir(exist_ok=True, parents=True)

## Task Setup
Setup datasets, loss functions, etc.

In [5]:
import transformers
import experiment_lang

device='cuda:1'

bert, tokenizer = experiment_lang.get_model_and_tokenizer('prajjwal1/bert-medium', transformers.AutoModelForSequenceClassification)
stop_tokens = experiment_lang.get_stop_tokens(tokenizer)

def loss_func(logits, labels):
    #print(logits, labels)
    return F.binary_cross_entropy_with_logits(logits, labels[..., None].float())
def get_language_task(name):
    ds_train, ds_val = experiment_lang.get_dataset(name)

    generator_train = experiment_lang.ClassificationDataGenerator(ds_train, tokenizer, device, batchSize=32)
    generator_val = experiment_lang.ClassificationDataGenerator(ds_val, tokenizer, device, batchSize=32)
    return {
        'train_gen': generator_train,
        'val_gen': generator_val,
        'loss': loss_func, #lambda logits, labels: F.binary_cross_entropy_with_logits(logits, labels[..., None]),
        'predict': lambda logits: torch.sigmoid(logits) > 0.5,
        'metric': experiment_lang.get_acc,
    }

tasks = {
    'mrpc': get_language_task('glue/mrpc'),
    'qnli': get_language_task('glue/qnli'),
}
task_keys = list(tasks.keys())


Some weights of the model checkpoint at prajjwal1/bert-medium were not used when initializing BertForSequenceClassification: ['cls.predictions.transform.LayerNorm.weight', 'cls.predictions.bias', 'cls.seq_relationship.weight', 'cls.predictions.decoder.bias', 'cls.predictions.decoder.weight', 'cls.seq_relationship.bias', 'cls.predictions.transform.dense.weight', 'cls.predictions.transform.dense.bias', 'cls.predictions.transform.LayerNorm.bias']
- This IS expected if you are initializing BertForSequenceClassification from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertForSequenceClassification from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Some weights of BertForSequenceClassification were not init

## Experiments

In [7]:
import torch.utils.tensorboard as TUTb


In [None]:
modelSeparate = experiment_lang.TaskAwareBert(bert, list(tasks.keys()), 'separate').to(device)
writers = {k:TUTb.SummaryWriter(pathCache / f"separate-{k}") for k in tasks.keys()}
separate_exp = train.train_and_evaluate(
    model=modelSeparate,
    tasks=tasks,
    steps=500,
    lr=1e-4,
    eval_every=50,
    DEVICE=device,
    writers=writers,
)
with open(pathCache / "separate.pickle", "wb") as f:
    pickle.dump(separate_exp, f)

In [None]:
modelShared = experiment_lang.TaskAwareBert(bert, list(tasks.keys()), 'shared').to(device)

writers = {k:TUTb.SummaryWriter(pathCache / f"shared-{k}") for k in tasks.keys()}
fully_shared_exp = train.train_and_evaluate(
    model=modelShared,
    tasks=tasks,
    steps=500,
    lr=1e-4,
    eval_every=50,
    DEVICE=device,
    writers=writers,
)
with open(pathCache / "shared.pickle", "wb") as f:
    pickle.dump(fully_shared_exp, f)


In [None]:
modelSurgical = experiment_lang.TaskAwareBert(bert, list(tasks.keys()), 'surgical').to(device)
writers = {k:TUTb.SummaryWriter(pathCache / f"surgical-{k}") for k in tasks.keys()}
surgical_exp = train.train_and_evaluate(
    model=modelSurgical,
    tasks=tasks,
    steps=500,
    lr=1e-4,
    eval_every=50,
    DEVICE=device,
    writers=writers,
)
with open(pathCache / "surgical.pickle", "wb") as f:
    pickle.dump(surgical_exp, f)

In [None]:

fig, axes = plt.subplots(2, 4, figsize=(15, 6))
for iTask, task_name in enumerate(tasks):

    for exp_name, exp in [
        ('shared', fully_shared_exp),
        ('separate', separate_exp),
        ('surgical', surgical_exp),
    ]:
        losses, metrics, eval_losses, eval_metrics = exp
        
        tl = [(s, l[task_name]) for s, l in losses]
        el = [(s, l[task_name]) for s, l in eval_losses]
        tm = [(s, float(m[task_name])) for s, m in metrics]
        em = [(s, float(m[task_name])) for s, m in eval_metrics]
    
        # plot
        ax = axes[iTask][0]
        tl_x, tl_y = zip(*tl)
        ax.plot(tl_x, tl_y, label=f'{exp_name}_{task_name}')
        ax.set_title('Train Loss')
        ax.set_yscale('log')
        ax.legend()

        ax = axes[iTask][1]
        el_x, el_y = zip(*el)
        ax.plot(el_x, el_y, label=f'{exp_name}_{task_name}')
        ax.set_title('Val Loss')
        ax.set_yscale('log')
        ax.legend()

        ax = axes[iTask][2]
        tm_x, tm_y = zip(*tm)
        ax.plot(tm_x, tm_y, label=f'{exp_name}_{task_name}')
        ax.set_title('Train Accuracy')
        #ax.set_ylim([0.95, 0.999])
        ax.legend()

        ax = axes[iTask][3]
        em_x, em_y = zip(*em)
        ax.plot(em_x, em_y, label=f'{exp_name}_{task_name}')
        ax.set_title('Val Accuracy')
        #ax.set_ylim([0.95, 0.999])
        ax.legend()
fig.tight_layout()

In [19]:
em = [(s, float(m['qnli'])) for s, m in eval_metrics]

In [22]:
modelSurgical.training

True