In [1]:
%load_ext autoreload
%autoreload 2

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [2]:
import logging
from functools import partial

import numpy as np
import torch.nn.functional as F
from torch import nn
import torch
import emmental
from emmental import Meta
from emmental.learner import EmmentalLearner
from emmental.model import EmmentalModel
from emmental.scorer import Scorer
from emmental.task import EmmentalTask
from modules.bert_module import BertModule
from parse_WiC_slice import get_WiC_dataloaders
from task_config import SuperGLUE_LABEL_MAPPING, SuperGLUE_TASK_METRIC_MAPPING
from sklearn.metrics import f1_score

In [3]:
logger = logging.getLogger(__name__)

# Initalize Emmental

In [4]:
emmental.init(
    "logs",
    config={
        "model_config": {"device": 0, "dataparallel": False},
        "learner_config": {
            "n_epochs": 10,
            "valid_split": "val",
            "optimizer_config": {"optimizer": "adam", "lr": 1e-5},
            "min_lr": 0,
            "lr_scheduler_config": {"warmup_percentage": 0.1, "lr_scheduler": None},
        },
        "logging_config": {
            "counter_unit": "batch",
            "evaluation_freq": 100,
            "checkpointing": True,
            "checkpointer_config": {"checkpoint_metric": {"WiC/SuperGLUE/val/accuracy":"max"}},
        },
    },
)

[2019-05-30 13:46:58,445][INFO] emmental.meta:95 - Setting logging directory to: logs/2019_05_30/13_46_58
[2019-05-30 13:46:58,455][INFO] emmental.meta:56 - Loading Emmental default config from /dfs/scratch0/bradenjh/emmental/src/emmental/emmental-default-config.yaml.
[2019-05-30 13:46:58,456][INFO] emmental.meta:143 - Updating Emmental config from user provided config.


In [5]:
Meta.config

{'meta_config': {'seed': 0, 'verbose': True, 'log_path': None},
 'model_config': {'model_path': None, 'device': 0, 'dataparallel': False},
 'learner_config': {'fp16': False,
  'n_epochs': 10,
  'train_split': 'train',
  'valid_split': 'val',
  'test_split': 'test',
  'ignore_index': -100,
  'optimizer_config': {'optimizer': 'adam',
   'lr': 1e-05,
   'l2': 0.0,
   'grad_clip': 1.0,
   'sgd_config': {'momentum': 0.9},
   'adam_config': {'betas': (0.9, 0.999)}},
  'lr_scheduler_config': {'lr_scheduler': None,
   'warmup_steps': None,
   'warmup_unit': 'batch',
   'warmup_percentage': 0.1,
   'min_lr': 0.0,
   'linear_config': {'min_lr': 0.0},
   'exponential_config': {'gamma': 0.9},
   'plateau_config': {'factor': 0.5, 'patience': 10, 'threshold': 0.0001}},
  'task_scheduler': 'round_robin',
  'global_evaluation_metric_dict': None,
  'min_lr': 0},
 'logging_config': {'counter_unit': 'batch',
  'evaluation_freq': 100,
  'writer_config': {'writer': 'tensorboard', 'verbose': True},
  'check

In [16]:
import os

TASK_NAME = "WiC"
DATA_DIR = os.environ["SUPERGLUEDATA"]
BERT_MODEL_NAME = "bert-large-cased"
BATCH_SIZE = 4

BERT_OUTPUT_DIM = 768 if "base" in BERT_MODEL_NAME else 1024
TASK_CARDINALITY = (
    len(SuperGLUE_LABEL_MAPPING[TASK_NAME].keys())
    if SuperGLUE_LABEL_MAPPING[TASK_NAME] is not None
    else 1
)

In [17]:
BERT_OUTPUT_DIM, TASK_CARDINALITY

(1024, 2)

# Extract train/dev dataset from file

In [18]:
def slice_verb(dataset):
    slice_name = "slice_verb"
    ind, pred = [], []
    cnt = 0
    for idx, pos in enumerate(dataset.X_dict["poses"]):
        if pos == "V":
            ind.append(1)
            pred.append(dataset.Y_dict["labels"][idx])
            cnt += 1
        else:
            ind.append(2)
            pred.append(Meta.config["learner_config"]["ignore_index"])
    ind = torch.from_numpy(np.array(ind)).view(-1)
    pred = torch.from_numpy(np.array(pred)).view(-1)
    logger.info(f"Total {cnt} / {len(dataset)} in the slice {slice_name}")
    print(ind.size(), pred.size())
    return ind, pred

In [19]:
def slice_noun(dataset):
    slice_name = "slice_noun"
    ind, pred = [], []
    cnt = 0
    for idx, pos in enumerate(dataset.X_dict["poses"]):
        if pos == "N":
            ind.append(1)
            pred.append(dataset.Y_dict["labels"][idx])
            cnt += 1
        else:
            ind.append(2)
            pred.append(Meta.config["learner_config"]["ignore_index"])
    ind = torch.from_numpy(np.array(ind)).view(-1)
    pred = torch.from_numpy(np.array(pred)).view(-1)
    logger.info(f"Total {cnt} / {len(dataset)} in the slice {slice_name}")
    print(ind.size(), pred.size())
    return ind, pred

In [20]:
def slice_first_pos(dataset):
    slice_name = "slice_first_pos"
    ind, pred = [], []
    cnt = 0
    for idx, pos in enumerate(dataset.X_dict["poses"]):
        if dataset.X_dict["sent1_ori_idxs"][idx] == 1 and dataset.X_dict["sent2_ori_idxs"][idx] == 1:
            ind.append(1)
            pred.append(dataset.Y_dict["labels"][idx])
            cnt += 1
        else:
            ind.append(2)
            pred.append(Meta.config["learner_config"]["ignore_index"])
    ind = torch.from_numpy(np.array(ind)).view(-1)
    pred = torch.from_numpy(np.array(pred)).view(-1)
    logger.info(f"Total {cnt} / {len(dataset)} in the slice {slice_name}")
    print(ind.size(), pred.size())
    return ind, pred

In [21]:
def slice_base(dataset):
    return torch.from_numpy(np.array([1] * len(dataset))), dataset.Y_dict["labels"]

In [22]:
slice_func_dict = {
    "slice_base": slice_base,
    "slice_verb": slice_verb,
#     "slice_noun": slice_noun,
#     "slice_first_pos": slice_first_pos,
}
slice_func_dict.keys()

dict_keys(['slice_base', 'slice_verb'])

In [23]:
dataloaders = get_WiC_dataloaders(
    data_dir=DATA_DIR,
    task_name=TASK_NAME,
    splits=["train", "val", "test"],
    max_sequence_length=128,
    max_data_samples=None,
    tokenizer_name=BERT_MODEL_NAME,
    batch_size=BATCH_SIZE,
    slice_func_dict=slice_func_dict,
)

[2019-05-30 13:47:40,123][INFO] tokenizer:8 - Loading Tokenizer bert-large-cased
[2019-05-30 13:47:40,385][INFO] pytorch_pretrained_bert.tokenization:190 - loading vocabulary file https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-cased-vocab.txt from cache at /afs/cs.stanford.edu/u/bradenjh/.pytorch_pretrained_bert/cee054f6aafe5e2cf816d2228704e326446785f940f5451a5b26033516a4ac3d.e13dbb970cb325137104fb2e5f36fe865f27746c6b526f6352861b1980eb80b1
[2019-05-30 13:47:42,403][INFO] __main__:15 - Total 2634 / 5428 in the slice slice_verb
[2019-05-30 13:47:42,404][INFO] parse_WiC_slice:181 - Loaded train for WiC.


max len 68
torch.Size([5428]) torch.Size([5428])


[2019-05-30 13:47:42,646][INFO] __main__:15 - Total 243 / 638 in the slice slice_verb
[2019-05-30 13:47:42,646][INFO] parse_WiC_slice:181 - Loaded val for WiC.


max len 57
torch.Size([638]) torch.Size([638])


[2019-05-30 13:47:43,183][INFO] __main__:15 - Total 569 / 1400 in the slice slice_verb
[2019-05-30 13:47:43,184][INFO] parse_WiC_slice:181 - Loaded test for WiC.


max len 60
torch.Size([1400]) torch.Size([1400])


In [24]:
dataloaders["train"].dataset.Y_dict

{'labels': tensor([2, 2, 2,  ..., 1, 1, 1]),
 'WiC_slice_ind_slice_base': tensor([1, 1, 1,  ..., 1, 1, 1]),
 'WiC_slice_pred_slice_base': tensor([2, 2, 2,  ..., 1, 1, 1]),
 'WiC_slice_ind_slice_verb': tensor([1, 1, 1,  ..., 1, 1, 2]),
 'WiC_slice_pred_slice_verb': tensor([   2,    2,    2,  ...,    1,    1, -100])}

In [25]:
dataloaders["train"].task_to_label_dict

{'WiC_slice_ind_slice_base': 'WiC_slice_ind_slice_base',
 'WiC_slice_pred_slice_base': 'WiC_slice_pred_slice_base',
 'WiC_slice_ind_slice_verb': 'WiC_slice_ind_slice_verb',
 'WiC_slice_pred_slice_verb': 'WiC_slice_pred_slice_verb',
 'WiC': 'labels'}

In [26]:
for key, value in dataloaders["train"].dataset.Y_dict.items():
    print(key, value.size())

labels torch.Size([5428])
WiC_slice_ind_slice_base torch.Size([5428])
WiC_slice_pred_slice_base torch.Size([5428])
WiC_slice_ind_slice_verb torch.Size([5428])
WiC_slice_pred_slice_verb torch.Size([5428])


# Build Emmental task

In [27]:
def ce_loss(module_name, immediate_ouput_dict, Y, active):
    return F.cross_entropy(
        immediate_ouput_dict[module_name][0][active], (Y.view(-1) - 1)[active]
    )

In [28]:
def output(module_name, immediate_ouput_dict):
    return F.softmax(immediate_ouput_dict[module_name][0], dim=1)

In [29]:
class FeatureConcateModule(nn.Module):
    def __init__(self):
        super().__init__()

    def forward(self, feature, idx1, idx2):
#         import pdb; pdb.set_trace()
        last_layer = feature[-1]
        emb = last_layer[:,0,:]
        idx1 = idx1.unsqueeze(-1).unsqueeze(-1).expand([-1, -1, last_layer.size(-1)])
        idx2 = idx2.unsqueeze(-1).unsqueeze(-1).expand([-1, -1, last_layer.size(-1)])
        word1_emb = last_layer.gather(dim=1, index=idx1).squeeze(dim=1)
        word2_emb = last_layer.gather(dim=1, index=idx2).squeeze(dim=1)
        input = torch.cat([emb, word1_emb, word2_emb], dim=-1)
        return input

In [30]:
class SliceModule(nn.Module):
    def __init__(self, feature_dim, class_cardinality):
        super().__init__()
        self.linear = nn.Linear(feature_dim, class_cardinality)

    def forward(self, feature):
        return self.linear.forward(feature)

In [31]:
H = BERT_OUTPUT_DIM

In [32]:
shared_classification_module = nn.Linear(H, TASK_CARDINALITY)

In [33]:
bert_module = BertModule(BERT_MODEL_NAME)

[2019-05-30 13:47:51,470][INFO] pytorch_pretrained_bert.modeling:580 - loading archive file https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-cased.tar.gz from cache at ./cache/7fb0534b83c42daee7d3ddb0ebaa81387925b71665d6ea195c5447f1077454cd.eea60d9ebb03c75bb36302aa9d241d3b7a04bba39c360cf035e8bf8140816233
[2019-05-30 13:47:51,471][INFO] pytorch_pretrained_bert.modeling:588 - extracting archive file ./cache/7fb0534b83c42daee7d3ddb0ebaa81387925b71665d6ea195c5447f1077454cd.eea60d9ebb03c75bb36302aa9d241d3b7a04bba39c360cf035e8bf8140816233 to temp dir /tmp/tmpchqebob3
[2019-05-30 13:48:03,606][INFO] pytorch_pretrained_bert.modeling:598 - Model config {
  "attention_probs_dropout_prob": 0.1,
  "directionality": "bidi",
  "hidden_act": "gelu",
  "hidden_dropout_prob": 0.1,
  "hidden_size": 1024,
  "initializer_range": 0.02,
  "intermediate_size": 4096,
  "max_position_embeddings": 512,
  "num_attention_heads": 16,
  "num_hidden_layers": 24,
  "pooler_fc_size": 768,
  "pooler_num_a

In [34]:
tasks = []

In [35]:
# Add ind task

type = "ind"

for slice_name in slice_func_dict.keys():
    task = EmmentalTask(
        name=f"{TASK_NAME}_slice_{type}_{slice_name}",
        module_pool=nn.ModuleDict(
            {
#                 "bert_module": bert_module,
                "feature": FeatureConcateModule(),
                f"{TASK_NAME}_slice_{type}_{slice_name}_head": SliceModule(
                    3 * BERT_OUTPUT_DIM, 2
                ),
            }
        ),
        task_flow=[
            {
                "name": "input",
                "module": "bert_module",
                "inputs": [("_input_", "token_ids"), ("_input_", "token_segments")],
            },
            {
                "name": f"feature",
                "module": f"feature",
                "inputs": [
                    ("input", 0),
                    ("_input_", "sent1_idxs"),
                    ("_input_", "sent2_idxs"),
                ],
            },
            {
                "name": f"{TASK_NAME}_slice_{type}_{slice_name}_head",
                "module": f"{TASK_NAME}_slice_{type}_{slice_name}_head",
                "inputs": [("feature", 0)],
            },
        ],
        loss_func=partial(ce_loss, f"{TASK_NAME}_slice_{type}_{slice_name}_head"),
        output_func=partial(output, f"{TASK_NAME}_slice_{type}_{slice_name}_head"),
        scorer=Scorer(metrics=["accuracy"]),
    )
    tasks.append(task)

[2019-05-30 13:48:13,281][INFO] emmental.task:34 - Created task: WiC_slice_ind_slice_base
[2019-05-30 13:48:13,282][INFO] emmental.task:34 - Created task: WiC_slice_ind_slice_verb


In [36]:
# Add ind task

type = "pred"

for slice_name in slice_func_dict.keys():
    task = EmmentalTask(
        name=f"{TASK_NAME}_slice_{type}_{slice_name}",
        module_pool=nn.ModuleDict(
            {
#                 "bert_module": bert_module,
                "feature": FeatureConcateModule(),
                f"{TASK_NAME}_slice_feat_{slice_name}": nn.Linear(3 * BERT_OUTPUT_DIM, H),
                f"{TASK_NAME}_slice_{type}_{slice_name}_head": shared_classification_module,
            }
        ),
        task_flow=[
            {
                "name": "input",
                "module": "bert_module",
                "inputs": [("_input_", "token_ids"), ("_input_", "token_segments")],
            },
            {
                "name": f"feature",
                "module": f"feature",
                "inputs": [
                    ("input", 0),
                    ("_input_", "sent1_idxs"),
                    ("_input_", "sent2_idxs"),
                ],
            },
            {
                "name": f"{TASK_NAME}_slice_feat_{slice_name}",
                "module": f"{TASK_NAME}_slice_feat_{slice_name}",
                "inputs": [("feature", 0)],
            },
            {
                "name": f"{TASK_NAME}_slice_{type}_{slice_name}_head",
                "module": f"{TASK_NAME}_slice_{type}_{slice_name}_head",
                "inputs": [(f"{TASK_NAME}_slice_feat_{slice_name}", 0)],
            },
        ],
        loss_func=partial(ce_loss, f"{TASK_NAME}_slice_{type}_{slice_name}_head"),
        output_func=partial(output, f"{TASK_NAME}_slice_{type}_{slice_name}_head"),
        scorer=Scorer(metrics=SuperGLUE_TASK_METRIC_MAPPING[TASK_NAME]),
    )
    tasks.append(task)

[2019-05-30 13:48:13,350][INFO] emmental.task:34 - Created task: WiC_slice_pred_slice_base
[2019-05-30 13:48:13,380][INFO] emmental.task:34 - Created task: WiC_slice_pred_slice_verb


In [37]:
class MasterModule(nn.Module):
    def __init__(self, feature_dim, class_cardinality):
        super().__init__()
        self.linear = nn.Linear(feature_dim, class_cardinality)

    def forward(self, immediate_ouput_dict):
        slice_ind_names = sorted(
            [
                flow_name
                for flow_name in immediate_ouput_dict.keys()
                if "_slice_ind_" in flow_name
            ]
        )
        slice_pred_names = sorted(
            [
                flow_name
                for flow_name in immediate_ouput_dict.keys()
                if "_slice_pred_" in flow_name
            ]
        )

        Q = torch.cat(
            [
                F.softmax(immediate_ouput_dict[slice_ind_name][0])[:, 0].unsqueeze(1)
                for slice_ind_name in slice_ind_names
            ],
            dim=-1,
        )
        P = torch.cat(
            [
                F.softmax(immediate_ouput_dict[slice_pred_name][0])[:, 0].unsqueeze(1)
                for slice_pred_name in slice_pred_names
            ],
            dim=-1,
        )

        slice_feat_names = sorted(
            [
                flow_name
                for flow_name in immediate_ouput_dict.keys()
                if "_slice_feat_" in flow_name
            ]
        )

        slice_reps = torch.cat(
            [
                immediate_ouput_dict[slice_feat_name][0].unsqueeze(1)
                for slice_feat_name in slice_feat_names
            ],
            dim=1,
        )

        A = F.softmax(Q * P, dim=1).unsqueeze(-1).expand([-1, -1, slice_reps.size(-1)])

        reweighted_rep = torch.sum(A * slice_reps, 1)

        return self.linear.forward(reweighted_rep)

In [38]:
master_task = EmmentalTask(
    name=f"{TASK_NAME}",
    module_pool=nn.ModuleDict(
        {
            "bert_module": bert_module,
            f"{TASK_NAME}_pred_head": MasterModule(H, TASK_CARDINALITY),
        }
    ),
    task_flow=[
        {
            "name": f"{TASK_NAME}_pred_head",
            "module": f"{TASK_NAME}_pred_head",
            "inputs": [],
        }
    ],
    loss_func=partial(ce_loss, f"{TASK_NAME}_pred_head"),
    output_func=partial(output, f"{TASK_NAME}_pred_head"),
    scorer=Scorer(metrics=SuperGLUE_TASK_METRIC_MAPPING[TASK_NAME]),
)
tasks.append(master_task)

[2019-05-30 13:48:13,447][INFO] emmental.task:34 - Created task: WiC


In [39]:
mtl_model = EmmentalModel(name="SuperGLUE_single_task", tasks=tasks)

[2019-05-30 13:48:13,486][INFO] emmental.model:58 - Moving model to GPU (cuda:0).
[2019-05-30 13:48:16,602][INFO] emmental.model:58 - Moving model to GPU (cuda:0).
[2019-05-30 13:48:16,604][INFO] emmental.model:58 - Moving model to GPU (cuda:0).
[2019-05-30 13:48:16,609][INFO] emmental.model:58 - Moving model to GPU (cuda:0).
[2019-05-30 13:48:16,614][INFO] emmental.model:58 - Moving model to GPU (cuda:0).
[2019-05-30 13:48:17,024][INFO] emmental.model:44 - Created emmental model SuperGLUE_single_task that contains task {'WiC_slice_ind_slice_base', 'WiC', 'WiC_slice_ind_slice_verb', 'WiC_slice_pred_slice_verb', 'WiC_slice_pred_slice_base'}.
[2019-05-30 13:48:17,025][INFO] emmental.model:58 - Moving model to GPU (cuda:0).


In [40]:
emmental_learner = EmmentalLearner()

In [41]:
# for X, Y in dataloaders["train"]:
#     import pdb; pdb.set_trace()
# # #     print(X, Y)
#     pass

In [42]:
emmental_learner.learn(mtl_model, dataloaders.values())

[2019-05-30 13:48:17,144][INFO] emmental.logging.logging_manager:33 - Evaluating every 100 batch.
[2019-05-30 13:48:17,145][INFO] emmental.logging.logging_manager:43 - Checkpointing every 100 batch.
[2019-05-30 13:48:17,146][INFO] emmental.logging.checkpointer:42 - Save checkpoints at logs/2019_05_30/13_46_58 every 100 batch
[2019-05-30 13:48:17,146][INFO] emmental.logging.checkpointer:73 - No checkpoints saved before 0 batch.
[2019-05-30 13:48:17,171][INFO] emmental.learner:152 - Warmup 1357 batchs.
[2019-05-30 13:48:17,174][INFO] emmental.learner:303 - Start learning...


HBox(children=(IntProgress(value=0, description='Epoch 0:', max=1357, style=ProgressStyle(description_width='i…

[2019-05-30 13:49:01,155][INFO] emmental.logging.checkpointer:93 - checkpoint_runway condition has been met. Start checkpoining.
  "type " + obj.__name__ + ". It won't be checked "
  "type " + obj.__name__ + ". It won't be checked "
  "type " + obj.__name__ + ". It won't be checked "
[2019-05-30 13:49:12,499][INFO] emmental.logging.checkpointer:102 - Save checkpoint of 100 batch at logs/2019_05_30/13_46_58/checkpoint_100.pth.
[2019-05-30 13:49:23,652][INFO] emmental.logging.checkpointer:118 - Save best model of metric WiC/SuperGLUE/val/accuracy at logs/2019_05_30/13_46_58/best_model_WiC_SuperGLUE_val_accuracy.pth
[2019-05-30 13:50:18,929][INFO] emmental.logging.checkpointer:102 - Save checkpoint of 200 batch at logs/2019_05_30/13_46_58/checkpoint_200.pth.
[2019-05-30 13:51:15,175][INFO] emmental.logging.checkpointer:102 - Save checkpoint of 300 batch at logs/2019_05_30/13_46_58/checkpoint_300.pth.
[2019-05-30 13:51:27,939][INFO] emmental.logging.checkpointer:118 - Save best model of me




HBox(children=(IntProgress(value=0, description='Epoch 1:', max=1357, style=ProgressStyle(description_width='i…

[2019-05-30 14:03:03,524][INFO] emmental.logging.checkpointer:102 - Save checkpoint of 1400 batch at logs/2019_05_30/13_46_58/checkpoint_1400.pth.
[2019-05-30 14:03:15,056][INFO] emmental.logging.checkpointer:118 - Save best model of metric WiC/SuperGLUE/val/accuracy at logs/2019_05_30/13_46_58/best_model_WiC_SuperGLUE_val_accuracy.pth
[2019-05-30 14:04:10,254][INFO] emmental.logging.checkpointer:102 - Save checkpoint of 1500 batch at logs/2019_05_30/13_46_58/checkpoint_1500.pth.
[2019-05-30 14:05:05,210][INFO] emmental.logging.checkpointer:102 - Save checkpoint of 1600 batch at logs/2019_05_30/13_46_58/checkpoint_1600.pth.
[2019-05-30 14:05:17,206][INFO] emmental.logging.checkpointer:118 - Save best model of metric WiC/SuperGLUE/val/accuracy at logs/2019_05_30/13_46_58/best_model_WiC_SuperGLUE_val_accuracy.pth
[2019-05-30 14:06:12,115][INFO] emmental.logging.checkpointer:102 - Save checkpoint of 1700 batch at logs/2019_05_30/13_46_58/checkpoint_1700.pth.
[2019-05-30 14:07:11,129][INFO




HBox(children=(IntProgress(value=0, description='Epoch 2:', max=1357, style=ProgressStyle(description_width='i…

[2019-05-30 14:17:13,565][INFO] emmental.logging.checkpointer:102 - Save checkpoint of 2800 batch at logs/2019_05_30/13_46_58/checkpoint_2800.pth.
[2019-05-30 14:18:13,091][INFO] emmental.logging.checkpointer:102 - Save checkpoint of 2900 batch at logs/2019_05_30/13_46_58/checkpoint_2900.pth.
[2019-05-30 14:18:26,743][INFO] emmental.logging.checkpointer:118 - Save best model of metric WiC/SuperGLUE/val/accuracy at logs/2019_05_30/13_46_58/best_model_WiC_SuperGLUE_val_accuracy.pth
[2019-05-30 14:19:25,571][INFO] emmental.logging.checkpointer:102 - Save checkpoint of 3000 batch at logs/2019_05_30/13_46_58/checkpoint_3000.pth.
[2019-05-30 14:20:26,100][INFO] emmental.logging.checkpointer:102 - Save checkpoint of 3100 batch at logs/2019_05_30/13_46_58/checkpoint_3100.pth.
[2019-05-30 14:21:25,365][INFO] emmental.logging.checkpointer:102 - Save checkpoint of 3200 batch at logs/2019_05_30/13_46_58/checkpoint_3200.pth.
[2019-05-30 14:22:20,713][INFO] emmental.logging.checkpointer:102 - Save c




HBox(children=(IntProgress(value=0, description='Epoch 3:', max=1357, style=ProgressStyle(description_width='i…

[2019-05-30 14:30:23,806][INFO] emmental.logging.checkpointer:102 - Save checkpoint of 4100 batch at logs/2019_05_30/13_46_58/checkpoint_4100.pth.
[2019-05-30 14:31:24,724][INFO] emmental.logging.checkpointer:102 - Save checkpoint of 4200 batch at logs/2019_05_30/13_46_58/checkpoint_4200.pth.
[2019-05-30 14:32:21,643][INFO] emmental.logging.checkpointer:102 - Save checkpoint of 4300 batch at logs/2019_05_30/13_46_58/checkpoint_4300.pth.
[2019-05-30 14:32:33,105][INFO] emmental.logging.checkpointer:118 - Save best model of metric WiC/SuperGLUE/val/accuracy at logs/2019_05_30/13_46_58/best_model_WiC_SuperGLUE_val_accuracy.pth
[2019-05-30 14:33:33,375][INFO] emmental.logging.checkpointer:102 - Save checkpoint of 4400 batch at logs/2019_05_30/13_46_58/checkpoint_4400.pth.
[2019-05-30 14:34:31,104][INFO] emmental.logging.checkpointer:102 - Save checkpoint of 4500 batch at logs/2019_05_30/13_46_58/checkpoint_4500.pth.
[2019-05-30 14:35:25,183][INFO] emmental.logging.checkpointer:102 - Save c




HBox(children=(IntProgress(value=0, description='Epoch 4:', max=1357, style=ProgressStyle(description_width='i…

[2019-05-30 14:44:01,585][INFO] emmental.logging.checkpointer:102 - Save checkpoint of 5500 batch at logs/2019_05_30/13_46_58/checkpoint_5500.pth.
[2019-05-30 14:44:57,699][INFO] emmental.logging.checkpointer:102 - Save checkpoint of 5600 batch at logs/2019_05_30/13_46_58/checkpoint_5600.pth.
[2019-05-30 14:45:56,293][INFO] emmental.logging.checkpointer:102 - Save checkpoint of 5700 batch at logs/2019_05_30/13_46_58/checkpoint_5700.pth.
[2019-05-30 14:46:56,476][INFO] emmental.logging.checkpointer:102 - Save checkpoint of 5800 batch at logs/2019_05_30/13_46_58/checkpoint_5800.pth.
[2019-05-30 14:47:55,990][INFO] emmental.logging.checkpointer:102 - Save checkpoint of 5900 batch at logs/2019_05_30/13_46_58/checkpoint_5900.pth.
[2019-05-30 14:48:55,515][INFO] emmental.logging.checkpointer:102 - Save checkpoint of 6000 batch at logs/2019_05_30/13_46_58/checkpoint_6000.pth.
[2019-05-30 14:49:55,455][INFO] emmental.logging.checkpointer:102 - Save checkpoint of 6100 batch at logs/2019_05_30/1




HBox(children=(IntProgress(value=0, description='Epoch 5:', max=1357, style=ProgressStyle(description_width='i…

[2019-05-30 14:56:42,840][INFO] emmental.logging.checkpointer:102 - Save checkpoint of 6800 batch at logs/2019_05_30/13_46_58/checkpoint_6800.pth.
[2019-05-30 14:57:44,519][INFO] emmental.logging.checkpointer:102 - Save checkpoint of 6900 batch at logs/2019_05_30/13_46_58/checkpoint_6900.pth.
[2019-05-30 14:58:44,040][INFO] emmental.logging.checkpointer:102 - Save checkpoint of 7000 batch at logs/2019_05_30/13_46_58/checkpoint_7000.pth.
[2019-05-30 14:59:42,339][INFO] emmental.logging.checkpointer:102 - Save checkpoint of 7100 batch at logs/2019_05_30/13_46_58/checkpoint_7100.pth.
[2019-05-30 15:00:42,127][INFO] emmental.logging.checkpointer:102 - Save checkpoint of 7200 batch at logs/2019_05_30/13_46_58/checkpoint_7200.pth.
IOPub message rate exceeded.
The notebook server will temporarily stop sending output
to the client in order to avoid crashing it.
To change this limit, set the config variable
`--NotebookApp.iopub_msg_rate_limit`.

Current values:
NotebookApp.iopub_msg_rate_limit=




HBox(children=(IntProgress(value=0, description='Epoch 6:', max=1357, style=ProgressStyle(description_width='i…

[2019-05-30 15:10:34,148][INFO] emmental.logging.checkpointer:102 - Save checkpoint of 8200 batch at logs/2019_05_30/13_46_58/checkpoint_8200.pth.
[2019-05-30 15:11:32,030][INFO] emmental.logging.checkpointer:102 - Save checkpoint of 8300 batch at logs/2019_05_30/13_46_58/checkpoint_8300.pth.
[2019-05-30 15:12:28,645][INFO] emmental.logging.checkpointer:102 - Save checkpoint of 8400 batch at logs/2019_05_30/13_46_58/checkpoint_8400.pth.
[2019-05-30 15:13:22,082][INFO] emmental.logging.checkpointer:102 - Save checkpoint of 8500 batch at logs/2019_05_30/13_46_58/checkpoint_8500.pth.
[2019-05-30 15:14:16,187][INFO] emmental.logging.checkpointer:102 - Save checkpoint of 8600 batch at logs/2019_05_30/13_46_58/checkpoint_8600.pth.
[2019-05-30 15:15:11,691][INFO] emmental.logging.checkpointer:102 - Save checkpoint of 8700 batch at logs/2019_05_30/13_46_58/checkpoint_8700.pth.
[2019-05-30 15:16:08,689][INFO] emmental.logging.checkpointer:102 - Save checkpoint of 8800 batch at logs/2019_05_30/1




HBox(children=(IntProgress(value=0, description='Epoch 7:', max=1357, style=ProgressStyle(description_width='i…

[2019-05-30 15:22:43,173][INFO] emmental.logging.checkpointer:102 - Save checkpoint of 9500 batch at logs/2019_05_30/13_46_58/checkpoint_9500.pth.
[2019-05-30 15:23:43,294][INFO] emmental.logging.checkpointer:102 - Save checkpoint of 9600 batch at logs/2019_05_30/13_46_58/checkpoint_9600.pth.
[2019-05-30 15:24:44,282][INFO] emmental.logging.checkpointer:102 - Save checkpoint of 9700 batch at logs/2019_05_30/13_46_58/checkpoint_9700.pth.
[2019-05-30 15:25:42,962][INFO] emmental.logging.checkpointer:102 - Save checkpoint of 9800 batch at logs/2019_05_30/13_46_58/checkpoint_9800.pth.
[2019-05-30 15:26:42,143][INFO] emmental.logging.checkpointer:102 - Save checkpoint of 9900 batch at logs/2019_05_30/13_46_58/checkpoint_9900.pth.
[2019-05-30 15:27:41,674][INFO] emmental.logging.checkpointer:102 - Save checkpoint of 10000 batch at logs/2019_05_30/13_46_58/checkpoint_10000.pth.
[2019-05-30 15:28:40,879][INFO] emmental.logging.checkpointer:102 - Save checkpoint of 10100 batch at logs/2019_05_3




HBox(children=(IntProgress(value=0, description='Epoch 8:', max=1357, style=ProgressStyle(description_width='i…

[2019-05-30 15:36:28,352][INFO] emmental.logging.checkpointer:102 - Save checkpoint of 10900 batch at logs/2019_05_30/13_46_58/checkpoint_10900.pth.
[2019-05-30 15:37:22,819][INFO] emmental.logging.checkpointer:102 - Save checkpoint of 11000 batch at logs/2019_05_30/13_46_58/checkpoint_11000.pth.
[2019-05-30 15:38:19,116][INFO] emmental.logging.checkpointer:102 - Save checkpoint of 11100 batch at logs/2019_05_30/13_46_58/checkpoint_11100.pth.
[2019-05-30 15:39:15,441][INFO] emmental.logging.checkpointer:102 - Save checkpoint of 11200 batch at logs/2019_05_30/13_46_58/checkpoint_11200.pth.
[2019-05-30 15:40:13,208][INFO] emmental.logging.checkpointer:102 - Save checkpoint of 11300 batch at logs/2019_05_30/13_46_58/checkpoint_11300.pth.
[2019-05-30 15:41:12,215][INFO] emmental.logging.checkpointer:102 - Save checkpoint of 11400 batch at logs/2019_05_30/13_46_58/checkpoint_11400.pth.
[2019-05-30 15:42:11,547][INFO] emmental.logging.checkpointer:102 - Save checkpoint of 11500 batch at logs




HBox(children=(IntProgress(value=0, description='Epoch 9:', max=1357, style=ProgressStyle(description_width='i…

[2019-05-30 15:51:32,341][INFO] emmental.logging.checkpointer:102 - Save checkpoint of 12300 batch at logs/2019_05_30/13_46_58/checkpoint_12300.pth.
[2019-05-30 15:52:31,766][INFO] emmental.logging.checkpointer:102 - Save checkpoint of 12400 batch at logs/2019_05_30/13_46_58/checkpoint_12400.pth.
[2019-05-30 15:53:30,391][INFO] emmental.logging.checkpointer:102 - Save checkpoint of 12500 batch at logs/2019_05_30/13_46_58/checkpoint_12500.pth.
[2019-05-30 15:54:28,461][INFO] emmental.logging.checkpointer:102 - Save checkpoint of 12600 batch at logs/2019_05_30/13_46_58/checkpoint_12600.pth.
[2019-05-30 15:55:29,440][INFO] emmental.logging.checkpointer:102 - Save checkpoint of 12700 batch at logs/2019_05_30/13_46_58/checkpoint_12700.pth.
[2019-05-30 15:56:28,761][INFO] emmental.logging.checkpointer:102 - Save checkpoint of 12800 batch at logs/2019_05_30/13_46_58/checkpoint_12800.pth.
[2019-05-30 16:00:28,776][INFO] emmental.logging.checkpointer:102 - Save checkpoint of 13200 batch at logs




[2019-05-30 16:04:45,573][INFO] emmental.logging.checkpointer:188 - Loading the best model from logs/2019_05_30/13_46_58/best_model_WiC_SuperGLUE_val_accuracy.pth.
[2019-05-30 16:04:51,629][INFO] emmental.model:58 - Moving model to GPU (cuda:0).


In [43]:
mtl_model.score(dataloaders["val"])



{'WiC_slice_ind_slice_base/SuperGLUE/val/accuracy': 1.0,
 'WiC_slice_pred_slice_base/SuperGLUE/val/accuracy': 0.7476489028213166,
 'WiC_slice_ind_slice_verb/SuperGLUE/val/accuracy': 0.9968652037617555,
 'WiC_slice_pred_slice_verb/SuperGLUE/val/accuracy': 0.7283950617283951,
 'WiC/SuperGLUE/val/accuracy': 0.7476489028213166}

In [44]:
mtl_model.score(dataloaders["train"])



{'WiC_slice_ind_slice_base/SuperGLUE/train/accuracy': 1.0,
 'WiC_slice_pred_slice_base/SuperGLUE/train/accuracy': 0.9786293294030951,
 'WiC_slice_ind_slice_verb/SuperGLUE/train/accuracy': 1.0,
 'WiC_slice_pred_slice_verb/SuperGLUE/train/accuracy': 0.9787395596051632,
 'WiC/SuperGLUE/train/accuracy': 0.9784450994841563}

In [45]:
mtl_model.score(dataloaders["val"])



{'WiC_slice_ind_slice_base/SuperGLUE/val/accuracy': 1.0,
 'WiC_slice_pred_slice_base/SuperGLUE/val/accuracy': 0.7476489028213166,
 'WiC_slice_ind_slice_verb/SuperGLUE/val/accuracy': 0.9968652037617555,
 'WiC_slice_pred_slice_verb/SuperGLUE/val/accuracy': 0.7283950617283951,
 'WiC/SuperGLUE/val/accuracy': 0.7476489028213166}

In [None]:
# mtl_model.module_pool

In [47]:
# mtl_model1 = EmmentalModel()

In [48]:
# mtl_model1.load("logs/2019_05_20/17_58_01/best_model_WiC_SuperGLUE_val_accuracy.pth")

In [49]:
# mtl_model.score(dataloaders["val"])

In [50]:
# res_dict = mtl_model.predict(dataloaders["val"], return_preds=True)

In [51]:
# res_dict_golds, res_dict_probs, res_dict_preds = res_dict

In [52]:
# res_dict_golds["WiC"]

In [53]:
# res_dict_probs["WiC"]

In [54]:
# res_dict_preds["WiC"]

In [55]:
# tot = 0
# for g, p in zip(res_dict_golds["WiC"], res_dict_preds["WiC"]):
#     if g == p:
#         tot += 1
# print(tot/len(res_dict_golds["WiC"]))

In [56]:
# dataset = dataloaders["val"].dataset

# for idx in range(len(dataloaders["val"].dataset)):
#     if res_dict_golds["WiC"][idx] != res_dict_preds["WiC"][idx]:
#         print("####".join([
#             str(idx),
#             "True" if res_dict_golds["WiC"][idx] == 1 else "False",
#             dataset.X_dict["words"][idx],
#             dataset.X_dict["poses"][idx],
#             dataset.X_dict["sent1"][idx],
#             str(dataset.X_dict["sent1_idxs"][idx].item()),
#             dataset.X_dict["sent2"][idx],
#             str(dataset.X_dict["sent2_idxs"][idx].item()),
#         ])
#         )
# #     print(dataloaders["val"].dataset.Y_dict["labels"][idx].item())