# Setup 

Slicing Formula:
1. look at errors for ideas
2. write a slice function
3. check slice size (large enough to matter?)
4. check performance on that slice with trained model (are we underperforming?)
5. train a model on just that slice (can we do better on it?)
6. train a full model including that slice (does that gain persist?)

In [1]:
%load_ext autoreload
%autoreload 2

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [2]:
import sys
import os
sys.path.append('../')

In [3]:
import logging
from functools import partial

import torch.nn.functional as F
from torch import nn

import emmental
from dataloaders import get_dataloaders
from emmental import Meta
from emmental.learner import EmmentalLearner
from emmental.model import EmmentalModel
from emmental.scorer import Scorer
from emmental.task import EmmentalTask
from modules.bert_module import BertModule
from slicing.slicing_function import slicing_function
from task_config import SuperGLUE_LABEL_MAPPING, SuperGLUE_TASK_METRIC_MAPPING

In [4]:
logger = logging.getLogger(__name__)

In [10]:
PKL_PATH = "/dfs/scratch1/bradenjh/emmental-tutorials/superglue/logs/2019_06_04/14_22_40/best_model_WiC_SuperGLUE_val_accuracy.pth"
SPLIT = "val"

TASK_NAME = "WiC"
DATA_DIR = os.environ["SUPERGLUEDATA"]
BERT_MODEL_NAME = "bert-large-cased"
BATCH_SIZE = 4

emmental.init(
    "logs",
    config={
        "model_config": {"device": 0, "dataparallel": False},
        "learner_config": {
            "n_epochs": 5,
            "valid_split": "val",
            "optimizer_config": {"optimizer": "adam", "lr": 1e-5},
            "min_lr": 0,
            "lr_scheduler_config": {
                "warmup_percentage": 0.1,
                "lr_scheduler": None,
            },
        },
        "logging_config": {
            "counter_unit": "batch",
            "evaluation_freq": 0.25,
            "checkpointing": True,
            "checkpointer_config": {
                "checkpoint_metric": {"RTE/SuperGLUE/val/accuracy":"max"},
                "checkpoint_freq": 1,
            },
        },
    },
)

dataloaders = get_dataloaders(
            data_dir=DATA_DIR,
            task_name=TASK_NAME,
            splits=["train", "val", "test"],
            max_sequence_length=256,
            tokenizer_name=BERT_MODEL_NAME,
            batch_size=BATCH_SIZE,
        )

[2019-06-10 14:12:38,408][INFO] emmental.meta:99 - Logging was already initialized to use logs/2019_06_10/14_10_22.  To configure logging manually, call emmental.init_logging before initialiting Meta.
[2019-06-10 14:12:38,419][INFO] emmental.meta:56 - Loading Emmental default config from /dfs/scratch1/bradenjh/emmental/src/emmental/emmental-default-config.yaml.
[2019-06-10 14:12:38,419][INFO] emmental.meta:143 - Updating Emmental config from user provided config.
[2019-06-10 14:12:38,420][INFO] tokenizer:9 - Loading Tokenizer bert-large-cased
[2019-06-10 14:12:38,683][INFO] pytorch_pretrained_bert.tokenization:190 - loading vocabulary file https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-cased-vocab.txt from cache at /dfs/scratch1/bradenjh/.pytorch_pretrained_bert/cee054f6aafe5e2cf816d2228704e326446785f940f5451a5b26033516a4ac3d.e13dbb970cb325137104fb2e5f36fe865f27746c6b526f6352861b1980eb80b1
[2019-06-10 14:12:38,717][INFO] parsers.wic:20 - Loading data from /dfs/scratch1/

# 1. Error Analysis

Do offline: load saved model, make csv

# 2. Write Slice Function

In [66]:
from slicing.slicing_function import slicing_function

@slicing_function(fields=["pos", "sentence1", "sentence2", "sentence1_idx", "sentence2_idx"])
def sf(example):
    """Is the target word a noun with different forms between sentences?"""
    form1 = example.sentence1.split()[example.sentence1_idx]
    form2 = example.sentence2.split()[example.sentence2_idx]
    return (form1 != form2) and example.pos == "V"

slicing_functions = [
    sf,
]

# 3. Apply SF

In [67]:
inds, preds = sf(dataloaders[0].dataset)

[2019-06-10 14:40:47,747][INFO] slicing.slicing_function:43 - Total 1830 / 5428 examples are in slice sf


Is slice large enough?

In [68]:
count = (inds == 1).sum()
total = len(inds)
print(f"Slice labels {count}/{total} ({float(count)/total}) examples")

slice_preds = preds[inds == 1].numpy()
print(f"Slice polarity: {sum(slice_preds == 1)}/{len(slice_preds)} ({sum(slice_preds == 1)/len(slice_preds)}) positive class")

Slice labels 1830/5428 (0.3371407516580693) examples
Slice polarity: 794/1830 (0.43387978142076505) positive class


# 4. Static check slice

Load model

In [24]:
import models
tasks = [models.model[TASK_NAME](BERT_MODEL_NAME)]
model = EmmentalModel(name=f"SuperGLUE", tasks=tasks)
model.load(PKL_PATH)

[2019-06-10 14:20:30,814][INFO] pytorch_pretrained_bert.modeling:580 - loading archive file https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-cased.tar.gz from cache at ./cache/7fb0534b83c42daee7d3ddb0ebaa81387925b71665d6ea195c5447f1077454cd.eea60d9ebb03c75bb36302aa9d241d3b7a04bba39c360cf035e8bf8140816233
[2019-06-10 14:20:30,816][INFO] pytorch_pretrained_bert.modeling:588 - extracting archive file ./cache/7fb0534b83c42daee7d3ddb0ebaa81387925b71665d6ea195c5447f1077454cd.eea60d9ebb03c75bb36302aa9d241d3b7a04bba39c360cf035e8bf8140816233 to temp dir /tmp/tmpzxj_w4mp
[2019-06-10 14:20:42,799][INFO] pytorch_pretrained_bert.modeling:598 - Model config {
  "attention_probs_dropout_prob": 0.1,
  "directionality": "bidi",
  "hidden_act": "gelu",
  "hidden_dropout_prob": 0.1,
  "hidden_size": 1024,
  "initializer_range": 0.02,
  "intermediate_size": 4096,
  "max_position_embeddings": 512,
  "num_attention_heads": 16,
  "num_hidden_layers": 24,
  "pooler_fc_size": 768,
  "pooler_num_a

Score base and slices

In [25]:
from slicing import score_slices, slice_func_dict

# slice_func_dict = {sf.__name__: sf for sf in slicing_functions}
slice_func_dict = slice_func_dict[TASK_NAME]

slice_scores = score_slices(model, [dataloaders[1]], [TASK_NAME], slice_func_dict)
slice_scores

[2019-06-10 14:20:59,889][INFO] root:201 - Evaluating on task WiC, val split
[2019-06-10 14:21:06,177][INFO] root:208 - Evaluating slice slice_base
[2019-06-10 14:21:06,188][INFO] slicing.slicing_function:43 - Total 638 / 638 examples are in slice slice_base
[2019-06-10 14:21:06,188][INFO] root:208 - Evaluating slice slice_verb
[2019-06-10 14:21:06,191][INFO] slicing.slicing_function:43 - Total 243 / 638 examples are in slice slice_verb
[2019-06-10 14:21:06,192][INFO] root:208 - Evaluating slice slice_noun
[2019-06-10 14:21:06,196][INFO] slicing.slicing_function:43 - Total 395 / 638 examples are in slice slice_noun
[2019-06-10 14:21:06,197][INFO] root:208 - Evaluating slice slice_trigram
[2019-06-10 14:21:06,203][INFO] slicing.slicing_function:43 - Total 16 / 638 examples are in slice slice_trigram
[2019-06-10 14:21:06,203][INFO] root:208 - Evaluating slice slice_mismatch_verb
[2019-06-10 14:21:06,208][INFO] slicing.slicing_function:43 - Total 192 / 638 examples are in slice slice_mism

{'accuracy': 0.7445141065830722,
 'WiC:slice_base/SuperGLUE/val/accuracy': 0.7445141065830722,
 'WiC:slice_verb/SuperGLUE/val/accuracy': 0.7325102880658436,
 'WiC:slice_noun/SuperGLUE/val/accuracy': 0.7518987341772152,
 'WiC:slice_trigram/SuperGLUE/val/accuracy': 0.75,
 'WiC:slice_mismatch_verb/SuperGLUE/val/accuracy': 0.734375,
 'WiC:slice_mismatch_noun/SuperGLUE/val/accuracy': 0.7927927927927928}