In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
from metal.mmtl.trainer import MultitaskTrainer
from metal.mmtl.glue.glue_tasks import create_glue_tasks_payloads
from metal.mmtl.metal_model import MetalModel
from metal.mmtl.slicing.slice_model import SliceModel

In [3]:
SEED = 1

### Initialize normal payloads

In [4]:
task_kwargs = {
    "dl_kwargs": {"batch_size": 8},
    "freeze_bert":False,
    "bert_model": 'bert-base-cased',
    "max_len": 200,
    "attention": False
}
task_names = ["RTE"]

In [5]:
%%time

# Create tasks and payloads
tasks, payloads = create_glue_tasks_payloads(task_names, **task_kwargs)

Using random seed: 107448
Loading RTE Dataset


HBox(children=(IntProgress(value=0, max=2490), HTML(value='')))




HBox(children=(IntProgress(value=0, max=277), HTML(value='')))




HBox(children=(IntProgress(value=0, max=3000), HTML(value='')))


CPU times: user 13.4 s, sys: 1.07 s, total: 14.5 s
Wall time: 15.4 s


In [6]:
tasks, payloads

([ClassificationTask(name=RTE, loss_multiplier=1.00)],
 [Payload(RTE_train: labels_to_tasks=[{'RTE_gold': 'RTE'}], split=train),
  Payload(RTE_valid: labels_to_tasks=[{'RTE_gold': 'RTE'}], split=valid),
  Payload(RTE_test: labels_to_tasks=[{'RTE_gold': 'RTE'}], split=test)])

### Initialize slice payloads

In [7]:
# Create tasks and payloads
task_kwargs.update({"slice_dict": {
    "RTE": ["dash_semicolon", "more_people", "BASE"]}
})
task_kwargs['attention'] = None

tasks_slice, payloads_slice = create_glue_tasks_payloads(
    task_names, **task_kwargs
)

Using random seed: 709265
Loading RTE Dataset


HBox(children=(IntProgress(value=0, max=2490), HTML(value='')))




HBox(children=(IntProgress(value=0, max=277), HTML(value='')))




HBox(children=(IntProgress(value=0, max=3000), HTML(value='')))


Added label_set with 1003/2490 labels for task RTE_slice:dash_semicolon to payload RTE_train.
Added label_set with 64/2490 labels for task RTE_slice:more_people to payload RTE_train.
Added label_set with 2490/2490 labels for task RTE_slice:BASE to payload RTE_train.
Added label_set with 116/277 labels for task RTE_slice:dash_semicolon to payload RTE_valid.
Added label_set with 12/277 labels for task RTE_slice:more_people to payload RTE_valid.
Added label_set with 277/277 labels for task RTE_slice:BASE to payload RTE_valid.
Added label_set with 1103/3000 labels for task RTE_slice:dash_semicolon to payload RTE_test.
Added label_set with 67/3000 labels for task RTE_slice:more_people to payload RTE_test.
Added label_set with 3000/3000 labels for task RTE_slice:BASE to payload RTE_test.


In [8]:
tasks_slice, payloads_slice

([ClassificationTask(name=RTE, loss_multiplier=1.00),
  ClassificationTask(name=RTE_slice:dash_semicolon, loss_multiplier=1.00),
  ClassificationTask(name=RTE_slice:more_people, loss_multiplier=1.00),
  ClassificationTask(name=RTE_slice:BASE, loss_multiplier=1.00)],
 [Payload(RTE_train: labels_to_tasks=[{'RTE_gold': 'RTE', 'RTE_slice:dash_semicolon': 'RTE_slice:dash_semicolon', 'RTE_slice:more_people': 'RTE_slice:more_people', 'RTE_slice:BASE': 'RTE_slice:BASE'}], split=train),
  Payload(RTE_valid: labels_to_tasks=[{'RTE_gold': 'RTE', 'RTE_slice:dash_semicolon': 'RTE_slice:dash_semicolon', 'RTE_slice:more_people': 'RTE_slice:more_people', 'RTE_slice:BASE': 'RTE_slice:BASE'}], split=valid),
  Payload(RTE_test: labels_to_tasks=[{'RTE_gold': 'RTE', 'RTE_slice:dash_semicolon': 'RTE_slice:dash_semicolon', 'RTE_slice:more_people': 'RTE_slice:more_people', 'RTE_slice:BASE': 'RTE_slice:BASE'}], split=test)])

### Initialize and train baseline model 

In [9]:
model = MetalModel(tasks, seed=SEED, verbose=False)

In [10]:
%%time
trainer = MultitaskTrainer(seed=SEED)
trainer.train_model(
    model,
    payloads,
    checkpoint_metric="RTE/RTE_valid/RTE_gold/accuracy",
    checkpoint_metric_mode="max",
    checkoint_best=True,
    writer="tensorboard",
    optimizer="adamax",
    lr=5e-5,
    l2=1e-3,
    log_every=0.1, 
    score_every=0.1,
    n_epochs=10,
    progress_bar=True,
    checkpoint_best=True,
    checkpoint_cleanup=False,
)

Beginning train loop.
Expecting a total of approximately 2496 examples and 312 batches per epoch from 1 payload(s) in the train split.
Writing config to /dfs/scratch0/vschen/metal-mmtl/logs/2019_04_15/21_38_47/config.json


HBox(children=(IntProgress(value=0, max=312), HTML(value='')))

[0.10 epo]: RTE:[RTE_train/RTE_gold/loss=6.88e-01, RTE_valid/RTE_gold/accuracy=5.38e-01] model:[train/all/loss=6.88e-01, train/all/lr=5.00e-05]
Saving model at iteration 0.10 with best (max) score RTE/RTE_valid/RTE_gold/accuracy=0.538
[0.21 epo]: RTE:[RTE_train/RTE_gold/loss=7.16e-01, RTE_valid/RTE_gold/accuracy=4.73e-01] model:[train/all/loss=7.16e-01, train/all/lr=5.00e-05]
[0.31 epo]: RTE:[RTE_train/RTE_gold/loss=6.82e-01, RTE_valid/RTE_gold/accuracy=5.52e-01] model:[train/all/loss=6.82e-01, train/all/lr=5.00e-05]
Saving model at iteration 0.31 with best (max) score RTE/RTE_valid/RTE_gold/accuracy=0.552
[0.41 epo]: RTE:[RTE_train/RTE_gold/loss=6.86e-01, RTE_valid/RTE_gold/accuracy=5.70e-01] model:[train/all/loss=6.86e-01, train/all/lr=5.00e-05]
Saving model at iteration 0.41 with best (max) score RTE/RTE_valid/RTE_gold/accuracy=0.570
[0.51 epo]: RTE:[RTE_train/RTE_gold/loss=6.71e-01, RTE_valid/RTE_gold/accuracy=5.88e-01] model:[train/all/loss=6.71e-01, train/all/lr=5.00e-05]
Saving 

HBox(children=(IntProgress(value=0, max=312), HTML(value='')))

[1.03 epo]: RTE:[RTE_train/RTE_gold/loss=6.25e-01, RTE_valid/RTE_gold/accuracy=6.75e-01] model:[train/all/loss=6.25e-01, train/all/lr=5.00e-05]
Saving model at iteration 1.03 with best (max) score RTE/RTE_valid/RTE_gold/accuracy=0.675
[1.13 epo]: RTE:[RTE_train/RTE_gold/loss=5.73e-01, RTE_valid/RTE_gold/accuracy=6.61e-01] model:[train/all/loss=5.73e-01, train/all/lr=5.00e-05]
[1.23 epo]: RTE:[RTE_train/RTE_gold/loss=5.73e-01, RTE_valid/RTE_gold/accuracy=6.43e-01] model:[train/all/loss=5.73e-01, train/all/lr=5.00e-05]
[1.33 epo]: RTE:[RTE_train/RTE_gold/loss=5.94e-01, RTE_valid/RTE_gold/accuracy=6.28e-01] model:[train/all/loss=5.94e-01, train/all/lr=5.00e-05]
[1.44 epo]: RTE:[RTE_train/RTE_gold/loss=5.76e-01, RTE_valid/RTE_gold/accuracy=6.71e-01] model:[train/all/loss=5.76e-01, train/all/lr=5.00e-05]
[1.54 epo]: RTE:[RTE_train/RTE_gold/loss=5.68e-01, RTE_valid/RTE_gold/accuracy=6.46e-01] model:[train/all/loss=5.68e-01, train/all/lr=5.00e-05]
[1.64 epo]: RTE:[RTE_train/RTE_gold/loss=5.51

HBox(children=(IntProgress(value=0, max=312), HTML(value='')))

[2.05 epo]: RTE:[RTE_train/RTE_gold/loss=5.46e-01, RTE_valid/RTE_gold/accuracy=6.53e-01] model:[train/all/loss=5.46e-01, train/all/lr=5.00e-05]
[2.15 epo]: RTE:[RTE_train/RTE_gold/loss=4.09e-01, RTE_valid/RTE_gold/accuracy=6.82e-01] model:[train/all/loss=4.09e-01, train/all/lr=5.00e-05]
Saving model at iteration 2.15 with best (max) score RTE/RTE_valid/RTE_gold/accuracy=0.682
[2.26 epo]: RTE:[RTE_train/RTE_gold/loss=3.71e-01, RTE_valid/RTE_gold/accuracy=6.68e-01] model:[train/all/loss=3.71e-01, train/all/lr=5.00e-05]
[2.36 epo]: RTE:[RTE_train/RTE_gold/loss=4.73e-01, RTE_valid/RTE_gold/accuracy=6.97e-01] model:[train/all/loss=4.73e-01, train/all/lr=5.00e-05]
Saving model at iteration 2.36 with best (max) score RTE/RTE_valid/RTE_gold/accuracy=0.697
[2.46 epo]: RTE:[RTE_train/RTE_gold/loss=4.82e-01, RTE_valid/RTE_gold/accuracy=6.68e-01] model:[train/all/loss=4.82e-01, train/all/lr=5.00e-05]
[2.56 epo]: RTE:[RTE_train/RTE_gold/loss=4.55e-01, RTE_valid/RTE_gold/accuracy=6.68e-01] model:[tr

HBox(children=(IntProgress(value=0, max=312), HTML(value='')))

[3.08 epo]: RTE:[RTE_train/RTE_gold/loss=3.17e-01, RTE_valid/RTE_gold/accuracy=6.86e-01] model:[train/all/loss=3.17e-01, train/all/lr=5.00e-05]
[3.18 epo]: RTE:[RTE_train/RTE_gold/loss=2.95e-01, RTE_valid/RTE_gold/accuracy=6.61e-01] model:[train/all/loss=2.95e-01, train/all/lr=5.00e-05]
[3.28 epo]: RTE:[RTE_train/RTE_gold/loss=4.13e-01, RTE_valid/RTE_gold/accuracy=6.64e-01] model:[train/all/loss=4.13e-01, train/all/lr=5.00e-05]
[3.38 epo]: RTE:[RTE_train/RTE_gold/loss=3.21e-01, RTE_valid/RTE_gold/accuracy=6.82e-01] model:[train/all/loss=3.21e-01, train/all/lr=5.00e-05]
[3.49 epo]: RTE:[RTE_train/RTE_gold/loss=2.72e-01, RTE_valid/RTE_gold/accuracy=6.71e-01] model:[train/all/loss=2.72e-01, train/all/lr=5.00e-05]
[3.59 epo]: RTE:[RTE_train/RTE_gold/loss=4.57e-01, RTE_valid/RTE_gold/accuracy=6.43e-01] model:[train/all/loss=4.57e-01, train/all/lr=5.00e-05]
[3.69 epo]: RTE:[RTE_train/RTE_gold/loss=3.86e-01, RTE_valid/RTE_gold/accuracy=6.71e-01] model:[train/all/loss=3.86e-01, train/all/lr=5.

HBox(children=(IntProgress(value=0, max=312), HTML(value='')))

[4.10 epo]: RTE:[RTE_train/RTE_gold/loss=2.73e-01, RTE_valid/RTE_gold/accuracy=6.71e-01] model:[train/all/loss=2.73e-01, train/all/lr=5.00e-05]
[4.21 epo]: RTE:[RTE_train/RTE_gold/loss=2.95e-01, RTE_valid/RTE_gold/accuracy=6.79e-01] model:[train/all/loss=2.95e-01, train/all/lr=5.00e-05]
[4.31 epo]: RTE:[RTE_train/RTE_gold/loss=3.41e-01, RTE_valid/RTE_gold/accuracy=6.53e-01] model:[train/all/loss=3.41e-01, train/all/lr=5.00e-05]
[4.41 epo]: RTE:[RTE_train/RTE_gold/loss=3.56e-01, RTE_valid/RTE_gold/accuracy=6.64e-01] model:[train/all/loss=3.56e-01, train/all/lr=5.00e-05]
[4.51 epo]: RTE:[RTE_train/RTE_gold/loss=3.45e-01, RTE_valid/RTE_gold/accuracy=6.97e-01] model:[train/all/loss=3.45e-01, train/all/lr=5.00e-05]
[4.62 epo]: RTE:[RTE_train/RTE_gold/loss=2.78e-01, RTE_valid/RTE_gold/accuracy=7.00e-01] model:[train/all/loss=2.78e-01, train/all/lr=5.00e-05]
[4.72 epo]: RTE:[RTE_train/RTE_gold/loss=3.52e-01, RTE_valid/RTE_gold/accuracy=6.75e-01] model:[train/all/loss=3.52e-01, train/all/lr=5.

HBox(children=(IntProgress(value=0, max=312), HTML(value='')))

[5.03 epo]: RTE:[RTE_train/RTE_gold/loss=3.52e-01, RTE_valid/RTE_gold/accuracy=6.75e-01] model:[train/all/loss=3.52e-01, train/all/lr=5.00e-05]
[5.13 epo]: RTE:[RTE_train/RTE_gold/loss=2.02e-01, RTE_valid/RTE_gold/accuracy=6.68e-01] model:[train/all/loss=2.02e-01, train/all/lr=5.00e-05]
[5.23 epo]: RTE:[RTE_train/RTE_gold/loss=2.96e-01, RTE_valid/RTE_gold/accuracy=6.43e-01] model:[train/all/loss=2.96e-01, train/all/lr=5.00e-05]
[5.33 epo]: RTE:[RTE_train/RTE_gold/loss=4.26e-01, RTE_valid/RTE_gold/accuracy=6.86e-01] model:[train/all/loss=4.26e-01, train/all/lr=5.00e-05]
[5.44 epo]: RTE:[RTE_train/RTE_gold/loss=2.31e-01, RTE_valid/RTE_gold/accuracy=6.82e-01] model:[train/all/loss=2.31e-01, train/all/lr=5.00e-05]
[5.54 epo]: RTE:[RTE_train/RTE_gold/loss=3.65e-01, RTE_valid/RTE_gold/accuracy=6.57e-01] model:[train/all/loss=3.65e-01, train/all/lr=5.00e-05]
[5.64 epo]: RTE:[RTE_train/RTE_gold/loss=3.11e-01, RTE_valid/RTE_gold/accuracy=6.64e-01] model:[train/all/loss=3.11e-01, train/all/lr=5.

HBox(children=(IntProgress(value=0, max=312), HTML(value='')))

[6.05 epo]: RTE:[RTE_train/RTE_gold/loss=1.63e-01, RTE_valid/RTE_gold/accuracy=6.86e-01] model:[train/all/loss=1.63e-01, train/all/lr=5.00e-05]
[6.15 epo]: RTE:[RTE_train/RTE_gold/loss=2.66e-01, RTE_valid/RTE_gold/accuracy=6.90e-01] model:[train/all/loss=2.66e-01, train/all/lr=5.00e-05]
[6.26 epo]: RTE:[RTE_train/RTE_gold/loss=2.57e-01, RTE_valid/RTE_gold/accuracy=6.64e-01] model:[train/all/loss=2.57e-01, train/all/lr=5.00e-05]
[6.36 epo]: RTE:[RTE_train/RTE_gold/loss=2.58e-01, RTE_valid/RTE_gold/accuracy=6.46e-01] model:[train/all/loss=2.58e-01, train/all/lr=5.00e-05]
[6.46 epo]: RTE:[RTE_train/RTE_gold/loss=3.55e-01, RTE_valid/RTE_gold/accuracy=6.57e-01] model:[train/all/loss=3.55e-01, train/all/lr=5.00e-05]
[6.56 epo]: RTE:[RTE_train/RTE_gold/loss=3.78e-01, RTE_valid/RTE_gold/accuracy=6.35e-01] model:[train/all/loss=3.78e-01, train/all/lr=5.00e-05]
[6.67 epo]: RTE:[RTE_train/RTE_gold/loss=4.28e-01, RTE_valid/RTE_gold/accuracy=6.28e-01] model:[train/all/loss=4.28e-01, train/all/lr=5.

HBox(children=(IntProgress(value=0, max=312), HTML(value='')))

[7.08 epo]: RTE:[RTE_train/RTE_gold/loss=2.05e-01, RTE_valid/RTE_gold/accuracy=6.68e-01] model:[train/all/loss=2.05e-01, train/all/lr=5.00e-05]
[7.18 epo]: RTE:[RTE_train/RTE_gold/loss=3.03e-01, RTE_valid/RTE_gold/accuracy=6.75e-01] model:[train/all/loss=3.03e-01, train/all/lr=5.00e-05]
[7.28 epo]: RTE:[RTE_train/RTE_gold/loss=3.56e-01, RTE_valid/RTE_gold/accuracy=6.75e-01] model:[train/all/loss=3.56e-01, train/all/lr=5.00e-05]
[7.38 epo]: RTE:[RTE_train/RTE_gold/loss=3.37e-01, RTE_valid/RTE_gold/accuracy=6.79e-01] model:[train/all/loss=3.37e-01, train/all/lr=5.00e-05]
[7.49 epo]: RTE:[RTE_train/RTE_gold/loss=3.98e-01, RTE_valid/RTE_gold/accuracy=6.79e-01] model:[train/all/loss=3.98e-01, train/all/lr=5.00e-05]
[7.59 epo]: RTE:[RTE_train/RTE_gold/loss=4.29e-01, RTE_valid/RTE_gold/accuracy=6.82e-01] model:[train/all/loss=4.29e-01, train/all/lr=5.00e-05]
[7.69 epo]: RTE:[RTE_train/RTE_gold/loss=2.43e-01, RTE_valid/RTE_gold/accuracy=6.53e-01] model:[train/all/loss=2.43e-01, train/all/lr=5.

HBox(children=(IntProgress(value=0, max=312), HTML(value='')))

[8.10 epo]: RTE:[RTE_train/RTE_gold/loss=2.16e-01, RTE_valid/RTE_gold/accuracy=6.32e-01] model:[train/all/loss=2.16e-01, train/all/lr=5.00e-05]
[8.21 epo]: RTE:[RTE_train/RTE_gold/loss=2.79e-01, RTE_valid/RTE_gold/accuracy=6.39e-01] model:[train/all/loss=2.79e-01, train/all/lr=5.00e-05]
[8.31 epo]: RTE:[RTE_train/RTE_gold/loss=3.04e-01, RTE_valid/RTE_gold/accuracy=6.53e-01] model:[train/all/loss=3.04e-01, train/all/lr=5.00e-05]
[8.41 epo]: RTE:[RTE_train/RTE_gold/loss=3.81e-01, RTE_valid/RTE_gold/accuracy=5.96e-01] model:[train/all/loss=3.81e-01, train/all/lr=5.00e-05]
[8.51 epo]: RTE:[RTE_train/RTE_gold/loss=3.00e-01, RTE_valid/RTE_gold/accuracy=6.43e-01] model:[train/all/loss=3.00e-01, train/all/lr=5.00e-05]
[8.62 epo]: RTE:[RTE_train/RTE_gold/loss=4.18e-01, RTE_valid/RTE_gold/accuracy=6.39e-01] model:[train/all/loss=4.18e-01, train/all/lr=5.00e-05]
[8.72 epo]: RTE:[RTE_train/RTE_gold/loss=3.83e-01, RTE_valid/RTE_gold/accuracy=6.17e-01] model:[train/all/loss=3.83e-01, train/all/lr=5.

HBox(children=(IntProgress(value=0, max=312), HTML(value='')))

[9.03 epo]: RTE:[RTE_train/RTE_gold/loss=3.81e-01, RTE_valid/RTE_gold/accuracy=6.43e-01] model:[train/all/loss=3.81e-01, train/all/lr=5.00e-05]
[9.13 epo]: RTE:[RTE_train/RTE_gold/loss=2.13e-01, RTE_valid/RTE_gold/accuracy=6.25e-01] model:[train/all/loss=2.13e-01, train/all/lr=5.00e-05]
[9.23 epo]: RTE:[RTE_train/RTE_gold/loss=2.79e-01, RTE_valid/RTE_gold/accuracy=6.32e-01] model:[train/all/loss=2.79e-01, train/all/lr=5.00e-05]
[9.33 epo]: RTE:[RTE_train/RTE_gold/loss=3.56e-01, RTE_valid/RTE_gold/accuracy=6.14e-01] model:[train/all/loss=3.56e-01, train/all/lr=5.00e-05]
[9.44 epo]: RTE:[RTE_train/RTE_gold/loss=3.90e-01, RTE_valid/RTE_gold/accuracy=6.28e-01] model:[train/all/loss=3.90e-01, train/all/lr=5.00e-05]
[9.54 epo]: RTE:[RTE_train/RTE_gold/loss=2.69e-01, RTE_valid/RTE_gold/accuracy=6.61e-01] model:[train/all/loss=2.69e-01, train/all/lr=5.00e-05]
[9.64 epo]: RTE:[RTE_train/RTE_gold/loss=3.93e-01, RTE_valid/RTE_gold/accuracy=5.85e-01] model:[train/all/loss=3.93e-01, train/all/lr=5.

### Evaluate baseline slices

In [11]:
import copy
eval_payload = copy.deepcopy(payloads_slice[1])

# NOTE: we need to retarget slices to the original RTE head
for label_name in ['RTE_slice:dash_semicolon', 'RTE_slice:more_people', 'RTE_slice:BASE']:
    eval_payload.retarget_labelset(label_name, 'RTE')

label_set RTE_slice:dash_semicolon now points to task RTE (originally, RTE_slice:dash_semicolon).
label_set RTE_slice:more_people now points to task RTE (originally, RTE_slice:more_people).
label_set RTE_slice:BASE now points to task RTE (originally, RTE_slice:BASE).


In [12]:
eval_payload

Payload(RTE_valid: labels_to_tasks=[{'RTE_gold': 'RTE', 'RTE_slice:dash_semicolon': 'RTE', 'RTE_slice:more_people': 'RTE', 'RTE_slice:BASE': 'RTE'}], split=valid)

In [13]:
model.score(eval_payload)

{'RTE/RTE_valid/RTE_gold/accuracy': 0.703971119133574,
 'RTE/RTE_valid/RTE_slice:dash_semicolon/accuracy': 0.6293103448275862,
 'RTE/RTE_valid/RTE_slice:more_people/accuracy': 0.5,
 'RTE/RTE_valid/RTE_slice:BASE/accuracy': 0.703971119133574}

### Initialize and train slice model

In [14]:
# tasks_slice

In [15]:
from metal.mmtl.slicing.slicing_tasks import convert_to_slicing_tasks
slicing_tasks = convert_to_slicing_tasks(tasks_slice)
slicing_tasks

Modifying RTE out_features from 2 -> 1
Modifying RTE_slice:dash_semicolon out_features from 2 -> 1
Modifying RTE_slice:more_people out_features from 2 -> 1
Modifying RTE_slice:BASE out_features from 2 -> 1


[BinaryClassificationTask(name=RTE, loss_multiplier=1.00, is_slice=False),
 BinaryClassificationTask(name=RTE_slice:dash_semicolon, loss_multiplier=1.00, is_slice=True),
 BinaryClassificationTask(name=RTE_slice:more_people, loss_multiplier=1.00, is_slice=True),
 BinaryClassificationTask(name=RTE_slice:BASE, loss_multiplier=1.00, is_slice=True)]

In [20]:
# model = MetalModel(tasks_slice, seed=SEED, verbose=False)
slice_model = SliceModel(slicing_tasks, seed=SEED, verbose=False)
slice_model

SliceModel(
  (input_modules): ModuleDict(
    (RTE): DataParallel(
      (module): BertRaw(
        (embeddings): BertEmbeddings(
          (word_embeddings): Embedding(28996, 768)
          (position_embeddings): Embedding(512, 768)
          (token_type_embeddings): Embedding(2, 768)
          (LayerNorm): FusedLayerNorm(torch.Size([768]), eps=1e-12, elementwise_affine=True)
          (dropout): Dropout(p=0.1)
        )
        (encoder): BertEncoder(
          (layer): ModuleList(
            (0): BertLayer(
              (attention): BertAttention(
                (self): BertSelfAttention(
                  (query): Linear(in_features=768, out_features=768, bias=True)
                  (key): Linear(in_features=768, out_features=768, bias=True)
                  (value): Linear(in_features=768, out_features=768, bias=True)
                  (dropout): Dropout(p=0.1)
                )
                (output): BertSelfOutput(
                  (dense): Linear(in_features=768, out_

In [21]:
## %%time
trainer = MultitaskTrainer(seed=SEED)
trainer.train_model(
    slice_model,
    payloads_slice,
    task_metrics=[
        "RTE/RTE_train/RTE_gold/loss", 
        "RTE/RTE_train/RTE_slice:dash_semicolon/loss", 
        "RTE/RTE_train/RTE_slice:more_people/loss",
        "RTE/RTE_valid/RTE_gold/accuracy",
        "RTE/RTE_valid/RTE_slice:dash_semicolon/accuracy", 
        "RTE/RTE_valid/RTE_slice:more_people/accuracy",
    ],
    checkpoint_metric="RTE/RTE_valid/RTE_gold/accuracy",
    checkpoint_metric_mode="max",
    checkoint_best=True,
    writer="tensorboard",
    optimizer="adamax",
    lr=1e-5,
#     l2=1e-3,
    l2=1e-3,
    log_every=0.1, 
    score_every=0.1,
    n_epochs=20,
    progress_bar=True,
    checkpoint_best=True,
    checkpoint_cleanup=False
)

Beginning train loop.
Expecting a total of approximately 2496 examples and 312 batches per epoch from 1 payload(s) in the train split.
Writing config to /dfs/scratch0/vschen/metal-mmtl/logs/2019_04_15/22_03_58/config.json


HBox(children=(IntProgress(value=0, max=312), HTML(value='')))



[0.10 epo]: RTE:[RTE_train/RTE_gold/loss=6.91e-01, RTE_valid/RTE_gold/accuracy=5.27e-01] RTE_slice:dash_semicolon:[RTE_train/RTE_slice:dash_semicolon/loss=1.52e-01, RTE_valid/RTE_slice:dash_semicolon/accuracy=6.12e-01] RTE_slice:more_people:[RTE_train/RTE_slice:more_people/loss=2.13e-01, RTE_valid/RTE_slice:more_people/accuracy=4.17e-01] RTE_slice:BASE:[RTE_train/RTE_slice:BASE/loss=1.79e-01, RTE_valid/RTE_slice:BASE/accuracy=0] model:[train/all/loss=3.87e-01, train/all/lr=1.00e-05]
Saving model at iteration 0.10 with best (max) score RTE/RTE_valid/RTE_gold/accuracy=0.527
[0.21 epo]: RTE:[RTE_train/RTE_gold/loss=6.93e-01, RTE_valid/RTE_gold/accuracy=5.27e-01] RTE_slice:dash_semicolon:[RTE_train/RTE_slice:dash_semicolon/loss=1.95e-01, RTE_valid/RTE_slice:dash_semicolon/accuracy=6.29e-01] RTE_slice:more_people:[RTE_train/RTE_slice:more_people/loss=1.94e-01, RTE_valid/RTE_slice:more_people/accuracy=4.17e-01] RTE_slice:BASE:[RTE_train/RTE_slice:BASE/loss=2.08e-01, RTE_valid/RTE_slice:BASE/

HBox(children=(IntProgress(value=0, max=312), HTML(value='')))

[1.03 epo]: RTE:[RTE_train/RTE_gold/loss=6.91e-01, RTE_valid/RTE_gold/accuracy=5.27e-01] RTE_slice:dash_semicolon:[RTE_train/RTE_slice:dash_semicolon/loss=1.52e-01, RTE_valid/RTE_slice:dash_semicolon/accuracy=6.03e-01] RTE_slice:BASE:[RTE_train/RTE_slice:BASE/loss=1.71e-01, RTE_valid/RTE_slice:BASE/accuracy=0] RTE_slice:more_people:[RTE_train/RTE_slice:more_people/loss=2.34e-01, RTE_valid/RTE_slice:more_people/accuracy=4.17e-01] model:[train/all/loss=3.81e-01, train/all/lr=1.00e-05]
[1.13 epo]: RTE:[RTE_train/RTE_gold/loss=6.90e-01, RTE_valid/RTE_gold/accuracy=5.27e-01] RTE_slice:dash_semicolon:[RTE_train/RTE_slice:dash_semicolon/loss=1.22e-01, RTE_valid/RTE_slice:dash_semicolon/accuracy=6.21e-01] RTE_slice:BASE:[RTE_train/RTE_slice:BASE/loss=1.52e-01, RTE_valid/RTE_slice:BASE/accuracy=0] RTE_slice:more_people:[RTE_train/RTE_slice:more_people/loss=1.78e-01, RTE_valid/RTE_slice:more_people/accuracy=4.17e-01] model:[train/all/loss=3.70e-01, train/all/lr=1.00e-05]
[1.23 epo]: RTE:[RTE_tra

HBox(children=(IntProgress(value=0, max=312), HTML(value='')))

[2.05 epo]: RTE:[RTE_train/RTE_gold/loss=6.90e-01, RTE_valid/RTE_gold/accuracy=5.45e-01] RTE_slice:dash_semicolon:[RTE_train/RTE_slice:dash_semicolon/loss=1.43e-01, RTE_valid/RTE_slice:dash_semicolon/accuracy=6.03e-01] RTE_slice:BASE:[RTE_train/RTE_slice:BASE/loss=1.68e-01, RTE_valid/RTE_slice:BASE/accuracy=0] RTE_slice:more_people:[RTE_train/RTE_slice:more_people/loss=2.79e-01, RTE_valid/RTE_slice:more_people/accuracy=4.17e-01] model:[train/all/loss=3.86e-01, train/all/lr=1.00e-05]
Saving model at iteration 2.05 with best (max) score RTE/RTE_valid/RTE_gold/accuracy=0.545
[2.15 epo]: RTE:[RTE_train/RTE_gold/loss=6.89e-01, RTE_valid/RTE_gold/accuracy=5.96e-01] RTE_slice:dash_semicolon:[RTE_train/RTE_slice:dash_semicolon/loss=1.06e-01, RTE_valid/RTE_slice:dash_semicolon/accuracy=6.21e-01] RTE_slice:BASE:[RTE_train/RTE_slice:BASE/loss=1.35e-01, RTE_valid/RTE_slice:BASE/accuracy=0] RTE_slice:more_people:[RTE_train/RTE_slice:more_people/loss=1.97e-01, RTE_valid/RTE_slice:more_people/accurac

HBox(children=(IntProgress(value=0, max=312), HTML(value='')))

[3.08 epo]: RTE:[RTE_train/RTE_gold/loss=6.87e-01, RTE_valid/RTE_gold/accuracy=6.06e-01] RTE_slice:dash_semicolon:[RTE_train/RTE_slice:dash_semicolon/loss=1.17e-01, RTE_valid/RTE_slice:dash_semicolon/accuracy=6.12e-01] RTE_slice:BASE:[RTE_train/RTE_slice:BASE/loss=1.28e-01, RTE_valid/RTE_slice:BASE/accuracy=0] RTE_slice:more_people:[RTE_train/RTE_slice:more_people/loss=1.88e-01, RTE_valid/RTE_slice:more_people/accuracy=5.00e-01] model:[train/all/loss=3.53e-01, train/all/lr=1.00e-05]
[3.18 epo]: RTE:[RTE_train/RTE_gold/loss=6.88e-01, RTE_valid/RTE_gold/accuracy=6.61e-01] RTE_slice:dash_semicolon:[RTE_train/RTE_slice:dash_semicolon/loss=1.20e-01, RTE_valid/RTE_slice:dash_semicolon/accuracy=6.29e-01] RTE_slice:BASE:[RTE_train/RTE_slice:BASE/loss=1.35e-01, RTE_valid/RTE_slice:BASE/accuracy=0] RTE_slice:more_people:[RTE_train/RTE_slice:more_people/loss=1.55e-01, RTE_valid/RTE_slice:more_people/accuracy=5.00e-01] model:[train/all/loss=3.63e-01, train/all/lr=1.00e-05]
Saving model at iteratio

HBox(children=(IntProgress(value=0, max=312), HTML(value='')))

[4.10 epo]: RTE:[RTE_train/RTE_gold/loss=6.86e-01, RTE_valid/RTE_gold/accuracy=6.50e-01] RTE_slice:dash_semicolon:[RTE_train/RTE_slice:dash_semicolon/loss=8.91e-02, RTE_valid/RTE_slice:dash_semicolon/accuracy=6.03e-01] RTE_slice:BASE:[RTE_train/RTE_slice:BASE/loss=1.21e-01, RTE_valid/RTE_slice:BASE/accuracy=0] RTE_slice:more_people:[RTE_train/RTE_slice:more_people/loss=2.60e-01, RTE_valid/RTE_slice:more_people/accuracy=5.00e-01] model:[train/all/loss=3.49e-01, train/all/lr=1.00e-05]
[4.21 epo]: RTE:[RTE_train/RTE_gold/loss=6.85e-01, RTE_valid/RTE_gold/accuracy=6.25e-01] RTE_slice:dash_semicolon:[RTE_train/RTE_slice:dash_semicolon/loss=1.08e-01, RTE_valid/RTE_slice:dash_semicolon/accuracy=6.03e-01] RTE_slice:BASE:[RTE_train/RTE_slice:BASE/loss=1.26e-01, RTE_valid/RTE_slice:BASE/accuracy=0] RTE_slice:more_people:[RTE_train/RTE_slice:more_people/loss=2.41e-01, RTE_valid/RTE_slice:more_people/accuracy=5.00e-01] model:[train/all/loss=3.52e-01, train/all/lr=1.00e-05]
[4.31 epo]: RTE:[RTE_tra

HBox(children=(IntProgress(value=0, max=312), HTML(value='')))

[5.03 epo]: RTE:[RTE_train/RTE_gold/loss=6.85e-01, RTE_valid/RTE_gold/accuracy=6.43e-01] RTE_slice:dash_semicolon:[RTE_train/RTE_slice:dash_semicolon/loss=7.58e-02, RTE_valid/RTE_slice:dash_semicolon/accuracy=5.95e-01] RTE_slice:BASE:[RTE_train/RTE_slice:BASE/loss=1.27e-01, RTE_valid/RTE_slice:BASE/accuracy=0] RTE_slice:more_people:[RTE_train/RTE_slice:more_people/loss=2.55e-01, RTE_valid/RTE_slice:more_people/accuracy=5.00e-01] model:[train/all/loss=3.50e-01, train/all/lr=1.00e-05]
[5.13 epo]: RTE:[RTE_train/RTE_gold/loss=6.84e-01, RTE_valid/RTE_gold/accuracy=6.86e-01] RTE_slice:dash_semicolon:[RTE_train/RTE_slice:dash_semicolon/loss=1.04e-01, RTE_valid/RTE_slice:dash_semicolon/accuracy=6.12e-01] RTE_slice:BASE:[RTE_train/RTE_slice:BASE/loss=1.24e-01, RTE_valid/RTE_slice:BASE/accuracy=0] RTE_slice:more_people:[RTE_train/RTE_slice:more_people/loss=1.73e-01, RTE_valid/RTE_slice:more_people/accuracy=4.17e-01] model:[train/all/loss=3.54e-01, train/all/lr=1.00e-05]
Saving model at iteratio

HBox(children=(IntProgress(value=0, max=312), HTML(value='')))

[6.05 epo]: RTE:[RTE_train/RTE_gold/loss=6.83e-01, RTE_valid/RTE_gold/accuracy=6.50e-01] RTE_slice:dash_semicolon:[RTE_train/RTE_slice:dash_semicolon/loss=8.31e-02, RTE_valid/RTE_slice:dash_semicolon/accuracy=6.03e-01] RTE_slice:BASE:[RTE_train/RTE_slice:BASE/loss=1.03e-01, RTE_valid/RTE_slice:BASE/accuracy=0] RTE_slice:more_people:[RTE_train/RTE_slice:more_people/loss=3.04e-01, RTE_valid/RTE_slice:more_people/accuracy=4.17e-01] model:[train/all/loss=3.43e-01, train/all/lr=1.00e-05]
[6.15 epo]: RTE:[RTE_train/RTE_gold/loss=6.82e-01, RTE_valid/RTE_gold/accuracy=6.64e-01] RTE_slice:dash_semicolon:[RTE_train/RTE_slice:dash_semicolon/loss=7.54e-02, RTE_valid/RTE_slice:dash_semicolon/accuracy=5.78e-01] RTE_slice:BASE:[RTE_train/RTE_slice:BASE/loss=9.72e-02, RTE_valid/RTE_slice:BASE/accuracy=0] RTE_slice:more_people:[RTE_train/RTE_slice:more_people/loss=1.67e-01, RTE_valid/RTE_slice:more_people/accuracy=5.00e-01] model:[train/all/loss=3.41e-01, train/all/lr=1.00e-05]
[6.26 epo]: RTE:[RTE_tra

HBox(children=(IntProgress(value=0, max=312), HTML(value='')))

[7.08 epo]: RTE:[RTE_train/RTE_gold/loss=6.81e-01, RTE_valid/RTE_gold/accuracy=6.75e-01] RTE_slice:dash_semicolon:[RTE_train/RTE_slice:dash_semicolon/loss=6.87e-02, RTE_valid/RTE_slice:dash_semicolon/accuracy=6.03e-01] RTE_slice:BASE:[RTE_train/RTE_slice:BASE/loss=1.09e-01, RTE_valid/RTE_slice:BASE/accuracy=0] RTE_slice:more_people:[RTE_train/RTE_slice:more_people/loss=2.28e-01, RTE_valid/RTE_slice:more_people/accuracy=4.17e-01] model:[train/all/loss=3.43e-01, train/all/lr=1.00e-05]
[7.18 epo]: RTE:[RTE_train/RTE_gold/loss=6.79e-01, RTE_valid/RTE_gold/accuracy=6.79e-01] RTE_slice:dash_semicolon:[RTE_train/RTE_slice:dash_semicolon/loss=4.76e-02, RTE_valid/RTE_slice:dash_semicolon/accuracy=5.86e-01] RTE_slice:BASE:[RTE_train/RTE_slice:BASE/loss=9.20e-02, RTE_valid/RTE_slice:BASE/accuracy=0] RTE_slice:more_people:[RTE_train/RTE_slice:more_people/loss=2.56e-01, RTE_valid/RTE_slice:more_people/accuracy=5.00e-01] model:[train/all/loss=3.28e-01, train/all/lr=1.00e-05]
[7.28 epo]: RTE:[RTE_tra

HBox(children=(IntProgress(value=0, max=312), HTML(value='')))

[8.10 epo]: RTE:[RTE_train/RTE_gold/loss=6.76e-01, RTE_valid/RTE_gold/accuracy=6.46e-01] RTE_slice:dash_semicolon:[RTE_train/RTE_slice:dash_semicolon/loss=6.18e-02, RTE_valid/RTE_slice:dash_semicolon/accuracy=6.12e-01] RTE_slice:BASE:[RTE_train/RTE_slice:BASE/loss=7.98e-02, RTE_valid/RTE_slice:BASE/accuracy=0] RTE_slice:more_people:[RTE_train/RTE_slice:more_people/loss=1.28e-01, RTE_valid/RTE_slice:more_people/accuracy=4.17e-01] model:[train/all/loss=3.28e-01, train/all/lr=1.00e-05]
[8.21 epo]: RTE:[RTE_train/RTE_gold/loss=6.77e-01, RTE_valid/RTE_gold/accuracy=6.79e-01] RTE_slice:dash_semicolon:[RTE_train/RTE_slice:dash_semicolon/loss=7.05e-02, RTE_valid/RTE_slice:dash_semicolon/accuracy=6.12e-01] RTE_slice:BASE:[RTE_train/RTE_slice:BASE/loss=8.08e-02, RTE_valid/RTE_slice:BASE/accuracy=0] RTE_slice:more_people:[RTE_train/RTE_slice:more_people/loss=2.06e-01, RTE_valid/RTE_slice:more_people/accuracy=4.17e-01] model:[train/all/loss=3.25e-01, train/all/lr=1.00e-05]
[8.31 epo]: RTE:[RTE_tra

HBox(children=(IntProgress(value=0, max=312), HTML(value='')))

[9.03 epo]: RTE:[RTE_train/RTE_gold/loss=6.76e-01, RTE_valid/RTE_gold/accuracy=6.90e-01] RTE_slice:dash_semicolon:[RTE_train/RTE_slice:dash_semicolon/loss=6.76e-02, RTE_valid/RTE_slice:dash_semicolon/accuracy=6.03e-01] RTE_slice:BASE:[RTE_train/RTE_slice:BASE/loss=1.02e-01, RTE_valid/RTE_slice:BASE/accuracy=0] RTE_slice:more_people:[RTE_train/RTE_slice:more_people/loss=1.46e-01, RTE_valid/RTE_slice:more_people/accuracy=5.83e-01] model:[train/all/loss=3.38e-01, train/all/lr=1.00e-05]
[9.13 epo]: RTE:[RTE_train/RTE_gold/loss=6.74e-01, RTE_valid/RTE_gold/accuracy=6.71e-01] RTE_slice:dash_semicolon:[RTE_train/RTE_slice:dash_semicolon/loss=6.30e-02, RTE_valid/RTE_slice:dash_semicolon/accuracy=5.86e-01] RTE_slice:more_people:[RTE_train/RTE_slice:more_people/loss=1.34e-01, RTE_valid/RTE_slice:more_people/accuracy=4.17e-01] RTE_slice:BASE:[RTE_train/RTE_slice:BASE/loss=7.68e-02, RTE_valid/RTE_slice:BASE/accuracy=0] model:[train/all/loss=3.22e-01, train/all/lr=1.00e-05]
[9.23 epo]: RTE:[RTE_tra

HBox(children=(IntProgress(value=0, max=312), HTML(value='')))

[10.05 epo]: RTE:[RTE_train/RTE_gold/loss=6.73e-01, RTE_valid/RTE_gold/accuracy=6.75e-01] RTE_slice:dash_semicolon:[RTE_train/RTE_slice:dash_semicolon/loss=5.51e-02, RTE_valid/RTE_slice:dash_semicolon/accuracy=6.21e-01] RTE_slice:BASE:[RTE_train/RTE_slice:BASE/loss=9.43e-02, RTE_valid/RTE_slice:BASE/accuracy=0] RTE_slice:more_people:[RTE_train/RTE_slice:more_people/loss=9.42e-02, RTE_valid/RTE_slice:more_people/accuracy=5.83e-01] model:[train/all/loss=3.25e-01, train/all/lr=1.00e-05]
[10.15 epo]: RTE:[RTE_train/RTE_gold/loss=6.71e-01, RTE_valid/RTE_gold/accuracy=6.75e-01] RTE_slice:dash_semicolon:[RTE_train/RTE_slice:dash_semicolon/loss=3.25e-02, RTE_valid/RTE_slice:dash_semicolon/accuracy=5.95e-01] RTE_slice:BASE:[RTE_train/RTE_slice:BASE/loss=6.87e-02, RTE_valid/RTE_slice:BASE/accuracy=0] RTE_slice:more_people:[RTE_train/RTE_slice:more_people/loss=1.94e-01, RTE_valid/RTE_slice:more_people/accuracy=4.17e-01] model:[train/all/loss=3.17e-01, train/all/lr=1.00e-05]
[10.26 epo]: RTE:[RTE_

HBox(children=(IntProgress(value=0, max=312), HTML(value='')))

[11.08 epo]: RTE:[RTE_train/RTE_gold/loss=6.68e-01, RTE_valid/RTE_gold/accuracy=6.79e-01] RTE_slice:dash_semicolon:[RTE_train/RTE_slice:dash_semicolon/loss=6.21e-02, RTE_valid/RTE_slice:dash_semicolon/accuracy=6.03e-01] RTE_slice:BASE:[RTE_train/RTE_slice:BASE/loss=6.20e-02, RTE_valid/RTE_slice:BASE/accuracy=0] RTE_slice:more_people:[RTE_train/RTE_slice:more_people/loss=9.57e-02, RTE_valid/RTE_slice:more_people/accuracy=6.67e-01] model:[train/all/loss=3.07e-01, train/all/lr=1.00e-05]
[11.18 epo]: RTE:[RTE_train/RTE_gold/loss=6.66e-01, RTE_valid/RTE_gold/accuracy=6.68e-01] RTE_slice:dash_semicolon:[RTE_train/RTE_slice:dash_semicolon/loss=2.40e-02, RTE_valid/RTE_slice:dash_semicolon/accuracy=6.03e-01] RTE_slice:BASE:[RTE_train/RTE_slice:BASE/loss=5.35e-02, RTE_valid/RTE_slice:BASE/accuracy=0] RTE_slice:more_people:[RTE_train/RTE_slice:more_people/loss=1.73e-01, RTE_valid/RTE_slice:more_people/accuracy=5.83e-01] model:[train/all/loss=3.03e-01, train/all/lr=1.00e-05]
[11.28 epo]: RTE:[RTE_

HBox(children=(IntProgress(value=0, max=312), HTML(value='')))

[12.10 epo]: RTE:[RTE_train/RTE_gold/loss=6.63e-01, RTE_valid/RTE_gold/accuracy=6.86e-01] RTE_slice:dash_semicolon:[RTE_train/RTE_slice:dash_semicolon/loss=4.21e-02, RTE_valid/RTE_slice:dash_semicolon/accuracy=6.03e-01] RTE_slice:BASE:[RTE_train/RTE_slice:BASE/loss=6.21e-02, RTE_valid/RTE_slice:BASE/accuracy=0] RTE_slice:more_people:[RTE_train/RTE_slice:more_people/loss=8.12e-02, RTE_valid/RTE_slice:more_people/accuracy=5.83e-01] model:[train/all/loss=3.07e-01, train/all/lr=1.00e-05]
[12.21 epo]: RTE:[RTE_train/RTE_gold/loss=6.64e-01, RTE_valid/RTE_gold/accuracy=6.71e-01] RTE_slice:dash_semicolon:[RTE_train/RTE_slice:dash_semicolon/loss=4.75e-02, RTE_valid/RTE_slice:dash_semicolon/accuracy=5.78e-01] RTE_slice:BASE:[RTE_train/RTE_slice:BASE/loss=7.26e-02, RTE_valid/RTE_slice:BASE/accuracy=0] RTE_slice:more_people:[RTE_train/RTE_slice:more_people/loss=9.14e-02, RTE_valid/RTE_slice:more_people/accuracy=5.83e-01] model:[train/all/loss=3.12e-01, train/all/lr=1.00e-05]
[12.31 epo]: RTE:[RTE_

HBox(children=(IntProgress(value=0, max=312), HTML(value='')))

[13.03 epo]: RTE:[RTE_train/RTE_gold/loss=6.60e-01, RTE_valid/RTE_gold/accuracy=6.64e-01] RTE_slice:dash_semicolon:[RTE_train/RTE_slice:dash_semicolon/loss=4.57e-02, RTE_valid/RTE_slice:dash_semicolon/accuracy=5.86e-01] RTE_slice:BASE:[RTE_train/RTE_slice:BASE/loss=5.96e-02, RTE_valid/RTE_slice:BASE/accuracy=0] RTE_slice:more_people:[RTE_train/RTE_slice:more_people/loss=1.82e-01, RTE_valid/RTE_slice:more_people/accuracy=7.50e-01] model:[train/all/loss=3.05e-01, train/all/lr=1.00e-05]
[13.13 epo]: RTE:[RTE_train/RTE_gold/loss=6.62e-01, RTE_valid/RTE_gold/accuracy=6.50e-01] RTE_slice:dash_semicolon:[RTE_train/RTE_slice:dash_semicolon/loss=6.31e-02, RTE_valid/RTE_slice:dash_semicolon/accuracy=5.78e-01] RTE_slice:BASE:[RTE_train/RTE_slice:BASE/loss=8.87e-02, RTE_valid/RTE_slice:BASE/accuracy=0] RTE_slice:more_people:[RTE_train/RTE_slice:more_people/loss=6.26e-02, RTE_valid/RTE_slice:more_people/accuracy=5.83e-01] model:[train/all/loss=3.27e-01, train/all/lr=1.00e-05]
[13.23 epo]: RTE:[RTE_

HBox(children=(IntProgress(value=0, max=312), HTML(value='')))

[14.05 epo]: RTE:[RTE_train/RTE_gold/loss=6.62e-01, RTE_valid/RTE_gold/accuracy=6.64e-01] RTE_slice:dash_semicolon:[RTE_train/RTE_slice:dash_semicolon/loss=8.81e-02, RTE_valid/RTE_slice:dash_semicolon/accuracy=5.95e-01] RTE_slice:more_people:[RTE_train/RTE_slice:more_people/loss=1.47e-01, RTE_valid/RTE_slice:more_people/accuracy=4.17e-01] RTE_slice:BASE:[RTE_train/RTE_slice:BASE/loss=1.19e-01, RTE_valid/RTE_slice:BASE/accuracy=0] model:[train/all/loss=3.39e-01, train/all/lr=1.00e-05]
[14.15 epo]: RTE:[RTE_train/RTE_gold/loss=6.55e-01, RTE_valid/RTE_gold/accuracy=6.53e-01] RTE_slice:dash_semicolon:[RTE_train/RTE_slice:dash_semicolon/loss=4.26e-02, RTE_valid/RTE_slice:dash_semicolon/accuracy=5.86e-01] RTE_slice:BASE:[RTE_train/RTE_slice:BASE/loss=5.51e-02, RTE_valid/RTE_slice:BASE/accuracy=0] RTE_slice:more_people:[RTE_train/RTE_slice:more_people/loss=1.44e-01, RTE_valid/RTE_slice:more_people/accuracy=7.50e-01] model:[train/all/loss=2.99e-01, train/all/lr=1.00e-05]
[14.26 epo]: RTE:[RTE_

HBox(children=(IntProgress(value=0, max=312), HTML(value='')))

[15.08 epo]: RTE:[RTE_train/RTE_gold/loss=6.49e-01, RTE_valid/RTE_gold/accuracy=6.46e-01] RTE_slice:dash_semicolon:[RTE_train/RTE_slice:dash_semicolon/loss=3.70e-02, RTE_valid/RTE_slice:dash_semicolon/accuracy=5.86e-01] RTE_slice:BASE:[RTE_train/RTE_slice:BASE/loss=5.28e-02, RTE_valid/RTE_slice:BASE/accuracy=0] RTE_slice:more_people:[RTE_train/RTE_slice:more_people/loss=1.90e-01, RTE_valid/RTE_slice:more_people/accuracy=5.83e-01] model:[train/all/loss=2.94e-01, train/all/lr=1.00e-05]
[15.18 epo]: RTE:[RTE_train/RTE_gold/loss=6.52e-01, RTE_valid/RTE_gold/accuracy=6.57e-01] RTE_slice:dash_semicolon:[RTE_train/RTE_slice:dash_semicolon/loss=7.27e-02, RTE_valid/RTE_slice:dash_semicolon/accuracy=5.95e-01] RTE_slice:BASE:[RTE_train/RTE_slice:BASE/loss=6.68e-02, RTE_valid/RTE_slice:BASE/accuracy=0] RTE_slice:more_people:[RTE_train/RTE_slice:more_people/loss=8.06e-02, RTE_valid/RTE_slice:more_people/accuracy=6.67e-01] model:[train/all/loss=3.12e-01, train/all/lr=1.00e-05]
[15.28 epo]: RTE:[RTE_

HBox(children=(IntProgress(value=0, max=312), HTML(value='')))

[16.10 epo]: RTE:[RTE_train/RTE_gold/loss=6.46e-01, RTE_valid/RTE_gold/accuracy=6.57e-01] RTE_slice:dash_semicolon:[RTE_train/RTE_slice:dash_semicolon/loss=5.33e-02, RTE_valid/RTE_slice:dash_semicolon/accuracy=6.03e-01] RTE_slice:BASE:[RTE_train/RTE_slice:BASE/loss=5.38e-02, RTE_valid/RTE_slice:BASE/accuracy=0] RTE_slice:more_people:[RTE_train/RTE_slice:more_people/loss=1.57e-01, RTE_valid/RTE_slice:more_people/accuracy=5.83e-01] model:[train/all/loss=3.01e-01, train/all/lr=1.00e-05]
[16.21 epo]: RTE:[RTE_train/RTE_gold/loss=6.45e-01, RTE_valid/RTE_gold/accuracy=6.57e-01] RTE_slice:dash_semicolon:[RTE_train/RTE_slice:dash_semicolon/loss=3.15e-02, RTE_valid/RTE_slice:dash_semicolon/accuracy=6.12e-01] RTE_slice:BASE:[RTE_train/RTE_slice:BASE/loss=7.33e-02, RTE_valid/RTE_slice:BASE/accuracy=0] RTE_slice:more_people:[RTE_train/RTE_slice:more_people/loss=6.73e-02, RTE_valid/RTE_slice:more_people/accuracy=5.83e-01] model:[train/all/loss=3.06e-01, train/all/lr=1.00e-05]
[16.31 epo]: RTE:[RTE_

HBox(children=(IntProgress(value=0, max=312), HTML(value='')))

[17.03 epo]: RTE:[RTE_train/RTE_gold/loss=6.46e-01, RTE_valid/RTE_gold/accuracy=6.28e-01] RTE_slice:dash_semicolon:[RTE_train/RTE_slice:dash_semicolon/loss=1.29e-01, RTE_valid/RTE_slice:dash_semicolon/accuracy=5.52e-01] RTE_slice:BASE:[RTE_train/RTE_slice:BASE/loss=1.09e-01, RTE_valid/RTE_slice:BASE/accuracy=0] RTE_slice:more_people:[RTE_train/RTE_slice:more_people/loss=1.43e-01, RTE_valid/RTE_slice:more_people/accuracy=4.17e-01] model:[train/all/loss=3.38e-01, train/all/lr=1.00e-05]
[17.13 epo]: RTE:[RTE_train/RTE_gold/loss=6.42e-01, RTE_valid/RTE_gold/accuracy=6.61e-01] RTE_slice:dash_semicolon:[RTE_train/RTE_slice:dash_semicolon/loss=6.71e-02, RTE_valid/RTE_slice:dash_semicolon/accuracy=6.03e-01] RTE_slice:more_people:[RTE_train/RTE_slice:more_people/loss=4.57e-02, RTE_valid/RTE_slice:more_people/accuracy=5.83e-01] RTE_slice:BASE:[RTE_train/RTE_slice:BASE/loss=7.67e-02, RTE_valid/RTE_slice:BASE/accuracy=0] model:[train/all/loss=3.07e-01, train/all/lr=1.00e-05]
[17.23 epo]: RTE:[RTE_

HBox(children=(IntProgress(value=0, max=312), HTML(value='')))

[18.05 epo]: RTE:[RTE_train/RTE_gold/loss=6.36e-01, RTE_valid/RTE_gold/accuracy=6.53e-01] RTE_slice:BASE:[RTE_train/RTE_slice:BASE/loss=8.87e-02, RTE_valid/RTE_slice:BASE/accuracy=0] RTE_slice:dash_semicolon:[RTE_train/RTE_slice:dash_semicolon/loss=9.58e-02, RTE_valid/RTE_slice:dash_semicolon/accuracy=6.03e-01] RTE_slice:more_people:[RTE_train/RTE_slice:more_people/loss=4.27e-02, RTE_valid/RTE_slice:more_people/accuracy=6.67e-01] model:[train/all/loss=3.20e-01, train/all/lr=1.00e-05]
[18.15 epo]: RTE:[RTE_train/RTE_gold/loss=6.31e-01, RTE_valid/RTE_gold/accuracy=6.53e-01] RTE_slice:dash_semicolon:[RTE_train/RTE_slice:dash_semicolon/loss=3.72e-02, RTE_valid/RTE_slice:dash_semicolon/accuracy=5.86e-01] RTE_slice:BASE:[RTE_train/RTE_slice:BASE/loss=5.25e-02, RTE_valid/RTE_slice:BASE/accuracy=0] RTE_slice:more_people:[RTE_train/RTE_slice:more_people/loss=4.42e-02, RTE_valid/RTE_slice:more_people/accuracy=6.67e-01] model:[train/all/loss=2.87e-01, train/all/lr=1.00e-05]
[18.26 epo]: RTE:[RTE_

HBox(children=(IntProgress(value=0, max=312), HTML(value='')))

[19.08 epo]: RTE:[RTE_train/RTE_gold/loss=6.29e-01, RTE_valid/RTE_gold/accuracy=6.57e-01] RTE_slice:dash_semicolon:[RTE_train/RTE_slice:dash_semicolon/loss=1.13e-01, RTE_valid/RTE_slice:dash_semicolon/accuracy=6.03e-01] RTE_slice:more_people:[RTE_train/RTE_slice:more_people/loss=4.52e-02, RTE_valid/RTE_slice:more_people/accuracy=5.83e-01] RTE_slice:BASE:[RTE_train/RTE_slice:BASE/loss=8.92e-02, RTE_valid/RTE_slice:BASE/accuracy=0] model:[train/all/loss=3.24e-01, train/all/lr=1.00e-05]
[19.18 epo]: RTE:[RTE_train/RTE_gold/loss=6.25e-01, RTE_valid/RTE_gold/accuracy=6.61e-01] RTE_slice:dash_semicolon:[RTE_train/RTE_slice:dash_semicolon/loss=5.19e-02, RTE_valid/RTE_slice:dash_semicolon/accuracy=6.29e-01] RTE_slice:BASE:[RTE_train/RTE_slice:BASE/loss=7.45e-02, RTE_valid/RTE_slice:BASE/accuracy=0] RTE_slice:more_people:[RTE_train/RTE_slice:more_people/loss=4.07e-02, RTE_valid/RTE_slice:more_people/accuracy=5.00e-01] model:[train/all/loss=3.00e-01, train/all/lr=1.00e-05]
[19.28 epo]: RTE:[RTE_

{'RTE/RTE_train/RTE_gold/accuracy': 0.9325301204819277,
 'RTE_slice:dash_semicolon/RTE_train/RTE_slice:dash_semicolon/accuracy': 0.9700897308075773,
 'RTE_slice:more_people/RTE_train/RTE_slice:more_people/accuracy': 0.78125,
 'RTE_slice:BASE/RTE_train/RTE_slice:BASE/accuracy': 0,
 'RTE/RTE_valid/RTE_gold/accuracy': 0.6967509025270758,
 'RTE_slice:dash_semicolon/RTE_valid/RTE_slice:dash_semicolon/accuracy': 0.6293103448275862,
 'RTE_slice:more_people/RTE_valid/RTE_slice:more_people/accuracy': 0.4166666666666667,
 'RTE_slice:BASE/RTE_valid/RTE_slice:BASE/accuracy': 0,
 'RTE/RTE_test/RTE_gold/accuracy': 0.0,
 'RTE_slice:dash_semicolon/RTE_test/RTE_slice:dash_semicolon/accuracy': 0.0,
 'RTE_slice:more_people/RTE_test/RTE_slice:more_people/accuracy': 0.0,
 'RTE_slice:BASE/RTE_test/RTE_slice:BASE/accuracy': 0}

#### Did we improve?

In [22]:
%%time
slice_model.score(payloads_slice[1])

CPU times: user 1.49 s, sys: 208 ms, total: 1.7 s
Wall time: 1.64 s


{'RTE/RTE_valid/RTE_gold/accuracy': 0.6967509025270758,
 'RTE_slice:dash_semicolon/RTE_valid/RTE_slice:dash_semicolon/accuracy': 0.6293103448275862,
 'RTE_slice:more_people/RTE_valid/RTE_slice:more_people/accuracy': 0.4166666666666667,
 'RTE_slice:BASE/RTE_valid/RTE_slice:BASE/accuracy': 0.6895306859205776}

In [23]:
slice_model.score(eval_payload)

{'RTE/RTE_valid/RTE_gold/accuracy': 0.6967509025270758,
 'RTE/RTE_valid/RTE_slice:dash_semicolon/accuracy': 0.5862068965517241,
 'RTE/RTE_valid/RTE_slice:more_people/accuracy': 0.75,
 'RTE/RTE_valid/RTE_slice:BASE/accuracy': 0.6967509025270758}