In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import logging

import torch.nn.functional as F
from torch import nn
from torch.nn import MSELoss

import emmental
from emmental import Meta
from emmental.data import EmmentalDataLoader, EmmentalDataset
from emmental.learner import EmmentalLearner
from emmental.model import EmmentalModel
from emmental.scorer import Scorer
from emmental.task import EmmentalTask
from modules.bert_module import BertModule
from modules.classification_module import ClassificationModule
from modules.regression_module import RegressionModule
from preprocessor import preprocessor
from task_config import LABEL_MAPPING

Better speed can be achieved with apex installed from https://www.github.com/nvidia/apex.


In [3]:
logger = logging.getLogger(__name__)

In [4]:
TASK_NAMES = ["RTE", "STS-B"]
DATA_DIR = "data"
BERT_MODEL_NAME = "bert-base-uncased"
BATCH_SIZE = 16

# Initalize Emmental

In [5]:
emmental.init("logs")

[2019-04-23 20:51:11,024][INFO] emmental.meta:95 - Setting logging directory to: logs/2019_04_23/20_51_10
[2019-04-23 20:51:11,040][INFO] emmental.meta:56 - Loading Emmental default config from /dfs/scratch1/senwu/mmtl/emmental/src/emmental/emmental-default-config.yaml.


# Extract train/dev/test dataset from file

In [6]:
datasets = {}

for task_name in TASK_NAMES:
    for split in ["train", "dev", "test"]:
        bert_token_ids, bert_token_segments, bert_token_masks, labels = preprocessor(
            data_dir=DATA_DIR,
            task_name=task_name,
            split=split,
            bert_model_name=BERT_MODEL_NAME,
            max_data_samples=None,
            max_sequence_length=200,
        )
        X_dict = {
            "token_ids": bert_token_ids,
            "token_segments": bert_token_segments,
            "token_masks": bert_token_masks,
        }
        Y_dict = {"labels": labels}

        if task_name not in datasets: datasets[task_name] = {}
        
        datasets[task_name][split] = EmmentalDataset(name="GLUE", X_dict=X_dict, Y_dict=Y_dict)

        logger.info(f"Loaded {split} for {task_name}.")

HBox(children=(IntProgress(value=0, max=2490), HTML(value='')))




[2019-04-23 20:51:13,184][INFO] pytorch_pretrained_bert.tokenization:146 - loading vocabulary file https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-uncased-vocab.txt from cache at /lfs/local/0/senwu/.pytorch_pretrained_bert/26bc1ad6c0ac742e9b52263248f6d0f00068293b33709fae12320c0e35ccfbbb.542ce4285a40d23a559526243235df47c5f75c197f04f37d1a0c124c32c9a084
[2019-04-23 20:51:18,472][INFO] __main__:24 - Loaded train for RTE.


HBox(children=(IntProgress(value=0, max=277), HTML(value='')))




[2019-04-23 20:51:18,789][INFO] pytorch_pretrained_bert.tokenization:146 - loading vocabulary file https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-uncased-vocab.txt from cache at /lfs/local/0/senwu/.pytorch_pretrained_bert/26bc1ad6c0ac742e9b52263248f6d0f00068293b33709fae12320c0e35ccfbbb.542ce4285a40d23a559526243235df47c5f75c197f04f37d1a0c124c32c9a084
[2019-04-23 20:51:19,399][INFO] __main__:24 - Loaded dev for RTE.


HBox(children=(IntProgress(value=0, max=3000), HTML(value='')))




[2019-04-23 20:51:19,946][INFO] pytorch_pretrained_bert.tokenization:146 - loading vocabulary file https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-uncased-vocab.txt from cache at /lfs/local/0/senwu/.pytorch_pretrained_bert/26bc1ad6c0ac742e9b52263248f6d0f00068293b33709fae12320c0e35ccfbbb.542ce4285a40d23a559526243235df47c5f75c197f04f37d1a0c124c32c9a084
[2019-04-23 20:51:25,986][INFO] __main__:24 - Loaded test for RTE.


HBox(children=(IntProgress(value=0, max=5749), HTML(value='')))




[2019-04-23 20:51:26,443][INFO] pytorch_pretrained_bert.tokenization:146 - loading vocabulary file https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-uncased-vocab.txt from cache at /lfs/local/0/senwu/.pytorch_pretrained_bert/26bc1ad6c0ac742e9b52263248f6d0f00068293b33709fae12320c0e35ccfbbb.542ce4285a40d23a559526243235df47c5f75c197f04f37d1a0c124c32c9a084
[2019-04-23 20:51:31,139][INFO] __main__:24 - Loaded train for STS-B.


HBox(children=(IntProgress(value=0, max=1500), HTML(value='')))




[2019-04-23 20:51:31,506][INFO] pytorch_pretrained_bert.tokenization:146 - loading vocabulary file https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-uncased-vocab.txt from cache at /lfs/local/0/senwu/.pytorch_pretrained_bert/26bc1ad6c0ac742e9b52263248f6d0f00068293b33709fae12320c0e35ccfbbb.542ce4285a40d23a559526243235df47c5f75c197f04f37d1a0c124c32c9a084
[2019-04-23 20:51:32,881][INFO] __main__:24 - Loaded dev for STS-B.


HBox(children=(IntProgress(value=0, max=1379), HTML(value='')))




[2019-04-23 20:51:33,212][INFO] pytorch_pretrained_bert.tokenization:146 - loading vocabulary file https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-uncased-vocab.txt from cache at /lfs/local/0/senwu/.pytorch_pretrained_bert/26bc1ad6c0ac742e9b52263248f6d0f00068293b33709fae12320c0e35ccfbbb.542ce4285a40d23a559526243235df47c5f75c197f04f37d1a0c124c32c9a084
[2019-04-23 20:51:34,532][INFO] __main__:24 - Loaded test for STS-B.


# Build Emmental dataloader

In [7]:
dataloaders = []

for task_name in TASK_NAMES:
    for split in ["train", "dev", "test"]:
        dataloaders.append(
            EmmentalDataLoader(
                task_name=task_name,
                dataset=datasets[task_name][split],
                label_name="labels",
                split=split,
                batch_size=BATCH_SIZE,
            )
        )
        logger.info(f"Built dataloader for {task_name} {split} set.")

[2019-04-23 20:51:34,575][INFO] __main__:14 - Built dataloader for RTE train set.
[2019-04-23 20:51:34,576][INFO] __main__:14 - Built dataloader for RTE dev set.
[2019-04-23 20:51:34,577][INFO] __main__:14 - Built dataloader for RTE test set.
[2019-04-23 20:51:34,578][INFO] __main__:14 - Built dataloader for STS-B train set.
[2019-04-23 20:51:34,580][INFO] __main__:14 - Built dataloader for STS-B dev set.
[2019-04-23 20:51:34,581][INFO] __main__:14 - Built dataloader for STS-B test set.


# Build Emmental task

In [8]:
def mse_loss(immediate_ouput, Y):
    mse = MSELoss()
    return mse(immediate_ouput[-1][0].view(-1), Y.view(-1))

In [9]:
def ce_loss(immediate_ouput, Y):
    return F.cross_entropy(immediate_ouput[-1][0], Y.view(-1) - 1)

In [10]:
def output(immediate_ouput):
    return immediate_ouput[-1][0]

In [11]:
BERT_OUTPUT_DIM = 768 if "uncased" in BERT_MODEL_NAME else 1024

TASK_CARDINALITY = len(LABEL_MAPPING["RTE"].keys()) if LABEL_MAPPING["RTE"] is not None else 1
RTE_task = EmmentalTask(
    name="RTE",
    module_pool=nn.ModuleDict(
        {
            "bert_module": BertModule(BERT_MODEL_NAME),
            "classification_module": ClassificationModule(BERT_OUTPUT_DIM, TASK_CARDINALITY),
        }
    ),
    task_flow=[
        {"module": "bert_module", "inputs": [(0, 'token_ids'), (0, 'token_segments')]},
        {"module": "classification_module", "inputs": [(1, 1)]},
    ],
    loss_func=ce_loss,
    output_func=output,
    scorer=Scorer(metrics=['accuracy']),
)

STSB_task = EmmentalTask(
    name="STS-B",
    module_pool=nn.ModuleDict(
        {
            "bert_module": BertModule(BERT_MODEL_NAME),
            "regression_module": RegressionModule(BERT_OUTPUT_DIM),
        }
    ),
    task_flow=[
        {"module": "bert_module", "inputs": [(0, 'token_ids'), (0, 'token_segments')]},
        {"module": "regression_module", "inputs": [(1, 1)]},
    ],
    loss_func=mse_loss,
    output_func=output,
    scorer=Scorer(metrics=['pearson_spearman']),
)

[2019-04-23 20:51:34,996][INFO] pytorch_pretrained_bert.modeling:564 - loading archive file https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-uncased.tar.gz from cache at ./cache/9c41111e2de84547a463fd39217199738d1e3deb72d4fec4399e6e241983c6f0.ae3cef932725ca7a30cdcb93fc6e09150a55e2a130ec7af63975a16c153ae2ba
[2019-04-23 20:51:34,998][INFO] pytorch_pretrained_bert.modeling:572 - extracting archive file ./cache/9c41111e2de84547a463fd39217199738d1e3deb72d4fec4399e6e241983c6f0.ae3cef932725ca7a30cdcb93fc6e09150a55e2a130ec7af63975a16c153ae2ba to temp dir /tmp/tmpvn8ltbn8
[2019-04-23 20:51:40,849][INFO] pytorch_pretrained_bert.modeling:579 - Model config {
  "attention_probs_dropout_prob": 0.1,
  "hidden_act": "gelu",
  "hidden_dropout_prob": 0.1,
  "hidden_size": 768,
  "initializer_range": 0.02,
  "intermediate_size": 3072,
  "max_position_embeddings": 512,
  "num_attention_heads": 12,
  "num_hidden_layers": 12,
  "type_vocab_size": 2,
  "vocab_size": 30522
}

[2019-04-23 20:51:5

In [12]:
Meta.update_config(
    config={
        "meta_config": {"device": 1},
        "learner_config": {
            "n_epochs": 3,
            "valid_split": "dev",
            "optimizer_config": {"optimizer": "adam", "lr": 5e-5},
            "lr_scheduler_config": {"warmup_steps": 100, "warmup_unit": "batch", "lr_scheduler":"linear"},
        },
        "logging_config": {
            "evaluation_freq": 20,
            "checkpointer_config": {
                "checkpoint_metric": f"RTE/GLUE/train/accuracy",
                "checkpoint_freq": 10,
            },
        },
    }
)

[2019-04-23 20:52:12,835][INFO] emmental.meta:143 - Updating Emmental config from user provided config.


In [13]:
mtl_model = EmmentalModel(name = 'GLUE_multi_task', tasks=[RTE_task, STSB_task])

[2019-04-23 20:52:12,874][INFO] emmental.model:57 - Moving model to GPU.
[2019-04-23 20:52:17,756][INFO] emmental.model:57 - Moving model to GPU.
[2019-04-23 20:52:17,762][INFO] emmental.model:44 - Created emmental model GLUE_multi_task that contains task {'RTE', 'STS-B'}.
[2019-04-23 20:52:17,763][INFO] emmental.model:57 - Moving model to GPU.


In [14]:
emmental_learner = EmmentalLearner()

In [15]:
emmental_learner.learn(mtl_model, dataloaders)

[2019-04-23 20:52:17,847][INFO] emmental.logging.logging_manager:33 - Evaluating every 20 batch.
[2019-04-23 20:52:17,848][INFO] emmental.logging.logging_manager:40 - Checkpointing every 200 batch.
[2019-04-23 20:52:17,884][INFO] emmental.logging.checkpointer:41 - Save checkpoints at logs/2019_04_23/20_51_10 every 200 batch
[2019-04-23 20:52:17,885][INFO] emmental.logging.checkpointer:65 - No checkpoints saved before 0 batch.
[2019-04-23 20:52:17,890][INFO] emmental.learner:249 - Start learning...


HBox(children=(IntProgress(value=0, max=516), HTML(value='')))

[2019-04-23 20:57:55,092][INFO] emmental.logging.checkpointer:87 - Save checkpoint of 200 batch at logs/2019_04_23/20_51_10/checkpoint_200.pth.
[2019-04-23 21:03:32,947][INFO] emmental.logging.checkpointer:87 - Save checkpoint of 400 batch at logs/2019_04_23/20_51_10/checkpoint_400.pth.





HBox(children=(IntProgress(value=0, max=516), HTML(value='')))

[2019-04-23 21:09:11,176][INFO] emmental.logging.checkpointer:87 - Save checkpoint of 600 batch at logs/2019_04_23/20_51_10/checkpoint_600.pth.
[2019-04-23 21:14:49,036][INFO] emmental.logging.checkpointer:87 - Save checkpoint of 800 batch at logs/2019_04_23/20_51_10/checkpoint_800.pth.
[2019-04-23 21:20:26,628][INFO] emmental.logging.checkpointer:87 - Save checkpoint of 1000 batch at logs/2019_04_23/20_51_10/checkpoint_1000.pth.





HBox(children=(IntProgress(value=0, max=516), HTML(value='')))

[2019-04-23 21:26:03,907][INFO] emmental.logging.checkpointer:87 - Save checkpoint of 1200 batch at logs/2019_04_23/20_51_10/checkpoint_1200.pth.
[2019-04-23 21:31:40,926][INFO] emmental.logging.checkpointer:87 - Save checkpoint of 1400 batch at logs/2019_04_23/20_51_10/checkpoint_1400.pth.





In [17]:
mtl_model.score(dataloaders)

  r = r_num / r_den
  c /= stddev[:, None]
  c /= stddev[None, :]
  return (self.a < x) & (x < self.b)
  return (self.a < x) & (x < self.b)
  cond2 = cond0 & (x <= self.a)


{'RTE/GLUE/train/accuracy': 0.8618473895582329,
 'RTE/GLUE/dev/accuracy': 0.6173285198555957,
 'RTE/GLUE/test/accuracy': 0.0,
 'STS-B/GLUE/train/pearson_correlation': 0.9574412,
 'STS-B/GLUE/train/pearson_pvalue': 0.0,
 'STS-B/GLUE/train/spearman_correlation': 0.9526191808218984,
 'STS-B/GLUE/train/spearman_pvalue': 0.0,
 'STS-B/GLUE/train/pearson_spearman': 0.9550301957843318,
 'STS-B/GLUE/dev/pearson_correlation': 0.87345105,
 'STS-B/GLUE/dev/pearson_pvalue': 0.0,
 'STS-B/GLUE/dev/spearman_correlation': 0.8707788884160069,
 'STS-B/GLUE/dev/spearman_pvalue': 0.0,
 'STS-B/GLUE/dev/pearson_spearman': 0.8721149712561145,
 'STS-B/GLUE/test/pearson_correlation': 0.0,
 'STS-B/GLUE/test/pearson_pvalue': 1.0,
 'STS-B/GLUE/test/spearman_correlation': 0.0,
 'STS-B/GLUE/test/spearman_pvalue': nan,
 'STS-B/GLUE/test/pearson_spearman': 0.0}