In [1]:
%load_ext autoreload
%autoreload 2

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [2]:
import sys
import os
sys.path.append('../')

In [3]:
import logging
from functools import partial

import torch.nn.functional as F
from torch import nn

import emmental
from dataloaders import get_dataloaders
from emmental import Meta
from emmental.learner import EmmentalLearner
from emmental.model import EmmentalModel
from emmental.scorer import Scorer
from emmental.task import EmmentalTask
from modules.bert_module import BertModule
from slicing.slicing_function import slicing_function
from task_config import SuperGLUE_LABEL_MAPPING, SuperGLUE_TASK_METRIC_MAPPING

In [4]:
logger = logging.getLogger(__name__)

In [5]:
PKL_PATH = "/dfs/scratch0/paroma/emmental-tutorials/superglue/logs/2019_06_04/17_38_41/best_model_RTE_SuperGLUE_val_accuracy.pth"
SPLIT = "val"

TASK_NAME = "RTE"
DATA_DIR = os.environ["SUPERGLUEDATA"]
TUTORIALS_ROOT = "/dfs/scratch0/paroma/emmental-tutorials/"
BERT_MODEL_NAME = "bert-large-cased"
BATCH_SIZE = 4

emmental.init(
    "logs",
    config={
        "model_config": {"device": 0, "dataparallel": False},
        "learner_config": {
            "n_epochs": 5,
            "valid_split": "val",
            "optimizer_config": {"optimizer": "adam", "lr": 1e-5},
            "min_lr": 0,
            "lr_scheduler_config": {
                "warmup_percentage": 0.1,
                "lr_scheduler": None,
            },
        },
        "logging_config": {
            "counter_unit": "batch",
            "evaluation_freq": 0.25,
            "checkpointing": True,
            "checkpointer_config": {
                "checkpoint_metric": {"RTE/SuperGLUE/val/accuracy":"max"},
                "checkpoint_freq": 1,
            },
        },
    },
)

dataloaders = get_dataloaders(
            data_dir=DATA_DIR,
            task_name=TASK_NAME,
            splits=["train", "val", "test"],
            max_sequence_length=256,
            tokenizer_name=BERT_MODEL_NAME,
            batch_size=BATCH_SIZE,
        )

[2019-06-05 20:07:54,820][INFO] emmental.meta:95 - Setting logging directory to: logs/2019_06_05/20_07_54
[2019-06-05 20:07:54,842][INFO] emmental.meta:56 - Loading Emmental default config from /dfs/scratch0/bradenjh/emmental/src/emmental/emmental-default-config.yaml.
[2019-06-05 20:07:54,843][INFO] emmental.meta:143 - Updating Emmental config from user provided config.
[2019-06-05 20:07:54,895][INFO] tokenizer:9 - Loading Tokenizer bert-large-cased
[2019-06-05 20:07:55,202][INFO] pytorch_pretrained_bert.tokenization:190 - loading vocabulary file https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-cased-vocab.txt from cache at /afs/cs.stanford.edu/u/paroma/.pytorch_pretrained_bert/cee054f6aafe5e2cf816d2228704e326446785f940f5451a5b26033516a4ac3d.e13dbb970cb325137104fb2e5f36fe865f27746c6b526f6352861b1980eb80b1
[2019-06-05 20:07:55,325][INFO] parsers.rte:20 - Loading data from /dfs/scratch1/senwu/mmtl/emmental-tutorials/superglue/data/RTE/train.jsonl.
[2019-06-05 20:07:55,345][

In [6]:
# Load model and sanity check quality
# model = get_and_load_model()
import models
tasks = [models.model[TASK_NAME](BERT_MODEL_NAME)]
model = EmmentalModel(name=f"SuperGLUE", tasks=tasks)
model.load(PKL_PATH)

[2019-06-05 20:08:02,021][INFO] pytorch_pretrained_bert.modeling:580 - loading archive file https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-cased.tar.gz from cache at ./cache/7fb0534b83c42daee7d3ddb0ebaa81387925b71665d6ea195c5447f1077454cd.eea60d9ebb03c75bb36302aa9d241d3b7a04bba39c360cf035e8bf8140816233
[2019-06-05 20:08:02,024][INFO] pytorch_pretrained_bert.modeling:588 - extracting archive file ./cache/7fb0534b83c42daee7d3ddb0ebaa81387925b71665d6ea195c5447f1077454cd.eea60d9ebb03c75bb36302aa9d241d3b7a04bba39c360cf035e8bf8140816233 to temp dir /tmp/tmpmu4rad7v
[2019-06-05 20:08:13,282][INFO] pytorch_pretrained_bert.modeling:598 - Model config {
  "attention_probs_dropout_prob": 0.1,
  "directionality": "bidi",
  "hidden_act": "gelu",
  "hidden_dropout_prob": 0.1,
  "hidden_size": 1024,
  "initializer_range": 0.02,
  "intermediate_size": 4096,
  "max_position_embeddings": 512,
  "num_attention_heads": 16,
  "num_hidden_layers": 24,
  "pooler_fc_size": 768,
  "pooler_num_a

In [7]:
model.score(dataloaders)

{'RTE/SuperGLUE/train/accuracy': 0.9891566265060241,
 'RTE/SuperGLUE/val/accuracy': 0.7689530685920578,
 'RTE/SuperGLUE/test/accuracy': 0.5793333333333334}


## TIME TO SLICE!!!

In [35]:
@slicing_function()
def slice_base(example):
    return 1

@slicing_function(fields=["sentence1", "sentence2"])
def slice_nasa(example):
    return 'NASA' in example.sentence1

In [36]:
import slicing
slices = [
    slice_base,
    slice_nasa,
]

slice_func_dict = {slice.__name__: slice for slice in slices}

In [37]:
slice_func_dict

{'slice_base': <function __main__.slice_base(example)>,
 'slice_nasa': <function __main__.slice_nasa(example)>}

In [39]:
return_uids=False

metric_score_dict = dict()
for dataloader in dataloaders:
    preds = model.predict(dataloader, return_preds=True, return_uids=False)
    for task_name in preds["golds"].keys():
        metric_score = model.scorers[task_name].score(
            preds["golds"][task_name],
            preds["probs"][task_name],
            preds["preds"][task_name],
            preds["uids"][task_name] if return_uids else None,
        )
        
        for metric_name, metric_value in metric_score.items():
            identifier = "/".join(
                [task_name, dataloader.data_name, dataloader.split, metric_name]
            )
            metric_score_dict[identifier] = metric_value
        
            for slice_id,slice_func in slice_func_dict.items():
                slice_idx = slice_func(dataloader.dataset)[0].numpy().astype(bool)
                slice_score = model.scorers[task_name].score(preds["golds"][task_name][slice_idx], preds["probs"][task_name][slice_idx], preds["preds"][task_name][slice_idx])

                for metric_name, metric_value in slice_score.items():
                    identifier = "/".join(
                        [task_name, dataloader.data_name, dataloader.split+'+'+slice_id, metric_name]
                    )
                    metric_score_dict[identifier] = metric_value


[2019-06-05 20:55:47,514][INFO] slicing.slicing_function:38 - Total 2490 / 2490 examples are in slice slice_base
[2019-06-05 20:55:47,522][INFO] slicing.slicing_function:38 - Total 5 / 2490 examples are in slice slice_nasa
[2019-06-05 20:55:52,080][INFO] slicing.slicing_function:38 - Total 277 / 277 examples are in slice slice_base
[2019-06-05 20:55:52,081][INFO] slicing.slicing_function:38 - Total 2 / 277 examples are in slice slice_nasa
[2019-06-05 20:56:29,753][INFO] slicing.slicing_function:38 - Total 3000 / 3000 examples are in slice slice_base
[2019-06-05 20:56:29,768][INFO] slicing.slicing_function:38 - Total 20 / 3000 examples are in slice slice_nasa


In [None]:
metric_score_dict