In [1]:
import os
import csv
import json
import torch
import numpy as np
from argparse import ArgumentParser
import mlearn.modeling.multitask as mtl
from mlearn.utils.metrics import Metrics
from mlearn.data.batching import TorchtextExtractor
from mlearn.data.clean import Cleaner, Preprocessors
from mlearn.utils.train import run_mtl_model, train_mtl_model
from torchtext.data import TabularDataset, Field, LabelField, BucketIterator

In [2]:
# Data inputs and outputs
main = 'davidson'
auxi = ['waseem']
datadir = '../data/json/'
results = '../results/'
save_model = '../results/'

# Cleaning and metrics
cleaners = ['lower', 'username', 'url']
metrics = ['f1-score', 'precision', 'recall', 'accuracy']
display_metric = stop_metric = 'f1-score'
dev_metrics = Metrics(metrics, display_metric, stop_metric)
metrics = Metrics(metrics, display_metric, stop_metric)

# Experiment
experiment = 'word'
tokenizer = 'bpe'
seed = 42

# Modelling
# All models
model = 'lstm'
patience = 1
encoding = 'embedding'
loss = 'nlll'
optimizer = 'adam'
shuffle = True
gpu = False
batch_first = True
clip = 1.0

# LSTM
layers = 1

# CNN
window_sizes = "2,3,4"
filters = 128

# Hyper Parameters
embedding = 64
hidden = ["64,64"]
shared = 64
epochs = 10
batch_size = 64
learning_rate = 0.02
dropout = 0.2
nonlinearity = 'tanh'

# MTL specific
batches_epoch = 50
loss_weights = [1.0, 0.5]

In [3]:
# Initialise random seeds
torch.random.manual_seed(seed)
np.random.seed(seed)

# Set up experiment and cleaner
c = Cleaner(processes = cleaners)
exp = Preprocessors('data/').select_experiment(experiment)
onehot = True if encoding == 'onehot' else False

# Load tokenizers
if tokenizer == 'spacy':
    selected_tok  = c.tokenize
elif tokenizer == 'bpe':
    selected_tok = c.bpe_tokenize
elif tokenizer == 'ekphrasis' and args.experiment == 'word':
    selected_tok = c.ekphrasis_tokenize
    annotate = {'elongated', 'emphasis'}
    flters = [f"<{filtr}>" for filtr in annotate]
    c._load_ekphrasis(annotate, flters)
elif tokenizer == 'ekphrasis' and args.experiment == 'liwc':
    ekphr = c.ekphrasis_tokenize
    annotate = {'elongated', 'emphasis'}
    flters = [f"<{filtr}>" for filtr in annotate]
    c._load_ekphrasis(annotate, flters)

    def liwc_toks(doc):
        tokens = ekphr(doc)
        tokens = exp(tokens)
        return tokens
    selected_tok = liwc_toks
tokenizer = selected_tok

In [4]:
# Set up fields
text = Field(tokenize = tokenizer, lower = True, batch_first = True)
label = LabelField()
fields = {'text': ('text', text), 'label': ('label', label)}  # Because we load from json we just need this.

# Load main task training data
if main == 'davidson':
    train, dev, test = TabularDataset.splits(datadir, train = 'davidson_binary_train.json',
                                             validation = 'davidson_binary_dev.json',
                                             test = 'davidson_binary_test.json',
                                             format = 'json', skip_header = True, fields = fields)
text.build_vocab(train)
label.build_vocab(train)
main = {'train': train, 'dev': dev, 'test': test, 'text': text, 'labels': label, 'name': main}

# Load aux tasks
auxillary = []
for aux in auxi:
    # Set up fields
    text = Field(tokenize = tokenizer, lower = True, batch_first = True)
    label = LabelField()
    fields = {'text': ('text', text), 'label': ('label', label)}  # Because we load from json we just need this.

    if aux == 'davidson':
        train, dev, test = TabularDataset.splits(datadir, train = 'davidson_binary_train.json',
                                                 validation = 'davidson_binary_dev.json',
                                                 test = 'davidson_binary_test.json',
                                                 format = 'json', skip_header = True, fields = fields)
    elif aux == 'hoover':
        train, dev, test = TabularDataset.splits(datadir, train = 'hoover_train.json',
                                                 validation = 'hoover_dev.json',
                                                 test = 'hoover_test.json',
                                                 format = 'json', skip_header = True, fields = fields)
    elif aux == 'oraby_factfeel':
        train, dev, test = TabularDataset.splits(datadir, train = 'oraby_fact_feel_train.json',
                                                 validation = 'oraby_fact_feel_dev.json',
                                                 test = 'oraby_fact_feel_test.json',
                                                 format = 'json', skip_header = True, fields = fields)
    elif aux == 'oraby_sarcasm':
        train, dev, test = TabularDataset.splits(datadir, train = 'oraby_sarcasm_train.json',
                                                 validation = 'oraby_sarcasm_dev.json',
                                                 test = 'oraby_sarcasm_test.json',
                                                 format = 'json', skip_header = True, fields = fields)
    elif aux == 'waseem':
        train, dev, test = TabularDataset.splits(datadir, train = 'waseem_train.json',
                                                 validation = 'waseem_dev.json',
                                                 test = 'waseem_test.json',
                                                 format = 'json', skip_header = True, fields = fields)
    elif aux == 'waseem_hovy':
        train, dev, test = TabularDataset.splits(datadir, train = 'waseem_hovy_train.json',
                                                 validation = 'waseem_hovy_dev.json',
                                                 test = 'waseem_hovy_test.json',
                                                 format = 'json', skip_header = True, fields = fields)
    text.build_vocab(train)
    label.build_vocab(train)
    auxillary.append({'train': train, 'dev': dev, 'test': test, 'text': text, 'labels': label, 'name': aux})
    if len(auxillary) == len(auxi): break

In [5]:
# Hyper parameters
dropout = dropout
nonlinearity = nonlinearity
learning_rate = learning_rate
epochs = epochs
batch_size = batch_size
batch_count = batches_epoch
loss_weights = loss_weights

params = dict(shared_dim = shared,
              batch_first = True,
              hidden_dims = [int(hidden) for hidden in hidden[0].split(',')],
              input_dims = [len(main['text'].vocab.stoi)] + [len(aux['text'].vocab.stoi) for aux in auxillary],
              output_dims = [len(main['labels'].vocab.stoi)] + [len(aux['labels'].vocab.stoi) for aux in auxillary],
              )
print(params)

{'shared_dim': 64, 'batch_first': True, 'hidden_dims': [64, 64], 'input_dims': [18176, 9826], 'output_dims': [2, 4]}


In [6]:
if not onehot:
    params.update({'embedding_dims': embedding})
if model == 'lstm':
    params.update({'no_layers': layers})
    model = mtl.OnehotLSTMClassifier if onehot else mtl.EmbeddingLSTMClassifier
else:
    params.update({'non-linearity': nonlinearity})

    if model == 'cnn':
        params.update({'window_sizes': [int(win) for win in window_sizes[0].split(',')],
                       'num_filters': filters})
        model = mtl.OnehotCNNClassifier if onehot else mtl.EmbeddingCNNClassifier
    elif model == 'mlp':
        model = mtl.OnehotMLPClassifier if onehot else mtl.EmbeddingMLPClassifier

model = model(**params)
print(model)

EmbeddingLSTMClassifier(
  (all_parameters): ParameterList(
      (0): Parameter containing: [torch.FloatTensor of size 18176x64]
      (1): Parameter containing: [torch.FloatTensor of size 9826x64]
      (2): Parameter containing: [torch.FloatTensor of size 64x64]
      (3): Parameter containing: [torch.FloatTensor of size 64]
      (4): Parameter containing: [torch.FloatTensor of size 256x64]
      (5): Parameter containing: [torch.FloatTensor of size 256x64]
      (6): Parameter containing: [torch.FloatTensor of size 256]
      (7): Parameter containing: [torch.FloatTensor of size 256]
      (8): Parameter containing: [torch.FloatTensor of size 256x64]
      (9): Parameter containing: [torch.FloatTensor of size 256x64]
      (10): Parameter containing: [torch.FloatTensor of size 256]
      (11): Parameter containing: [torch.FloatTensor of size 256]
      (12): Parameter containing: [torch.FloatTensor of size 2x64]
      (13): Parameter containing: [torch.FloatTensor of size 2]
     

In [7]:
# Info about losses: https://bit.ly/3irxvYK
if loss == 'nlll':
    loss = torch.nn.NLLLoss()
elif loss == 'crossentropy':
    loss = torch.nn.CrossEntropyLoss()

# Set optimizer
if optimizer == 'adam':
    optimizer = torch.optim.Adam(model.parameters(), learning_rate)
elif optimizer == 'sgd':
    optimizer = torch.optim.SGD(model.parameters(), learning_rate)
elif optimizer == 'asgd':
    optimizer = torch.optim.ASGD(model.parameters(), learning_rate)
elif optimizer == 'adamw':
    optimizer = torch.optim.AdamW(model.parameters(), learning_rate)

In [8]:
# Batch data
batchers = []
test_batchers = []
if not onehot:
    train_buckets = BucketIterator(dataset = main['train'], batch_size = batch_size, sort_key = lambda x: len(x))
    main_train = TorchtextExtractor('text', 'label', main['name'], train_buckets)
    batchers.append(main_train)

    dev_buckets = BucketIterator(dataset = main['dev'], batch_size = 64, sort_key = lambda x: len(x))
    dev = TorchtextExtractor('text', 'label', main['name'], dev_buckets)

    test_buckets = BucketIterator(dataset = main['test'], batch_size = 64, sort_key = lambda x: len(x))
    test = TorchtextExtractor('text', 'label', main['name'], test_buckets)
    test_batchers.append(test)

    for aux in auxillary:
        train_buckets = BucketIterator(dataset = aux['train'], batch_size = batch_size, sort_key = lambda x: len(x))
        train = TorchtextExtractor('text', 'label', aux['name'], train_buckets)
        batchers.append(train)

        test_buckets = BucketIterator(dataset = aux['test'], batch_size = 64, sort_key = lambda x: len(x))
        test = TorchtextExtractor('text', 'label', aux['name'], test_buckets)
        test_batchers.append(test)
else:
    train_buckets = BucketIterator(dataset = main['train'], batch_size = batch_size, sort_key = lambda x: len(x))
    train = TorchtextExtractor('text', 'label', main['name'], train_buckets, len(main['text'].vocab.stoi))
    batchers.append(train)

    dev_buckets = BucketIterator(dataset = main['dev'], batch_size = 64, sort_key = lambda x: len(x))
    dev = TorchtextExtractor('text', 'label', main['name'], dev_buckets, len(main['text'].vocab.stoi))

    test_buckets = BucketIterator(dataset = main['test'], batch_size = 64, sort_key = lambda x: len(x))
    test = TorchtextExtractor('text', 'label', main['name'], test_buckets)
    test_batchers.append(test)

    for aux in auxillary:
        train_buckets = BucketIterator(dataset = aux['train'], batch_size = batch_size, sort_key = lambda x: len(x))
        train = TorchtextExtractor('text', 'label', aux['name'], train_buckets, len(aux['text'].vocab.stoi))
        batchers.append(train)

        test_buckets = BucketIterator(dataset = aux['test'], batch_size = 64, sort_key = lambda x: len(x))
        test = TorchtextExtractor('text', 'label', aux['name'], test_buckets)
        test_batchers.append(test)
next(iter(batchers[0]))

(tensor([[    2, 12800,    11,  ...,     1,     1,     1],
         [   96,   114,     4,  ...,     1,     1,     1],
         [  238,   411,     6,  ...,     1,     1,     1],
         ...,
         [   32,  7108,   127,  ...,     1,     1,     1],
         [   63,  1910,  8765,  ...,     1,     1,     1],
         [   12,     2,    32,  ...,     1,     1,     1]]),
 tensor([0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 0, 1, 0, 0, 0, 0, 0, 1, 1, 0, 0, 0, 0, 1,
         0, 0, 0, 0, 1, 0, 0, 1, 0, 0, 0, 0, 1, 0, 1, 0, 1, 1, 0, 0, 0, 1, 0, 0,
         0, 0, 0, 0, 1, 0, 0, 1, 0, 0, 0, 0, 0, 0, 1, 0]))

In [19]:
batch_writing = dict(mdl_hdr = ['model', 'scores'], batch_writer = csv.writer(open('test', 'w')),
                     model_hdr = ['name'], main_name = 'davidson', hyper_info = ['embedding_dim'],
                     metric_hdr = ['f1-score'], writer = csv.writer(open('test2', 'w')), data_name = 'davidson')
modelling_vars = dict(model = model, batchers = batchers, optimizer = optimizer, loss = loss,
                      metrics = metrics, batch_size = batch_size, epochs = epochs, clip = clip,
                      early_stopping = patience, save_model = results, dev = dev_buckets, dev_metrics = dev_metrics,
                      dev_task_id = 0, batches_per_epoch = batches_epoch, low = False, shuffle = False, 
                      dataset_weights = None, loss_weights = loss_weights, gpu = False, hyperopt = False)

In [10]:
train_mtl_model(**modelling_vars, **batch_writing)

Training model:   0%|          | 0/10 [00:00<?, ?it/s]
Batch:   0%|          | 0/50 [00:00<?, ?it/s][A
Batch:   0%|          | 0/50 [00:00<?, ?it/s, batch_loss=0.0325, epoch_loss=0.0325, task=1, task_score=0.0312][A
Batch:   2%|▏         | 1/50 [00:00<00:10,  4.63it/s, batch_loss=0.0325, epoch_loss=0.0325, task=1, task_score=0.0312][A
Batch:   2%|▏         | 1/50 [00:00<00:10,  4.63it/s, batch_loss=0.0650, epoch_loss=0.0487, task=0, task_score=0.1302][A
Batch:   4%|▍         | 2/50 [00:00<00:11,  4.29it/s, batch_loss=0.0650, epoch_loss=0.0487, task=0, task_score=0.1302][A
Batch:   4%|▍         | 2/50 [00:00<00:11,  4.29it/s, batch_loss=0.0327, epoch_loss=0.0434, task=1, task_score=0.2193][A
Batch:   6%|▌         | 3/50 [00:00<00:09,  4.76it/s, batch_loss=0.0327, epoch_loss=0.0434, task=1, task_score=0.2193][A
Batch:   6%|▌         | 3/50 [00:00<00:09,  4.76it/s, batch_loss=0.0325, epoch_loss=0.0407, task=1, task_score=0.1793][A
Batch:   8%|▊         | 4/50 [00:00<00:09,  5.09it

Batch:  66%|██████▌   | 33/50 [00:06<00:03,  4.95it/s, batch_loss=0.0653, epoch_loss=0.0542, task=0, task_score=0.1111][A
Batch:  66%|██████▌   | 33/50 [00:06<00:03,  4.95it/s, batch_loss=0.0324, epoch_loss=0.0536, task=1, task_score=0.1348][A
Batch:  68%|██████▊   | 34/50 [00:06<00:03,  5.26it/s, batch_loss=0.0324, epoch_loss=0.0536, task=1, task_score=0.1348][A
Batch:  68%|██████▊   | 34/50 [00:06<00:03,  5.26it/s, batch_loss=0.0322, epoch_loss=0.0529, task=1, task_score=0.2500][A
Batch:  70%|███████   | 35/50 [00:06<00:02,  5.67it/s, batch_loss=0.0322, epoch_loss=0.0529, task=1, task_score=0.2500][A
Batch:  70%|███████   | 35/50 [00:06<00:02,  5.67it/s, batch_loss=0.0649, epoch_loss=0.0533, task=0, task_score=0.2123][A
Batch:  72%|███████▏  | 36/50 [00:06<00:02,  5.88it/s, batch_loss=0.0649, epoch_loss=0.0533, task=0, task_score=0.2123][A
Batch:  72%|███████▏  | 36/50 [00:06<00:02,  5.88it/s, batch_loss=0.0323, epoch_loss=0.0527, task=1, task_score=0.2347][A
Batch:  74%|████

Batch:  26%|██▌       | 13/50 [00:02<00:08,  4.55it/s, batch_loss=0.0650, epoch_loss=0.0499, task=0, task_score=0.1429][A
Batch:  26%|██▌       | 13/50 [00:02<00:08,  4.55it/s, batch_loss=0.0320, epoch_loss=0.0486, task=1, task_score=0.2534][A
Batch:  28%|██▊       | 14/50 [00:02<00:07,  4.82it/s, batch_loss=0.0320, epoch_loss=0.0486, task=1, task_score=0.2534][A
Batch:  28%|██▊       | 14/50 [00:03<00:07,  4.82it/s, batch_loss=0.0650, epoch_loss=0.0497, task=0, task_score=0.1789][A
Batch:  30%|███       | 15/50 [00:03<00:07,  4.68it/s, batch_loss=0.0650, epoch_loss=0.0497, task=0, task_score=0.1789][A
Batch:  30%|███       | 15/50 [00:03<00:07,  4.68it/s, batch_loss=0.0650, epoch_loss=0.0506, task=0, task_score=0.1039][A
Batch:  32%|███▏      | 16/50 [00:03<00:07,  4.54it/s, batch_loss=0.0650, epoch_loss=0.0506, task=0, task_score=0.1039][A
Batch:  32%|███▏      | 16/50 [00:03<00:07,  4.54it/s, batch_loss=0.0324, epoch_loss=0.0496, task=1, task_score=0.1922][A
Batch:  34%|███▍

Batch:  92%|█████████▏| 46/50 [00:09<00:00,  4.71it/s, batch_loss=0.0313, epoch_loss=0.0483, task=1, task_score=0.3217][A
Batch:  92%|█████████▏| 46/50 [00:09<00:00,  4.71it/s, batch_loss=0.0316, epoch_loss=0.0479, task=1, task_score=0.4236][A
Batch:  94%|█████████▍| 47/50 [00:09<00:00,  4.97it/s, batch_loss=0.0316, epoch_loss=0.0479, task=1, task_score=0.4236][A
Batch:  94%|█████████▍| 47/50 [00:10<00:00,  4.97it/s, batch_loss=0.0317, epoch_loss=0.0476, task=1, task_score=0.3692][A
Batch:  96%|█████████▌| 48/50 [00:10<00:00,  5.07it/s, batch_loss=0.0317, epoch_loss=0.0476, task=1, task_score=0.3692][A
Batch:  96%|█████████▌| 48/50 [00:10<00:00,  5.07it/s, batch_loss=0.0650, epoch_loss=0.0479, task=0, task_score=0.1353][A
Batch:  98%|█████████▊| 49/50 [00:10<00:00,  4.44it/s, batch_loss=0.0650, epoch_loss=0.0479, task=0, task_score=0.1353][A
Batch:  98%|█████████▊| 49/50 [00:10<00:00,  4.44it/s, batch_loss=0.0650, epoch_loss=0.0483, task=0, task_score=0.1672][A
Batch: 100%|████

Batch:  52%|█████▏    | 26/50 [00:05<00:06,  3.61it/s, batch_loss=0.0650, epoch_loss=0.0483, task=0, task_score=0.1429][A
Batch:  52%|█████▏    | 26/50 [00:05<00:06,  3.61it/s, batch_loss=0.0320, epoch_loss=0.0477, task=1, task_score=0.3170][A
Batch:  54%|█████▍    | 27/50 [00:05<00:05,  3.89it/s, batch_loss=0.0320, epoch_loss=0.0477, task=1, task_score=0.3170][A
Batch:  54%|█████▍    | 27/50 [00:05<00:05,  3.89it/s, batch_loss=0.0319, epoch_loss=0.0471, task=1, task_score=0.3247][A
Batch:  56%|█████▌    | 28/50 [00:05<00:05,  4.05it/s, batch_loss=0.0319, epoch_loss=0.0471, task=1, task_score=0.3247][A
Batch:  56%|█████▌    | 28/50 [00:06<00:05,  4.05it/s, batch_loss=0.0311, epoch_loss=0.0466, task=1, task_score=0.4237][A
Batch:  58%|█████▊    | 29/50 [00:06<00:05,  3.69it/s, batch_loss=0.0311, epoch_loss=0.0466, task=1, task_score=0.4237][A
Batch:  58%|█████▊    | 29/50 [00:06<00:05,  3.69it/s, batch_loss=0.0650, epoch_loss=0.0472, task=0, task_score=0.1552][A
Batch:  60%|████

Batch:  12%|█▏        | 6/50 [00:01<00:10,  4.00it/s, batch_loss=0.0650, epoch_loss=0.0539, task=0, task_score=0.1552][A
Batch:  12%|█▏        | 6/50 [00:01<00:10,  4.00it/s, batch_loss=0.0314, epoch_loss=0.0507, task=1, task_score=0.4829][A
Batch:  14%|█▍        | 7/50 [00:01<00:09,  4.31it/s, batch_loss=0.0314, epoch_loss=0.0507, task=1, task_score=0.4829][A
Batch:  14%|█▍        | 7/50 [00:01<00:09,  4.31it/s, batch_loss=0.0650, epoch_loss=0.0525, task=0, task_score=0.2014][A
Batch:  16%|█▌        | 8/50 [00:01<00:09,  4.37it/s, batch_loss=0.0650, epoch_loss=0.0525, task=0, task_score=0.2014][A
Batch:  16%|█▌        | 8/50 [00:02<00:09,  4.37it/s, batch_loss=0.0312, epoch_loss=0.0501, task=1, task_score=0.3744][A
Batch:  18%|█▊        | 9/50 [00:02<00:09,  4.55it/s, batch_loss=0.0312, epoch_loss=0.0501, task=1, task_score=0.3744][A
Batch:  18%|█▊        | 9/50 [00:02<00:09,  4.55it/s, batch_loss=0.0650, epoch_loss=0.0516, task=0, task_score=0.1903][A
Batch:  20%|██        | 

Batch:  78%|███████▊  | 39/50 [00:09<00:02,  4.52it/s, batch_loss=0.0654, epoch_loss=0.0487, task=0, task_score=0.1467][A
Batch:  78%|███████▊  | 39/50 [00:09<00:02,  4.52it/s, batch_loss=0.0650, epoch_loss=0.0491, task=0, task_score=0.1672][A
Batch:  80%|████████  | 40/50 [00:09<00:02,  4.44it/s, batch_loss=0.0650, epoch_loss=0.0491, task=0, task_score=0.1672][A
Batch:  80%|████████  | 40/50 [00:09<00:02,  4.44it/s, batch_loss=0.0320, epoch_loss=0.0487, task=1, task_score=0.2675][A
Batch:  82%|████████▏ | 41/50 [00:09<00:01,  4.74it/s, batch_loss=0.0320, epoch_loss=0.0487, task=1, task_score=0.2675][A
Batch:  82%|████████▏ | 41/50 [00:09<00:01,  4.74it/s, batch_loss=0.0317, epoch_loss=0.0483, task=1, task_score=0.2662][A
Batch:  84%|████████▍ | 42/50 [00:09<00:01,  4.69it/s, batch_loss=0.0317, epoch_loss=0.0483, task=1, task_score=0.2662][A
Batch:  84%|████████▍ | 42/50 [00:10<00:01,  4.69it/s, batch_loss=0.0650, epoch_loss=0.0487, task=0, task_score=0.1302][A
Batch:  86%|████

Batch:  38%|███▊      | 19/50 [00:04<00:06,  4.81it/s, batch_loss=0.0650, epoch_loss=0.0546, task=0, task_score=0.1302][A
Batch:  38%|███▊      | 19/50 [00:04<00:06,  4.81it/s, batch_loss=0.0650, epoch_loss=0.0551, task=0, task_score=0.4336][A
Batch:  40%|████      | 20/50 [00:04<00:07,  4.19it/s, batch_loss=0.0650, epoch_loss=0.0551, task=0, task_score=0.4336][A
Batch:  40%|████      | 20/50 [00:04<00:07,  4.19it/s, batch_loss=0.0313, epoch_loss=0.0540, task=1, task_score=0.3578][A
Batch:  42%|████▏     | 21/50 [00:04<00:06,  4.65it/s, batch_loss=0.0313, epoch_loss=0.0540, task=1, task_score=0.3578][A
Batch:  42%|████▏     | 21/50 [00:04<00:06,  4.65it/s, batch_loss=0.0322, epoch_loss=0.0530, task=1, task_score=0.2409][A
Batch:  44%|████▍     | 22/50 [00:04<00:05,  4.96it/s, batch_loss=0.0322, epoch_loss=0.0530, task=1, task_score=0.2409][A
Batch:  44%|████▍     | 22/50 [00:05<00:05,  4.96it/s, batch_loss=0.0652, epoch_loss=0.0535, task=0, task_score=0.4336][A
Batch:  46%|████

Evaluating model:  46%|████▌     | 18/39 [00:00<00:00, 38.24it/s][A
Evaluating model:  59%|█████▉    | 23/39 [00:00<00:00, 41.07it/s][A
Evaluating model:  74%|███████▍  | 29/39 [00:00<00:00, 45.15it/s][A
Evaluating model:  87%|████████▋ | 34/39 [00:00<00:00, 43.51it/s][A
Training model:  50%|█████     | 5/10 [00:59<00:58, 11.71s/it, dev_loss=0.0653, dev_score=0.4551, loss=0.0323]
Batch:   0%|          | 0/50 [00:00<?, ?it/s][A
Batch:   0%|          | 0/50 [00:00<?, ?it/s, batch_loss=0.0647, epoch_loss=0.0647, task=0, task_score=0.4571][A
Batch:   2%|▏         | 1/50 [00:00<00:13,  3.72it/s, batch_loss=0.0647, epoch_loss=0.0647, task=0, task_score=0.4571][A
Batch:   2%|▏         | 1/50 [00:00<00:13,  3.72it/s, batch_loss=0.0646, epoch_loss=0.0647, task=0, task_score=0.4667][A
Batch:   4%|▍         | 2/50 [00:00<00:12,  3.82it/s, batch_loss=0.0646, epoch_loss=0.0647, task=0, task_score=0.4667][A
Batch:   4%|▍         | 2/50 [00:00<00:12,  3.82it/s, batch_loss=0.0320, epoch_loss=

Batch:  64%|██████▍   | 32/50 [00:07<00:03,  5.16it/s, batch_loss=0.0314, epoch_loss=0.0483, task=1, task_score=0.3423][A
Batch:  64%|██████▍   | 32/50 [00:07<00:03,  5.16it/s, batch_loss=0.0652, epoch_loss=0.0488, task=0, task_score=0.4602][A
Batch:  66%|██████▌   | 33/50 [00:07<00:03,  4.85it/s, batch_loss=0.0652, epoch_loss=0.0488, task=0, task_score=0.4602][A
Batch:  66%|██████▌   | 33/50 [00:08<00:03,  4.85it/s, batch_loss=0.0649, epoch_loss=0.0493, task=0, task_score=0.5495][A
Batch:  68%|██████▊   | 34/50 [00:08<00:03,  4.36it/s, batch_loss=0.0649, epoch_loss=0.0493, task=0, task_score=0.5495][A
Batch:  68%|██████▊   | 34/50 [00:08<00:03,  4.36it/s, batch_loss=0.0645, epoch_loss=0.0497, task=0, task_score=0.6522][A
Batch:  70%|███████   | 35/50 [00:08<00:03,  4.07it/s, batch_loss=0.0645, epoch_loss=0.0497, task=0, task_score=0.6522][A
Batch:  70%|███████   | 35/50 [00:08<00:03,  4.07it/s, batch_loss=0.0651, epoch_loss=0.0501, task=0, task_score=0.5211][A
Batch:  72%|████

Batch:  22%|██▏       | 11/50 [00:02<00:08,  4.86it/s, batch_loss=0.0648, epoch_loss=0.0457, task=0, task_score=0.6008][A
Batch:  24%|██▍       | 12/50 [00:02<00:08,  4.54it/s, batch_loss=0.0648, epoch_loss=0.0457, task=0, task_score=0.6008][A
Batch:  24%|██▍       | 12/50 [00:02<00:08,  4.54it/s, batch_loss=0.0323, epoch_loss=0.0447, task=1, task_score=0.2275][A
Batch:  26%|██▌       | 13/50 [00:02<00:07,  4.83it/s, batch_loss=0.0323, epoch_loss=0.0447, task=1, task_score=0.2275][A
Batch:  26%|██▌       | 13/50 [00:03<00:07,  4.83it/s, batch_loss=0.0321, epoch_loss=0.0438, task=1, task_score=0.1786][A
Batch:  28%|██▊       | 14/50 [00:03<00:07,  4.61it/s, batch_loss=0.0321, epoch_loss=0.0438, task=1, task_score=0.1786][A
Batch:  28%|██▊       | 14/50 [00:03<00:07,  4.61it/s, batch_loss=0.0648, epoch_loss=0.0452, task=0, task_score=0.5272][A
Batch:  30%|███       | 15/50 [00:03<00:07,  4.46it/s, batch_loss=0.0648, epoch_loss=0.0452, task=0, task_score=0.5272][A
Batch:  30%|███ 

Batch:  88%|████████▊ | 44/50 [00:09<00:01,  4.70it/s, batch_loss=0.0649, epoch_loss=0.0482, task=0, task_score=0.4545][A
Batch:  90%|█████████ | 45/50 [00:09<00:01,  4.45it/s, batch_loss=0.0649, epoch_loss=0.0482, task=0, task_score=0.4545][A
Batch:  90%|█████████ | 45/50 [00:10<00:01,  4.45it/s, batch_loss=0.0653, epoch_loss=0.0486, task=0, task_score=0.4922][A
Batch:  92%|█████████▏| 46/50 [00:10<00:00,  4.36it/s, batch_loss=0.0653, epoch_loss=0.0486, task=0, task_score=0.4922][A
Batch:  92%|█████████▏| 46/50 [00:10<00:00,  4.36it/s, batch_loss=0.0324, epoch_loss=0.0482, task=1, task_score=0.0811][A
Batch:  94%|█████████▍| 47/50 [00:10<00:00,  4.45it/s, batch_loss=0.0324, epoch_loss=0.0482, task=1, task_score=0.0811][A
Batch:  94%|█████████▍| 47/50 [00:10<00:00,  4.45it/s, batch_loss=0.0649, epoch_loss=0.0486, task=0, task_score=0.5748][A
Batch:  96%|█████████▌| 48/50 [00:10<00:00,  4.37it/s, batch_loss=0.0649, epoch_loss=0.0486, task=0, task_score=0.5748][A
Batch:  96%|████

Batch:  46%|████▌     | 23/50 [00:05<00:05,  5.25it/s, batch_loss=0.0320, epoch_loss=0.0430, task=1, task_score=0.2712][A
Batch:  46%|████▌     | 23/50 [00:05<00:05,  5.25it/s, batch_loss=0.0647, epoch_loss=0.0439, task=0, task_score=0.5393][A
Batch:  48%|████▊     | 24/50 [00:05<00:05,  4.61it/s, batch_loss=0.0647, epoch_loss=0.0439, task=0, task_score=0.5393][A
Batch:  48%|████▊     | 24/50 [00:05<00:05,  4.61it/s, batch_loss=0.0650, epoch_loss=0.0448, task=0, task_score=0.4684][A
Batch:  50%|█████     | 25/50 [00:05<00:05,  4.42it/s, batch_loss=0.0650, epoch_loss=0.0448, task=0, task_score=0.4684][A
Batch:  50%|█████     | 25/50 [00:05<00:05,  4.42it/s, batch_loss=0.0319, epoch_loss=0.0443, task=1, task_score=0.3082][A
Batch:  52%|█████▏    | 26/50 [00:05<00:05,  4.53it/s, batch_loss=0.0319, epoch_loss=0.0443, task=1, task_score=0.3082][A
Batch:  52%|█████▏    | 26/50 [00:05<00:05,  4.53it/s, batch_loss=0.0320, epoch_loss=0.0438, task=1, task_score=0.3183][A
Batch:  54%|████

Batch:   4%|▍         | 2/50 [00:00<00:09,  5.08it/s, batch_loss=0.0651, epoch_loss=0.0487, task=0, task_score=0.3883][A
Batch:   4%|▍         | 2/50 [00:00<00:09,  5.08it/s, batch_loss=0.0649, epoch_loss=0.0541, task=0, task_score=0.4789][A
Batch:   6%|▌         | 3/50 [00:00<00:10,  4.50it/s, batch_loss=0.0649, epoch_loss=0.0541, task=0, task_score=0.4789][A
Batch:   6%|▌         | 3/50 [00:00<00:10,  4.50it/s, batch_loss=0.0648, epoch_loss=0.0568, task=0, task_score=0.4818][A
Batch:   8%|▊         | 4/50 [00:00<00:10,  4.28it/s, batch_loss=0.0648, epoch_loss=0.0568, task=0, task_score=0.4818][A
Batch:   8%|▊         | 4/50 [00:01<00:10,  4.28it/s, batch_loss=0.0649, epoch_loss=0.0584, task=0, task_score=0.4910][A
Batch:  10%|█         | 5/50 [00:01<00:10,  4.17it/s, batch_loss=0.0649, epoch_loss=0.0584, task=0, task_score=0.4910][A
Batch:  10%|█         | 5/50 [00:01<00:10,  4.17it/s, batch_loss=0.0327, epoch_loss=0.0541, task=1, task_score=0.2286][A
Batch:  12%|█▏        | 

Batch:  70%|███████   | 35/50 [00:08<00:03,  3.91it/s, batch_loss=0.0650, epoch_loss=0.0557, task=0, task_score=0.2632][A
Batch:  70%|███████   | 35/50 [00:08<00:03,  3.91it/s, batch_loss=0.0323, epoch_loss=0.0550, task=1, task_score=0.2419][A
Batch:  72%|███████▏  | 36/50 [00:08<00:03,  4.08it/s, batch_loss=0.0323, epoch_loss=0.0550, task=1, task_score=0.2419][A
Batch:  72%|███████▏  | 36/50 [00:08<00:03,  4.08it/s, batch_loss=0.0324, epoch_loss=0.0544, task=1, task_score=0.2005][A
Batch:  74%|███████▍  | 37/50 [00:08<00:02,  4.36it/s, batch_loss=0.0324, epoch_loss=0.0544, task=1, task_score=0.2005][A
Batch:  74%|███████▍  | 37/50 [00:08<00:02,  4.36it/s, batch_loss=0.0650, epoch_loss=0.0547, task=0, task_score=0.1172][A
Batch:  76%|███████▌  | 38/50 [00:08<00:02,  4.03it/s, batch_loss=0.0650, epoch_loss=0.0547, task=0, task_score=0.1172][A
Batch:  76%|███████▌  | 38/50 [00:09<00:02,  4.03it/s, batch_loss=0.0326, epoch_loss=0.0541, task=1, task_score=0.2038][A
Batch:  78%|████

Early stopping: Terminate
Loading weights from epoch 7




In [11]:
print(dev_metrics.scores)

{'f1-score': [0.1656626310434239, 0.16474607084212997, 0.3628525784808495, 0.2057478353597029, 0.45509745543704383, 0.4970382299397844, 0.5228508795354752, 0.4760422276980435, 0.38235553670499867], 'precision': [0.5017581480139763, 0.5034403374896287, 0.49785796921926306, 0.5047894779895528, 0.5520911969897047, 0.5572674792282613, 0.5826826218550243, 0.550113937807486, 0.47901324015752633], 'recall': [0.5112941971323474, 0.5242604037906881, 0.498619957537155, 0.5108661070782805, 0.5292692792072291, 0.5325882110415087, 0.5479663420732219, 0.5280007574125222, 0.4876298388207742], 'accuracy': [0.18328623334679048, 0.1828825191764231, 0.37706903512313283, 0.21396851029471134, 0.5010092854259185, 0.5785224061364553, 0.6112232539362131, 0.540976988292289, 0.4113847396043601], 'loss': [0.06545948106336844, 0.06538571202567281, 0.06533544969308458, 0.065346574321311, 0.06526330964046474, 0.06538201244694598, 0.06518948852991836, 0.06531957334503802, 0.06541401819206795]}


In [20]:
run_mtl_model(train = True, **modelling_vars, **batch_writing)

Training model:   0%|          | 0/10 [00:00<?, ?it/s]
Batch:   0%|          | 0/50 [00:00<?, ?it/s][A
Batch:   0%|          | 0/50 [00:00<?, ?it/s, batch_loss=0.0646, epoch_loss=0.0646, task=0, task_score=0.5995][A
Batch:   2%|▏         | 1/50 [00:00<00:12,  3.78it/s, batch_loss=0.0646, epoch_loss=0.0646, task=0, task_score=0.5995][A
Batch:   2%|▏         | 1/50 [00:00<00:12,  3.78it/s, batch_loss=0.0647, epoch_loss=0.0647, task=0, task_score=0.5594][A
Batch:   4%|▍         | 2/50 [00:00<00:12,  3.81it/s, batch_loss=0.0647, epoch_loss=0.0647, task=0, task_score=0.5594][A
Batch:   4%|▍         | 2/50 [00:01<00:12,  3.81it/s, batch_loss=0.0647, epoch_loss=0.0647, task=0, task_score=0.5319][A
Batch:   6%|▌         | 3/50 [00:01<00:16,  2.83it/s, batch_loss=0.0647, epoch_loss=0.0647, task=0, task_score=0.5319][A
Batch:   6%|▌         | 3/50 [00:01<00:16,  2.83it/s, batch_loss=0.0322, epoch_loss=0.0566, task=1, task_score=0.3101][A
Batch:   8%|▊         | 4/50 [00:01<00:14,  3.14it

Batch:  66%|██████▌   | 33/50 [00:08<00:04,  4.16it/s, batch_loss=0.0650, epoch_loss=0.0490, task=0, task_score=0.5022][A
Batch:  66%|██████▌   | 33/50 [00:08<00:04,  4.16it/s, batch_loss=0.0655, epoch_loss=0.0495, task=0, task_score=0.3905][A
Batch:  68%|██████▊   | 34/50 [00:08<00:03,  4.17it/s, batch_loss=0.0655, epoch_loss=0.0495, task=0, task_score=0.3905][A
Batch:  68%|██████▊   | 34/50 [00:09<00:03,  4.17it/s, batch_loss=0.0324, epoch_loss=0.0490, task=1, task_score=0.2407][A
Batch:  70%|███████   | 35/50 [00:09<00:03,  4.43it/s, batch_loss=0.0324, epoch_loss=0.0490, task=1, task_score=0.2407][A
Batch:  70%|███████   | 35/50 [00:09<00:03,  4.43it/s, batch_loss=0.0648, epoch_loss=0.0494, task=0, task_score=0.5514][A
Batch:  72%|███████▏  | 36/50 [00:09<00:03,  4.43it/s, batch_loss=0.0648, epoch_loss=0.0494, task=0, task_score=0.5514][A
Batch:  72%|███████▏  | 36/50 [00:09<00:03,  4.43it/s, batch_loss=0.0651, epoch_loss=0.0498, task=0, task_score=0.4589][A
Batch:  74%|████

Batch:  20%|██        | 10/50 [00:02<00:09,  4.24it/s, batch_loss=0.0647, epoch_loss=0.0491, task=0, task_score=0.5995][A
Batch:  22%|██▏       | 11/50 [00:02<00:09,  4.13it/s, batch_loss=0.0647, epoch_loss=0.0491, task=0, task_score=0.5995][A
Batch:  22%|██▏       | 11/50 [00:02<00:09,  4.13it/s, batch_loss=0.0328, epoch_loss=0.0477, task=1, task_score=0.2253][A
Batch:  24%|██▍       | 12/50 [00:02<00:08,  4.29it/s, batch_loss=0.0328, epoch_loss=0.0477, task=1, task_score=0.2253][A
Batch:  24%|██▍       | 12/50 [00:03<00:08,  4.29it/s, batch_loss=0.0317, epoch_loss=0.0464, task=1, task_score=0.3261][A
Batch:  26%|██▌       | 13/50 [00:03<00:08,  4.31it/s, batch_loss=0.0317, epoch_loss=0.0464, task=1, task_score=0.3261][A
Batch:  26%|██▌       | 13/50 [00:03<00:08,  4.31it/s, batch_loss=0.0321, epoch_loss=0.0453, task=1, task_score=0.3000][A
Batch:  28%|██▊       | 14/50 [00:03<00:08,  4.44it/s, batch_loss=0.0321, epoch_loss=0.0453, task=1, task_score=0.3000][A
Batch:  28%|██▊ 

Batch:  86%|████████▌ | 43/50 [00:11<00:01,  3.63it/s, batch_loss=0.0652, epoch_loss=0.0492, task=0, task_score=0.4231][A
Batch:  88%|████████▊ | 44/50 [00:11<00:02,  2.82it/s, batch_loss=0.0652, epoch_loss=0.0492, task=0, task_score=0.4231][A
Batch:  88%|████████▊ | 44/50 [00:11<00:02,  2.82it/s, batch_loss=0.0323, epoch_loss=0.0488, task=1, task_score=0.2773][A
Batch:  90%|█████████ | 45/50 [00:11<00:01,  3.22it/s, batch_loss=0.0323, epoch_loss=0.0488, task=1, task_score=0.2773][A
Batch:  90%|█████████ | 45/50 [00:11<00:01,  3.22it/s, batch_loss=0.0322, epoch_loss=0.0484, task=1, task_score=0.2962][A
Batch:  92%|█████████▏| 46/50 [00:11<00:01,  3.68it/s, batch_loss=0.0322, epoch_loss=0.0484, task=1, task_score=0.2962][A
Batch:  92%|█████████▏| 46/50 [00:11<00:01,  3.68it/s, batch_loss=0.0650, epoch_loss=0.0488, task=0, task_score=0.5509][A
Batch:  94%|█████████▍| 47/50 [00:11<00:00,  3.76it/s, batch_loss=0.0650, epoch_loss=0.0488, task=0, task_score=0.5509][A
Batch:  94%|████

Batch:  42%|████▏     | 21/50 [00:04<00:07,  4.04it/s, batch_loss=0.0322, epoch_loss=0.0474, task=1, task_score=0.2718][A
Batch:  44%|████▍     | 22/50 [00:04<00:06,  4.36it/s, batch_loss=0.0322, epoch_loss=0.0474, task=1, task_score=0.2718][A
Batch:  44%|████▍     | 22/50 [00:04<00:06,  4.36it/s, batch_loss=0.0650, epoch_loss=0.0482, task=0, task_score=0.4488][A
Batch:  46%|████▌     | 23/50 [00:04<00:06,  4.36it/s, batch_loss=0.0650, epoch_loss=0.0482, task=0, task_score=0.4488][A
Batch:  46%|████▌     | 23/50 [00:05<00:06,  4.36it/s, batch_loss=0.0650, epoch_loss=0.0489, task=0, task_score=0.5039][A
Batch:  48%|████▊     | 24/50 [00:05<00:06,  4.07it/s, batch_loss=0.0650, epoch_loss=0.0489, task=0, task_score=0.5039][A
Batch:  48%|████▊     | 24/50 [00:05<00:06,  4.07it/s, batch_loss=0.0318, epoch_loss=0.0482, task=1, task_score=0.3451][A
Batch:  50%|█████     | 25/50 [00:05<00:05,  4.51it/s, batch_loss=0.0318, epoch_loss=0.0482, task=1, task_score=0.3451][A
Batch:  50%|████

Batch:   0%|          | 0/50 [00:00<?, ?it/s][A
Batch:   0%|          | 0/50 [00:00<?, ?it/s, batch_loss=0.0321, epoch_loss=0.0321, task=1, task_score=0.2528][A
Batch:   2%|▏         | 1/50 [00:00<00:07,  6.62it/s, batch_loss=0.0321, epoch_loss=0.0321, task=1, task_score=0.2528][A
Batch:   2%|▏         | 1/50 [00:00<00:07,  6.62it/s, batch_loss=0.0646, epoch_loss=0.0483, task=0, task_score=0.6382][A
Batch:   4%|▍         | 2/50 [00:00<00:08,  5.67it/s, batch_loss=0.0646, epoch_loss=0.0483, task=0, task_score=0.6382][A
Batch:   4%|▍         | 2/50 [00:00<00:08,  5.67it/s, batch_loss=0.0652, epoch_loss=0.0539, task=0, task_score=0.5656][A
Batch:   6%|▌         | 3/50 [00:00<00:10,  4.65it/s, batch_loss=0.0652, epoch_loss=0.0539, task=0, task_score=0.5656][A
Batch:   6%|▌         | 3/50 [00:00<00:10,  4.65it/s, batch_loss=0.0323, epoch_loss=0.0485, task=1, task_score=0.1465][A
Batch:   8%|▊         | 4/50 [00:00<00:09,  4.88it/s, batch_loss=0.0323, epoch_loss=0.0485, task=1, task_

Batch:  66%|██████▌   | 33/50 [00:07<00:04,  3.71it/s, batch_loss=0.0320, epoch_loss=0.0486, task=1, task_score=0.2397][A
Batch:  68%|██████▊   | 34/50 [00:07<00:03,  4.03it/s, batch_loss=0.0320, epoch_loss=0.0486, task=1, task_score=0.2397][A
Batch:  68%|██████▊   | 34/50 [00:07<00:03,  4.03it/s, batch_loss=0.0324, epoch_loss=0.0481, task=1, task_score=0.1934][A
Batch:  70%|███████   | 35/50 [00:07<00:03,  4.10it/s, batch_loss=0.0324, epoch_loss=0.0481, task=1, task_score=0.1934][A
Batch:  70%|███████   | 35/50 [00:08<00:03,  4.10it/s, batch_loss=0.0319, epoch_loss=0.0476, task=1, task_score=0.2475][A
Batch:  72%|███████▏  | 36/50 [00:08<00:03,  3.99it/s, batch_loss=0.0319, epoch_loss=0.0476, task=1, task_score=0.2475][A
Batch:  72%|███████▏  | 36/50 [00:08<00:03,  3.99it/s, batch_loss=0.0655, epoch_loss=0.0481, task=0, task_score=0.4253][A
Batch:  74%|███████▍  | 37/50 [00:08<00:03,  3.58it/s, batch_loss=0.0655, epoch_loss=0.0481, task=0, task_score=0.4253][A
Batch:  74%|████

Batch:  24%|██▍       | 12/50 [00:02<00:08,  4.27it/s, batch_loss=0.0321, epoch_loss=0.0402, task=1, task_score=0.2264][A
Batch:  24%|██▍       | 12/50 [00:03<00:08,  4.27it/s, batch_loss=0.0650, epoch_loss=0.0421, task=0, task_score=0.5105][A
Batch:  26%|██▌       | 13/50 [00:03<00:09,  3.92it/s, batch_loss=0.0650, epoch_loss=0.0421, task=0, task_score=0.5105][A
Batch:  26%|██▌       | 13/50 [00:03<00:09,  3.92it/s, batch_loss=0.0652, epoch_loss=0.0437, task=0, task_score=0.5278][A
Batch:  28%|██▊       | 14/50 [00:03<00:10,  3.50it/s, batch_loss=0.0652, epoch_loss=0.0437, task=0, task_score=0.5278][A
Batch:  28%|██▊       | 14/50 [00:03<00:10,  3.50it/s, batch_loss=0.0322, epoch_loss=0.0430, task=1, task_score=0.1842][A
Batch:  30%|███       | 15/50 [00:03<00:08,  3.90it/s, batch_loss=0.0322, epoch_loss=0.0430, task=1, task_score=0.1842][A
Batch:  30%|███       | 15/50 [00:03<00:08,  3.90it/s, batch_loss=0.0652, epoch_loss=0.0444, task=0, task_score=0.4667][A
Batch:  32%|███▏

Batch:  90%|█████████ | 45/50 [00:10<00:01,  4.97it/s, batch_loss=0.0324, epoch_loss=0.0482, task=1, task_score=0.2191][A
Batch:  90%|█████████ | 45/50 [00:10<00:01,  4.97it/s, batch_loss=0.0658, epoch_loss=0.0486, task=0, task_score=0.4781][A
Batch:  92%|█████████▏| 46/50 [00:10<00:00,  4.74it/s, batch_loss=0.0658, epoch_loss=0.0486, task=0, task_score=0.4781][A
Batch:  92%|█████████▏| 46/50 [00:10<00:00,  4.74it/s, batch_loss=0.0651, epoch_loss=0.0489, task=0, task_score=0.4583][A
Batch:  94%|█████████▍| 47/50 [00:10<00:00,  4.39it/s, batch_loss=0.0651, epoch_loss=0.0489, task=0, task_score=0.4583][A
Batch:  94%|█████████▍| 47/50 [00:11<00:00,  4.39it/s, batch_loss=0.0323, epoch_loss=0.0486, task=1, task_score=0.1999][A
Batch:  96%|█████████▌| 48/50 [00:11<00:00,  4.61it/s, batch_loss=0.0323, epoch_loss=0.0486, task=1, task_score=0.1999][A
Batch:  96%|█████████▌| 48/50 [00:11<00:00,  4.61it/s, batch_loss=0.0649, epoch_loss=0.0489, task=0, task_score=0.4393][A
Batch:  98%|████

Batch:  46%|████▌     | 23/50 [00:05<00:05,  4.92it/s, batch_loss=0.0319, epoch_loss=0.0459, task=1, task_score=0.3227][A
Batch:  48%|████▊     | 24/50 [00:05<00:05,  5.18it/s, batch_loss=0.0319, epoch_loss=0.0459, task=1, task_score=0.3227][A
Batch:  48%|████▊     | 24/50 [00:05<00:05,  5.18it/s, batch_loss=0.0322, epoch_loss=0.0454, task=1, task_score=0.3706][A
Batch:  50%|█████     | 25/50 [00:05<00:04,  5.35it/s, batch_loss=0.0322, epoch_loss=0.0454, task=1, task_score=0.3706][A
Batch:  50%|█████     | 25/50 [00:05<00:04,  5.35it/s, batch_loss=0.0321, epoch_loss=0.0448, task=1, task_score=0.2587][A
Batch:  52%|█████▏    | 26/50 [00:05<00:04,  5.29it/s, batch_loss=0.0321, epoch_loss=0.0448, task=1, task_score=0.2587][A
Batch:  52%|█████▏    | 26/50 [00:05<00:04,  5.29it/s, batch_loss=0.0321, epoch_loss=0.0444, task=1, task_score=0.1888][A
Batch:  54%|█████▍    | 27/50 [00:05<00:04,  5.37it/s, batch_loss=0.0321, epoch_loss=0.0444, task=1, task_score=0.1888][A
Batch:  54%|████

Early stopping: Terminate
Loading weights from epoch 4


