In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import logging
from functools import partial

import os
import torch.nn.functional as F
from torch import nn

import emmental
from emmental import Meta
from emmental.data import EmmentalDataLoader, EmmentalDataset
from emmental.learner import EmmentalLearner
from emmental.model import EmmentalModel
from emmental.scorer import Scorer
from emmental.task import EmmentalTask
from modules.bert_module import BertModule
from modules.classification_module import ClassificationModule
from preprocessor import preprocessor
from task_config import LABEL_MAPPING

In [3]:
logger = logging.getLogger(__name__)

In [4]:
TASK_NAME = "MNLI"
DATA_DIR = os.environ["GLUEDATA"]
BERT_MODEL_NAME = "bert-large-uncased"
BATCH_SIZE = 32

# Initalize Emmental

In [5]:
emmental.init("logs/CB_finetune",
    config={
        "meta_config": {"seed": 1},
        "model_config": {"device": 0, "dataparallel": True},
        "learner_config": {
            "n_epochs": 10,
            "valid_split": "val",
            "optimizer_config": {"optimizer": "adam", "lr": 1e-5},
            "lr_scheduler_config": {
                "lr_scheduler": "linear",  # "linear",
                "min_lr": 1e-7,
            },
        },
        "logging_config": {
            "counter_unit": "epoch",
            "evaluation_freq": 0.5,
            "checkpointing": True,
            "checkpointer_config": {
                "checkpoint_metric": {"CB/SuperGLUE/val/accuracy":"max"},
                "checkpoint_freq": 1,
            },
        },
    }
)

[2019-05-31 07:44:08,086][INFO] emmental.meta:95 - Setting logging directory to: logs/CB_finetune/2019_05_31/07_44_08
[2019-05-31 07:44:08,095][INFO] emmental.meta:56 - Loading Emmental default config from /home/hazymturk/vincent/emmental/src/emmental/emmental-default-config.yaml.
[2019-05-31 07:44:08,097][INFO] emmental.meta:143 - Updating Emmental config from user provided config.


# Build Emmental task

In [8]:
def ce_loss(task_name, immediate_ouput_dict, Y, active):
    module_name = f"{task_name}_pred_head"
    return F.cross_entropy(
        immediate_ouput_dict[module_name][0][active], (Y.view(-1) - 1)[active]
    )

In [9]:
def output(task_name, immediate_ouput_dict):
    module_name = f"{task_name}_pred_head"
    return F.softmax(immediate_ouput_dict[module_name][0], dim=1)

In [10]:
from sklearn.metrics import f1_score
def macro_f1(golds, probs, preds):
    return {"macro_f1": f1_score(golds, preds, average="macro")}

In [11]:
mtl_model = EmmentalModel(name="GLUE_single_task")

[2019-05-31 07:44:21,502][INFO] emmental.model:44 - Created emmental model GLUE_single_task that contains task set().
[2019-05-31 07:44:21,503][INFO] emmental.model:58 - Moving model to GPU (cuda:0).


## Load Pretrained MNLI Model

In [12]:
mtl_model.load(
    "/home/hazymturk/vincent/emmental-tutorials/glue/logs/2019_05_30/07_16_23/best_model_MNLI_GLUE_dev_accuracy.pth"
)

[2019-05-31 07:44:29,280][INFO] emmental.model:412 - [GLUE_multi_task] Model loaded from /home/hazymturk/vincent/emmental-tutorials/glue/logs/2019_05_30/07_16_23/best_model_MNLI_GLUE_dev_accuracy.pth
[2019-05-31 07:44:29,282][INFO] emmental.model:58 - Moving model to GPU (cuda:0).


In [14]:
TASK_NAME = "CB"

In [15]:
from superglue.parse_CB import get_CB_dataloaders

In [16]:
from superglue.task_config import SuperGLUE_LABEL_MAPPING, SuperGLUE_TASK_METRIC_MAPPING

In [17]:
BERT_OUTPUT_DIM = 768 if "base" in BERT_MODEL_NAME else 1024
TASK_CARDINALITY = (
    len(SuperGLUE_LABEL_MAPPING[TASK_NAME].keys())
    if SuperGLUE_LABEL_MAPPING[TASK_NAME] is not None
    else 1
)

emmental_task = EmmentalTask(
    name=TASK_NAME,
    module_pool=nn.ModuleDict(
        {
            "bert_module": BertModule(BERT_MODEL_NAME),
            f"{TASK_NAME}_pred_head": nn.Linear(BERT_OUTPUT_DIM, TASK_CARDINALITY),
        }
    ),
    task_flow=[
        {
            "name": "input",
            "module": "bert_module",
            "inputs": [("_input_", "token_ids"), ("_input_", "token_segments")],
        },
        {
            "name": f"{TASK_NAME}_pred_head",
            "module": f"{TASK_NAME}_pred_head",
            "inputs": [("input", 1)],
        },
    ],
    loss_func=partial(ce_loss, TASK_NAME),
    output_func=partial(output, TASK_NAME),
    scorer=Scorer(
        metrics=SuperGLUE_TASK_METRIC_MAPPING[TASK_NAME],     
        customize_metric_funcs={"macro_f1": macro_f1}
    ),
)

[2019-05-31 07:44:30,220][INFO] pytorch_pretrained_bert.modeling:583 - loading archive file https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-uncased.tar.gz from cache at ./cache/214d4777e8e3eb234563136cd3a49f6bc34131de836848454373fa43f10adc5e.abfbb80ee795a608acbf35c7bf2d2d58574df3887cdd94b355fc67e03fddba05
[2019-05-31 07:44:30,223][INFO] pytorch_pretrained_bert.modeling:591 - extracting archive file ./cache/214d4777e8e3eb234563136cd3a49f6bc34131de836848454373fa43f10adc5e.abfbb80ee795a608acbf35c7bf2d2d58574df3887cdd94b355fc67e03fddba05 to temp dir /tmp/tmpzew5ubka
[2019-05-31 07:44:41,233][INFO] pytorch_pretrained_bert.modeling:601 - Model config {
  "attention_probs_dropout_prob": 0.1,
  "hidden_act": "gelu",
  "hidden_dropout_prob": 0.1,
  "hidden_size": 1024,
  "initializer_range": 0.02,
  "intermediate_size": 4096,
  "layer_norm_eps": 1e-12,
  "max_position_embeddings": 512,
  "num_attention_heads": 16,
  "num_hidden_layers": 24,
  "type_vocab_size": 2,
  "vocab_size":

### Replace With CB Tasks

In [18]:
mtl_model.remove_task("MNLI")

[2019-05-31 07:44:46,415][INFO] emmental.model:141 - Removing Task MNLI.


In [19]:
mtl_model.add_task(emmental_task)

[2019-05-31 07:44:46,447][INFO] emmental.model:58 - Moving model to GPU (cuda:0).


In [20]:
DATA_DIR = os.environ["SUPERGLUEDATA"]
dataloaders = get_CB_dataloaders(
    data_dir=DATA_DIR,
    task_name=TASK_NAME,
    splits=["train", "val"],
    max_sequence_length=200,
    max_data_samples=None,
    tokenizer_name=BERT_MODEL_NAME,
    batch_size=BATCH_SIZE,
)

[2019-05-31 07:44:46,472][INFO] superglue.tokenizer:8 - Loading Tokenizer bert-large-uncased
[2019-05-31 07:44:47,281][INFO] pytorch_pretrained_bert.tokenization:190 - loading vocabulary file https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-uncased-vocab.txt from cache at /home/hazymturk/.cache/torch/pytorch_pretrained_bert/9b3c03a36e83b13d5ba95ac965c9f9074a99e14340c523ab405703179e79fc46.542ce4285a40d23a559526243235df47c5f75c197f04f37d1a0c124c32c9a084


/home/hazymturk/vincent/emmental-tutorials/superglue/data/CB/train.jsonl
{'premise': 'It was a complex language. Not written down but handed down. One might say it was peeled down.', 'hypothesis': 'the language was peeled down', 'label': 'entailment', 'idx': 0}


[2019-05-31 07:44:47,576][INFO] superglue.parse_CB:123 - Loaded train for CB.
[2019-05-31 07:44:47,641][INFO] superglue.parse_CB:123 - Loaded val for CB.


/home/hazymturk/vincent/emmental-tutorials/superglue/data/CB/val.jsonl
{'premise': "Valence the void-brain, Valence the virtuous valet. Why couldn't the figger choose his own portion of titanic anatomy to shaft? Did he think he was helping?", 'hypothesis': 'Valence was helping', 'label': 'contradiction', 'idx': 0}


In [21]:
# mtl_model.score(dataloaders["val"])

In [22]:
emmental_learner = EmmentalLearner()
emmental_learner.learn(mtl_model, dataloaders.values())

[2019-05-31 07:44:47,688][INFO] emmental.logging.logging_manager:33 - Evaluating every 0.5 epoch.
[2019-05-31 07:44:47,689][INFO] emmental.logging.logging_manager:43 - Checkpointing every 0.5 epoch.
[2019-05-31 07:44:47,690][INFO] emmental.logging.checkpointer:42 - Save checkpoints at logs/CB_finetune/2019_05_31/07_44_08 every 0.5 epoch
[2019-05-31 07:44:47,692][INFO] emmental.logging.checkpointer:73 - No checkpoints saved before 0 epoch.
[2019-05-31 07:44:47,698][INFO] emmental.learner:303 - Start learning...


HBox(children=(IntProgress(value=0, description='Epoch 0:', max=8, style=ProgressStyle(description_width='init…

[2019-05-31 07:45:08,157][INFO] emmental.logging.checkpointer:93 - checkpoint_runway condition has been met. Start checkpoining.
[2019-05-31 07:45:12,220][INFO] emmental.logging.checkpointer:102 - Save checkpoint of 0.5 epoch at logs/CB_finetune/2019_05_31/07_44_08/checkpoint_0.5.pth.
[2019-05-31 07:45:15,911][INFO] emmental.logging.checkpointer:118 - Save best model of metric CB/SuperGLUE/val/accuracy at logs/CB_finetune/2019_05_31/07_44_08/best_model_CB_SuperGLUE_val_accuracy.pth
[2019-05-31 07:45:24,328][INFO] emmental.logging.checkpointer:102 - Save checkpoint of 1.0 epoch at logs/CB_finetune/2019_05_31/07_44_08/checkpoint_1.0.pth.





HBox(children=(IntProgress(value=0, description='Epoch 1:', max=8, style=ProgressStyle(description_width='init…

[2019-05-31 07:45:34,972][INFO] emmental.logging.checkpointer:102 - Save checkpoint of 1.5 epoch at logs/CB_finetune/2019_05_31/07_44_08/checkpoint_1.5.pth.
[2019-05-31 07:45:43,861][INFO] emmental.logging.checkpointer:102 - Save checkpoint of 2.0 epoch at logs/CB_finetune/2019_05_31/07_44_08/checkpoint_2.0.pth.
[2019-05-31 07:46:47,158][INFO] emmental.logging.checkpointer:118 - Save best model of metric CB/SuperGLUE/val/accuracy at logs/CB_finetune/2019_05_31/07_44_08/best_model_CB_SuperGLUE_val_accuracy.pth





HBox(children=(IntProgress(value=0, description='Epoch 2:', max=8, style=ProgressStyle(description_width='init…

[2019-05-31 07:47:00,796][INFO] emmental.logging.checkpointer:102 - Save checkpoint of 2.5 epoch at logs/CB_finetune/2019_05_31/07_44_08/checkpoint_2.5.pth.
[2019-05-31 07:47:11,702][INFO] emmental.logging.checkpointer:102 - Save checkpoint of 3.0 epoch at logs/CB_finetune/2019_05_31/07_44_08/checkpoint_3.0.pth.





HBox(children=(IntProgress(value=0, description='Epoch 3:', max=8, style=ProgressStyle(description_width='init…

[2019-05-31 07:47:28,055][INFO] emmental.logging.checkpointer:102 - Save checkpoint of 3.5 epoch at logs/CB_finetune/2019_05_31/07_44_08/checkpoint_3.5.pth.
[2019-05-31 07:47:47,510][INFO] emmental.logging.checkpointer:102 - Save checkpoint of 4.0 epoch at logs/CB_finetune/2019_05_31/07_44_08/checkpoint_4.0.pth.





HBox(children=(IntProgress(value=0, description='Epoch 4:', max=8, style=ProgressStyle(description_width='init…

[2019-05-31 07:48:24,266][INFO] emmental.logging.checkpointer:102 - Save checkpoint of 4.5 epoch at logs/CB_finetune/2019_05_31/07_44_08/checkpoint_4.5.pth.
[2019-05-31 07:48:56,850][INFO] emmental.logging.checkpointer:102 - Save checkpoint of 5.0 epoch at logs/CB_finetune/2019_05_31/07_44_08/checkpoint_5.0.pth.





HBox(children=(IntProgress(value=0, description='Epoch 5:', max=8, style=ProgressStyle(description_width='init…

[2019-05-31 07:49:29,741][INFO] emmental.logging.checkpointer:102 - Save checkpoint of 5.5 epoch at logs/CB_finetune/2019_05_31/07_44_08/checkpoint_5.5.pth.
[2019-05-31 07:49:56,488][INFO] emmental.logging.checkpointer:102 - Save checkpoint of 6.0 epoch at logs/CB_finetune/2019_05_31/07_44_08/checkpoint_6.0.pth.





HBox(children=(IntProgress(value=0, description='Epoch 6:', max=8, style=ProgressStyle(description_width='init…

[2019-05-31 07:50:28,713][INFO] emmental.logging.checkpointer:102 - Save checkpoint of 6.5 epoch at logs/CB_finetune/2019_05_31/07_44_08/checkpoint_6.5.pth.
[2019-05-31 07:51:01,563][INFO] emmental.logging.checkpointer:102 - Save checkpoint of 7.0 epoch at logs/CB_finetune/2019_05_31/07_44_08/checkpoint_7.0.pth.





HBox(children=(IntProgress(value=0, description='Epoch 7:', max=8, style=ProgressStyle(description_width='init…

[2019-05-31 07:51:32,757][INFO] emmental.logging.checkpointer:102 - Save checkpoint of 7.5 epoch at logs/CB_finetune/2019_05_31/07_44_08/checkpoint_7.5.pth.
[2019-05-31 07:52:06,781][INFO] emmental.logging.checkpointer:102 - Save checkpoint of 8.0 epoch at logs/CB_finetune/2019_05_31/07_44_08/checkpoint_8.0.pth.





HBox(children=(IntProgress(value=0, description='Epoch 8:', max=8, style=ProgressStyle(description_width='init…

[2019-05-31 07:52:38,442][INFO] emmental.logging.checkpointer:102 - Save checkpoint of 8.5 epoch at logs/CB_finetune/2019_05_31/07_44_08/checkpoint_8.5.pth.
[2019-05-31 07:53:14,954][INFO] emmental.logging.checkpointer:102 - Save checkpoint of 9.0 epoch at logs/CB_finetune/2019_05_31/07_44_08/checkpoint_9.0.pth.





HBox(children=(IntProgress(value=0, description='Epoch 9:', max=8, style=ProgressStyle(description_width='init…

[2019-05-31 07:53:44,361][INFO] emmental.logging.checkpointer:102 - Save checkpoint of 9.5 epoch at logs/CB_finetune/2019_05_31/07_44_08/checkpoint_9.5.pth.
[2019-05-31 07:54:13,745][INFO] emmental.logging.checkpointer:102 - Save checkpoint of 10.0 epoch at logs/CB_finetune/2019_05_31/07_44_08/checkpoint_10.0.pth.
[2019-05-31 07:54:13,752][INFO] emmental.logging.checkpointer:148 - Clear all immediate checkpoints.





[2019-05-31 07:54:52,420][INFO] emmental.logging.checkpointer:188 - Loading the best model from logs/CB_finetune/2019_05_31/07_44_08/best_model_CB_SuperGLUE_val_accuracy.pth.
[2019-05-31 07:54:54,841][INFO] emmental.model:58 - Moving model to GPU (cuda:0).


In [23]:
mtl_model.score(dataloaders["val"])

{'CB/SuperGLUE/val/accuracy': 0.9107142857142857,
 'CB/SuperGLUE/val/macro_f1': 0.861923583662714}

In [24]:
mtl_model.save("./0530_CB_finetuned")

[2019-05-31 07:55:05,846][INFO] emmental.model:386 - [GLUE_multi_task] Model saved in ./0530_CB_finetuned
