In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import logging
from functools import partial

import torch.nn.functional as F
from torch import nn
import torch
import emmental
from emmental import Meta
from emmental.learner import EmmentalLearner
from emmental.model import EmmentalModel
from emmental.scorer import Scorer
from emmental.task import EmmentalTask
from modules.bert_module import BertModule
from parse_WSC import get_WSC_dataloaders
from task_config import SuperGLUE_LABEL_MAPPING, SuperGLUE_TASK_METRIC_MAPPING
from sklearn.metrics import f1_score

In [3]:
logger = logging.getLogger(__name__)

# Initalize Emmental

In [4]:
emmental.init(
    "logs",
    config={
        "meta_config": {"seed": 111},
        "model_config": {"device": 1, "dataparallel": False},
        "learner_config": {
            "n_epochs": 10,
            "valid_split": "val",
            "optimizer_config": {"optimizer": "adam", "lr": 1e-5},
            "lr_scheduler_config": {
                #                 "warmup_percentage": 0.1,
                "lr_scheduler": "linear",  # "linear",
                "min_lr": 1e-7,
            },
        },
        "logging_config": {
            "counter_unit": "epoch",
            "evaluation_freq": 1,
            "checkpointing": True,
            "checkpointer_config": {
                "checkpoint_metric": {"WSC/SuperGLUE/val/accuracy":"max"},
                "checkpoint_freq": 1,
            },
        },
    },
)

[2019-05-29 04:39:39,809][INFO] emmental.meta:95 - Setting logging directory to: logs/2019_05_29/04_39_39
[2019-05-29 04:39:39,826][INFO] emmental.meta:56 - Loading Emmental default config from /dfs/scratch1/senwu/mmtl/emmental/src/emmental/emmental-default-config.yaml.
[2019-05-29 04:39:39,827][INFO] emmental.meta:143 - Updating Emmental config from user provided config.


In [5]:
Meta.config

{'meta_config': {'seed': 111, 'verbose': True, 'log_path': None},
 'model_config': {'model_path': None, 'device': 1, 'dataparallel': False},
 'learner_config': {'fp16': False,
  'n_epochs': 10,
  'train_split': 'train',
  'valid_split': 'val',
  'test_split': 'test',
  'ignore_index': 0,
  'optimizer_config': {'optimizer': 'adam',
   'lr': 1e-05,
   'l2': 0.0,
   'grad_clip': 5.0,
   'sgd_config': {'momentum': 0.9},
   'adam_config': {'betas': (0.9, 0.999)}},
  'lr_scheduler_config': {'lr_scheduler': None,
   'warmup_steps': None,
   'warmup_unit': 'batch',
   'warmup_percentage': None,
   'min_lr': 1e-07,
   'linear_config': {'min_lr': 0.0},
   'exponential_config': {'gamma': 0.9},
   'plateau_config': {'factor': 0.5, 'patience': 10, 'threshold': 0.0001}},
  'task_scheduler': 'round_robin',
  'global_evaluation_metric_dict': None},
 'logging_config': {'counter_unit': 'epoch',
  'evaluation_freq': 1,
  'writer_config': {'writer': 'tensorboard', 'verbose': True},
  'checkpointing': True

In [6]:
TASK_NAME = "WSC"
DATA_DIR = "data"
BERT_MODEL_NAME = "bert-large-cased"
BATCH_SIZE = 4

BERT_OUTPUT_DIM = 768 if "base" in BERT_MODEL_NAME else 1024
TASK_CARDINALITY = (
    len(SuperGLUE_LABEL_MAPPING[TASK_NAME].keys())
    if SuperGLUE_LABEL_MAPPING[TASK_NAME] is not None
    else 1
)

# Extract train/dev dataset from file

In [7]:
dataloaders = get_WSC_dataloaders(
    data_dir=DATA_DIR,
    task_name=TASK_NAME,
    splits=["train", "val", "test"],
    max_sequence_length=256,
    max_data_samples=None,
    tokenizer_name=BERT_MODEL_NAME,
    batch_size=BATCH_SIZE,
)

[2019-05-29 04:39:39,959][INFO] tokenizer:8 - Loading Tokenizer bert-large-cased
[2019-05-29 04:39:40,236][INFO] pytorch_pretrained_bert.tokenization:190 - loading vocabulary file https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-cased-vocab.txt from cache at /lfs/local/0/senwu/.pytorch_pretrained_bert/cee054f6aafe5e2cf816d2228704e326446785f940f5451a5b26033516a4ac3d.e13dbb970cb325137104fb2e5f36fe865f27746c6b526f6352861b1980eb80b1
[2019-05-29 04:39:40,322][INFO] parse_WSC:128 - Loaded train for WSC. True
[2019-05-29 04:39:40,333][INFO] parse_WSC:128 - Loaded val for WSC. False
[2019-05-29 04:39:40,347][INFO] parse_WSC:128 - Loaded test for WSC. False


data/WSC/train.jsonl.retokenized.bert-large-cased
{'idx': 0, 'label': False, 'target': {'span2_index': 13, 'span1_index': 0, 'span1_text': 'Mark', 'span2_text': 'He', 'span1': [0, 1], 'span2': [15, 16]}, 'text': 'Mark told Pete many lies about himself , which Pete included in his book . He should have been more skeptical .'}
max len 114
data/WSC/val.jsonl.retokenized.bert-large-cased
{'idx': 0, 'label': False, 'target': {'span2_index': 47, 'span1_index': 32, 'span1_text': 'anyone', 'span2_text': 'him', 'span1': [37, 38], 'span2': [52, 53]}, 'text': 'Bernard , who had not told the government official that he was less than 21 when he filed for a homestead claim , did not consider that he had done anything dish ##ones ##t . Still , anyone who knew that he was 19 years old could take his claim away from him .'}
max len 65
data/WSC/test.jsonl.retokenized.bert-large-cased
{'idx': 0, 'target': {'span1_text': 'Maude and Dora', 'span1_index': 0, 'span2_text': 'they', 'span2_index': 40, 'span1':

# Build Emmental task

In [9]:
def ce_loss(task_name, immediate_ouput_dict, Y, active):
    module_name = f"{task_name}_pred_head"
    return F.cross_entropy(
        immediate_ouput_dict[module_name][0][active], (Y.view(-1) - 1)[active]
    )

In [10]:
def output(task_name, immediate_ouput_dict):
    module_name = f"{task_name}_pred_head"
    return F.softmax(immediate_ouput_dict[module_name][0], dim=1)

In [11]:
from allennlp.modules.span_extractors import SelfAttentiveSpanExtractor


class SpanClassifierModule(nn.Module):
    def _make_span_extractor(self):
        return SelfAttentiveSpanExtractor(self.proj_dim)

    def _make_cnn_layer(self, d_inp):
        """
        Make a CNN layer as a projection of local context.
        CNN maps [batch_size, max_len, d_inp]
        to [batch_size, max_len, proj_dim] with no change in length.
        """
        k = 1 + 2 * self.cnn_context
        padding = self.cnn_context
        return nn.Conv1d(
            d_inp,
            self.proj_dim,
            kernel_size=k,
            stride=1,
            padding=padding,
            dilation=1,
            groups=1,
            bias=True,
        )

    def __init__(
        self,
        d_inp=BERT_OUTPUT_DIM,
        proj_dim=BERT_OUTPUT_DIM // 2,
        num_spans=2,
        cnn_context=0,
        n_classes=2,
    ):
        super().__init__()

        self.cnn_context = cnn_context
        self.num_spans = num_spans
        self.proj_dim = proj_dim
        self.projs = torch.nn.ModuleList()

        for i in range(num_spans):
            # create a word-level pooling layer operator
            proj = self._make_cnn_layer(d_inp)
            self.projs.append(proj)
        self.span_extractors = torch.nn.ModuleList()

        # Lee's self-pooling operator (https://arxiv.org/abs/1812.10860)
        for i in range(num_spans):
            span_extractor = self._make_span_extractor()
            self.span_extractors.append(span_extractor)

        # Classifier gets concatenated projections of spans.
        clf_input_dim = self.span_extractors[1].get_output_dim() * num_spans
        self.classifier = nn.Linear(clf_input_dim, n_classes)

    def forward(self, feature, span1_idxs, span2_idxs, mask):
        # Apply projection CNN layer for each span of the input sentence
        sent_embs_t = feature[-1].transpose(1, 2)  # needed for CNN layer

        se_projs = []
        for i in range(self.num_spans):
            se_proj = self.projs[i](sent_embs_t).transpose(2, 1).contiguous()
            se_projs.append(se_proj)

        span_embs = None

        _kw = dict(sequence_mask=mask.unsqueeze(2).long())
        span_idxs = [span1_idxs.unsqueeze(1), span2_idxs.unsqueeze(1)]
        for i in range(self.num_spans):
            # spans are [batch_size, num_targets, span_modules]
            span_emb = self.span_extractors[i](se_projs[i], span_idxs[i], **_kw)
            if span_embs is None:
                span_embs = span_emb
            else:
                span_embs = torch.cat([span_embs, span_emb], dim=2)

        # [batch_size, num_targets, n_classes]
        logits = self.classifier(span_embs).squeeze(1)

        return logits

In [12]:
emmental_task = EmmentalTask(
    name=TASK_NAME,
    module_pool=nn.ModuleDict(
        {
            "bert_module": BertModule(BERT_MODEL_NAME),
            f"{TASK_NAME}_pred_head": SpanClassifierModule(),
        }
    ),
    task_flow=[
        {
            "name": "input",
            "module": "bert_module",
            "inputs": [
                ("_input_", "token_ids"),
                ("_input_", "token_segments"),
                ("_input_", "token_masks"),
            ],
        },
        {
            "name": f"{TASK_NAME}_pred_head",
            "module": f"{TASK_NAME}_pred_head",
            "inputs": [
                ("input", 0),
                ("_input_", "span1_idxs"),
                ("_input_", "span2_idxs"),
                ("_input_", "token_masks"),
            ],
        },
    ],
    loss_func=partial(ce_loss, TASK_NAME),
    output_func=partial(output, TASK_NAME),
    scorer=Scorer(metrics=SuperGLUE_TASK_METRIC_MAPPING[TASK_NAME]),
)

[2019-05-29 04:39:41,715][INFO] pytorch_pretrained_bert.modeling:580 - loading archive file https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-cased.tar.gz from cache at ./cache/7fb0534b83c42daee7d3ddb0ebaa81387925b71665d6ea195c5447f1077454cd.eea60d9ebb03c75bb36302aa9d241d3b7a04bba39c360cf035e8bf8140816233
[2019-05-29 04:39:41,717][INFO] pytorch_pretrained_bert.modeling:588 - extracting archive file ./cache/7fb0534b83c42daee7d3ddb0ebaa81387925b71665d6ea195c5447f1077454cd.eea60d9ebb03c75bb36302aa9d241d3b7a04bba39c360cf035e8bf8140816233 to temp dir /tmp/tmpft5jn_nv
[2019-05-29 04:39:59,223][INFO] pytorch_pretrained_bert.modeling:598 - Model config {
  "attention_probs_dropout_prob": 0.1,
  "directionality": "bidi",
  "hidden_act": "gelu",
  "hidden_dropout_prob": 0.1,
  "hidden_size": 1024,
  "initializer_range": 0.02,
  "intermediate_size": 4096,
  "max_position_embeddings": 512,
  "num_attention_heads": 16,
  "num_hidden_layers": 24,
  "pooler_fc_size": 768,
  "pooler_num_a

In [13]:
mtl_model = EmmentalModel(name="SuperGLUE_single_task", tasks=[emmental_task])

[2019-05-29 04:40:30,729][INFO] emmental.model:58 - Moving model to GPU (cuda:1).
[2019-05-29 04:40:36,681][INFO] emmental.model:44 - Created emmental model SuperGLUE_single_task that contains task {'WSC'}.
[2019-05-29 04:40:36,684][INFO] emmental.model:58 - Moving model to GPU (cuda:1).


In [14]:
emmental_learner = EmmentalLearner()

In [15]:
emmental_learner.learn(mtl_model, dataloaders.values())

[2019-05-29 04:40:36,798][INFO] emmental.logging.logging_manager:33 - Evaluating every 1 epoch.
[2019-05-29 04:40:36,799][INFO] emmental.logging.logging_manager:43 - Checkpointing every 1 epoch.
[2019-05-29 04:40:36,801][INFO] emmental.logging.checkpointer:42 - Save checkpoints at logs/2019_05_29/04_39_39 every 1 epoch
[2019-05-29 04:40:36,801][INFO] emmental.logging.checkpointer:73 - No checkpoints saved before 0 epoch.
[2019-05-29 04:40:36,848][INFO] root:123 - Generating grammar tables from /usr/lib/python3.6/lib2to3/Grammar.txt
[2019-05-29 04:40:37,009][INFO] root:123 - Generating grammar tables from /usr/lib/python3.6/lib2to3/PatternGrammar.txt
[2019-05-29 04:40:37,042][INFO] emmental.learner:299 - Start learning...


HBox(children=(IntProgress(value=0, description='Epoch 0:', max=139, style=ProgressStyle(description_width='in…

[2019-05-29 04:41:40,334][INFO] emmental.logging.checkpointer:93 - checkpoint_runway condition has been met. Start checkpoining.
  "type " + obj.__name__ + ". It won't be checked "
[2019-05-29 04:41:50,708][INFO] emmental.logging.checkpointer:103 - Save checkpoint of 1.0 epoch at logs/2019_05_29/04_39_39/checkpoint_1.0.pth.
[2019-05-29 04:42:01,037][INFO] emmental.logging.checkpointer:119 - Save best model of metric WSC/SuperGLUE/val/accuracy at logs/2019_05_29/04_39_39/best_model_WSC_SuperGLUE_val_accuracy.pth





HBox(children=(IntProgress(value=0, description='Epoch 1:', max=139, style=ProgressStyle(description_width='in…

[2019-05-29 04:43:12,779][INFO] emmental.logging.checkpointer:103 - Save checkpoint of 2.0 epoch at logs/2019_05_29/04_39_39/checkpoint_2.0.pth.





HBox(children=(IntProgress(value=0, description='Epoch 2:', max=139, style=ProgressStyle(description_width='in…

[2019-05-29 04:44:24,336][INFO] emmental.logging.checkpointer:103 - Save checkpoint of 3.0 epoch at logs/2019_05_29/04_39_39/checkpoint_3.0.pth.





HBox(children=(IntProgress(value=0, description='Epoch 3:', max=139, style=ProgressStyle(description_width='in…

[2019-05-29 04:45:37,542][INFO] emmental.logging.checkpointer:103 - Save checkpoint of 4.0 epoch at logs/2019_05_29/04_39_39/checkpoint_4.0.pth.





HBox(children=(IntProgress(value=0, description='Epoch 4:', max=139, style=ProgressStyle(description_width='in…

[2019-05-29 04:46:55,739][INFO] emmental.logging.checkpointer:103 - Save checkpoint of 5.0 epoch at logs/2019_05_29/04_39_39/checkpoint_5.0.pth.
[2019-05-29 04:47:07,353][INFO] emmental.logging.checkpointer:119 - Save best model of metric WSC/SuperGLUE/val/accuracy at logs/2019_05_29/04_39_39/best_model_WSC_SuperGLUE_val_accuracy.pth





HBox(children=(IntProgress(value=0, description='Epoch 5:', max=139, style=ProgressStyle(description_width='in…

[2019-05-29 04:48:19,917][INFO] emmental.logging.checkpointer:103 - Save checkpoint of 6.0 epoch at logs/2019_05_29/04_39_39/checkpoint_6.0.pth.





HBox(children=(IntProgress(value=0, description='Epoch 6:', max=139, style=ProgressStyle(description_width='in…

[2019-05-29 04:49:33,174][INFO] emmental.logging.checkpointer:103 - Save checkpoint of 7.0 epoch at logs/2019_05_29/04_39_39/checkpoint_7.0.pth.





HBox(children=(IntProgress(value=0, description='Epoch 7:', max=139, style=ProgressStyle(description_width='in…

[2019-05-29 04:50:47,215][INFO] emmental.logging.checkpointer:103 - Save checkpoint of 8.0 epoch at logs/2019_05_29/04_39_39/checkpoint_8.0.pth.





HBox(children=(IntProgress(value=0, description='Epoch 8:', max=139, style=ProgressStyle(description_width='in…

[2019-05-29 04:52:00,143][INFO] emmental.logging.checkpointer:103 - Save checkpoint of 9.0 epoch at logs/2019_05_29/04_39_39/checkpoint_9.0.pth.
[2019-05-29 04:52:13,020][INFO] emmental.logging.checkpointer:119 - Save best model of metric WSC/SuperGLUE/val/accuracy at logs/2019_05_29/04_39_39/best_model_WSC_SuperGLUE_val_accuracy.pth





HBox(children=(IntProgress(value=0, description='Epoch 9:', max=139, style=ProgressStyle(description_width='in…

[2019-05-29 04:53:26,605][INFO] emmental.logging.checkpointer:103 - Save checkpoint of 10.0 epoch at logs/2019_05_29/04_39_39/checkpoint_10.0.pth.
[2019-05-29 04:53:38,338][INFO] emmental.logging.checkpointer:119 - Save best model of metric WSC/SuperGLUE/val/accuracy at logs/2019_05_29/04_39_39/best_model_WSC_SuperGLUE_val_accuracy.pth
[2019-05-29 04:53:38,394][INFO] emmental.logging.checkpointer:149 - Clear all immediate checkpoints.





[2019-05-29 04:53:42,480][INFO] emmental.logging.checkpointer:189 - Loading the best model from logs/2019_05_29/04_39_39/best_model_WSC_SuperGLUE_val_accuracy.pth.
[2019-05-29 04:53:45,349][INFO] emmental.model:58 - Moving model to GPU (cuda:1).


In [16]:
mtl_model.score(dataloaders["val"])

{'WSC/SuperGLUE/val/accuracy': 0.6730769230769231}

In [17]:
mtl_model.score(dataloaders["train"])

{'WSC/SuperGLUE/train/accuracy': 0.9007220216606499}