# Setup 

Slicing Formula:
1. look at errors for ideas
2. write a slice function
3. check slice size (large enough to matter?)
4. check performance on that slice with trained model (are we underperforming?)
5. train a model on just that slice (can we do better on it?)
6. train a full model including that slice (does that gain persist?)

In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
import sys
import os
sys.path.append('../')

In [None]:
import logging
from functools import partial

import torch
from torch import nn
import torch.nn.functional as F

import emmental
from dataloaders import get_dataloaders
from emmental import Meta
from emmental.learner import EmmentalLearner
from emmental.model import EmmentalModel
from emmental.scorer import Scorer
from emmental.task import EmmentalTask
from modules.bert_module import BertModule
from slicing.slicing_function import slicing_function
from task_config import SuperGLUE_LABEL_MAPPING, SuperGLUE_TASK_METRIC_MAPPING

In [None]:
logger = logging.getLogger(__name__)

In [None]:
PKL_PATH = "/dfs/scratch1/bradenjh/emmental-tutorials/superglue/logs/2019_06_04/14_22_40/best_model_WiC_SuperGLUE_val_accuracy.pth"
SPLIT = "val"

TASK_NAME = "WiC"
DATA_DIR = os.environ["SUPERGLUEDATA"]
BERT_MODEL_NAME = "bert-large-cased"
BATCH_SIZE = 4

emmental.init(
    "logs",
    config={
        "model_config": {"device": 0, "dataparallel": False},
        "learner_config": {
            "n_epochs": 5,
            "valid_split": "val",
            "optimizer_config": {"optimizer": "adam", "lr": 1e-5},
            "min_lr": 0,
            "lr_scheduler_config": {
                "warmup_percentage": 0.1,
                "lr_scheduler": None,
            },
        },
        "logging_config": {
            "counter_unit": "epoch",
            "evaluation_freq": 0.25,
            "checkpointing": True,
            "checkpointer_config": {
                "checkpoint_metric": {"RTE/SuperGLUE/val/accuracy":"max"},
                "checkpoint_freq": 1,
            },
        },
    },
)

dataloaders = get_dataloaders(
            data_dir=DATA_DIR,
            task_name=TASK_NAME,
            splits=["train", "val", "test"],
            max_sequence_length=256,
            tokenizer_name=BERT_MODEL_NAME,
            batch_size=BATCH_SIZE,
        )

# 1. Error Analysis

Do offline: load saved model, make csv

# 2. Write Slice Function

In [None]:
from slicing.slicing_function import slicing_function

@slicing_function(fields=["pos", "sentence1", "sentence2", "sentence1_idx", "sentence2_idx"])
def sf(example):
    form1 = example.sentence1.split()[example.sentence1_idx]
    form2 = example.sentence2.split()[example.sentence2_idx]
    return (form1 != form2) and example.pos == "V"

slicing_functions = [
    sf,
]

# 3. Apply SF

In [None]:
inds, preds = sf(dataloaders[0].dataset)

Is slice large enough?

In [None]:
count = (inds == 1).sum()
total = len(inds)
print(f"Slice labels {count}/{total} ({float(count)/total}) examples")

slice_preds = preds[inds == 1].numpy()
print(f"Slice polarity: {sum(slice_preds == 1)}/{len(slice_preds)} ({sum(slice_preds == 1)/len(slice_preds)}) positive class")

# 4. Static slice check

How does vanilla model perform on the slice?

Load model

In [None]:
import models
tasks = [models.model[TASK_NAME](BERT_MODEL_NAME)]
model = EmmentalModel(name=f"SuperGLUE", tasks=tasks)
model.load(PKL_PATH)

Score base and slices

In [None]:
from slicing import score_slices, slice_func_dict

slice_func_dict = {sf.__name__: sf for sf in slicing_functions}
# slice_func_dict = slice_func_dict[TASK_NAME]

slice_scores = score_slices(model, [dataloaders[1]], [TASK_NAME], slice_func_dict)
slice_scores

# 5. Dynamic Slice Check

Does a model trained on just the slice task improve?

In [None]:
import copy
from emmental.data import emmental_collate_fn, EmmentalDataLoader

# Slim down datasets to contain just the slices
dataloaders = get_dataloaders(
            data_dir=DATA_DIR,
            task_name=TASK_NAME,
            splits=["train", "val", "test"],
            max_sequence_length=256,
            tokenizer_name=BERT_MODEL_NAME,
            batch_size=BATCH_SIZE,
        )


slice_dataloaders = []
for sf in slicing_functions:
    for dataloader in dataloaders:
        inds, preds = sf(dataloader.dataset)       
        dataset = dataloader.dataset
        for key, values in dataset.X_dict.items():
            if isinstance(values, torch.Tensor) or isinstance(values, np.ndarray):
                dataset.X_dict[key] = [v for v, ind in zip(values, inds) if ind.item() == 1]
            elif isinstance(values, list):
                dataset.X_dict[key] = [v for v, ind in zip(values, inds) if ind.item() == 1]
        for key, values in dataloader.dataset.Y_dict.items():
            dataset.Y_dict[key] = values[inds == 1] 
        slice_dataloaders.append(
            EmmentalDataLoader(
                dataloader.task_to_label_dict,
                dataset,
                dataloader.split,
                emmental_collate_fn
            ))

In [None]:
learner = EmmentalLearner()
learner.learn(model, slice_dataloaders)

# 6. Full Slicing Model

Run a full training operation with a slice-aware model