In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import logging
from functools import partial

import torch.nn.functional as F
from torch import nn

import emmental
from emmental import Meta
from emmental.learner import EmmentalLearner
from emmental.model import EmmentalModel
from emmental.scorer import Scorer
from emmental.task import EmmentalTask
from modules.bert_module import BertModule
from parse_RTE import get_RTE_dataloaders
from task_config import SuperGLUE_LABEL_MAPPING, SuperGLUE_TASK_METRIC_MAPPING

Better speed can be achieved with apex installed from https://www.github.com/nvidia/apex.


In [3]:
logger = logging.getLogger(__name__)

# Initalize Emmental

In [4]:
emmental.init("logs")

[2019-05-13 01:36:23,777][INFO] emmental.meta:95 - Setting logging directory to: logs/2019_05_13/01_36_23
[2019-05-13 01:36:23,792][INFO] emmental.meta:56 - Loading Emmental default config from /dfs/scratch1/senwu/mmtl/emmental/src/emmental/emmental-default-config.yaml.


In [5]:
TASK_NAME = "RTE"
DATA_DIR = "data"
BERT_MODEL_NAME = "bert-base-uncased"
BATCH_SIZE = 16

# Extract train/dev dataset from file

In [6]:
dataloaders = get_RTE_dataloaders(
    data_dir=DATA_DIR,
    task_name=TASK_NAME,
    splits=["train", "val"],
    max_sequence_length=256,
    max_data_samples=None,
    tokenizer_name=BERT_MODEL_NAME,
    batch_size=BATCH_SIZE,
)

[2019-05-13 01:36:23,887][INFO] tokenizer:8 - Loading Tokenizer bert-base-uncased
[2019-05-13 01:36:24,170][INFO] pytorch_pretrained_bert.tokenization:146 - loading vocabulary file https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-uncased-vocab.txt from cache at /lfs/local/0/senwu/.pytorch_pretrained_bert/26bc1ad6c0ac742e9b52263248f6d0f00068293b33709fae12320c0e35ccfbbb.542ce4285a40d23a559526243235df47c5f75c197f04f37d1a0c124c32c9a084


data/RTE/train.jsonl
{'premise': 'No Weapons of Mass Destruction Found in Iraq Yet.', 'hypothesis': 'Weapons of Mass Destruction Found in Iraq.', 'label': 'not_entailment', 'idx': 0}


[2019-05-13 01:36:29,639][INFO] parse_RTE:123 - Loaded train for RTE.


data/RTE/val.jsonl
{'premise': 'Dana Reeve, the widow of the actor Christopher Reeve, has died of lung cancer at age 44, according to the Christopher Reeve Foundation.', 'hypothesis': 'Christopher Reeve had an accident.', 'label': 'not_entailment', 'idx': 0}


[2019-05-13 01:36:30,222][INFO] parse_RTE:123 - Loaded val for RTE.


# Build Emmental task

In [7]:
def ce_loss(task_name, immediate_ouput_dict, Y, active):
    module_name = f"{task_name}_pred_head"
    return F.cross_entropy(
        immediate_ouput_dict[module_name][0][active], (Y.view(-1) - 1)[active]
    )

In [8]:
def output(task_name, immediate_ouput_dict):
    module_name = f"{task_name}_pred_head"
    return F.softmax(immediate_ouput_dict[module_name][0], dim=1)

In [9]:
BERT_OUTPUT_DIM = 768 if "uncased" in BERT_MODEL_NAME else 1024
TASK_CARDINALITY = (
    len(SuperGLUE_LABEL_MAPPING[TASK_NAME].keys())
    if SuperGLUE_LABEL_MAPPING[TASK_NAME] is not None
    else 1
)

emmental_task = EmmentalTask(
    name=TASK_NAME,
    module_pool=nn.ModuleDict(
        {
            "bert_module": BertModule(BERT_MODEL_NAME),
            f"{TASK_NAME}_pred_head": nn.Linear(BERT_OUTPUT_DIM, TASK_CARDINALITY),
        }
    ),
    task_flow=[
        {
            "name": "input",
            "module": "bert_module",
            "inputs": [("_input_", "token_ids"), ("_input_", "token_segments")],
        },
        {
            "name": f"{TASK_NAME}_pred_head",
            "module": f"{TASK_NAME}_pred_head",
            "inputs": [("input", 1)],
        },
    ],
    loss_func=partial(ce_loss, TASK_NAME),
    output_func=partial(output, TASK_NAME),
    scorer=Scorer(metrics=SuperGLUE_TASK_METRIC_MAPPING[TASK_NAME]),
)

[2019-05-13 01:36:30,629][INFO] pytorch_pretrained_bert.modeling:564 - loading archive file https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-uncased.tar.gz from cache at ./cache/9c41111e2de84547a463fd39217199738d1e3deb72d4fec4399e6e241983c6f0.ae3cef932725ca7a30cdcb93fc6e09150a55e2a130ec7af63975a16c153ae2ba
[2019-05-13 01:36:30,632][INFO] pytorch_pretrained_bert.modeling:572 - extracting archive file ./cache/9c41111e2de84547a463fd39217199738d1e3deb72d4fec4399e6e241983c6f0.ae3cef932725ca7a30cdcb93fc6e09150a55e2a130ec7af63975a16c153ae2ba to temp dir /tmp/tmp_tcjdzoe
[2019-05-13 01:36:36,418][INFO] pytorch_pretrained_bert.modeling:579 - Model config {
  "attention_probs_dropout_prob": 0.1,
  "hidden_act": "gelu",
  "hidden_dropout_prob": 0.1,
  "hidden_size": 768,
  "initializer_range": 0.02,
  "intermediate_size": 3072,
  "max_position_embeddings": 512,
  "num_attention_heads": 12,
  "num_hidden_layers": 12,
  "type_vocab_size": 2,
  "vocab_size": 30522
}

[2019-05-13 01:36:5

In [10]:
Meta.update_config(
    config={
        "meta_config": {"device": 1},
        "learner_config": {
            "n_epochs": 3,
            "valid_split": "val",
            "optimizer_config": {"optimizer": "adam", "lr": 1e-5},
            "lr_scheduler_config": {
                "warmup_steps": 156,
                "warmup_unit": "batch",
                "lr_scheduler": "linear",
            },
        },
        "logging_config": {
            "evaluation_freq": 20,
            "checkpointing": None,
        },
    }
)

[2019-05-13 01:36:51,673][INFO] emmental.meta:143 - Updating Emmental config from user provided config.


In [11]:
mtl_model = EmmentalModel(name="SuperGLUE_single_task", tasks=[emmental_task])

[2019-05-13 01:36:51,710][INFO] emmental.model:58 - Moving model to GPU (cuda:1).
[2019-05-13 01:37:00,042][INFO] emmental.model:44 - Created emmental model SuperGLUE_single_task that contains task {'RTE'}.
[2019-05-13 01:37:00,045][INFO] emmental.model:58 - Moving model to GPU (cuda:1).


In [12]:
emmental_learner = EmmentalLearner()

In [13]:
emmental_learner.learn(mtl_model, dataloaders.values())

[2019-05-13 01:37:00,133][INFO] emmental.logging.logging_manager:33 - Evaluating every 20 batch.
[2019-05-13 01:37:00,133][INFO] emmental.logging.logging_manager:51 - No checkpointing.
[2019-05-13 01:37:00,174][INFO] emmental.learner:286 - Start learning...


HBox(children=(IntProgress(value=0, max=156), HTML(value='')))




HBox(children=(IntProgress(value=0, max=156), HTML(value='')))




HBox(children=(IntProgress(value=0, max=156), HTML(value='')))




In [14]:
mtl_model.score(dataloaders["val"])

{'RTE/SuperGLUE/val/accuracy': 0.6353790613718412}