In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import logging

import torch.nn.functional as F
from torch import nn
from torch.nn import MSELoss

import emmental
from emmental import Meta
from emmental.data import EmmentalDataLoader, EmmentalDataset
from emmental.learner import EmmentalLearner
from emmental.model import EmmentalModel
from emmental.scorer import Scorer
from emmental.task import EmmentalTask
from modules.bert_module import BertModule
from modules.classification_module import ClassificationModule
from modules.regression_module import RegressionModule
from preprocessor import preprocessor
from task_config import LABEL_MAPPING, GLUE_TASK_NAMES
from glue_tasks import get_gule_task

Better speed can be achieved with apex installed from https://www.github.com/nvidia/apex.


In [3]:
logger = logging.getLogger(__name__)

In [4]:
# TASK_NAMES = ["RTE", "STS-B"]
DATA_DIR = "data"
BERT_MODEL_NAME = "bert-base-uncased"
BATCH_SIZE = 16

# Initalize Emmental

In [5]:
emmental.init("logs")

[2019-04-25 17:01:28,372][INFO] emmental.meta:95 - Setting logging directory to: logs/2019_04_25/17_01_28
[2019-04-25 17:01:28,388][INFO] emmental.meta:56 - Loading Emmental default config from /dfs/scratch1/senwu/mmtl/emmental/src/emmental/emmental-default-config.yaml.


# Extract train/dev/test dataset from file

In [6]:
datasets = {}

for task_name in GLUE_TASK_NAMES:
    for split in ["train", "dev", "test"]:
        bert_token_ids, bert_token_segments, bert_token_masks, labels = preprocessor(
            data_dir=DATA_DIR,
            task_name=task_name,
            split=split,
            bert_model_name=BERT_MODEL_NAME,
            max_data_samples=1000,
            max_sequence_length=100,
        )
        X_dict = {
            "token_ids": bert_token_ids,
            "token_segments": bert_token_segments,
            "token_masks": bert_token_masks,
        }
        Y_dict = {"labels": labels}

        if task_name not in datasets: datasets[task_name] = {}
        
        datasets[task_name][split] = EmmentalDataset(name="GLUE", X_dict=X_dict, Y_dict=Y_dict)

        logger.info(f"Loaded {split} for {task_name}.")

HBox(children=(IntProgress(value=0, max=1000), HTML(value='')))




[2019-04-25 17:01:28,850][INFO] pytorch_pretrained_bert.tokenization:146 - loading vocabulary file https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-uncased-vocab.txt from cache at /lfs/local/0/senwu/.pytorch_pretrained_bert/26bc1ad6c0ac742e9b52263248f6d0f00068293b33709fae12320c0e35ccfbbb.542ce4285a40d23a559526243235df47c5f75c197f04f37d1a0c124c32c9a084
[2019-04-25 17:01:29,246][INFO] __main__:24 - Loaded train for CoLA.


HBox(children=(IntProgress(value=0, max=1000), HTML(value='')))




[2019-04-25 17:01:29,562][INFO] pytorch_pretrained_bert.tokenization:146 - loading vocabulary file https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-uncased-vocab.txt from cache at /lfs/local/0/senwu/.pytorch_pretrained_bert/26bc1ad6c0ac742e9b52263248f6d0f00068293b33709fae12320c0e35ccfbbb.542ce4285a40d23a559526243235df47c5f75c197f04f37d1a0c124c32c9a084
[2019-04-25 17:01:29,922][INFO] __main__:24 - Loaded dev for CoLA.


HBox(children=(IntProgress(value=0, max=1000), HTML(value='')))




[2019-04-25 17:01:30,260][INFO] pytorch_pretrained_bert.tokenization:146 - loading vocabulary file https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-uncased-vocab.txt from cache at /lfs/local/0/senwu/.pytorch_pretrained_bert/26bc1ad6c0ac742e9b52263248f6d0f00068293b33709fae12320c0e35ccfbbb.542ce4285a40d23a559526243235df47c5f75c197f04f37d1a0c124c32c9a084
[2019-04-25 17:01:30,611][INFO] __main__:24 - Loaded test for CoLA.


HBox(children=(IntProgress(value=0, max=1000), HTML(value='')))




[2019-04-25 17:01:46,461][INFO] pytorch_pretrained_bert.tokenization:146 - loading vocabulary file https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-uncased-vocab.txt from cache at /lfs/local/0/senwu/.pytorch_pretrained_bert/26bc1ad6c0ac742e9b52263248f6d0f00068293b33709fae12320c0e35ccfbbb.542ce4285a40d23a559526243235df47c5f75c197f04f37d1a0c124c32c9a084
[2019-04-25 17:01:47,681][INFO] __main__:24 - Loaded train for MNLI.


HBox(children=(IntProgress(value=0, max=1000), HTML(value='')))




[2019-04-25 17:01:48,484][INFO] pytorch_pretrained_bert.tokenization:146 - loading vocabulary file https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-uncased-vocab.txt from cache at /lfs/local/0/senwu/.pytorch_pretrained_bert/26bc1ad6c0ac742e9b52263248f6d0f00068293b33709fae12320c0e35ccfbbb.542ce4285a40d23a559526243235df47c5f75c197f04f37d1a0c124c32c9a084
[2019-04-25 17:01:49,675][INFO] __main__:24 - Loaded dev for MNLI.


HBox(children=(IntProgress(value=0, max=1000), HTML(value='')))




[2019-04-25 17:01:50,360][INFO] pytorch_pretrained_bert.tokenization:146 - loading vocabulary file https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-uncased-vocab.txt from cache at /lfs/local/0/senwu/.pytorch_pretrained_bert/26bc1ad6c0ac742e9b52263248f6d0f00068293b33709fae12320c0e35ccfbbb.542ce4285a40d23a559526243235df47c5f75c197f04f37d1a0c124c32c9a084
[2019-04-25 17:01:51,577][INFO] __main__:24 - Loaded test for MNLI.


HBox(children=(IntProgress(value=0, max=1000), HTML(value='')))




[2019-04-25 17:01:51,964][INFO] pytorch_pretrained_bert.tokenization:146 - loading vocabulary file https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-uncased-vocab.txt from cache at /lfs/local/0/senwu/.pytorch_pretrained_bert/26bc1ad6c0ac742e9b52263248f6d0f00068293b33709fae12320c0e35ccfbbb.542ce4285a40d23a559526243235df47c5f75c197f04f37d1a0c124c32c9a084
[2019-04-25 17:01:53,618][INFO] __main__:24 - Loaded train for MRPC.


HBox(children=(IntProgress(value=0, max=408), HTML(value='')))




[2019-04-25 17:01:53,937][INFO] pytorch_pretrained_bert.tokenization:146 - loading vocabulary file https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-uncased-vocab.txt from cache at /lfs/local/0/senwu/.pytorch_pretrained_bert/26bc1ad6c0ac742e9b52263248f6d0f00068293b33709fae12320c0e35ccfbbb.542ce4285a40d23a559526243235df47c5f75c197f04f37d1a0c124c32c9a084
[2019-04-25 17:01:54,650][INFO] __main__:24 - Loaded dev for MRPC.


HBox(children=(IntProgress(value=0, max=1000), HTML(value='')))




[2019-04-25 17:01:54,990][INFO] pytorch_pretrained_bert.tokenization:146 - loading vocabulary file https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-uncased-vocab.txt from cache at /lfs/local/0/senwu/.pytorch_pretrained_bert/26bc1ad6c0ac742e9b52263248f6d0f00068293b33709fae12320c0e35ccfbbb.542ce4285a40d23a559526243235df47c5f75c197f04f37d1a0c124c32c9a084
[2019-04-25 17:01:56,631][INFO] __main__:24 - Loaded test for MRPC.


HBox(children=(IntProgress(value=0, max=1000), HTML(value='')))




[2019-04-25 17:01:59,088][INFO] pytorch_pretrained_bert.tokenization:146 - loading vocabulary file https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-uncased-vocab.txt from cache at /lfs/local/0/senwu/.pytorch_pretrained_bert/26bc1ad6c0ac742e9b52263248f6d0f00068293b33709fae12320c0e35ccfbbb.542ce4285a40d23a559526243235df47c5f75c197f04f37d1a0c124c32c9a084
[2019-04-25 17:02:00,635][INFO] __main__:24 - Loaded train for QNLI.


HBox(children=(IntProgress(value=0, max=1000), HTML(value='')))




[2019-04-25 17:02:01,062][INFO] pytorch_pretrained_bert.tokenization:146 - loading vocabulary file https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-uncased-vocab.txt from cache at /lfs/local/0/senwu/.pytorch_pretrained_bert/26bc1ad6c0ac742e9b52263248f6d0f00068293b33709fae12320c0e35ccfbbb.542ce4285a40d23a559526243235df47c5f75c197f04f37d1a0c124c32c9a084
[2019-04-25 17:02:02,538][INFO] __main__:24 - Loaded dev for QNLI.


HBox(children=(IntProgress(value=0, max=1000), HTML(value='')))




[2019-04-25 17:02:02,955][INFO] pytorch_pretrained_bert.tokenization:146 - loading vocabulary file https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-uncased-vocab.txt from cache at /lfs/local/0/senwu/.pytorch_pretrained_bert/26bc1ad6c0ac742e9b52263248f6d0f00068293b33709fae12320c0e35ccfbbb.542ce4285a40d23a559526243235df47c5f75c197f04f37d1a0c124c32c9a084
[2019-04-25 17:02:05,201][INFO] __main__:24 - Loaded test for QNLI.


HBox(children=(IntProgress(value=0, max=1000), HTML(value='')))




[2019-04-25 17:02:10,967][INFO] pytorch_pretrained_bert.tokenization:146 - loading vocabulary file https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-uncased-vocab.txt from cache at /lfs/local/0/senwu/.pytorch_pretrained_bert/26bc1ad6c0ac742e9b52263248f6d0f00068293b33709fae12320c0e35ccfbbb.542ce4285a40d23a559526243235df47c5f75c197f04f37d1a0c124c32c9a084
[2019-04-25 17:02:11,871][INFO] __main__:24 - Loaded train for QQP.


HBox(children=(IntProgress(value=0, max=1000), HTML(value='')))




[2019-04-25 17:02:12,775][INFO] pytorch_pretrained_bert.tokenization:146 - loading vocabulary file https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-uncased-vocab.txt from cache at /lfs/local/0/senwu/.pytorch_pretrained_bert/26bc1ad6c0ac742e9b52263248f6d0f00068293b33709fae12320c0e35ccfbbb.542ce4285a40d23a559526243235df47c5f75c197f04f37d1a0c124c32c9a084
[2019-04-25 17:02:13,680][INFO] __main__:24 - Loaded dev for QQP.


HBox(children=(IntProgress(value=0, max=1000), HTML(value='')))




[2019-04-25 17:02:19,625][INFO] pytorch_pretrained_bert.tokenization:146 - loading vocabulary file https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-uncased-vocab.txt from cache at /lfs/local/0/senwu/.pytorch_pretrained_bert/26bc1ad6c0ac742e9b52263248f6d0f00068293b33709fae12320c0e35ccfbbb.542ce4285a40d23a559526243235df47c5f75c197f04f37d1a0c124c32c9a084
[2019-04-25 17:02:20,529][INFO] __main__:24 - Loaded test for QQP.


HBox(children=(IntProgress(value=0, max=1000), HTML(value='')))




[2019-04-25 17:02:20,886][INFO] pytorch_pretrained_bert.tokenization:146 - loading vocabulary file https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-uncased-vocab.txt from cache at /lfs/local/0/senwu/.pytorch_pretrained_bert/26bc1ad6c0ac742e9b52263248f6d0f00068293b33709fae12320c0e35ccfbbb.542ce4285a40d23a559526243235df47c5f75c197f04f37d1a0c124c32c9a084
[2019-04-25 17:02:23,102][INFO] __main__:24 - Loaded train for RTE.


HBox(children=(IntProgress(value=0, max=277), HTML(value='')))




[2019-04-25 17:02:23,410][INFO] pytorch_pretrained_bert.tokenization:146 - loading vocabulary file https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-uncased-vocab.txt from cache at /lfs/local/0/senwu/.pytorch_pretrained_bert/26bc1ad6c0ac742e9b52263248f6d0f00068293b33709fae12320c0e35ccfbbb.542ce4285a40d23a559526243235df47c5f75c197f04f37d1a0c124c32c9a084
[2019-04-25 17:02:24,027][INFO] __main__:24 - Loaded dev for RTE.


HBox(children=(IntProgress(value=0, max=1000), HTML(value='')))




[2019-04-25 17:02:24,401][INFO] pytorch_pretrained_bert.tokenization:146 - loading vocabulary file https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-uncased-vocab.txt from cache at /lfs/local/0/senwu/.pytorch_pretrained_bert/26bc1ad6c0ac742e9b52263248f6d0f00068293b33709fae12320c0e35ccfbbb.542ce4285a40d23a559526243235df47c5f75c197f04f37d1a0c124c32c9a084
[2019-04-25 17:02:25,946][INFO] __main__:24 - Loaded test for RTE.


HBox(children=(IntProgress(value=0, max=1000), HTML(value='')))




[2019-04-25 17:02:45,184][INFO] pytorch_pretrained_bert.tokenization:146 - loading vocabulary file https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-uncased-vocab.txt from cache at /lfs/local/0/senwu/.pytorch_pretrained_bert/26bc1ad6c0ac742e9b52263248f6d0f00068293b33709fae12320c0e35ccfbbb.542ce4285a40d23a559526243235df47c5f75c197f04f37d1a0c124c32c9a084
[2019-04-25 17:02:45,995][INFO] __main__:24 - Loaded train for SNLI.


HBox(children=(IntProgress(value=0, max=1000), HTML(value='')))




[2019-04-25 17:02:46,660][INFO] pytorch_pretrained_bert.tokenization:146 - loading vocabulary file https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-uncased-vocab.txt from cache at /lfs/local/0/senwu/.pytorch_pretrained_bert/26bc1ad6c0ac742e9b52263248f6d0f00068293b33709fae12320c0e35ccfbbb.542ce4285a40d23a559526243235df47c5f75c197f04f37d1a0c124c32c9a084
[2019-04-25 17:02:47,508][INFO] __main__:24 - Loaded dev for SNLI.


HBox(children=(IntProgress(value=0, max=1000), HTML(value='')))






[2019-04-25 17:02:48,377][INFO] pytorch_pretrained_bert.tokenization:146 - loading vocabulary file https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-uncased-vocab.txt from cache at /lfs/local/0/senwu/.pytorch_pretrained_bert/26bc1ad6c0ac742e9b52263248f6d0f00068293b33709fae12320c0e35ccfbbb.542ce4285a40d23a559526243235df47c5f75c197f04f37d1a0c124c32c9a084
[2019-04-25 17:02:49,198][INFO] __main__:24 - Loaded test for SNLI.


HBox(children=(IntProgress(value=0, max=1000), HTML(value='')))




[2019-04-25 17:02:50,064][INFO] pytorch_pretrained_bert.tokenization:146 - loading vocabulary file https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-uncased-vocab.txt from cache at /lfs/local/0/senwu/.pytorch_pretrained_bert/26bc1ad6c0ac742e9b52263248f6d0f00068293b33709fae12320c0e35ccfbbb.542ce4285a40d23a559526243235df47c5f75c197f04f37d1a0c124c32c9a084
[2019-04-25 17:02:50,488][INFO] __main__:24 - Loaded train for SST-2.


HBox(children=(IntProgress(value=0, max=872), HTML(value='')))




[2019-04-25 17:02:50,805][INFO] pytorch_pretrained_bert.tokenization:146 - loading vocabulary file https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-uncased-vocab.txt from cache at /lfs/local/0/senwu/.pytorch_pretrained_bert/26bc1ad6c0ac742e9b52263248f6d0f00068293b33709fae12320c0e35ccfbbb.542ce4285a40d23a559526243235df47c5f75c197f04f37d1a0c124c32c9a084
[2019-04-25 17:02:51,493][INFO] __main__:24 - Loaded dev for SST-2.


HBox(children=(IntProgress(value=0, max=1000), HTML(value='')))




[2019-04-25 17:02:51,819][INFO] pytorch_pretrained_bert.tokenization:146 - loading vocabulary file https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-uncased-vocab.txt from cache at /lfs/local/0/senwu/.pytorch_pretrained_bert/26bc1ad6c0ac742e9b52263248f6d0f00068293b33709fae12320c0e35ccfbbb.542ce4285a40d23a559526243235df47c5f75c197f04f37d1a0c124c32c9a084
[2019-04-25 17:02:52,583][INFO] __main__:24 - Loaded test for SST-2.


# Build Emmental dataloader

In [7]:
dataloaders = []

for task_name in GLUE_TASK_NAMES:
    for split in ["train", "dev", "test"]:
        dataloaders.append(
            EmmentalDataLoader(
                task_name=task_name,
                dataset=datasets[task_name][split],
                label_name="labels",
                split=split,
                batch_size=BATCH_SIZE,
            )
        )
        logger.info(f"Built dataloader for {task_name} {split} set.")

# Build Emmental task

In [8]:
tasks = [get_gule_task(task_name, BERT_MODEL_NAME) for task_name in GLUE_TASK_NAMES]

[2019-04-25 17:04:35,192][INFO] pytorch_pretrained_bert.modeling:579 - Model config {
  "attention_probs_dropout_prob": 0.1,
  "hidden_act": "gelu",
  "hidden_dropout_prob": 0.1,
  "hidden_size": 768,
  "initializer_range": 0.02,
  "intermediate_size": 3072,
  "max_position_embeddings": 512,
  "num_attention_heads": 12,
  "num_hidden_layers": 12,
  "type_vocab_size": 2,
  "vocab_size": 30522
}

[2019-04-25 17:04:46,583][INFO] emmental.task:34 - Created task: RTE
[2019-04-25 17:04:46,854][INFO] pytorch_pretrained_bert.modeling:564 - loading archive file https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-uncased.tar.gz from cache at ./cache/9c41111e2de84547a463fd39217199738d1e3deb72d4fec4399e6e241983c6f0.ae3cef932725ca7a30cdcb93fc6e09150a55e2a130ec7af63975a16c153ae2ba
[2019-04-25 17:04:46,856][INFO] pytorch_pretrained_bert.modeling:572 - extracting archive file ./cache/9c41111e2de84547a463fd39217199738d1e3deb72d4fec4399e6e241983c6f0.ae3cef932725ca7a30cdcb93fc6e09150a55e2a130ec

In [9]:
# def mse_loss(immediate_ouput, Y):
#     mse = MSELoss()
#     return mse(immediate_ouput[-1][0].view(-1), Y.view(-1))

In [10]:
# def ce_loss(immediate_ouput, Y):
#     return F.cross_entropy(immediate_ouput[-1][0], Y.view(-1) - 1)

In [11]:
# def output(immediate_ouput):
#     return immediate_ouput[-1][0]

In [12]:
# BERT_OUTPUT_DIM = 768 if "uncased" in BERT_MODEL_NAME else 1024

# TASK_CARDINALITY = len(LABEL_MAPPING["RTE"].keys()) if LABEL_MAPPING["RTE"] is not None else 1
# RTE_task = EmmentalTask(
#     name="RTE",
#     module_pool=nn.ModuleDict(
#         {
#             "bert_module": BertModule(BERT_MODEL_NAME),
#             "classification_module": ClassificationModule(BERT_OUTPUT_DIM, TASK_CARDINALITY),
#         }
#     ),
#     task_flow=[
#         {"module": "bert_module", "inputs": [(0, 'token_ids'), (0, 'token_segments')]},
#         {"module": "classification_module", "inputs": [(1, 1)]},
#     ],
#     loss_func=ce_loss,
#     output_func=output,
#     scorer=Scorer(metrics=['accuracy']),
# )

# STSB_task = EmmentalTask(
#     name="STS-B",
#     module_pool=nn.ModuleDict(
#         {
#             "bert_module": BertModule(BERT_MODEL_NAME),
#             "regression_module": RegressionModule(BERT_OUTPUT_DIM),
#         }
#     ),
#     task_flow=[
#         {"module": "bert_module", "inputs": [(0, 'token_ids'), (0, 'token_segments')]},
#         {"module": "regression_module", "inputs": [(1, 1)]},
#     ],
#     loss_func=mse_loss,
#     output_func=output,
#     scorer=Scorer(metrics=['pearson_spearman']),
# )

In [13]:
Meta.update_config(
    config={
        "meta_config": {"device": 0},
        "learner_config": {
            "n_epochs": 3,
            "valid_split": "dev",
            "optimizer_config": {"optimizer": "adam", "lr": 5e-5},
            "lr_scheduler_config": {"warmup_steps": 100, "warmup_unit": "batch", "lr_scheduler":"linear"},
        },
        "logging_config": {
            "evaluation_freq": 100,
            "checkpointer_config": {
                "checkpoint_metric": f"RTE/GLUE/train/accuracy",
                "checkpoint_freq": 100000000,
            },
        },
    }
)

[2019-04-25 17:05:58,078][INFO] emmental.meta:143 - Updating Emmental config from user provided config.


In [14]:
# mtl_model = EmmentalModel(name = 'GLUE_multi_task', tasks=[RTE_task, STSB_task])
mtl_model = EmmentalModel(name = 'GLUE_multi_task', tasks=tasks)

[2019-04-25 17:05:58,112][INFO] emmental.model:57 - Moving model to GPU.
[2019-04-25 17:06:02,994][INFO] emmental.model:57 - Moving model to GPU.
[2019-04-25 17:06:03,003][INFO] emmental.model:57 - Moving model to GPU.
[2019-04-25 17:06:03,008][INFO] emmental.model:57 - Moving model to GPU.
[2019-04-25 17:06:03,013][INFO] emmental.model:57 - Moving model to GPU.
[2019-04-25 17:06:03,017][INFO] emmental.model:57 - Moving model to GPU.
[2019-04-25 17:06:03,022][INFO] emmental.model:57 - Moving model to GPU.
[2019-04-25 17:06:03,027][INFO] emmental.model:57 - Moving model to GPU.
[2019-04-25 17:06:03,032][INFO] emmental.model:57 - Moving model to GPU.
[2019-04-25 17:06:03,036][INFO] emmental.model:57 - Moving model to GPU.
[2019-04-25 17:06:03,040][INFO] emmental.model:44 - Created emmental model GLUE_multi_task that contains task {'QQP', 'WNLI', 'SST-2', 'MRPC', 'MNLI', 'STS-B', 'SNLI', 'CoLA', 'RTE', 'QNLI'}.
[2019-04-25 17:06:03,040][INFO] emmental.model:57 - Moving model to GPU.


In [15]:
emmental_learner = EmmentalLearner()

In [17]:
emmental_learner.learn(mtl_model, dataloaders)

[2019-04-25 17:29:19,145][INFO] emmental.logging.logging_manager:33 - Evaluating every 100 batch.
[2019-04-25 17:29:19,146][INFO] emmental.logging.logging_manager:40 - Checkpointing every 10000000000 batch.
[2019-04-25 17:29:19,210][INFO] emmental.logging.checkpointer:41 - Save checkpoints at logs/2019_04_25/17_01_28 every 10000000000 batch
[2019-04-25 17:29:19,211][INFO] emmental.logging.checkpointer:65 - No checkpoints saved before 0 batch.
[2019-04-25 17:29:19,217][INFO] emmental.learner:249 - Start learning...


HBox(children=(IntProgress(value=0, max=607), HTML(value='')))

  mcc = cov_ytyp / np.sqrt(cov_ytyt * cov_ypyp)


HBox(children=(IntProgress(value=0, max=607), HTML(value='')))

HBox(children=(IntProgress(value=0, max=607), HTML(value='')))

In [18]:
mtl_model.score(dataloaders)

  r = r_num / r_den
  c /= stddev[:, None]
  c /= stddev[None, :]
  return (self.a < x) & (x < self.b)
  return (self.a < x) & (x < self.b)
  cond2 = cond0 & (x <= self.a)


{'CoLA/GLUE/train/matthews_corrcoef': 0.8037668823449967,
 'CoLA/GLUE/dev/matthews_corrcoef': 0.3310809477308711,
 'CoLA/GLUE/test/matthews_corrcoef': 0.0,
 'MNLI/GLUE/train/accuracy': 0.917,
 'MNLI/GLUE/dev/accuracy': 0.616,
 'MNLI/GLUE/test/accuracy': 0.0,
 'MRPC/GLUE/train/accuracy': 0.968,
 'MRPC/GLUE/train/f1': 0.9762611275964392,
 'MRPC/GLUE/dev/accuracy': 0.7843137254901961,
 'MRPC/GLUE/dev/f1': 0.8508474576271187,
 'MRPC/GLUE/test/accuracy': 0.0,
 'MRPC/GLUE/test/f1': 0.0,
 'QNLI/GLUE/train/accuracy': 0.967,
 'QNLI/GLUE/dev/accuracy': 0.677,
 'QNLI/GLUE/test/accuracy': 0.0,
 'QQP/GLUE/train/accuracy': 0.972,
 'QQP/GLUE/train/f1': 0.9615384615384616,
 'QQP/GLUE/dev/accuracy': 0.775,
 'QQP/GLUE/dev/f1': 0.6853146853146853,
 'QQP/GLUE/test/accuracy': 0.0,
 'QQP/GLUE/test/f1': 0.0,
 'RTE/GLUE/train/accuracy': 0.933,
 'RTE/GLUE/dev/accuracy': 0.6570397111913358,
 'RTE/GLUE/test/accuracy': 0.0,
 'SNLI/GLUE/train/accuracy': 0.935,
 'SNLI/GLUE/dev/accuracy': 0.754,
 'SNLI/GLUE/test/acc