In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import logging
from functools import partial

import numpy as np
import torch.nn.functional as F
from torch import nn
import torch
import emmental
from emmental import Meta
from emmental.learner import EmmentalLearner
from emmental.model import EmmentalModel
from emmental.scorer import Scorer
from emmental.task import EmmentalTask
from modules.bert_module import BertModule
from parse_WiC_slice import get_WiC_dataloaders
from task_config import SuperGLUE_LABEL_MAPPING, SuperGLUE_TASK_METRIC_MAPPING
from sklearn.metrics import f1_score

In [3]:
logger = logging.getLogger(__name__)

# Initalize Emmental

In [4]:
emmental.init(
    "logs",
    config={
        "model_config": {"device": 0, "dataparallel": False},
        "learner_config": {
            "n_epochs": 10,
            "valid_split": "val",
            "optimizer_config": {"optimizer": "adam", "lr": 1e-5},
            "min_lr": 0,
            "lr_scheduler_config": {"warmup_percentage": 0.1, "lr_scheduler": None},
        },
        "logging_config": {
            "counter_unit": "batch",
            "evaluation_freq": 100,
            "checkpointing": True,
            "checkpointer_config": {"checkpoint_metric": {"WiC/SuperGLUE/val/accuracy":"max"}},
        },
    },
)

05/22/2019 15:23:57 - INFO - emmental.meta -   Setting logging directory to: logs/2019_05_22/15_23_57
05/22/2019 15:23:57 - INFO - emmental.meta -   Loading Emmental default config from /dfs/scratch1/senwu/mmtl/emmental/src/emmental/emmental-default-config.yaml.
05/22/2019 15:23:57 - INFO - emmental.meta -   Updating Emmental config from user provided config.


In [5]:
Meta.config

{'meta_config': {'seed': 0, 'verbose': True, 'log_path': None},
 'model_config': {'model_path': None, 'device': 0, 'dataparallel': False},
 'learner_config': {'fp16': False,
  'n_epochs': 10,
  'train_split': 'train',
  'valid_split': 'val',
  'test_split': 'test',
  'ignore_index': -100,
  'optimizer_config': {'optimizer': 'adam',
   'lr': 1e-05,
   'l2': 0.0,
   'grad_clip': 1.0,
   'sgd_config': {'momentum': 0.9},
   'adam_config': {'betas': (0.9, 0.999)}},
  'lr_scheduler_config': {'lr_scheduler': None,
   'warmup_steps': None,
   'warmup_unit': 'batch',
   'warmup_percentage': 0.1,
   'min_lr': 0.0,
   'linear_config': {'min_lr': 0.0},
   'exponential_config': {'gamma': 0.9},
   'plateau_config': {'factor': 0.5, 'patience': 10, 'threshold': 0.0001}},
  'task_scheduler': 'round_robin',
  'global_evaluation_metric_dict': None,
  'min_lr': 0},
 'logging_config': {'counter_unit': 'batch',
  'evaluation_freq': 100,
  'writer_config': {'writer': 'tensorboard', 'verbose': True},
  'check

In [6]:
TASK_NAME = "WiC"
DATA_DIR = "data"
BERT_MODEL_NAME = "bert-large-cased"
BATCH_SIZE = 4

BERT_OUTPUT_DIM = 768 if "base" in BERT_MODEL_NAME else 1024
TASK_CARDINALITY = (
    len(SuperGLUE_LABEL_MAPPING[TASK_NAME].keys())
    if SuperGLUE_LABEL_MAPPING[TASK_NAME] is not None
    else 1
)

In [7]:
BERT_OUTPUT_DIM, TASK_CARDINALITY

(1024, 2)

# Extract train/dev dataset from file

In [8]:
def slice_verb(dataset):
    slice_name = "slice_verb"
    ind, pred = [], []
    cnt = 0
    for idx, pos in enumerate(dataset.X_dict["poses"]):
        if pos == "V":
            ind.append(1)
            pred.append(dataset.Y_dict["labels"][idx])
            cnt += 1
        else:
            ind.append(2)
            pred.append(Meta.config["learner_config"]["ignore_index"])
    ind = torch.from_numpy(np.array(ind)).view(-1)
    pred = torch.from_numpy(np.array(pred)).view(-1)
    logger.info(f"Total {cnt} / {len(dataset)} in the slice {slice_name}")
    print(ind.size(), pred.size())
    return ind, pred

In [9]:
def slice_noun(dataset):
    slice_name = "slice_noun"
    ind, pred = [], []
    cnt = 0
    for idx, pos in enumerate(dataset.X_dict["poses"]):
        if pos == "N":
            ind.append(1)
            pred.append(dataset.Y_dict["labels"][idx])
            cnt += 1
        else:
            ind.append(2)
            pred.append(Meta.config["learner_config"]["ignore_index"])
    ind = torch.from_numpy(np.array(ind)).view(-1)
    pred = torch.from_numpy(np.array(pred)).view(-1)
    logger.info(f"Total {cnt} / {len(dataset)} in the slice {slice_name}")
    print(ind.size(), pred.size())
    return ind, pred

In [10]:
def slice_first_pos(dataset):
    slice_name = "slice_first_pos"
    ind, pred = [], []
    cnt = 0
    for idx, pos in enumerate(dataset.X_dict["poses"]):
        if dataset.X_dict["sent1_ori_idxs"][idx] == 1 and dataset.X_dict["sent2_ori_idxs"][idx] == 1:
            ind.append(1)
            pred.append(dataset.Y_dict["labels"][idx])
            cnt += 1
        else:
            ind.append(2)
            pred.append(Meta.config["learner_config"]["ignore_index"])
    ind = torch.from_numpy(np.array(ind)).view(-1)
    pred = torch.from_numpy(np.array(pred)).view(-1)
    logger.info(f"Total {cnt} / {len(dataset)} in the slice {slice_name}")
    print(ind.size(), pred.size())
    return ind, pred

In [11]:
def slice_base(dataset):
    return torch.from_numpy(np.array([1] * len(dataset))), dataset.Y_dict["labels"]

In [12]:
slice_func_dict = {
    "slice_base": slice_base,
    "slice_verb": slice_verb,
    "slice_noun": slice_noun,
    "slice_first_pos": slice_first_pos,
}
slice_func_dict.keys()

dict_keys(['slice_base', 'slice_verb', 'slice_noun', 'slice_first_pos'])

In [13]:
dataloaders = get_WiC_dataloaders(
    data_dir=DATA_DIR,
    task_name=TASK_NAME,
    splits=["train", "val", "test"],
    max_sequence_length=128,
    max_data_samples=None,
    tokenizer_name=BERT_MODEL_NAME,
    batch_size=BATCH_SIZE,
    slice_func_dict=slice_func_dict,
)

05/22/2019 15:23:59 - INFO - tokenizer -   Loading Tokenizer bert-large-cased
05/22/2019 15:23:59 - INFO - pytorch_pretrained_bert.tokenization -   loading vocabulary file https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-cased-vocab.txt from cache at /lfs/local/0/senwu/.pytorch_pretrained_bert/cee054f6aafe5e2cf816d2228704e326446785f940f5451a5b26033516a4ac3d.e13dbb970cb325137104fb2e5f36fe865f27746c6b526f6352861b1980eb80b1
05/22/2019 15:24:03 - INFO - __main__ -   Total 2634 / 5428 in the slice slice_verb
05/22/2019 15:24:03 - INFO - __main__ -   Total 2794 / 5428 in the slice slice_noun


max len 68
torch.Size([5428]) torch.Size([5428])
torch.Size([5428]) torch.Size([5428])


05/22/2019 15:24:03 - INFO - __main__ -   Total 869 / 5428 in the slice slice_first_pos
05/22/2019 15:24:03 - INFO - parse_WiC_slice -   Loaded train for WiC.


torch.Size([5428]) torch.Size([5428])


05/22/2019 15:24:03 - INFO - __main__ -   Total 243 / 638 in the slice slice_verb
05/22/2019 15:24:03 - INFO - __main__ -   Total 395 / 638 in the slice slice_noun
05/22/2019 15:24:03 - INFO - __main__ -   Total 129 / 638 in the slice slice_first_pos
05/22/2019 15:24:03 - INFO - parse_WiC_slice -   Loaded val for WiC.


max len 57
torch.Size([638]) torch.Size([638])
torch.Size([638]) torch.Size([638])
torch.Size([638]) torch.Size([638])


05/22/2019 15:24:04 - INFO - __main__ -   Total 569 / 1400 in the slice slice_verb
05/22/2019 15:24:04 - INFO - __main__ -   Total 831 / 1400 in the slice slice_noun
05/22/2019 15:24:04 - INFO - __main__ -   Total 240 / 1400 in the slice slice_first_pos
05/22/2019 15:24:04 - INFO - parse_WiC_slice -   Loaded test for WiC.


max len 60
torch.Size([1400]) torch.Size([1400])
torch.Size([1400]) torch.Size([1400])
torch.Size([1400]) torch.Size([1400])


In [14]:
dataloaders["train"].dataset.Y_dict

{'labels': tensor([2, 2, 2,  ..., 1, 1, 1]),
 'WiC_slice_ind_slice_base': tensor([1, 1, 1,  ..., 1, 1, 1]),
 'WiC_slice_pred_slice_base': tensor([2, 2, 2,  ..., 1, 1, 1]),
 'WiC_slice_ind_slice_verb': tensor([1, 1, 1,  ..., 1, 1, 2]),
 'WiC_slice_pred_slice_verb': tensor([   2,    2,    2,  ...,    1,    1, -100]),
 'WiC_slice_ind_slice_noun': tensor([2, 2, 2,  ..., 2, 2, 1]),
 'WiC_slice_pred_slice_noun': tensor([-100, -100, -100,  ..., -100, -100,    1]),
 'WiC_slice_ind_slice_first_pos': tensor([2, 2, 2,  ..., 2, 2, 1]),
 'WiC_slice_pred_slice_first_pos': tensor([-100, -100, -100,  ..., -100, -100,    1])}

In [15]:
dataloaders["train"].task_to_label_dict

{'WiC_slice_ind_slice_base': 'WiC_slice_ind_slice_base',
 'WiC_slice_pred_slice_base': 'WiC_slice_pred_slice_base',
 'WiC_slice_ind_slice_verb': 'WiC_slice_ind_slice_verb',
 'WiC_slice_pred_slice_verb': 'WiC_slice_pred_slice_verb',
 'WiC_slice_ind_slice_noun': 'WiC_slice_ind_slice_noun',
 'WiC_slice_pred_slice_noun': 'WiC_slice_pred_slice_noun',
 'WiC_slice_ind_slice_first_pos': 'WiC_slice_ind_slice_first_pos',
 'WiC_slice_pred_slice_first_pos': 'WiC_slice_pred_slice_first_pos',
 'WiC': 'labels'}

In [16]:
for key, value in dataloaders["train"].dataset.Y_dict.items():
    print(key, value.size())

labels torch.Size([5428])
WiC_slice_ind_slice_base torch.Size([5428])
WiC_slice_pred_slice_base torch.Size([5428])
WiC_slice_ind_slice_verb torch.Size([5428])
WiC_slice_pred_slice_verb torch.Size([5428])
WiC_slice_ind_slice_noun torch.Size([5428])
WiC_slice_pred_slice_noun torch.Size([5428])
WiC_slice_ind_slice_first_pos torch.Size([5428])
WiC_slice_pred_slice_first_pos torch.Size([5428])


# Build Emmental task

In [17]:
def ce_loss(module_name, immediate_ouput_dict, Y, active):
    return F.cross_entropy(
        immediate_ouput_dict[module_name][0][active], (Y.view(-1) - 1)[active]
    )

In [18]:
def output(module_name, immediate_ouput_dict):
    return F.softmax(immediate_ouput_dict[module_name][0], dim=1)

In [19]:
class FeatureConcateModule(nn.Module):
    def __init__(self):
        super().__init__()

    def forward(self, feature, idx1, idx2):
#         import pdb; pdb.set_trace()
        last_layer = feature[-1]
        emb = last_layer[:,0,:]
        idx1 = idx1.unsqueeze(-1).unsqueeze(-1).expand([-1, -1, last_layer.size(-1)])
        idx2 = idx2.unsqueeze(-1).unsqueeze(-1).expand([-1, -1, last_layer.size(-1)])
        word1_emb = last_layer.gather(dim=1, index=idx1).squeeze(dim=1)
        word2_emb = last_layer.gather(dim=1, index=idx2).squeeze(dim=1)
        input = torch.cat([emb, word1_emb, word2_emb], dim=-1)
        return input

In [20]:
class SliceModule(nn.Module):
    def __init__(self, feature_dim, class_cardinality):
        super().__init__()
        self.linear = nn.Linear(feature_dim, class_cardinality)

    def forward(self, feature):
        return self.linear.forward(feature)

In [21]:
H = BERT_OUTPUT_DIM

In [22]:
shared_classification_module = nn.Linear(H, TASK_CARDINALITY)

In [23]:
bert_module = BertModule(BERT_MODEL_NAME)

05/22/2019 15:24:05 - INFO - pytorch_pretrained_bert.modeling -   loading archive file https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-cased.tar.gz from cache at ./cache/7fb0534b83c42daee7d3ddb0ebaa81387925b71665d6ea195c5447f1077454cd.eea60d9ebb03c75bb36302aa9d241d3b7a04bba39c360cf035e8bf8140816233
05/22/2019 15:24:05 - INFO - pytorch_pretrained_bert.modeling -   extracting archive file ./cache/7fb0534b83c42daee7d3ddb0ebaa81387925b71665d6ea195c5447f1077454cd.eea60d9ebb03c75bb36302aa9d241d3b7a04bba39c360cf035e8bf8140816233 to temp dir /tmp/tmpumrkv2e0
05/22/2019 15:24:22 - INFO - pytorch_pretrained_bert.modeling -   Model config {
  "attention_probs_dropout_prob": 0.1,
  "directionality": "bidi",
  "hidden_act": "gelu",
  "hidden_dropout_prob": 0.1,
  "hidden_size": 1024,
  "initializer_range": 0.02,
  "intermediate_size": 4096,
  "max_position_embeddings": 512,
  "num_attention_heads": 16,
  "num_hidden_layers": 24,
  "pooler_fc_size": 768,
  "pooler_num_attention_heads"

In [24]:
tasks = []

In [25]:
# Add ind task

type = "ind"

for slice_name in slice_func_dict.keys():
    task = EmmentalTask(
        name=f"{TASK_NAME}_slice_{type}_{slice_name}",
        module_pool=nn.ModuleDict(
            {
#                 "bert_module": bert_module,
                "feature": FeatureConcateModule(),
                f"{TASK_NAME}_slice_{type}_{slice_name}_head": SliceModule(
                    3 * BERT_OUTPUT_DIM, 2
                ),
            }
        ),
        task_flow=[
            {
                "name": "input",
                "module": "bert_module",
                "inputs": [("_input_", "token_ids"), ("_input_", "token_segments")],
            },
            {
                "name": f"feature",
                "module": f"feature",
                "inputs": [
                    ("input", 0),
                    ("_input_", "sent1_idxs"),
                    ("_input_", "sent2_idxs"),
                ],
            },
            {
                "name": f"{TASK_NAME}_slice_{type}_{slice_name}_head",
                "module": f"{TASK_NAME}_slice_{type}_{slice_name}_head",
                "inputs": [("feature", 0)],
            },
        ],
        loss_func=partial(ce_loss, f"{TASK_NAME}_slice_{type}_{slice_name}_head"),
        output_func=partial(output, f"{TASK_NAME}_slice_{type}_{slice_name}_head"),
        scorer=Scorer(metrics=["accuracy"]),
    )
    tasks.append(task)

05/22/2019 15:24:54 - INFO - emmental.task -   Created task: WiC_slice_ind_slice_base
05/22/2019 15:24:54 - INFO - emmental.task -   Created task: WiC_slice_ind_slice_verb
05/22/2019 15:24:54 - INFO - emmental.task -   Created task: WiC_slice_ind_slice_noun
05/22/2019 15:24:54 - INFO - emmental.task -   Created task: WiC_slice_ind_slice_first_pos


In [26]:
# Add ind task

type = "pred"

for slice_name in slice_func_dict.keys():
    task = EmmentalTask(
        name=f"{TASK_NAME}_slice_{type}_{slice_name}",
        module_pool=nn.ModuleDict(
            {
#                 "bert_module": bert_module,
                "feature": FeatureConcateModule(),
                f"{TASK_NAME}_slice_feat_{slice_name}": nn.Linear(3 * BERT_OUTPUT_DIM, H),
                f"{TASK_NAME}_slice_{type}_{slice_name}_head": shared_classification_module,
            }
        ),
        task_flow=[
            {
                "name": "input",
                "module": "bert_module",
                "inputs": [("_input_", "token_ids"), ("_input_", "token_segments")],
            },
            {
                "name": f"feature",
                "module": f"feature",
                "inputs": [
                    ("input", 0),
                    ("_input_", "sent1_idxs"),
                    ("_input_", "sent2_idxs"),
                ],
            },
            {
                "name": f"{TASK_NAME}_slice_feat_{slice_name}",
                "module": f"{TASK_NAME}_slice_feat_{slice_name}",
                "inputs": [("feature", 0)],
            },
            {
                "name": f"{TASK_NAME}_slice_{type}_{slice_name}_head",
                "module": f"{TASK_NAME}_slice_{type}_{slice_name}_head",
                "inputs": [(f"{TASK_NAME}_slice_feat_{slice_name}", 0)],
            },
        ],
        loss_func=partial(ce_loss, f"{TASK_NAME}_slice_{type}_{slice_name}_head"),
        output_func=partial(output, f"{TASK_NAME}_slice_{type}_{slice_name}_head"),
        scorer=Scorer(metrics=SuperGLUE_TASK_METRIC_MAPPING[TASK_NAME]),
    )
    tasks.append(task)

05/22/2019 15:24:54 - INFO - emmental.task -   Created task: WiC_slice_pred_slice_base
05/22/2019 15:24:54 - INFO - emmental.task -   Created task: WiC_slice_pred_slice_verb
05/22/2019 15:24:54 - INFO - emmental.task -   Created task: WiC_slice_pred_slice_noun
05/22/2019 15:24:54 - INFO - emmental.task -   Created task: WiC_slice_pred_slice_first_pos


In [27]:
class MasterModule(nn.Module):
    def __init__(self, feature_dim, class_cardinality):
        super().__init__()
        self.linear = nn.Linear(feature_dim, class_cardinality)

    def forward(self, immediate_ouput_dict):
        slice_ind_names = sorted(
            [
                flow_name
                for flow_name in immediate_ouput_dict.keys()
                if "_slice_ind_" in flow_name
            ]
        )
        slice_pred_names = sorted(
            [
                flow_name
                for flow_name in immediate_ouput_dict.keys()
                if "_slice_pred_" in flow_name
            ]
        )

        Q = torch.cat(
            [
                F.softmax(immediate_ouput_dict[slice_ind_name][0])[:, 0].unsqueeze(1)
                for slice_ind_name in slice_ind_names
            ],
            dim=-1,
        )
        P = torch.cat(
            [
                F.softmax(immediate_ouput_dict[slice_pred_name][0])[:, 0].unsqueeze(1)
                for slice_pred_name in slice_pred_names
            ],
            dim=-1,
        )

        slice_feat_names = sorted(
            [
                flow_name
                for flow_name in immediate_ouput_dict.keys()
                if "_slice_feat_" in flow_name
            ]
        )

        slice_reps = torch.cat(
            [
                immediate_ouput_dict[slice_feat_name][0].unsqueeze(1)
                for slice_feat_name in slice_feat_names
            ],
            dim=1,
        )

        A = F.softmax(Q * P, dim=1).unsqueeze(-1).expand([-1, -1, slice_reps.size(-1)])

        reweighted_rep = torch.sum(A * slice_reps, 1)

        return self.linear.forward(reweighted_rep)

In [28]:
master_task = EmmentalTask(
    name=f"{TASK_NAME}",
    module_pool=nn.ModuleDict(
        {
            "bert_module": bert_module,
            f"{TASK_NAME}_pred_head": MasterModule(H, TASK_CARDINALITY),
        }
    ),
    task_flow=[
        {
            "name": f"{TASK_NAME}_pred_head",
            "module": f"{TASK_NAME}_pred_head",
            "inputs": [],
        }
    ],
    loss_func=partial(ce_loss, f"{TASK_NAME}_pred_head"),
    output_func=partial(output, f"{TASK_NAME}_pred_head"),
    scorer=Scorer(metrics=SuperGLUE_TASK_METRIC_MAPPING[TASK_NAME]),
)
tasks.append(master_task)

05/22/2019 15:24:54 - INFO - emmental.task -   Created task: WiC


In [29]:
mtl_model = EmmentalModel(name="SuperGLUE_single_task", tasks=tasks)

05/22/2019 15:24:54 - INFO - emmental.model -   Moving model to GPU (cuda:0).
05/22/2019 15:24:59 - INFO - emmental.model -   Moving model to GPU (cuda:0).
05/22/2019 15:24:59 - INFO - emmental.model -   Moving model to GPU (cuda:0).
05/22/2019 15:24:59 - INFO - emmental.model -   Moving model to GPU (cuda:0).
05/22/2019 15:24:59 - INFO - emmental.model -   Moving model to GPU (cuda:0).
05/22/2019 15:24:59 - INFO - emmental.model -   Moving model to GPU (cuda:0).
05/22/2019 15:24:59 - INFO - emmental.model -   Moving model to GPU (cuda:0).
05/22/2019 15:24:59 - INFO - emmental.model -   Moving model to GPU (cuda:0).
05/22/2019 15:24:59 - INFO - emmental.model -   Moving model to GPU (cuda:0).
05/22/2019 15:25:00 - INFO - emmental.model -   Created emmental model SuperGLUE_single_task that contains task {'WiC_slice_pred_slice_base', 'WiC_slice_pred_slice_verb', 'WiC_slice_ind_slice_noun', 'WiC_slice_ind_slice_first_pos', 'WiC_slice_ind_slice_verb', 'WiC_slice_ind_slice_base', 'WiC', 'Wi

In [30]:
emmental_learner = EmmentalLearner()

In [31]:
# for X, Y in dataloaders["train"]:
#     import pdb; pdb.set_trace()
# # #     print(X, Y)
#     pass

In [32]:
emmental_learner.learn(mtl_model, dataloaders.values())

05/22/2019 15:25:00 - INFO - emmental.logging.logging_manager -   Evaluating every 100 batch.
05/22/2019 15:25:00 - INFO - emmental.logging.logging_manager -   Checkpointing every 100 batch.
05/22/2019 15:25:00 - INFO - emmental.logging.checkpointer -   Save checkpoints at logs/2019_05_22/15_23_57 every 100 batch
05/22/2019 15:25:00 - INFO - emmental.logging.checkpointer -   No checkpoints saved before 0 batch.
05/22/2019 15:25:00 - INFO - root -   Generating grammar tables from /usr/lib/python3.6/lib2to3/Grammar.txt
05/22/2019 15:25:00 - INFO - root -   Generating grammar tables from /usr/lib/python3.6/lib2to3/PatternGrammar.txt
05/22/2019 15:25:00 - INFO - emmental.learner -   Warmup 1357 batchs.
05/22/2019 15:25:00 - INFO - emmental.learner -   Start learning...


HBox(children=(IntProgress(value=0, description='Epoch 0:', max=1357, style=ProgressStyle(description_width='i…

05/22/2019 15:26:08 - INFO - emmental.logging.checkpointer -   checkpoint_runway condition has been met. Start checkpoining.
  "type " + obj.__name__ + ". It won't be checked "
  "type " + obj.__name__ + ". It won't be checked "
  "type " + obj.__name__ + ". It won't be checked "
05/22/2019 15:26:20 - INFO - emmental.logging.checkpointer -   Save checkpoint of 100 batch at logs/2019_05_22/15_23_57/checkpoint_100.pth.
05/22/2019 15:26:33 - INFO - emmental.logging.checkpointer -   Save best model of metric WiC/SuperGLUE/val/accuracy at logs/2019_05_22/15_23_57/best_model_WiC_SuperGLUE_val_accuracy.pth
05/22/2019 15:27:53 - INFO - emmental.logging.checkpointer -   Save checkpoint of 200 batch at logs/2019_05_22/15_23_57/checkpoint_200.pth.
05/22/2019 15:28:05 - INFO - emmental.logging.checkpointer -   Save best model of metric WiC/SuperGLUE/val/accuracy at logs/2019_05_22/15_23_57/best_model_WiC_SuperGLUE_val_accuracy.pth
05/22/2019 15:29:24 - INFO - emmental.logging.checkpointer -   Save




HBox(children=(IntProgress(value=0, description='Epoch 1:', max=1357, style=ProgressStyle(description_width='i…

05/22/2019 15:45:19 - INFO - emmental.logging.checkpointer -   Save checkpoint of 1400 batch at logs/2019_05_22/15_23_57/checkpoint_1400.pth.
05/22/2019 15:45:31 - INFO - emmental.logging.checkpointer -   Save best model of metric WiC/SuperGLUE/val/accuracy at logs/2019_05_22/15_23_57/best_model_WiC_SuperGLUE_val_accuracy.pth
05/22/2019 15:46:50 - INFO - emmental.logging.checkpointer -   Save checkpoint of 1500 batch at logs/2019_05_22/15_23_57/checkpoint_1500.pth.
05/22/2019 15:48:10 - INFO - emmental.logging.checkpointer -   Save checkpoint of 1600 batch at logs/2019_05_22/15_23_57/checkpoint_1600.pth.
05/22/2019 15:48:21 - INFO - emmental.logging.checkpointer -   Save best model of metric WiC/SuperGLUE/val/accuracy at logs/2019_05_22/15_23_57/best_model_WiC_SuperGLUE_val_accuracy.pth
05/22/2019 15:49:40 - INFO - emmental.logging.checkpointer -   Save checkpoint of 1700 batch at logs/2019_05_22/15_23_57/checkpoint_1700.pth.
05/22/2019 15:51:00 - INFO - emmental.logging.checkpointer -




HBox(children=(IntProgress(value=0, description='Epoch 2:', max=1357, style=ProgressStyle(description_width='i…

05/22/2019 16:04:33 - INFO - emmental.logging.checkpointer -   Save checkpoint of 2800 batch at logs/2019_05_22/15_23_57/checkpoint_2800.pth.
05/22/2019 16:06:00 - INFO - emmental.logging.checkpointer -   Save checkpoint of 2900 batch at logs/2019_05_22/15_23_57/checkpoint_2900.pth.
05/22/2019 16:06:18 - INFO - emmental.logging.checkpointer -   Save best model of metric WiC/SuperGLUE/val/accuracy at logs/2019_05_22/15_23_57/best_model_WiC_SuperGLUE_val_accuracy.pth
05/22/2019 16:07:47 - INFO - emmental.logging.checkpointer -   Save checkpoint of 3000 batch at logs/2019_05_22/15_23_57/checkpoint_3000.pth.
05/22/2019 16:09:14 - INFO - emmental.logging.checkpointer -   Save checkpoint of 3100 batch at logs/2019_05_22/15_23_57/checkpoint_3100.pth.
05/22/2019 16:10:42 - INFO - emmental.logging.checkpointer -   Save checkpoint of 3200 batch at logs/2019_05_22/15_23_57/checkpoint_3200.pth.
05/22/2019 16:12:12 - INFO - emmental.logging.checkpointer -   Save checkpoint of 3300 batch at logs/201




HBox(children=(IntProgress(value=0, description='Epoch 3:', max=1357, style=ProgressStyle(description_width='i…

05/22/2019 16:23:52 - INFO - emmental.logging.checkpointer -   Save checkpoint of 4100 batch at logs/2019_05_22/15_23_57/checkpoint_4100.pth.
05/22/2019 16:24:09 - INFO - emmental.logging.checkpointer -   Save best model of metric WiC/SuperGLUE/val/accuracy at logs/2019_05_22/15_23_57/best_model_WiC_SuperGLUE_val_accuracy.pth
05/22/2019 16:25:39 - INFO - emmental.logging.checkpointer -   Save checkpoint of 4200 batch at logs/2019_05_22/15_23_57/checkpoint_4200.pth.
05/22/2019 16:27:07 - INFO - emmental.logging.checkpointer -   Save checkpoint of 4300 batch at logs/2019_05_22/15_23_57/checkpoint_4300.pth.
05/22/2019 16:28:37 - INFO - emmental.logging.checkpointer -   Save checkpoint of 4400 batch at logs/2019_05_22/15_23_57/checkpoint_4400.pth.
05/22/2019 16:30:06 - INFO - emmental.logging.checkpointer -   Save checkpoint of 4500 batch at logs/2019_05_22/15_23_57/checkpoint_4500.pth.
05/22/2019 16:31:37 - INFO - emmental.logging.checkpointer -   Save checkpoint of 4600 batch at logs/201




HBox(children=(IntProgress(value=0, description='Epoch 4:', max=1357, style=ProgressStyle(description_width='i…

05/22/2019 16:45:20 - INFO - emmental.logging.checkpointer -   Save checkpoint of 5500 batch at logs/2019_05_22/15_23_57/checkpoint_5500.pth.
05/22/2019 16:46:47 - INFO - emmental.logging.checkpointer -   Save checkpoint of 5600 batch at logs/2019_05_22/15_23_57/checkpoint_5600.pth.
05/22/2019 16:48:18 - INFO - emmental.logging.checkpointer -   Save checkpoint of 5700 batch at logs/2019_05_22/15_23_57/checkpoint_5700.pth.
05/22/2019 16:49:51 - INFO - emmental.logging.checkpointer -   Save checkpoint of 5800 batch at logs/2019_05_22/15_23_57/checkpoint_5800.pth.
05/22/2019 16:51:24 - INFO - emmental.logging.checkpointer -   Save checkpoint of 5900 batch at logs/2019_05_22/15_23_57/checkpoint_5900.pth.
05/22/2019 16:52:56 - INFO - emmental.logging.checkpointer -   Save checkpoint of 6000 batch at logs/2019_05_22/15_23_57/checkpoint_6000.pth.
05/22/2019 16:54:29 - INFO - emmental.logging.checkpointer -   Save checkpoint of 6100 batch at logs/2019_05_22/15_23_57/checkpoint_6100.pth.
05/22/




HBox(children=(IntProgress(value=0, description='Epoch 5:', max=1357, style=ProgressStyle(description_width='i…

05/22/2019 17:05:18 - INFO - emmental.logging.checkpointer -   Save checkpoint of 6800 batch at logs/2019_05_22/15_23_57/checkpoint_6800.pth.
05/22/2019 17:06:51 - INFO - emmental.logging.checkpointer -   Save checkpoint of 6900 batch at logs/2019_05_22/15_23_57/checkpoint_6900.pth.
05/22/2019 17:08:31 - INFO - emmental.logging.checkpointer -   Save checkpoint of 7000 batch at logs/2019_05_22/15_23_57/checkpoint_7000.pth.
05/22/2019 17:10:07 - INFO - emmental.logging.checkpointer -   Save checkpoint of 7100 batch at logs/2019_05_22/15_23_57/checkpoint_7100.pth.
05/22/2019 17:11:43 - INFO - emmental.logging.checkpointer -   Save checkpoint of 7200 batch at logs/2019_05_22/15_23_57/checkpoint_7200.pth.
05/22/2019 17:13:13 - INFO - emmental.logging.checkpointer -   Save checkpoint of 7300 batch at logs/2019_05_22/15_23_57/checkpoint_7300.pth.
05/22/2019 17:14:39 - INFO - emmental.logging.checkpointer -   Save checkpoint of 7400 batch at logs/2019_05_22/15_23_57/checkpoint_7400.pth.
05/22/




HBox(children=(IntProgress(value=0, description='Epoch 6:', max=1357, style=ProgressStyle(description_width='i…

05/22/2019 17:25:22 - INFO - emmental.logging.checkpointer -   Save checkpoint of 8200 batch at logs/2019_05_22/15_23_57/checkpoint_8200.pth.
05/22/2019 17:26:40 - INFO - emmental.logging.checkpointer -   Save checkpoint of 8300 batch at logs/2019_05_22/15_23_57/checkpoint_8300.pth.
05/22/2019 17:28:00 - INFO - emmental.logging.checkpointer -   Save checkpoint of 8400 batch at logs/2019_05_22/15_23_57/checkpoint_8400.pth.
05/22/2019 17:29:19 - INFO - emmental.logging.checkpointer -   Save checkpoint of 8500 batch at logs/2019_05_22/15_23_57/checkpoint_8500.pth.
05/22/2019 17:30:38 - INFO - emmental.logging.checkpointer -   Save checkpoint of 8600 batch at logs/2019_05_22/15_23_57/checkpoint_8600.pth.
05/22/2019 17:31:57 - INFO - emmental.logging.checkpointer -   Save checkpoint of 8700 batch at logs/2019_05_22/15_23_57/checkpoint_8700.pth.
05/22/2019 17:33:18 - INFO - emmental.logging.checkpointer -   Save checkpoint of 8800 batch at logs/2019_05_22/15_23_57/checkpoint_8800.pth.
05/22/




HBox(children=(IntProgress(value=0, description='Epoch 7:', max=1357, style=ProgressStyle(description_width='i…

05/22/2019 17:42:34 - INFO - emmental.logging.checkpointer -   Save checkpoint of 9500 batch at logs/2019_05_22/15_23_57/checkpoint_9500.pth.
05/22/2019 17:43:52 - INFO - emmental.logging.checkpointer -   Save checkpoint of 9600 batch at logs/2019_05_22/15_23_57/checkpoint_9600.pth.
05/22/2019 17:45:11 - INFO - emmental.logging.checkpointer -   Save checkpoint of 9700 batch at logs/2019_05_22/15_23_57/checkpoint_9700.pth.
05/22/2019 17:46:30 - INFO - emmental.logging.checkpointer -   Save checkpoint of 9800 batch at logs/2019_05_22/15_23_57/checkpoint_9800.pth.
05/22/2019 17:47:49 - INFO - emmental.logging.checkpointer -   Save checkpoint of 9900 batch at logs/2019_05_22/15_23_57/checkpoint_9900.pth.
05/22/2019 17:49:07 - INFO - emmental.logging.checkpointer -   Save checkpoint of 10000 batch at logs/2019_05_22/15_23_57/checkpoint_10000.pth.
05/22/2019 17:50:28 - INFO - emmental.logging.checkpointer -   Save checkpoint of 10100 batch at logs/2019_05_22/15_23_57/checkpoint_10100.pth.
05




HBox(children=(IntProgress(value=0, description='Epoch 8:', max=1357, style=ProgressStyle(description_width='i…

05/22/2019 18:00:59 - INFO - emmental.logging.checkpointer -   Save checkpoint of 10900 batch at logs/2019_05_22/15_23_57/checkpoint_10900.pth.
05/22/2019 18:02:19 - INFO - emmental.logging.checkpointer -   Save checkpoint of 11000 batch at logs/2019_05_22/15_23_57/checkpoint_11000.pth.
05/22/2019 18:03:38 - INFO - emmental.logging.checkpointer -   Save checkpoint of 11100 batch at logs/2019_05_22/15_23_57/checkpoint_11100.pth.
05/22/2019 18:04:57 - INFO - emmental.logging.checkpointer -   Save checkpoint of 11200 batch at logs/2019_05_22/15_23_57/checkpoint_11200.pth.
05/22/2019 18:06:16 - INFO - emmental.logging.checkpointer -   Save checkpoint of 11300 batch at logs/2019_05_22/15_23_57/checkpoint_11300.pth.
05/22/2019 18:07:35 - INFO - emmental.logging.checkpointer -   Save checkpoint of 11400 batch at logs/2019_05_22/15_23_57/checkpoint_11400.pth.
05/22/2019 18:08:54 - INFO - emmental.logging.checkpointer -   Save checkpoint of 11500 batch at logs/2019_05_22/15_23_57/checkpoint_115




HBox(children=(IntProgress(value=0, description='Epoch 9:', max=1357, style=ProgressStyle(description_width='i…

05/22/2019 18:19:29 - INFO - emmental.logging.checkpointer -   Save checkpoint of 12300 batch at logs/2019_05_22/15_23_57/checkpoint_12300.pth.
05/22/2019 18:20:47 - INFO - emmental.logging.checkpointer -   Save checkpoint of 12400 batch at logs/2019_05_22/15_23_57/checkpoint_12400.pth.
05/22/2019 18:22:06 - INFO - emmental.logging.checkpointer -   Save checkpoint of 12500 batch at logs/2019_05_22/15_23_57/checkpoint_12500.pth.
05/22/2019 18:23:29 - INFO - emmental.logging.checkpointer -   Save checkpoint of 12600 batch at logs/2019_05_22/15_23_57/checkpoint_12600.pth.
05/22/2019 18:24:49 - INFO - emmental.logging.checkpointer -   Save checkpoint of 12700 batch at logs/2019_05_22/15_23_57/checkpoint_12700.pth.
05/22/2019 18:26:07 - INFO - emmental.logging.checkpointer -   Save checkpoint of 12800 batch at logs/2019_05_22/15_23_57/checkpoint_12800.pth.
05/22/2019 18:27:26 - INFO - emmental.logging.checkpointer -   Save checkpoint of 12900 batch at logs/2019_05_22/15_23_57/checkpoint_129




05/22/2019 18:37:04 - INFO - emmental.logging.checkpointer -   Loading the best model from logs/2019_05_22/15_23_57/best_model_WiC_SuperGLUE_val_accuracy.pth.
05/22/2019 18:37:07 - INFO - emmental.model -   Moving model to GPU (cuda:0).


In [33]:
mtl_model.score(dataloaders["val"])

[autoreload of pytorch_pretrained_bert failed: Traceback (most recent call last):
  File "/lfs/raiders3/0/senwu/.venv_emmental/lib/python3.6/site-packages/IPython/extensions/autoreload.py", line 244, in check
    superreload(m, reload, self.old_objects)
  File "/lfs/raiders3/0/senwu/.venv_emmental/lib/python3.6/site-packages/IPython/extensions/autoreload.py", line 378, in superreload
    module = reload(module)
  File "/lfs/raiders3/0/senwu/.venv_emmental/lib/python3.6/imp.py", line 315, in reload
    return importlib.reload(module)
  File "/lfs/raiders3/0/senwu/.venv_emmental/lib/python3.6/importlib/__init__.py", line 166, in reload
    _bootstrap._exec(spec, module)
  File "<frozen importlib._bootstrap>", line 618, in _exec
  File "<frozen importlib._bootstrap_external>", line 678, in exec_module
  File "<frozen importlib._bootstrap>", line 219, in _call_with_frames_removed
  File "/lfs/raiders3/0/senwu/.venv_emmental/lib/python3.6/site-packages/pytorch_pretrained_bert/__init__.py", 

AttributeError: 'BertLayerNorm' object has no attribute 'weight'

In [None]:
mtl_model.score(dataloaders["train"])

In [None]:
mtl_model.score(dataloaders["val"])

In [None]:
mtl_model.module_pool

In [None]:
# mtl_model1 = EmmentalModel()

In [None]:
# mtl_model1.load("logs/2019_05_20/17_58_01/best_model_WiC_SuperGLUE_val_accuracy.pth")

In [None]:
# mtl_model.score(dataloaders["val"])

In [None]:
# res_dict = mtl_model.predict(dataloaders["val"], return_preds=True)

In [None]:
# res_dict_golds, res_dict_probs, res_dict_preds = res_dict

In [None]:
# res_dict_golds["WiC"]

In [None]:
# res_dict_probs["WiC"]

In [None]:
# res_dict_preds["WiC"]

In [None]:
# tot = 0
# for g, p in zip(res_dict_golds["WiC"], res_dict_preds["WiC"]):
#     if g == p:
#         tot += 1
# print(tot/len(res_dict_golds["WiC"]))

In [None]:
# dataset = dataloaders["val"].dataset

# for idx in range(len(dataloaders["val"].dataset)):
#     if res_dict_golds["WiC"][idx] != res_dict_preds["WiC"][idx]:
#         print("####".join([
#             str(idx),
#             "True" if res_dict_golds["WiC"][idx] == 1 else "False",
#             dataset.X_dict["words"][idx],
#             dataset.X_dict["poses"][idx],
#             dataset.X_dict["sent1"][idx],
#             str(dataset.X_dict["sent1_idxs"][idx].item()),
#             dataset.X_dict["sent2"][idx],
#             str(dataset.X_dict["sent2_idxs"][idx].item()),
#         ])
#         )
# #     print(dataloaders["val"].dataset.Y_dict["labels"][idx].item())