In [1]:
%load_ext autoreload
%autoreload 2

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [2]:
import sys
sys.path.append("..")

In [3]:
import logging
from functools import partial

import torch.nn.functional as F
from torch import nn
import torch
import emmental
from emmental import Meta
from emmental.learner import EmmentalLearner
from emmental.model import EmmentalModel
from emmental.scorer import Scorer
from emmental.task import EmmentalTask
from modules.bert_module import BertModule
from task_config import SuperGLUE_LABEL_MAPPING, SuperGLUE_TASK_METRIC_MAPPING
from sklearn.metrics import f1_score

In [4]:
logger = logging.getLogger(__name__)

# Initalize Emmental

In [5]:
emmental.init(
    "logs",
    config={
        "model_config": {"device": 0, "dataparallel": False},
        "learner_config": {
            "n_epochs": 4,
            "valid_split": "val",
            "optimizer_config": {"optimizer": "adam", "lr": 1e-5},
            "min_lr": 0,
            "lr_scheduler_config": {
                "warmup_percentage": 0.1,
                "lr_scheduler": None,
            },
        },
        "logging_config": {
            "counter_unit": "batch",
            "evaluation_freq": 100,
            "checkpointing": True,
            "checkpointer_config": {
                "checkpoint_metric": {"WiC/SuperGLUE/val/accuracy":"max"},
                "checkpoint_freq": 1,
            },
        },
    },
)

[2019-06-05 15:01:41,593][INFO] emmental.meta:95 - Setting logging directory to: logs/2019_06_05/15_01_41
[2019-06-05 15:01:41,604][INFO] emmental.meta:56 - Loading Emmental default config from /dfs/scratch0/bradenjh/emmental/src/emmental/emmental-default-config.yaml.
[2019-06-05 15:01:41,604][INFO] emmental.meta:143 - Updating Emmental config from user provided config.


In [6]:
Meta.config

{'meta_config': {'seed': 0, 'verbose': True, 'log_path': None},
 'model_config': {'model_path': None, 'device': 0, 'dataparallel': False},
 'learner_config': {'fp16': False,
  'n_epochs': 4,
  'train_split': 'train',
  'valid_split': 'val',
  'test_split': 'test',
  'ignore_index': 0,
  'optimizer_config': {'optimizer': 'adam',
   'lr': 1e-05,
   'l2': 0.0,
   'grad_clip': 1.0,
   'sgd_config': {'momentum': 0.9},
   'adam_config': {'betas': (0.9, 0.999)}},
  'lr_scheduler_config': {'lr_scheduler': None,
   'warmup_steps': None,
   'warmup_unit': 'batch',
   'warmup_percentage': 0.1,
   'min_lr': 0.0,
   'linear_config': {'min_lr': 0.0},
   'exponential_config': {'gamma': 0.9},
   'plateau_config': {'factor': 0.5, 'patience': 10, 'threshold': 0.0001}},
  'task_scheduler': 'round_robin',
  'global_evaluation_metric_dict': None,
  'min_lr': 0},
 'logging_config': {'counter_unit': 'batch',
  'evaluation_freq': 100,
  'writer_config': {'writer': 'tensorboard', 'verbose': True},
  'checkpoin

In [7]:
import os

TASK_NAME = "WiC"
DATA_DIR = "/dfs/scratch0/bradenjh/superglue" #os.environ["SUPERGLUEDATA"]
BERT_MODEL_NAME = "bert-large-cased"
BATCH_SIZE = 4

# Extract train/dev dataset from file

In [8]:
from dataloaders import get_dataloaders

dataloaders = get_dataloaders(
    data_dir=DATA_DIR,
    task_name=TASK_NAME,
    splits=["train", "val", "test"],
    max_sequence_length=256,
    max_data_samples=None,
    tokenizer_name=BERT_MODEL_NAME,
    batch_size=BATCH_SIZE,
)

[2019-06-05 15:01:41,796][INFO] tokenizer:9 - Loading Tokenizer bert-large-cased
[2019-06-05 15:01:42,062][INFO] pytorch_pretrained_bert.tokenization:190 - loading vocabulary file https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-cased-vocab.txt from cache at /afs/cs.stanford.edu/u/bradenjh/.pytorch_pretrained_bert/cee054f6aafe5e2cf816d2228704e326446785f940f5451a5b26033516a4ac3d.e13dbb970cb325137104fb2e5f36fe865f27746c6b526f6352861b1980eb80b1
[2019-06-05 15:01:42,094][INFO] parsers.wic:20 - Loading data from /dfs/scratch0/bradenjh/superglue/WiC/train.jsonl.
[2019-06-05 15:01:42,126][INFO] parsers.wic:23 - Sample 0: {'label': False, 'word': 'carry', 'pos': 'V', 'sentence1': 'You must carry your camping gear .', 'sentence2': 'Sound carries well over water .', 'sentence1_idx': '2', 'sentence2_idx': '1', 'idx': 0}
[2019-06-05 15:01:42,127][INFO] parsers.wic:23 - Sample 1: {'label': False, 'word': 'go', 'pos': 'V', 'sentence1': 'Messages must go through diplomatic channels .', 

In [10]:
from slicing.WiC_slices import slice_func_dict

f = slice_func_dict["slice_verb"]
f(dataloaders[2].dataset)

[2019-06-05 15:02:41,675][INFO] slicing.slicing_function:38 - Total 569 / 1400 examples are in slice slice_verb


(tensor([False,  True, False,  ..., False,  True,  True], dtype=torch.bool),
 tensor([0, 1, 0,  ..., 0, 1, 1]))

# Build Emmental task

In [None]:
def ce_loss(task_name, immediate_ouput_dict, Y, active):
    module_name = f"{task_name}_pred_head"
    return F.cross_entropy(
        immediate_ouput_dict[module_name][0][active], (Y.view(-1) - 1)[active]
    )

In [None]:
def output(task_name, immediate_ouput_dict):
    module_name = f"{task_name}_pred_head"
    return F.softmax(immediate_ouput_dict[module_name][0], dim=1)

In [None]:
def macro_f1(golds, probs, preds):
    return {"macro_f1": f1_score(golds, preds, average="macro")}

In [None]:
class LinearModule(nn.Module):
    def __init__(self, feature_dim, class_cardinality):
        super().__init__()

        self.linear = nn.Linear(feature_dim, class_cardinality)

    def forward(self, feature, idx1, idx2):
        last_layer = feature[-1]
        emb = last_layer[:,0,:]
        idx1 = idx1.unsqueeze(-1).unsqueeze(-1).expand([-1, -1, last_layer.size(-1)])
        idx2 = idx2.unsqueeze(-1).unsqueeze(-1).expand([-1, -1, last_layer.size(-1)])
        word1_emb = last_layer.gather(dim=1, index=idx1).squeeze(dim=1)
        word2_emb = last_layer.gather(dim=1, index=idx2).squeeze(dim=1)
        input = torch.cat([emb, word1_emb, word2_emb], dim=-1)
        return self.linear.forward(input)

In [None]:
BERT_OUTPUT_DIM = 768 if "base" in BERT_MODEL_NAME else 1024
TASK_CARDINALITY = (
    len(SuperGLUE_LABEL_MAPPING[TASK_NAME].keys())
    if SuperGLUE_LABEL_MAPPING[TASK_NAME] is not None
    else 1
)

emmental_task = EmmentalTask(
    name=TASK_NAME,
    module_pool=nn.ModuleDict(
        {
            "bert_module": BertModule(BERT_MODEL_NAME),
            f"{TASK_NAME}_pred_head": LinearModule(3 * BERT_OUTPUT_DIM, TASK_CARDINALITY),
        }
    ),
    task_flow=[
        {
            "name": "input",
            "module": "bert_module",
            "inputs": [("_input_", "token_ids"), ("_input_", "token_segments")],
        },
        {
            "name": f"{TASK_NAME}_pred_head",
            "module": f"{TASK_NAME}_pred_head",
            "inputs": [("input", 0), ("_input_", "sent1_idxs"), ("_input_", "sent2_idxs")],
        },
    ],
    loss_func=partial(ce_loss, TASK_NAME),
    output_func=partial(output, TASK_NAME),
    scorer=Scorer(
        metrics=SuperGLUE_TASK_METRIC_MAPPING[TASK_NAME]
    ),
)

In [None]:
mtl_model = EmmentalModel(name="SuperGLUE_single_task", tasks=[emmental_task])

In [None]:
emmental_learner = EmmentalLearner()

In [None]:
# emmental_learner.learn(mtl_model, dataloaders.values())

In [None]:
mtl_model.score(dataloaders["val"])

In [None]:
# mtl_model.score(dataloaders["train"])

In [None]:
# PKL_PATH = "/dfs/scratch0/bradenjh/emmental-tutorials/superglue/models/WiC_verb_trigram_v2.pth"
# PKL_PATH = "/dfs/scratch0/bradenjh/emmental-tutorials/superglue/logs/2019_05_29/13_59_55/best_model_WiC_SuperGLUE_val_accuracy.pth"
PKL_PATH = "/dfs/scratch0/bradenjh/emmental-tutorials/superglue/logs/2019_06_04/11_01_21/best_model_WiC_SuperGLUE_val_accuracy.pth"

In [None]:
# mtl_model.save(PKL_PATH)

In [None]:
new_model = EmmentalModel(name="SuperGLUE_single_task", tasks=[emmental_task])
new_model.load(PKL_PATH)
new_model.score(dataloaders["val"])

In [None]:
import json

from task_config import (
    SuperGLUE_LABEL_MAPPING, 
    SuperGLUE_TASK_METRIC_MAPPING, 
    SuperGLUE_TASK_SPLIT_MAPPING
)

SPLIT = "val"

def make_analysis_df(model):
    # Get predictions
    gold_dict, prob_dict, pred_dict = model.predict(dataloaders[SPLIT], return_preds=True)
    probs = prob_dict["WiC"][:,0]
    preds = pred_dict["WiC"]

    # Load raw data
    jsonl_path = os.path.join(
        DATA_DIR, TASK_NAME, SuperGLUE_TASK_SPLIT_MAPPING[TASK_NAME][SPLIT]
    )

    # Add new columns
    rows = [json.loads(row) for row in open(jsonl_path, encoding="utf-8")]
    for i, row in enumerate(rows):
        row["prob"] = probs[i]
        row["pred"] = True if preds[i] == 1 else False
        row["correct"] = "Y" if row["pred"] == row["label"] else "N"

    # Make tsv
    df = pd.DataFrame(rows)
    df = df[['idx', 'label', 'pred', 'prob', 'correct', 'word', 'pos', 
             'sentence1_idx', 'sentence2_idx',
             'sentence1', 'sentence2']]
    return df

df = make_analysis_df(new_model)

In [None]:
# TUTORIALS_ROOT = "/dfs/scratch0/bradenjh/emmental-tutorials/"

# out_path = os.path.join(TUTORIALS_ROOT, "superglue", "analysis", f"WiC_{SPLIT}_analysis_v0.csv")
# df.to_csv(out_path)
# print(f"Wrote error analysis to {out_path}")

In [None]:
df.head(1)

In [None]:
# import nltk
# from nltk.corpus import stopwords as nltk_stopwords
# stopwords = set(nltk_stopwords.words('english'))

preds = []
labels = []

def get_ngrams(tokens, window=1):
    num_ngrams = len(tokens) - window + 1
    for i in range(num_ngrams):
        yield tokens[i:i+window]
        
for index, row in df.iterrows():
    target = row["word"]
    labels.append(1 if row["label"] == True else 2)
    sent1_target = row["sentence1"].split()[int(row["sentence1_idx"])]
    sent2_target = row["sentence2"].split()[int(row["sentence2_idx"])]
    if sent1_target.lower() != sent2_target.lower():
        print(index, sent1_target.lower(), sent2_target.lower())
#     print(row)
#     print(word, )
#     for sent in ["sentence1", "sentence2"]:
#         tokens = row[sent].split()
#         for i, tok in enumerate(tokens):
#             if target in tok:
#                 idx = i
#                 break
#         trigrams.append([' '.join(ngram) 
#                          for ngram in get_ngrams(tokens[idx-2:idx+2], window=3) 
#                          if len(ngram) == 3])
#     if (set(trigrams[0]).intersection(set(trigrams[1]))):
#         preds.append(1)
#         print(trigrams)
#         print(f"{target}: {row['sentence1']}.....{row['sentence2']}")
#         print()        
#     else:
#         preds.append(0)


In [None]:
from metal.metrics import *

print(accuracy_score(labels, preds, ignore_in_pred=[0]))
print(coverage_score(labels, preds))