In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import logging
from functools import partial

import torch.nn.functional as F
from torch import nn

import emmental
from emmental import Meta
from emmental.data import EmmentalDataLoader, EmmentalDataset
from emmental.learner import EmmentalLearner
from emmental.model import EmmentalModel
from emmental.scorer import Scorer
from emmental.task import EmmentalTask
from modules.bert_module import BertModule
from modules.classification_module import ClassificationModule
from preprocessor import preprocessor
from task_config import LABEL_MAPPING

Better speed can be achieved with apex installed from https://www.github.com/nvidia/apex.


In [3]:
logger = logging.getLogger(__name__)

In [4]:
TASK_NAME = "RTE"
DATA_DIR = "data"
BERT_MODEL_NAME = "bert-base-uncased"
BATCH_SIZE = 16

# Initalize Emmental

In [5]:
emmental.init("logs")

[2019-05-07 13:56:02,623][INFO] emmental.meta:95 - Setting logging directory to: logs/2019_05_07/13_56_02
[2019-05-07 13:56:02,639][INFO] emmental.meta:56 - Loading Emmental default config from /dfs/scratch1/senwu/mmtl/emmental/src/emmental/emmental-default-config.yaml.


# Extract train/dev/test dataset from file

In [6]:
datasets = {}

for split in ["train", "dev", "test"]:
    bert_token_ids, bert_token_segments, bert_token_masks, labels = preprocessor(
        data_dir=DATA_DIR,
        task_name=TASK_NAME,
        split=split,
        bert_model_name=BERT_MODEL_NAME,
        max_data_samples=None,
        max_sequence_length=200,
    )
    X_dict = {
        "token_ids": bert_token_ids,
        "token_segments": bert_token_segments,
        "token_masks": bert_token_masks,
    }
    Y_dict = {"labels": labels}

    datasets[split] = EmmentalDataset(name="GLUE", X_dict=X_dict, Y_dict=Y_dict)

    logger.info(f"Loaded {split} for {TASK_NAME}.")

HBox(children=(IntProgress(value=0, max=2490), HTML(value='')))




[2019-05-07 13:56:03,111][INFO] pytorch_pretrained_bert.tokenization:146 - loading vocabulary file https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-uncased-vocab.txt from cache at /lfs/local/0/senwu/.pytorch_pretrained_bert/26bc1ad6c0ac742e9b52263248f6d0f00068293b33709fae12320c0e35ccfbbb.542ce4285a40d23a559526243235df47c5f75c197f04f37d1a0c124c32c9a084
[2019-05-07 13:56:08,532][INFO] __main__:21 - Loaded train for RTE.


HBox(children=(IntProgress(value=0, max=277), HTML(value='')))




[2019-05-07 13:56:08,966][INFO] pytorch_pretrained_bert.tokenization:146 - loading vocabulary file https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-uncased-vocab.txt from cache at /lfs/local/0/senwu/.pytorch_pretrained_bert/26bc1ad6c0ac742e9b52263248f6d0f00068293b33709fae12320c0e35ccfbbb.542ce4285a40d23a559526243235df47c5f75c197f04f37d1a0c124c32c9a084
[2019-05-07 13:56:09,607][INFO] __main__:21 - Loaded dev for RTE.


HBox(children=(IntProgress(value=0, max=3000), HTML(value='')))




[2019-05-07 13:56:10,003][INFO] pytorch_pretrained_bert.tokenization:146 - loading vocabulary file https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-uncased-vocab.txt from cache at /lfs/local/0/senwu/.pytorch_pretrained_bert/26bc1ad6c0ac742e9b52263248f6d0f00068293b33709fae12320c0e35ccfbbb.542ce4285a40d23a559526243235df47c5f75c197f04f37d1a0c124c32c9a084
[2019-05-07 13:56:16,358][INFO] __main__:21 - Loaded test for RTE.


# Build Emmental dataloader

In [7]:
split = "train"
train_dataloader = EmmentalDataLoader(
    task_to_label_dict={TASK_NAME: "labels"},
    dataset=datasets[split],
    split=split,
    batch_size=BATCH_SIZE,
    shuffle=True,
)
logger.info(f"Built dataloader for {split} set.")

split = "dev"
dev_dataloader = EmmentalDataLoader(
    task_to_label_dict={TASK_NAME: "labels"},
    dataset=datasets[split],
    split=split,
    batch_size=BATCH_SIZE,
)
logger.info(f"Built dataloader for {split} set.")

split = "test"
test_dataloader = EmmentalDataLoader(
    task_to_label_dict={TASK_NAME: "labels"},
    dataset=datasets[split],
    split=split,
    batch_size=BATCH_SIZE,
)
logger.info(f"Built dataloader for {split} set.")

[2019-05-07 13:56:16,423][INFO] __main__:9 - Built dataloader for train set.
[2019-05-07 13:56:16,425][INFO] __main__:18 - Built dataloader for dev set.
[2019-05-07 13:56:16,426][INFO] __main__:27 - Built dataloader for test set.


# Build Emmental task

In [8]:
def ce_loss(task_name, immediate_ouput_dict, Y, active):
    module_name = f"{task_name}_pred_head"
    return F.cross_entropy(
        immediate_ouput_dict[module_name][0][active], (Y.view(-1) - 1)[active]
    )

In [9]:
def output(task_name, immediate_ouput_dict):
    module_name = f"{task_name}_pred_head"
    return F.softmax(immediate_ouput_dict[module_name][0], dim=1)

In [10]:
BERT_OUTPUT_DIM = 768 if "uncased" in BERT_MODEL_NAME else 1024
TASK_CARDINALITY = (
    len(LABEL_MAPPING[TASK_NAME].keys()) if LABEL_MAPPING[TASK_NAME] is not None else 1
)

emmental_task = EmmentalTask(
    name=TASK_NAME,
    module_pool=nn.ModuleDict(
        {
            "bert_module": BertModule(BERT_MODEL_NAME),
            f"{TASK_NAME}_pred_head": ClassificationModule(
                BERT_OUTPUT_DIM, TASK_CARDINALITY
            ),
        }
    ),
    task_flow=[
        {
            "name": "input",
            "module": "bert_module",
            "inputs": [("_input_", "token_ids"), ("_input_", "token_segments")],
        },
        {
            "name": f"{TASK_NAME}_pred_head",
            "module": f"{TASK_NAME}_pred_head",
            "inputs": [("input", 1)],
        },
    ],
    loss_func=partial(ce_loss, TASK_NAME),
    output_func=partial(output, TASK_NAME),
    scorer=Scorer(metrics=["accuracy"]),
)

[2019-05-07 13:56:16,816][INFO] pytorch_pretrained_bert.modeling:564 - loading archive file https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-uncased.tar.gz from cache at ./cache/9c41111e2de84547a463fd39217199738d1e3deb72d4fec4399e6e241983c6f0.ae3cef932725ca7a30cdcb93fc6e09150a55e2a130ec7af63975a16c153ae2ba
[2019-05-07 13:56:16,819][INFO] pytorch_pretrained_bert.modeling:572 - extracting archive file ./cache/9c41111e2de84547a463fd39217199738d1e3deb72d4fec4399e6e241983c6f0.ae3cef932725ca7a30cdcb93fc6e09150a55e2a130ec7af63975a16c153ae2ba to temp dir /tmp/tmpdxvz6r3j
[2019-05-07 13:56:22,664][INFO] pytorch_pretrained_bert.modeling:579 - Model config {
  "attention_probs_dropout_prob": 0.1,
  "hidden_act": "gelu",
  "hidden_dropout_prob": 0.1,
  "hidden_size": 768,
  "initializer_range": 0.02,
  "intermediate_size": 3072,
  "max_position_embeddings": 512,
  "num_attention_heads": 12,
  "num_hidden_layers": 12,
  "type_vocab_size": 2,
  "vocab_size": 30522
}

[2019-05-07 13:56:3

In [11]:
Meta.update_config(
    config={
        "meta_config": {"device": 1},
        "learner_config": {
            "n_epochs": 3,
            "valid_split": "dev",
            "optimizer_config": {"optimizer": "adam", "lr": 1e-5},
            "lr_scheduler_config": {
                "warmup_steps": 156,
                "warmup_unit": "batch",
                "lr_scheduler": "linear",
            },
        },
        "logging_config": {
            "evaluation_freq": 20,
            "checkpointing": None,
            #             "checkpointer_config": {
            #                 "checkpoint_metric": f"{TASK_NAME}/GLUE/train/accuracy",
            #                 "checkpoint_freq": 10,
            #             },
        },
    }
)

[2019-05-07 13:56:35,726][INFO] emmental.meta:143 - Updating Emmental config from user provided config.


In [12]:
mtl_model = EmmentalModel(name="GLUE_single_task", tasks=[emmental_task])

[2019-05-07 13:56:35,766][INFO] emmental.model:58 - Moving model to GPU (cuda:1).
[2019-05-07 13:56:40,992][INFO] emmental.model:44 - Created emmental model GLUE_single_task that contains task {'RTE'}.
[2019-05-07 13:56:40,995][INFO] emmental.model:58 - Moving model to GPU (cuda:1).


In [13]:
emmental_learner = EmmentalLearner()

In [14]:
emmental_learner.learn(mtl_model, [train_dataloader, dev_dataloader])

[2019-05-07 13:56:41,086][INFO] emmental.logging.logging_manager:33 - Evaluating every 20 batch.
[2019-05-07 13:56:41,087][INFO] emmental.logging.logging_manager:51 - No checkpointing.
[2019-05-07 13:56:41,133][INFO] emmental.learner:286 - Start learning...


HBox(children=(IntProgress(value=0, max=156), HTML(value='')))




HBox(children=(IntProgress(value=0, max=156), HTML(value='')))




HBox(children=(IntProgress(value=0, max=156), HTML(value='')))




In [15]:
mtl_model.score(dev_dataloader)

{'RTE/GLUE/dev/accuracy': 0.6462093862815884}