In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
import logging
from functools import partial

import os
import torch.nn.functional as F
from torch import nn

import emmental
from emmental import Meta
from emmental.data import EmmentalDataLoader, EmmentalDataset
from emmental.learner import EmmentalLearner
from emmental.model import EmmentalModel
from emmental.scorer import Scorer
from emmental.task import EmmentalTask
from modules.bert_module import BertModule
from modules.classification_module import ClassificationModule
from preprocessor import preprocessor
from task_config import LABEL_MAPPING

In [None]:
logger = logging.getLogger(__name__)

In [None]:
TASK_NAME = "MNLI"
DATA_DIR = os.environ["GLUEDATA"]
BERT_MODEL_NAME = "bert-large-uncased"
BATCH_SIZE = 32

# Initalize Emmental

In [None]:
emmental.init("logs/RTE_finetune",
    config={
        "meta_config": {"seed": 1},
        "model_config": {"device": 0, "dataparallel": True},
        "learner_config": {
            "n_epochs": 20,
            "valid_split": "val",
            "optimizer_config": {"optimizer": "adam", "lr": 1e-5},
            "lr_scheduler_config": {
                "lr_scheduler": "linear",  # "linear",
                "min_lr": 1e-7,
            },
        },
        "logging_config": {
            "counter_unit": "epoch",
            "evaluation_freq": 0.2,
            "checkpointing": True,
            "checkpointer_config": {
                "checkpoint_metric": {"RTE/SuperGLUE/val/accuracy":"max"},
                "checkpoint_freq": 1,
            },
        },
    }
)

# Build Emmental task

In [None]:
def ce_loss(task_name, immediate_ouput_dict, Y, active):
    module_name = f"{task_name}_pred_head"
    return F.cross_entropy(
        immediate_ouput_dict[module_name][0][active], (Y.view(-1) - 1)[active]
    )

In [None]:
def output(task_name, immediate_ouput_dict):
    module_name = f"{task_name}_pred_head"
    return F.softmax(immediate_ouput_dict[module_name][0], dim=1)

In [None]:
mtl_model = EmmentalModel(name="GLUE_single_task")

## Load Pretrained MNLI Model

In [None]:
mtl_model.load(
    "/home/hazymturk/vincent/emmental-tutorials/glue/logs/2019_05_30/07_16_23/best_model_MNLI_GLUE_dev_accuracy.pth"
)

In [None]:
TASK_NAME = "RTE"

In [None]:
from superglue.parse_RTE import get_RTE_dataloaders

In [None]:
from superglue.task_config import SuperGLUE_LABEL_MAPPING, SuperGLUE_TASK_METRIC_MAPPING

In [None]:
BERT_OUTPUT_DIM = 768 if "base" in BERT_MODEL_NAME else 1024
TASK_CARDINALITY = (
    len(SuperGLUE_LABEL_MAPPING[TASK_NAME].keys())
    if SuperGLUE_LABEL_MAPPING[TASK_NAME] is not None
    else 1
)

emmental_task = EmmentalTask(
    name=TASK_NAME,
    module_pool=nn.ModuleDict(
        {
            "bert_module": BertModule(BERT_MODEL_NAME),
            f"{TASK_NAME}_pred_head": nn.Linear(BERT_OUTPUT_DIM, TASK_CARDINALITY),
        }
    ),
    task_flow=[
        {
            "name": "input",
            "module": "bert_module",
            "inputs": [("_input_", "token_ids"), ("_input_", "token_segments")],
        },
        {
            "name": f"{TASK_NAME}_pred_head",
            "module": f"{TASK_NAME}_pred_head",
            "inputs": [("input", 1)],
        },
    ],
    loss_func=partial(ce_loss, TASK_NAME),
    output_func=partial(output, TASK_NAME),
    scorer=Scorer(metrics=SuperGLUE_TASK_METRIC_MAPPING[TASK_NAME]),
)

### Replace MNLI with RTE task

In [None]:
mtl_model.remove_task("MNLI")

In [None]:
mtl_model.add_task(emmental_task)

In [None]:
DATA_DIR = os.environ["SUPERGLUEDATA"]
dataloaders = get_RTE_dataloaders(
    data_dir=DATA_DIR,
    task_name=TASK_NAME,
    splits=["train", "val"],
    max_sequence_length=200,
    max_data_samples=None,
    tokenizer_name=BERT_MODEL_NAME,
    batch_size=BATCH_SIZE,
)

In [None]:
emmental_learner = EmmentalLearner()
emmental_learner.learn(mtl_model, dataloaders.values())

In [None]:
mtl_model.score(dataloaders["val"])