In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import logging

import torch.nn.functional as F
from torch import nn
from torch.nn import MSELoss

import emmental
from emmental import Meta
from emmental.data import EmmentalDataLoader, EmmentalDataset
from emmental.learner import EmmentalLearner
from emmental.model import EmmentalModel
from emmental.scorer import Scorer
from emmental.task import EmmentalTask
from modules.bert_module import BertModule
from modules.classification_module import ClassificationModule
from modules.regression_module import RegressionModule
from preprocessor import preprocessor
from task_config import LABEL_MAPPING, GLUE_TASK_NAMES
from glue_tasks import get_gule_task

Better speed can be achieved with apex installed from https://www.github.com/nvidia/apex.


In [3]:
logger = logging.getLogger(__name__)

In [4]:
# TASK_NAMES = ["RTE", "STS-B"]
DATA_DIR = "data"
BERT_MODEL_NAME = "bert-base-uncased"
BATCH_SIZE = 16

# Initalize Emmental

In [5]:
emmental.init("logs")

[2019-05-07 15:33:22,635][INFO] emmental.meta:95 - Setting logging directory to: logs/2019_05_07/15_33_22
[2019-05-07 15:33:22,650][INFO] emmental.meta:56 - Loading Emmental default config from /dfs/scratch1/senwu/mmtl/emmental/src/emmental/emmental-default-config.yaml.


# Extract train/dev/test dataset from file

In [6]:
datasets = {}

for task_name in GLUE_TASK_NAMES:
    for split in ["train", "dev", "test"]:
        bert_token_ids, bert_token_segments, bert_token_masks, labels = preprocessor(
            data_dir=DATA_DIR,
            task_name=task_name,
            split=split,
            bert_model_name=BERT_MODEL_NAME,
            max_data_samples=1000,
            max_sequence_length=128,
        )
        X_dict = {
            "token_ids": bert_token_ids,
            "token_segments": bert_token_segments,
            "token_masks": bert_token_masks,
        }
        Y_dict = {"labels": labels}

        if task_name not in datasets: datasets[task_name] = {}
        
        datasets[task_name][split] = EmmentalDataset(name="GLUE", X_dict=X_dict, Y_dict=Y_dict)

        logger.info(f"Loaded {split} for {task_name}.")

HBox(children=(IntProgress(value=0, max=1000), HTML(value='')))




[2019-05-07 15:33:25,095][INFO] pytorch_pretrained_bert.tokenization:146 - loading vocabulary file https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-uncased-vocab.txt from cache at /lfs/local/0/senwu/.pytorch_pretrained_bert/26bc1ad6c0ac742e9b52263248f6d0f00068293b33709fae12320c0e35ccfbbb.542ce4285a40d23a559526243235df47c5f75c197f04f37d1a0c124c32c9a084
[2019-05-07 15:33:25,512][INFO] __main__:24 - Loaded train for CoLA.


HBox(children=(IntProgress(value=0, max=1000), HTML(value='')))




[2019-05-07 15:33:25,826][INFO] pytorch_pretrained_bert.tokenization:146 - loading vocabulary file https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-uncased-vocab.txt from cache at /lfs/local/0/senwu/.pytorch_pretrained_bert/26bc1ad6c0ac742e9b52263248f6d0f00068293b33709fae12320c0e35ccfbbb.542ce4285a40d23a559526243235df47c5f75c197f04f37d1a0c124c32c9a084
[2019-05-07 15:33:26,233][INFO] __main__:24 - Loaded dev for CoLA.


HBox(children=(IntProgress(value=0, max=1000), HTML(value='')))




[2019-05-07 15:33:26,549][INFO] pytorch_pretrained_bert.tokenization:146 - loading vocabulary file https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-uncased-vocab.txt from cache at /lfs/local/0/senwu/.pytorch_pretrained_bert/26bc1ad6c0ac742e9b52263248f6d0f00068293b33709fae12320c0e35ccfbbb.542ce4285a40d23a559526243235df47c5f75c197f04f37d1a0c124c32c9a084
[2019-05-07 15:33:26,945][INFO] __main__:24 - Loaded test for CoLA.


HBox(children=(IntProgress(value=0, max=1000), HTML(value='')))




[2019-05-07 15:33:42,655][INFO] pytorch_pretrained_bert.tokenization:146 - loading vocabulary file https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-uncased-vocab.txt from cache at /lfs/local/0/senwu/.pytorch_pretrained_bert/26bc1ad6c0ac742e9b52263248f6d0f00068293b33709fae12320c0e35ccfbbb.542ce4285a40d23a559526243235df47c5f75c197f04f37d1a0c124c32c9a084
[2019-05-07 15:33:44,048][INFO] __main__:24 - Loaded train for MNLI.


HBox(children=(IntProgress(value=0, max=1000), HTML(value='')))




[2019-05-07 15:33:44,725][INFO] pytorch_pretrained_bert.tokenization:146 - loading vocabulary file https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-uncased-vocab.txt from cache at /lfs/local/0/senwu/.pytorch_pretrained_bert/26bc1ad6c0ac742e9b52263248f6d0f00068293b33709fae12320c0e35ccfbbb.542ce4285a40d23a559526243235df47c5f75c197f04f37d1a0c124c32c9a084
[2019-05-07 15:33:45,950][INFO] __main__:24 - Loaded dev for MNLI.


HBox(children=(IntProgress(value=0, max=1000), HTML(value='')))




[2019-05-07 15:33:46,640][INFO] pytorch_pretrained_bert.tokenization:146 - loading vocabulary file https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-uncased-vocab.txt from cache at /lfs/local/0/senwu/.pytorch_pretrained_bert/26bc1ad6c0ac742e9b52263248f6d0f00068293b33709fae12320c0e35ccfbbb.542ce4285a40d23a559526243235df47c5f75c197f04f37d1a0c124c32c9a084
[2019-05-07 15:33:47,883][INFO] __main__:24 - Loaded test for MNLI.


HBox(children=(IntProgress(value=0, max=1000), HTML(value='')))




[2019-05-07 15:33:48,267][INFO] pytorch_pretrained_bert.tokenization:146 - loading vocabulary file https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-uncased-vocab.txt from cache at /lfs/local/0/senwu/.pytorch_pretrained_bert/26bc1ad6c0ac742e9b52263248f6d0f00068293b33709fae12320c0e35ccfbbb.542ce4285a40d23a559526243235df47c5f75c197f04f37d1a0c124c32c9a084
[2019-05-07 15:33:49,939][INFO] __main__:24 - Loaded train for MRPC.


HBox(children=(IntProgress(value=0, max=408), HTML(value='')))




[2019-05-07 15:33:50,263][INFO] pytorch_pretrained_bert.tokenization:146 - loading vocabulary file https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-uncased-vocab.txt from cache at /lfs/local/0/senwu/.pytorch_pretrained_bert/26bc1ad6c0ac742e9b52263248f6d0f00068293b33709fae12320c0e35ccfbbb.542ce4285a40d23a559526243235df47c5f75c197f04f37d1a0c124c32c9a084
[2019-05-07 15:33:50,996][INFO] __main__:24 - Loaded dev for MRPC.


HBox(children=(IntProgress(value=0, max=1000), HTML(value='')))




[2019-05-07 15:33:51,348][INFO] pytorch_pretrained_bert.tokenization:146 - loading vocabulary file https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-uncased-vocab.txt from cache at /lfs/local/0/senwu/.pytorch_pretrained_bert/26bc1ad6c0ac742e9b52263248f6d0f00068293b33709fae12320c0e35ccfbbb.542ce4285a40d23a559526243235df47c5f75c197f04f37d1a0c124c32c9a084
[2019-05-07 15:33:53,043][INFO] __main__:24 - Loaded test for MRPC.


HBox(children=(IntProgress(value=0, max=1000), HTML(value='')))




[2019-05-07 15:33:55,501][INFO] pytorch_pretrained_bert.tokenization:146 - loading vocabulary file https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-uncased-vocab.txt from cache at /lfs/local/0/senwu/.pytorch_pretrained_bert/26bc1ad6c0ac742e9b52263248f6d0f00068293b33709fae12320c0e35ccfbbb.542ce4285a40d23a559526243235df47c5f75c197f04f37d1a0c124c32c9a084
[2019-05-07 15:33:57,047][INFO] __main__:24 - Loaded train for QNLI.


HBox(children=(IntProgress(value=0, max=1000), HTML(value='')))




[2019-05-07 15:33:57,460][INFO] pytorch_pretrained_bert.tokenization:146 - loading vocabulary file https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-uncased-vocab.txt from cache at /lfs/local/0/senwu/.pytorch_pretrained_bert/26bc1ad6c0ac742e9b52263248f6d0f00068293b33709fae12320c0e35ccfbbb.542ce4285a40d23a559526243235df47c5f75c197f04f37d1a0c124c32c9a084
[2019-05-07 15:33:58,961][INFO] __main__:24 - Loaded dev for QNLI.


HBox(children=(IntProgress(value=0, max=1000), HTML(value='')))




[2019-05-07 15:33:59,398][INFO] pytorch_pretrained_bert.tokenization:146 - loading vocabulary file https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-uncased-vocab.txt from cache at /lfs/local/0/senwu/.pytorch_pretrained_bert/26bc1ad6c0ac742e9b52263248f6d0f00068293b33709fae12320c0e35ccfbbb.542ce4285a40d23a559526243235df47c5f75c197f04f37d1a0c124c32c9a084
[2019-05-07 15:34:01,651][INFO] __main__:24 - Loaded test for QNLI.


HBox(children=(IntProgress(value=0, max=1000), HTML(value='')))




[2019-05-07 15:34:07,373][INFO] pytorch_pretrained_bert.tokenization:146 - loading vocabulary file https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-uncased-vocab.txt from cache at /lfs/local/0/senwu/.pytorch_pretrained_bert/26bc1ad6c0ac742e9b52263248f6d0f00068293b33709fae12320c0e35ccfbbb.542ce4285a40d23a559526243235df47c5f75c197f04f37d1a0c124c32c9a084
[2019-05-07 15:34:08,315][INFO] __main__:24 - Loaded train for QQP.


HBox(children=(IntProgress(value=0, max=1000), HTML(value='')))




[2019-05-07 15:34:09,214][INFO] pytorch_pretrained_bert.tokenization:146 - loading vocabulary file https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-uncased-vocab.txt from cache at /lfs/local/0/senwu/.pytorch_pretrained_bert/26bc1ad6c0ac742e9b52263248f6d0f00068293b33709fae12320c0e35ccfbbb.542ce4285a40d23a559526243235df47c5f75c197f04f37d1a0c124c32c9a084
[2019-05-07 15:34:10,320][INFO] __main__:24 - Loaded dev for QQP.


HBox(children=(IntProgress(value=0, max=1000), HTML(value='')))




[2019-05-07 15:34:16,024][INFO] pytorch_pretrained_bert.tokenization:146 - loading vocabulary file https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-uncased-vocab.txt from cache at /lfs/local/0/senwu/.pytorch_pretrained_bert/26bc1ad6c0ac742e9b52263248f6d0f00068293b33709fae12320c0e35ccfbbb.542ce4285a40d23a559526243235df47c5f75c197f04f37d1a0c124c32c9a084
[2019-05-07 15:34:16,961][INFO] __main__:24 - Loaded test for QQP.


HBox(children=(IntProgress(value=0, max=1000), HTML(value='')))




[2019-05-07 15:34:17,318][INFO] pytorch_pretrained_bert.tokenization:146 - loading vocabulary file https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-uncased-vocab.txt from cache at /lfs/local/0/senwu/.pytorch_pretrained_bert/26bc1ad6c0ac742e9b52263248f6d0f00068293b33709fae12320c0e35ccfbbb.542ce4285a40d23a559526243235df47c5f75c197f04f37d1a0c124c32c9a084
[2019-05-07 15:34:19,541][INFO] __main__:24 - Loaded train for RTE.


HBox(children=(IntProgress(value=0, max=277), HTML(value='')))




[2019-05-07 15:34:19,906][INFO] pytorch_pretrained_bert.tokenization:146 - loading vocabulary file https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-uncased-vocab.txt from cache at /lfs/local/0/senwu/.pytorch_pretrained_bert/26bc1ad6c0ac742e9b52263248f6d0f00068293b33709fae12320c0e35ccfbbb.542ce4285a40d23a559526243235df47c5f75c197f04f37d1a0c124c32c9a084
[2019-05-07 15:34:20,527][INFO] __main__:24 - Loaded dev for RTE.


HBox(children=(IntProgress(value=0, max=1000), HTML(value='')))




[2019-05-07 15:34:20,902][INFO] pytorch_pretrained_bert.tokenization:146 - loading vocabulary file https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-uncased-vocab.txt from cache at /lfs/local/0/senwu/.pytorch_pretrained_bert/26bc1ad6c0ac742e9b52263248f6d0f00068293b33709fae12320c0e35ccfbbb.542ce4285a40d23a559526243235df47c5f75c197f04f37d1a0c124c32c9a084
[2019-05-07 15:34:22,459][INFO] __main__:24 - Loaded test for RTE.


HBox(children=(IntProgress(value=0, max=1000), HTML(value='')))




[2019-05-07 15:34:41,526][INFO] pytorch_pretrained_bert.tokenization:146 - loading vocabulary file https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-uncased-vocab.txt from cache at /lfs/local/0/senwu/.pytorch_pretrained_bert/26bc1ad6c0ac742e9b52263248f6d0f00068293b33709fae12320c0e35ccfbbb.542ce4285a40d23a559526243235df47c5f75c197f04f37d1a0c124c32c9a084
[2019-05-07 15:34:42,361][INFO] __main__:24 - Loaded train for SNLI.


HBox(children=(IntProgress(value=0, max=1000), HTML(value='')))




[2019-05-07 15:34:43,020][INFO] pytorch_pretrained_bert.tokenization:146 - loading vocabulary file https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-uncased-vocab.txt from cache at /lfs/local/0/senwu/.pytorch_pretrained_bert/26bc1ad6c0ac742e9b52263248f6d0f00068293b33709fae12320c0e35ccfbbb.542ce4285a40d23a559526243235df47c5f75c197f04f37d1a0c124c32c9a084
[2019-05-07 15:34:43,884][INFO] __main__:24 - Loaded dev for SNLI.


HBox(children=(IntProgress(value=0, max=1000), HTML(value='')))






[2019-05-07 15:34:44,547][INFO] pytorch_pretrained_bert.tokenization:146 - loading vocabulary file https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-uncased-vocab.txt from cache at /lfs/local/0/senwu/.pytorch_pretrained_bert/26bc1ad6c0ac742e9b52263248f6d0f00068293b33709fae12320c0e35ccfbbb.542ce4285a40d23a559526243235df47c5f75c197f04f37d1a0c124c32c9a084
[2019-05-07 15:34:45,392][INFO] __main__:24 - Loaded test for SNLI.


HBox(children=(IntProgress(value=0, max=1000), HTML(value='')))




[2019-05-07 15:34:46,243][INFO] pytorch_pretrained_bert.tokenization:146 - loading vocabulary file https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-uncased-vocab.txt from cache at /lfs/local/0/senwu/.pytorch_pretrained_bert/26bc1ad6c0ac742e9b52263248f6d0f00068293b33709fae12320c0e35ccfbbb.542ce4285a40d23a559526243235df47c5f75c197f04f37d1a0c124c32c9a084
[2019-05-07 15:34:46,703][INFO] __main__:24 - Loaded train for SST-2.


HBox(children=(IntProgress(value=0, max=872), HTML(value='')))




[2019-05-07 15:34:47,049][INFO] pytorch_pretrained_bert.tokenization:146 - loading vocabulary file https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-uncased-vocab.txt from cache at /lfs/local/0/senwu/.pytorch_pretrained_bert/26bc1ad6c0ac742e9b52263248f6d0f00068293b33709fae12320c0e35ccfbbb.542ce4285a40d23a559526243235df47c5f75c197f04f37d1a0c124c32c9a084
[2019-05-07 15:34:47,751][INFO] __main__:24 - Loaded dev for SST-2.


HBox(children=(IntProgress(value=0, max=1000), HTML(value='')))




[2019-05-07 15:34:48,081][INFO] pytorch_pretrained_bert.tokenization:146 - loading vocabulary file https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-uncased-vocab.txt from cache at /lfs/local/0/senwu/.pytorch_pretrained_bert/26bc1ad6c0ac742e9b52263248f6d0f00068293b33709fae12320c0e35ccfbbb.542ce4285a40d23a559526243235df47c5f75c197f04f37d1a0c124c32c9a084
[2019-05-07 15:34:48,874][INFO] __main__:24 - Loaded test for SST-2.


HBox(children=(IntProgress(value=0, max=1000), HTML(value='')))




[2019-05-07 15:34:49,276][INFO] pytorch_pretrained_bert.tokenization:146 - loading vocabulary file https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-uncased-vocab.txt from cache at /lfs/local/0/senwu/.pytorch_pretrained_bert/26bc1ad6c0ac742e9b52263248f6d0f00068293b33709fae12320c0e35ccfbbb.542ce4285a40d23a559526243235df47c5f75c197f04f37d1a0c124c32c9a084
[2019-05-07 15:34:49,833][INFO] __main__:24 - Loaded train for STS-B.


HBox(children=(IntProgress(value=0, max=1000), HTML(value='')))




[2019-05-07 15:34:50,172][INFO] pytorch_pretrained_bert.tokenization:146 - loading vocabulary file https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-uncased-vocab.txt from cache at /lfs/local/0/senwu/.pytorch_pretrained_bert/26bc1ad6c0ac742e9b52263248f6d0f00068293b33709fae12320c0e35ccfbbb.542ce4285a40d23a559526243235df47c5f75c197f04f37d1a0c124c32c9a084
[2019-05-07 15:34:51,275][INFO] __main__:24 - Loaded dev for STS-B.


HBox(children=(IntProgress(value=0, max=1000), HTML(value='')))




[2019-05-07 15:34:51,601][INFO] pytorch_pretrained_bert.tokenization:146 - loading vocabulary file https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-uncased-vocab.txt from cache at /lfs/local/0/senwu/.pytorch_pretrained_bert/26bc1ad6c0ac742e9b52263248f6d0f00068293b33709fae12320c0e35ccfbbb.542ce4285a40d23a559526243235df47c5f75c197f04f37d1a0c124c32c9a084
[2019-05-07 15:34:52,371][INFO] __main__:24 - Loaded test for STS-B.


HBox(children=(IntProgress(value=0, max=635), HTML(value='')))




[2019-05-07 15:34:52,686][INFO] pytorch_pretrained_bert.tokenization:146 - loading vocabulary file https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-uncased-vocab.txt from cache at /lfs/local/0/senwu/.pytorch_pretrained_bert/26bc1ad6c0ac742e9b52263248f6d0f00068293b33709fae12320c0e35ccfbbb.542ce4285a40d23a559526243235df47c5f75c197f04f37d1a0c124c32c9a084
[2019-05-07 15:34:53,403][INFO] __main__:24 - Loaded train for WNLI.


HBox(children=(IntProgress(value=0, max=71), HTML(value='')))




[2019-05-07 15:34:53,701][INFO] pytorch_pretrained_bert.tokenization:146 - loading vocabulary file https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-uncased-vocab.txt from cache at /lfs/local/0/senwu/.pytorch_pretrained_bert/26bc1ad6c0ac742e9b52263248f6d0f00068293b33709fae12320c0e35ccfbbb.542ce4285a40d23a559526243235df47c5f75c197f04f37d1a0c124c32c9a084
[2019-05-07 15:34:53,822][INFO] __main__:24 - Loaded dev for WNLI.


HBox(children=(IntProgress(value=0, max=146), HTML(value='')))




[2019-05-07 15:34:54,121][INFO] pytorch_pretrained_bert.tokenization:146 - loading vocabulary file https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-uncased-vocab.txt from cache at /lfs/local/0/senwu/.pytorch_pretrained_bert/26bc1ad6c0ac742e9b52263248f6d0f00068293b33709fae12320c0e35ccfbbb.542ce4285a40d23a559526243235df47c5f75c197f04f37d1a0c124c32c9a084
[2019-05-07 15:34:54,406][INFO] __main__:24 - Loaded test for WNLI.


# Build Emmental dataloader

In [7]:
dataloaders = []

for task_name in GLUE_TASK_NAMES:
    for split in ["train", "dev", "test"]:
        dataloaders.append(
            EmmentalDataLoader(
                task_to_label_dict={task_name: "labels"},
                dataset=datasets[task_name][split],
                split=split,
                batch_size=BATCH_SIZE,
                shuffle=True if split == "train" else False,
            )
        )
        logger.info(f"Built dataloader for {task_name} {split} set.")

[2019-05-07 15:36:29,950][INFO] __main__:14 - Built dataloader for CoLA train set.
[2019-05-07 15:36:29,952][INFO] __main__:14 - Built dataloader for CoLA dev set.
[2019-05-07 15:36:29,953][INFO] __main__:14 - Built dataloader for CoLA test set.
[2019-05-07 15:36:29,954][INFO] __main__:14 - Built dataloader for MNLI train set.
[2019-05-07 15:36:29,956][INFO] __main__:14 - Built dataloader for MNLI dev set.
[2019-05-07 15:36:29,958][INFO] __main__:14 - Built dataloader for MNLI test set.
[2019-05-07 15:36:29,958][INFO] __main__:14 - Built dataloader for MRPC train set.
[2019-05-07 15:36:29,960][INFO] __main__:14 - Built dataloader for MRPC dev set.
[2019-05-07 15:36:29,961][INFO] __main__:14 - Built dataloader for MRPC test set.
[2019-05-07 15:36:29,962][INFO] __main__:14 - Built dataloader for QNLI train set.
[2019-05-07 15:36:29,963][INFO] __main__:14 - Built dataloader for QNLI dev set.
[2019-05-07 15:36:29,964][INFO] __main__:14 - Built dataloader for QNLI test set.
[2019-05-07 15:3

# Build Emmental task

In [11]:
tasks = get_gule_task(GLUE_TASK_NAMES, BERT_MODEL_NAME)

[2019-05-07 15:59:02,833][INFO] pytorch_pretrained_bert.modeling:564 - loading archive file https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-uncased.tar.gz from cache at ./cache/9c41111e2de84547a463fd39217199738d1e3deb72d4fec4399e6e241983c6f0.ae3cef932725ca7a30cdcb93fc6e09150a55e2a130ec7af63975a16c153ae2ba
[2019-05-07 15:59:02,837][INFO] pytorch_pretrained_bert.modeling:572 - extracting archive file ./cache/9c41111e2de84547a463fd39217199738d1e3deb72d4fec4399e6e241983c6f0.ae3cef932725ca7a30cdcb93fc6e09150a55e2a130ec7af63975a16c153ae2ba to temp dir /tmp/tmpuxt7ykoe
[2019-05-07 15:59:08,700][INFO] pytorch_pretrained_bert.modeling:579 - Model config {
  "attention_probs_dropout_prob": 0.1,
  "hidden_act": "gelu",
  "hidden_dropout_prob": 0.1,
  "hidden_size": 768,
  "initializer_range": 0.02,
  "intermediate_size": 3072,
  "max_position_embeddings": 512,
  "num_attention_heads": 12,
  "num_hidden_layers": 12,
  "type_vocab_size": 2,
  "vocab_size": 30522
}

[2019-05-07 15:59:1

In [12]:
Meta.update_config(
    config={
        "meta_config": {"device": 1},
        "learner_config": {
            "n_epochs": 3,
            "valid_split": "dev",
            "optimizer_config": {"optimizer": "adam", "lr": 5e-5},
            "lr_scheduler_config": {
                "warmup_steps": 70,
                "warmup_unit": "batch",
                "lr_scheduler": "linear",
            },
        },
        "logging_config": {
            "evaluation_freq": 50,
            "checkpointing": None,
            #             "checkpointer_config": {
            #                 "checkpoint_metric": f"{TASK_NAME}/GLUE/train/accuracy",
            #                 "checkpoint_freq": 10,
            #             },
        },
    }
)


[2019-05-07 16:00:34,801][INFO] emmental.meta:143 - Updating Emmental config from user provided config.


In [14]:
mtl_model = EmmentalModel(name = 'GLUE_multi_task', tasks=tasks.values())

[2019-05-07 16:00:52,190][INFO] emmental.model:58 - Moving model to GPU (cuda:1).
[2019-05-07 16:00:57,359][INFO] emmental.model:58 - Moving model to GPU (cuda:1).
[2019-05-07 16:00:57,368][INFO] emmental.model:58 - Moving model to GPU (cuda:1).
[2019-05-07 16:00:57,373][INFO] emmental.model:58 - Moving model to GPU (cuda:1).
[2019-05-07 16:00:57,378][INFO] emmental.model:58 - Moving model to GPU (cuda:1).
[2019-05-07 16:00:57,382][INFO] emmental.model:58 - Moving model to GPU (cuda:1).
[2019-05-07 16:00:57,386][INFO] emmental.model:58 - Moving model to GPU (cuda:1).
[2019-05-07 16:00:57,390][INFO] emmental.model:58 - Moving model to GPU (cuda:1).
[2019-05-07 16:00:57,393][INFO] emmental.model:58 - Moving model to GPU (cuda:1).
[2019-05-07 16:00:57,397][INFO] emmental.model:58 - Moving model to GPU (cuda:1).
[2019-05-07 16:00:57,402][INFO] emmental.model:44 - Created emmental model GLUE_multi_task that contains task {'RTE', 'MRPC', 'QNLI', 'STS-B', 'MNLI', 'SNLI', 'SST-2', 'WNLI', 'CoL

In [15]:
emmental_learner = EmmentalLearner()

In [16]:
emmental_learner.learn(mtl_model, dataloaders)

[2019-05-07 16:04:03,149][INFO] emmental.logging.logging_manager:33 - Evaluating every 50 batch.
[2019-05-07 16:04:03,150][INFO] emmental.logging.logging_manager:51 - No checkpointing.
[2019-05-07 16:04:03,196][INFO] emmental.learner:286 - Start learning...


HBox(children=(IntProgress(value=0, max=607), HTML(value='')))

  mcc = cov_ytyp / np.sqrt(cov_ytyt * cov_ypyp)





HBox(children=(IntProgress(value=0, max=607), HTML(value='')))




HBox(children=(IntProgress(value=0, max=607), HTML(value='')))




In [17]:
mtl_model.score(dataloaders)

  r = r_num / r_den
  c /= stddev[:, None]
  c /= stddev[None, :]
  return (self.a < x) & (x < self.b)
  return (self.a < x) & (x < self.b)
  cond2 = cond0 & (x <= self.a)


{'CoLA/GLUE/train/matthews_corrcoef': 0.38573432553207543,
 'CoLA/GLUE/dev/matthews_corrcoef': 0.1857313532962773,
 'CoLA/GLUE/test/matthews_corrcoef': 0.0,
 'MNLI/GLUE/train/accuracy': 0.518,
 'MNLI/GLUE/dev/accuracy': 0.468,
 'MNLI/GLUE/test/accuracy': 0.0,
 'MRPC/GLUE/train/accuracy': 0.776,
 'MRPC/GLUE/train/f1': 0.8502673796791445,
 'MRPC/GLUE/dev/accuracy': 0.7426470588235294,
 'MRPC/GLUE/dev/f1': 0.8356807511737089,
 'MRPC/GLUE/test/accuracy': 0.0,
 'MRPC/GLUE/test/f1': 0.0,
 'QNLI/GLUE/train/accuracy': 0.748,
 'QNLI/GLUE/dev/accuracy': 0.622,
 'QNLI/GLUE/test/accuracy': 0.0,
 'QQP/GLUE/train/accuracy': 0.718,
 'QQP/GLUE/train/f1': 0.6466165413533834,
 'QQP/GLUE/dev/accuracy': 0.709,
 'QQP/GLUE/dev/f1': 0.6196078431372549,
 'QQP/GLUE/test/accuracy': 0.0,
 'QQP/GLUE/test/f1': 0.0,
 'RTE/GLUE/train/accuracy': 0.678,
 'RTE/GLUE/dev/accuracy': 0.6462093862815884,
 'RTE/GLUE/test/accuracy': 0.0,
 'SNLI/GLUE/train/accuracy': 0.568,
 'SNLI/GLUE/dev/accuracy': 0.562,
 'SNLI/GLUE/test/ac