In [3]:
cfg = {
    'dev_train_len': 25*10**3,
    'dev_validation_len': 5*10**3,
    'learning_rate': 0.001,
    'epochs': 100,
    'embedding_dim': 50,
    'batch_size': 32,
    'dropout': 0,
    'optimizer': 'Adam',
    'num_layers': 2
}

cfg['hidden_dim'] = cfg['embedding_dim']

In [4]:
from nltk import word_tokenize
from gensim.models import KeyedVectors
import gensim.downloader as api
import unicodedata
import random
import torch
import re

def normalize_unicode(text: str) -> str:
    return unicodedata.normalize('NFD', text)

def tokenize_corpus(s: str) -> str:
    s = normalize_unicode(s)
    s = s.lower()
    s = re.sub(r"""[^a-zA-Z0-9?.,;'"]+""", " ", s)
    s = re.sub(r'(.)\1{3,}',r'\1', s)
    s = s.rstrip().strip()
    return s

def get_word_tokenized_corpus(s: str) -> list:
    return word_tokenize(s)

glove_dict = {
    '50': 'glove-wiki-gigaword-50',
    '100': 'glove-wiki-gigaword-100',
    '200': 'glove-wiki-gigaword-200'
}

glove_dict['50'] = api.load(glove_dict['50'])
# glove_dict['100'] = api.load(glove_dict['100'])
# glove_dict['200'] = api.load(glove_dict['200'])

def create_vocab(sentences: list, embedding_dim: int):
    glove = glove_dict[str(embedding_dim)]

    Emb = KeyedVectors(vector_size=glove.vector_size)
    vocab = []

    for sentence in sentences:
        vocab.extend(sentence)

    vocab = set(vocab)

    vectors, keys = [], []
    for token in vocab:
        if token in glove:
            vectors.append(torch.tensor(glove[token]))
            keys.append(token)

    keys.extend(['<unk>', '<pad>', '<sos>', '<eos>'])
    vectors.append(torch.mean(torch.stack(vectors), dim=0).numpy())
    vectors.append([0 for _ in range(embedding_dim)])
    vectors.append([random.random() for _ in range(embedding_dim)])
    vectors.append([random.random() for _ in range(embedding_dim)])
    Emb.add_vectors(keys, vectors)

    return Emb

def get_sentence_index(sentence: list, Emb: KeyedVectors):
    word_vec = []

    word_vec.append(Emb.key_to_index['<sos>'])
    for word in sentence:
        word_vec.append(get_vocab_index(word, Emb))
    word_vec.append(Emb.key_to_index['<eos>'])

    return torch.tensor(word_vec)

def get_vocab_index(word: str, Emb: KeyedVectors):
    if word in Emb:
        return Emb.key_to_index[word]
    return Emb.key_to_index['<unk>']


In [5]:
import numpy as np
import torch
import pandas as pd
import os

In [6]:
if torch.cuda.is_available():
    DEVICE = torch.device('cuda')
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = True
    os.environ["CUBLAS_WORKSPACE_CONFIG"] = ":16:8" # ":4096:2"
else:
    DEVICE = torch.device('cpu')
print(DEVICE)

cuda


In [7]:
DEV_TRAIN_LEN = cfg['dev_train_len']
DEV_VALIDATION_LEN = cfg['dev_validation_len']

DIR = '/scratch/shu7bh/RES/PRE'

In [8]:
import os
if not os.path.exists(DIR):
    os.makedirs(DIR)

In [9]:
print(DEV_TRAIN_LEN)
print(DEV_VALIDATION_LEN)

25000
5000


In [10]:
df = pd.read_csv('data/train.csv')
df = df.sample(frac=1, random_state=0).reset_index(drop=True)
df['Description'] = df['Description'].apply(tokenize_corpus)
df['Description'] = df['Description'].apply(get_word_tokenized_corpus)

In [11]:
df['Class Index']

0         4
1         4
2         4
3         1
4         3
         ..
119995    4
119996    3
119997    1
119998    4
119999    4
Name: Class Index, Length: 120000, dtype: int64

In [12]:
dev_train = df[:DEV_TRAIN_LEN]['Description']
dev_validation = df[DEV_TRAIN_LEN:DEV_TRAIN_LEN + DEV_VALIDATION_LEN]['Description']

In [13]:
dev_validation

25000    [linux, is, gaining, ground, with, companies, ...
25001    [google, ,, the, internet, search, engine, ,, ...
25002    [ap, defending, champion, tiger, woods, said, ...
25003    [advanced, micro, devices, will, add, quot, ;,...
25004    [reuters, about, 100, tibetan, exiles, chanted...
                               ...                        
29995    [they, were, 4, 12, last, year, ., they, have,...
29996    [luke, donald, and, paul, casey, are, unlikely...
29997    [yahoo, is, eager, to, improve, its, web, base...
29998    [ankara, strasbourg, ,, france, reuters, prime...
29999    [leipzig, game, convention, in, germany, ,, th...
Name: Description, Length: 5000, dtype: object

In [14]:
from torch.nn.utils.rnn import pad_sequence
from torch.utils.data import DataLoader

In [15]:
import torch
from torch import nn
from torch.nn.utils.rnn import pack_padded_sequence, pad_packed_sequence

class ELMO(nn.Module):
    def __init__(self, Emb, hidden_dim, dropout, num_layers):
        super().__init__()

        self.hidden_dim = hidden_dim
        self.num_layers = num_layers

        self.embedding = nn.Embedding.from_pretrained(torch.FloatTensor(Emb.vectors), padding_idx=Emb.key_to_index['<pad>'])
        self.lstm = nn.LSTM(Emb.vectors.shape[1], hidden_dim, batch_first=True, bidirectional=True, num_layers=num_layers, dropout=dropout)

    def forward(self, X, X_lengths):
        X = self.embedding(X)
        X = pack_padded_sequence(X, X_lengths, batch_first=True, enforce_sorted=False)
        X, (h_n, c_n) = self.lstm(X, None)
        X, _ = pad_packed_sequence(X, batch_first=True)
        return X, h_n, c_n

In [16]:
from torch import nn

class LM(nn.Module):
    def __init__(self, Emb, hidden_dim, dropout, num_layers):
        super().__init__()

        self.elmo = ELMO(Emb, hidden_dim, dropout, num_layers)
        self.fc = nn.Linear(hidden_dim * 2, Emb.vectors.shape[0])

    def forward(self, X, X_lengths):
        X, _, _ = self.elmo(X, X_lengths)
        X = self.fc(X)
        return X

In [17]:
from torch.utils.data import Dataset
import torch

class SentencesDataset(Dataset):
    def __init__(self, sentences: list, Emb):
        super().__init__()

        self.data = []
        for sentence in sentences:
            self.data.append(get_sentence_index(sentence, Emb))

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        return self.data[idx], torch.tensor(len(self.data[idx]))

In [18]:
class Collator:
    def __init__(self, Emb):
        self.pad_index = Emb.key_to_index['<pad>']

    def __call__(self, batch):
        X, X_lengths = zip(*batch)
        X = pad_sequence(X, batch_first=True, padding_value=self.pad_index)
        return X[:, :-1], X[:, 1:], torch.stack(X_lengths) - 1

In [19]:
import tqdm

def fit(model, dataloader, train, es, loss_fn, optimizer):
    model.train() if train else model.eval()
    epoch_loss = []

    pbar = tqdm.tqdm(dataloader)

    for X, Y, X_lengths in pbar:
        X = X.to(DEVICE)
        Y = Y.to(DEVICE)

        Y_pred = model(X, X_lengths)
        Y_pred = Y_pred.reshape(-1, Y_pred.shape[2])

        Y = Y.reshape(-1)

        loss = loss_fn(Y_pred, Y)
        epoch_loss.append(loss.item())

        X.detach()
        Y_pred.detach()
        Y.detach()

        if train:
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

        pbar.set_description(f'{"T" if train else "V"} Loss: {loss.item():7.4f}, Avg Loss: {np.mean(epoch_loss):7.4f}, Best Loss: {es.best_loss:7.4f}, Counter: {es.counter}')

    return np.mean(epoch_loss)

In [20]:
import numpy as np

class EarlyStopping:
    def __init__(self, patience:int = 3, delta:float = 0.001):
        self.patience = patience
        self.counter = 0
        self.best_loss:float = np.inf
        self.best_model_pth = 0
        self.delta = delta

    def __call__(self, loss, epoch: int):
        should_stop = False

        if loss >= self.best_loss - self.delta:
            self.counter += 1
            if self.counter > self.patience:
                should_stop = True
        else:
            self.best_loss = loss
            self.counter = 0
            self.best_model_pth = epoch
        return should_stop

In [21]:
def train(EPOCHS, model, training_dataloader, validation_dataloader, loss_fn, optimizer):
    es = EarlyStopping(patience=3, delta=0.001)

    for epoch in range(EPOCHS):
        print(f'\nEpoch {epoch+1}')

        epoch_loss = fit(model, training_dataloader, True, es, loss_fn, optimizer)

        with torch.no_grad():
            epoch_loss = fit(model, validation_dataloader, False, es, loss_fn, optimizer)
            if es(epoch_loss, epoch):
                break
            if es.counter == 0:
                torch.save(model.state_dict(), os.path.join(DIR, f'best_model.pth'))
                torch.save(model.elmo.state_dict(), os.path.join(DIR, f'best_model_elmo.pth'))

In [22]:
def run(config=None):
    BATCH_SIZE = cfg['batch_size']

    HIDDEN_DIM = config['hidden_dim']
    DROP_OUT = config['dropout']
    OPTIMIZER = config['optimizer']
    LEARNING_RATE = config['learning_rate']
    EPOCHS = config['epochs']
    EMBEDDITNG_DIM = config['embedding_dim']
    NUM_LAYERS = config['num_layers']

    Emb = create_vocab(df['Description'], EMBEDDITNG_DIM)

    dev_train_dataset = SentencesDataset(dev_train, Emb)
    dev_validation_dataset = SentencesDataset(dev_validation, Emb)

    collate_fn = Collator(Emb)

    training_dataloader = DataLoader(dev_train_dataset, batch_size=BATCH_SIZE, shuffle=True, collate_fn=collate_fn, pin_memory=True, num_workers=4)
    validation_dataloader = DataLoader(dev_validation_dataset, batch_size=BATCH_SIZE, shuffle=True, collate_fn=collate_fn, pin_memory=True, num_workers=4)

    torch.cuda.empty_cache()

    model = LM(Emb, HIDDEN_DIM, DROP_OUT, NUM_LAYERS).to(DEVICE)

    optimizer = getattr(torch.optim, OPTIMIZER)(model.parameters(), lr=LEARNING_RATE)
    loss_fn = nn.CrossEntropyLoss(ignore_index=Emb.key_to_index['<pad>'])

    train(EPOCHS, model, training_dataloader, validation_dataloader, loss_fn, optimizer)

run(cfg)


Epoch 1


T Loss:  5.6669, Avg Loss:  6.7665, Best Loss:     inf, Counter: 0: 100%|██████████| 782/782 [00:25<00:00, 30.71it/s]
V Loss:  5.8964, Avg Loss:  5.5419, Best Loss:     inf, Counter: 0: 100%|██████████| 157/157 [00:02<00:00, 70.91it/s]



Epoch 2


T Loss:  3.2303, Avg Loss:  4.5110, Best Loss:  5.5419, Counter: 0: 100%|██████████| 782/782 [00:24<00:00, 31.70it/s]
V Loss:  3.8009, Avg Loss:  3.7043, Best Loss:  5.5419, Counter: 0: 100%|██████████| 157/157 [00:02<00:00, 68.08it/s]



Epoch 3


T Loss:  2.6638, Avg Loss:  3.0581, Best Loss:  3.7043, Counter: 0: 100%|██████████| 782/782 [00:24<00:00, 31.81it/s]
V Loss:  2.3682, Avg Loss:  2.6370, Best Loss:  3.7043, Counter: 0: 100%|██████████| 157/157 [00:02<00:00, 66.78it/s]



Epoch 4


T Loss:  2.0449, Avg Loss:  2.2056, Best Loss:  2.6370, Counter: 0: 100%|██████████| 782/782 [00:24<00:00, 32.10it/s]
V Loss:  2.0542, Avg Loss:  1.9963, Best Loss:  2.6370, Counter: 0: 100%|██████████| 157/157 [00:02<00:00, 66.07it/s]



Epoch 5


T Loss:  1.6944, Avg Loss:  1.6509, Best Loss:  1.9963, Counter: 0: 100%|██████████| 782/782 [00:24<00:00, 32.14it/s]
V Loss:  1.4452, Avg Loss:  1.5670, Best Loss:  1.9963, Counter: 0: 100%|██████████| 157/157 [00:02<00:00, 67.92it/s]



Epoch 6


T Loss:  1.2775, Avg Loss:  1.2626, Best Loss:  1.5670, Counter: 0: 100%|██████████| 782/782 [00:24<00:00, 31.66it/s]
V Loss:  1.4089, Avg Loss:  1.2646, Best Loss:  1.5670, Counter: 0: 100%|██████████| 157/157 [00:02<00:00, 67.88it/s]



Epoch 7


T Loss:  0.8300, Avg Loss:  0.9761, Best Loss:  1.2646, Counter: 0: 100%|██████████| 782/782 [00:24<00:00, 32.04it/s]
V Loss:  1.1094, Avg Loss:  1.0395, Best Loss:  1.2646, Counter: 0: 100%|██████████| 157/157 [00:02<00:00, 67.41it/s]



Epoch 8


T Loss:  0.8049, Avg Loss:  0.7577, Best Loss:  1.0395, Counter: 0: 100%|██████████| 782/782 [00:24<00:00, 31.67it/s]
V Loss:  1.2037, Avg Loss:  0.8733, Best Loss:  1.0395, Counter: 0: 100%|██████████| 157/157 [00:02<00:00, 65.33it/s]



Epoch 9


T Loss:  0.5960, Avg Loss:  0.5874, Best Loss:  0.8733, Counter: 0: 100%|██████████| 782/782 [00:24<00:00, 31.57it/s]
V Loss:  0.5510, Avg Loss:  0.7411, Best Loss:  0.8733, Counter: 0: 100%|██████████| 157/157 [00:02<00:00, 68.10it/s]



Epoch 10


T Loss:  0.4723, Avg Loss:  0.4525, Best Loss:  0.7411, Counter: 0: 100%|██████████| 782/782 [00:24<00:00, 31.79it/s]
V Loss:  0.4759, Avg Loss:  0.6429, Best Loss:  0.7411, Counter: 0: 100%|██████████| 157/157 [00:02<00:00, 69.01it/s]



Epoch 11


T Loss:  0.2181, Avg Loss:  0.3441, Best Loss:  0.6429, Counter: 0: 100%|██████████| 782/782 [00:24<00:00, 31.79it/s]
V Loss:  0.5793, Avg Loss:  0.5673, Best Loss:  0.6429, Counter: 0: 100%|██████████| 157/157 [00:02<00:00, 69.40it/s]



Epoch 12


T Loss:  0.2058, Avg Loss:  0.2600, Best Loss:  0.5673, Counter: 0: 100%|██████████| 782/782 [00:24<00:00, 32.13it/s]
V Loss:  0.5778, Avg Loss:  0.5146, Best Loss:  0.5673, Counter: 0: 100%|██████████| 157/157 [00:02<00:00, 67.53it/s]



Epoch 13


T Loss:  0.1582, Avg Loss:  0.1950, Best Loss:  0.5146, Counter: 0: 100%|██████████| 782/782 [00:24<00:00, 31.64it/s]
V Loss:  0.3460, Avg Loss:  0.4760, Best Loss:  0.5146, Counter: 0: 100%|██████████| 157/157 [00:02<00:00, 67.63it/s]



Epoch 14


T Loss:  0.1365, Avg Loss:  0.1473, Best Loss:  0.4760, Counter: 0: 100%|██████████| 782/782 [00:24<00:00, 31.91it/s]
V Loss:  0.4337, Avg Loss:  0.4505, Best Loss:  0.4760, Counter: 0: 100%|██████████| 157/157 [00:02<00:00, 70.52it/s]



Epoch 15


T Loss:  0.0784, Avg Loss:  0.1106, Best Loss:  0.4505, Counter: 0: 100%|██████████| 782/782 [00:24<00:00, 31.74it/s]
V Loss:  0.4139, Avg Loss:  0.4324, Best Loss:  0.4505, Counter: 0: 100%|██████████| 157/157 [00:02<00:00, 68.61it/s]



Epoch 16


T Loss:  0.0591, Avg Loss:  0.0834, Best Loss:  0.4324, Counter: 0: 100%|██████████| 782/782 [00:24<00:00, 31.41it/s]
V Loss:  0.3869, Avg Loss:  0.4235, Best Loss:  0.4324, Counter: 0: 100%|██████████| 157/157 [00:02<00:00, 69.61it/s]



Epoch 17


T Loss:  0.0706, Avg Loss:  0.0642, Best Loss:  0.4235, Counter: 0: 100%|██████████| 782/782 [00:24<00:00, 31.58it/s]
V Loss:  0.6491, Avg Loss:  0.4135, Best Loss:  0.4235, Counter: 0: 100%|██████████| 157/157 [00:02<00:00, 68.42it/s]



Epoch 18


T Loss:  0.0470, Avg Loss:  0.0503, Best Loss:  0.4135, Counter: 0: 100%|██████████| 782/782 [00:24<00:00, 31.87it/s]
V Loss:  0.4522, Avg Loss:  0.4021, Best Loss:  0.4135, Counter: 0: 100%|██████████| 157/157 [00:02<00:00, 66.91it/s]



Epoch 19


T Loss:  0.0210, Avg Loss:  0.0405, Best Loss:  0.4021, Counter: 0: 100%|██████████| 782/782 [00:24<00:00, 31.49it/s]
V Loss:  0.3338, Avg Loss:  0.3978, Best Loss:  0.4021, Counter: 0: 100%|██████████| 157/157 [00:02<00:00, 70.56it/s]



Epoch 20


T Loss:  0.0405, Avg Loss:  0.0314, Best Loss:  0.3978, Counter: 0: 100%|██████████| 782/782 [00:24<00:00, 31.85it/s]
V Loss:  0.6362, Avg Loss:  0.3940, Best Loss:  0.3978, Counter: 0: 100%|██████████| 157/157 [00:02<00:00, 67.06it/s]



Epoch 21


T Loss:  0.0182, Avg Loss:  0.0256, Best Loss:  0.3940, Counter: 0: 100%|██████████| 782/782 [00:24<00:00, 31.36it/s]
V Loss:  0.4339, Avg Loss:  0.3928, Best Loss:  0.3940, Counter: 0: 100%|██████████| 157/157 [00:02<00:00, 67.68it/s]



Epoch 22


T Loss:  0.0179, Avg Loss:  0.0212, Best Loss:  0.3928, Counter: 0: 100%|██████████| 782/782 [00:24<00:00, 31.65it/s]
V Loss:  0.5643, Avg Loss:  0.3883, Best Loss:  0.3928, Counter: 0: 100%|██████████| 157/157 [00:02<00:00, 67.98it/s]



Epoch 23


T Loss:  0.0249, Avg Loss:  0.0181, Best Loss:  0.3883, Counter: 0: 100%|██████████| 782/782 [00:24<00:00, 31.45it/s]
V Loss:  0.0299, Avg Loss:  0.3859, Best Loss:  0.3883, Counter: 0: 100%|██████████| 157/157 [00:02<00:00, 70.13it/s]



Epoch 24


T Loss:  0.0157, Avg Loss:  0.0154, Best Loss:  0.3859, Counter: 0: 100%|██████████| 782/782 [00:24<00:00, 31.76it/s]
V Loss:  0.2539, Avg Loss:  0.3907, Best Loss:  0.3859, Counter: 0: 100%|██████████| 157/157 [00:02<00:00, 67.24it/s]



Epoch 25


T Loss:  0.0216, Avg Loss:  0.0118, Best Loss:  0.3859, Counter: 1: 100%|██████████| 782/782 [00:24<00:00, 31.82it/s]
V Loss:  0.2075, Avg Loss:  0.3846, Best Loss:  0.3859, Counter: 1: 100%|██████████| 157/157 [00:02<00:00, 68.30it/s]



Epoch 26


T Loss:  0.0070, Avg Loss:  0.0106, Best Loss:  0.3846, Counter: 0: 100%|██████████| 782/782 [00:24<00:00, 31.59it/s]
V Loss:  0.3104, Avg Loss:  0.3862, Best Loss:  0.3846, Counter: 0: 100%|██████████| 157/157 [00:02<00:00, 68.35it/s]



Epoch 27


T Loss:  0.0090, Avg Loss:  0.0085, Best Loss:  0.3846, Counter: 1: 100%|██████████| 782/782 [00:24<00:00, 31.85it/s]
V Loss:  0.5817, Avg Loss:  0.3875, Best Loss:  0.3846, Counter: 1: 100%|██████████| 157/157 [00:02<00:00, 67.62it/s]



Epoch 28


T Loss:  0.0071, Avg Loss:  0.0090, Best Loss:  0.3846, Counter: 2: 100%|██████████| 782/782 [00:24<00:00, 31.91it/s]
V Loss:  0.3678, Avg Loss:  0.3871, Best Loss:  0.3846, Counter: 2: 100%|██████████| 157/157 [00:02<00:00, 68.34it/s]



Epoch 29


T Loss:  0.0069, Avg Loss:  0.0060, Best Loss:  0.3846, Counter: 3: 100%|██████████| 782/782 [00:24<00:00, 31.88it/s]
V Loss:  0.3085, Avg Loss:  0.3859, Best Loss:  0.3846, Counter: 3: 100%|██████████| 157/157 [00:02<00:00, 68.14it/s]


In [22]:
import tqdm

def run_epoch(model, dataloader, loss_fn):
    epoch_loss = []

    pbar = tqdm.tqdm(dataloader)

    for X, Y, X_lengths in pbar:
        X = X.to(DEVICE)
        Y = Y.to(DEVICE)

        Y_pred = model(X, X_lengths)
        Y_pred = Y_pred.reshape(-1, Y_pred.shape[2])

        Y = Y.reshape(-1)

        loss = loss_fn(Y_pred, Y)
        epoch_loss.append(loss.item())

        pbar.set_description(f'Loss: {loss.item():7.4f}, Avg Loss: {np.mean(epoch_loss):7.4f}')

    return np.mean(epoch_loss)

In [23]:
def validate(elmo, validation_dataloader, loss_fn):
    with torch.no_grad():
        elmo.eval()
        epoch_loss = run_epoch(elmo, validation_dataloader, loss_fn)
        print(f'Validation Loss: {epoch_loss:7.4f}')

In [24]:
Emb = create_vocab(df['Description'], cfg['embedding_dim'])

dev_validation_dataset = SentencesDataset(dev_validation, Emb)

collate_fn = Collator(Emb)

validation_dataloader = DataLoader(dev_validation_dataset, batch_size=cfg['batch_size'], shuffle=True, collate_fn=collate_fn, pin_memory=True, num_workers=4)
model = LM(Emb, cfg['hidden_dim'], cfg['dropout'], cfg['num_layers']).to(DEVICE)

model.load_state_dict(torch.load(os.path.join(DIR, 'best_model.pth')))

validate(model, validation_dataloader, nn.CrossEntropyLoss(ignore_index=Emb.key_to_index['<pad>']))

Loss:  0.5460, Avg Loss:  0.3803: 100%|██████████| 313/313 [00:02<00:00, 108.19it/s]

Validation Loss:  0.3803





In [48]:
elmo = ELMO(Emb, cfg['hidden_dim'], cfg['dropout'], cfg['num_layers']).to(DEVICE)

# load only elmo from the best model
elmo.load_state_dict(torch.load(os.path.join(DIR, 'best_model_elmo.pth')))

<All keys matched successfully>

In [49]:
from torch.utils.data import DataLoader, Dataset

class DownStreamDataset(Dataset):
    def __init__(self, df: pd.DataFrame, Emb: KeyedVectors):
        self.descriptions = df['Description']
        self.descriptions = [get_sentence_index(description, Emb) for description in self.descriptions]
        # self.df.loc[: 'Description'] = df['Description'].apply(get_sentence_index, Emb=Emb)
        self.class_index = list(c - 1 for c in df['Class Index'])
        self.Emb = Emb

    def __len__(self):
        return len(self.class_index)

    def __getitem__(self, idx):
        l = torch.tensor(len(self.descriptions[idx]))
        return self.descriptions[idx], l, torch.tensor(self.class_index[idx])

In [50]:
class DownStreamCollator:
    def __init__(self, Emb):
        self.pad_index = Emb.key_to_index['<pad>']

    def __call__(self, batch):
        X, X_lengths, Y = zip(*batch)
        X = pad_sequence(X, batch_first=True, padding_value=self.pad_index)
        return X, torch.stack(X_lengths), torch.stack(Y)

In [51]:
downstream_train = df[DEV_TRAIN_LEN + DEV_VALIDATION_LEN:3 * DEV_TRAIN_LEN]
downstream_validation = df[:3 * DEV_VALIDATION_LEN]

downstream_train_dataset = DownStreamDataset(downstream_train, Emb)
downstream_validation_dataset = DownStreamDataset(downstream_validation, Emb)

downstream_collate_fn = DownStreamCollator(Emb)

downstream_training_dataloader = DataLoader(downstream_train_dataset, batch_size=64, shuffle=True, collate_fn=downstream_collate_fn, pin_memory=True, num_workers=4)
downstream_validation_dataloader = DataLoader(downstream_validation_dataset, batch_size=64, shuffle=True, collate_fn=downstream_collate_fn, pin_memory=True, num_workers=4)

In [73]:
class DownStream(nn.Module):
    def __init__(self, elmo, dropout, delta = None):
        super().__init__()
        self.elmo = elmo
        # freeze the ELMO parameters
        for param in self.elmo.parameters():
            param.requires_grad = False

        self.hidden_dim = self.elmo.hidden_dim
        self.num_layers = self.elmo.num_layers
        self.dropout = dropout

        if delta is None:
            self.delta = nn.Parameter(torch.randn(1, self.num_layers + 1))
        else:
            self.delta = delta.to(DEVICE)

        self.linear = nn.Linear(self.hidden_dim * 2, 4)

    def forward(self, X, X_lengths):
        _, Y1, Y2 = self.elmo(X, X_lengths)

        Y = torch.mean(torch.stack([Y1, Y2]), dim=0)
        Y = Y.permute(1, 0, 2).reshape(Y.shape[1], self.num_layers, self.hidden_dim * 2)

        X = self.elmo.embedding(X)
        X = torch.mean(X, dim=1, dtype=torch.float32)
        X = torch.cat([X, X], dim=1)
        X = X.unsqueeze(1)

        Y = torch.cat([X, Y], dim=1)

        Y = (self.delta / torch.sum(self.delta) ) @ Y

        Y = torch.sum(Y, dim=1)

        Y = self.linear(Y)

        return Y

In [87]:
def downstream_fit(model, dataloader, train, loss_fn, optimizer, es):
    model.train() if train else model.eval()
    epoch_loss = []

    pbar = tqdm.tqdm(dataloader)

    for X, X_lengths, Y in pbar:
        X = X.to(DEVICE)
        Y = Y.to(DEVICE)

        Y_pred = model(X, X_lengths)
        Y_pred = Y_pred.reshape(-1, Y_pred.shape[-1])

        loss = loss_fn(Y_pred, Y)
        epoch_loss.append(loss.item())

        X.detach()
        Y_pred.detach()
        Y.detach()

        if train:
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

        pbar.set_description(f'{"T" if train else "V"} Loss: {loss.item():7.4f}, Avg Loss: {np.mean(epoch_loss):7.4f}, Best Loss: {es.best_loss:7.4f}, Counter: {es.counter}')

    return np.mean(epoch_loss)

In [88]:
dloss_fn = nn.CrossEntropyLoss()

def downstream_train_fn(EPOCHS, model, training_dataloader, validation_dataloader, loss_fn, optimizer):
    es = EarlyStopping(patience=3, delta=0.001)
    for epoch in range(EPOCHS):
        print(f'\nEpoch {epoch+1}')

        epoch_loss = downstream_fit(model, training_dataloader, True, loss_fn, optimizer, es)
        print(f'Train Loss: {epoch_loss:7.4f}')
        # wandb.log({'downstream_train_loss': epoch_loss})

        with torch.no_grad():
            epoch_loss = downstream_fit(model, validation_dataloader, False, loss_fn, optimizer, es)
            print(f'Validation Loss: {epoch_loss:7.4f}')
            if es(epoch_loss, epoch):
                break
            if es.counter == 0:
                torch.save(model.state_dict(), os.path.join(DIR, f'downstream_best_model.pth'))
            # wandb.log({'downstream_validation_loss': epoch_loss})

In [76]:
tdf = pd.read_csv('data/test.csv')
tdf = tdf.sample(frac=1, random_state=0).reset_index(drop=True)
tdf['Description'] = tdf['Description'].apply(tokenize_corpus)
tdf['Description'] = tdf['Description'].apply(get_word_tokenized_corpus)

In [89]:
downstream_train_dataset = DownStreamDataset(tdf, Emb)
downstream_test_dataloader = DataLoader(downstream_train_dataset, batch_size=64, shuffle=True, collate_fn=downstream_collate_fn, pin_memory=True, num_workers=4)

In [90]:
from sklearn.metrics import classification_report, confusion_matrix, accuracy_score

def downstream_test_metrics(model, dataloader, loss_fn):
    epoch_loss = []

    pbar = tqdm.tqdm(dataloader)

    Y_pred_all = []
    Y_all = []
    Y_len = 0

    for X, X_lengths, Y in pbar:
        X = X.to(DEVICE)
        Y = Y.to(DEVICE)

        Y_pred = model(X, X_lengths)
        Y_pred = Y_pred.reshape(-1, Y_pred.shape[-1])

        Y_all.extend(Y.tolist())
        Y_pred_all.extend(torch.argmax(Y_pred, dim=1).tolist())
        Y_len += len(Y)

        loss = loss_fn(Y_pred, Y)
        epoch_loss.append(loss.item())

        X.detach()
        Y_pred.detach()
        Y.detach()

        pbar.set_description(f'{"T" if train else "V"} Loss: {loss.item():7.4f}, Avg Loss: {np.mean(epoch_loss):7.4f}')

    accuracy = accuracy_score(Y_all, Y_pred_all)
    print(f'Accuracy: {accuracy:7.4f}')
    print(confusion_matrix(Y_all, Y_pred_all))
    print(classification_report(Y_all, Y_pred_all, digits=4))

    return np.mean(epoch_loss)

In [93]:
delta = [None, [3, 0, 0], [1, 2, 3], [0, 0, 3], [1, 1, 1]] 
for i in range(1, len(delta)):
    delta[i] = torch.tensor(delta[i], dtype=torch.float32).reshape(1, -1).to(DEVICE)
    print(delta[i])
    print(delta[i].shape)

for d in delta:
    dmodel = DownStream(elmo, cfg['dropout'], d).to(DEVICE)
    doptimizer = getattr(torch.optim, cfg['optimizer'])(dmodel.parameters(), lr=cfg['learning_rate'])
    print(f'\n\nDelta: {d}')
    downstream_train_fn(cfg['epochs'], dmodel, downstream_training_dataloader, downstream_validation_dataloader, dloss_fn, doptimizer)
    dmodel.load_state_dict(torch.load(os.path.join(DIR, 'downstream_best_model.pth')))
    with torch.no_grad():
        loss = downstream_test_metrics(dmodel, downstream_test_dataloader, dloss_fn)
        # calculate accuracy and f1 score
        
        print(f'Test loss: {loss}')

tensor([[3., 0., 0.]], device='cuda:0')
torch.Size([1, 3])
tensor([[1., 2., 3.]], device='cuda:0')
torch.Size([1, 3])
tensor([[0., 0., 3.]], device='cuda:0')
torch.Size([1, 3])
tensor([[1., 1., 1.]], device='cuda:0')
torch.Size([1, 3])


Delta: None

Epoch 1


  0%|          | 0/704 [00:00<?, ?it/s]

T Loss:  0.8251, Avg Loss:  1.1084, Best Loss:     inf, Counter: 0: 100%|██████████| 704/704 [00:06<00:00, 104.15it/s]


Train Loss:  1.1084


V Loss:  0.8209, Avg Loss:  0.7772, Best Loss:     inf, Counter: 0: 100%|██████████| 235/235 [00:02<00:00, 111.51it/s]


Validation Loss:  0.7772

Epoch 2


T Loss:  0.5477, Avg Loss:  0.6912, Best Loss:  0.7772, Counter: 0: 100%|██████████| 704/704 [00:06<00:00, 102.98it/s]


Train Loss:  0.6912


V Loss:  0.8941, Avg Loss:  0.6576, Best Loss:  0.7772, Counter: 0: 100%|██████████| 235/235 [00:02<00:00, 111.64it/s]


Validation Loss:  0.6576

Epoch 3


T Loss:  0.9262, Avg Loss:  0.6191, Best Loss:  0.6576, Counter: 0: 100%|██████████| 704/704 [00:07<00:00, 100.51it/s]


Train Loss:  0.6191


V Loss:  0.5515, Avg Loss:  0.6392, Best Loss:  0.6576, Counter: 0: 100%|██████████| 235/235 [00:02<00:00, 109.10it/s]


Validation Loss:  0.6392

Epoch 4


T Loss:  0.2271, Avg Loss:  0.5899, Best Loss:  0.6392, Counter: 0: 100%|██████████| 704/704 [00:06<00:00, 102.75it/s]


Train Loss:  0.5899


V Loss:  0.8867, Avg Loss:  0.5952, Best Loss:  0.6392, Counter: 0: 100%|██████████| 235/235 [00:02<00:00, 110.53it/s]


Validation Loss:  0.5952

Epoch 5


T Loss:  0.5369, Avg Loss:  0.5761, Best Loss:  0.5952, Counter: 0: 100%|██████████| 704/704 [00:06<00:00, 105.17it/s]


Train Loss:  0.5761


V Loss:  0.6852, Avg Loss:  0.5808, Best Loss:  0.5952, Counter: 0: 100%|██████████| 235/235 [00:02<00:00, 113.43it/s]


Validation Loss:  0.5808

Epoch 6


T Loss:  0.8671, Avg Loss:  0.5610, Best Loss:  0.5808, Counter: 0: 100%|██████████| 704/704 [00:06<00:00, 104.23it/s]


Train Loss:  0.5610


V Loss:  0.6705, Avg Loss:  0.5548, Best Loss:  0.5808, Counter: 0: 100%|██████████| 235/235 [00:02<00:00, 105.21it/s]


Validation Loss:  0.5548

Epoch 7


T Loss:  0.4320, Avg Loss:  0.5575, Best Loss:  0.5548, Counter: 0: 100%|██████████| 704/704 [00:06<00:00, 100.60it/s]


Train Loss:  0.5575


V Loss:  0.4963, Avg Loss:  0.5712, Best Loss:  0.5548, Counter: 0: 100%|██████████| 235/235 [00:02<00:00, 111.83it/s]


Validation Loss:  0.5712

Epoch 8


T Loss:  1.1587, Avg Loss:  0.5525, Best Loss:  0.5548, Counter: 1: 100%|██████████| 704/704 [00:06<00:00, 102.39it/s]


Train Loss:  0.5525


V Loss:  0.5061, Avg Loss:  0.5670, Best Loss:  0.5548, Counter: 1: 100%|██████████| 235/235 [00:02<00:00, 112.12it/s]


Validation Loss:  0.5670

Epoch 9


T Loss:  0.6398, Avg Loss:  0.5447, Best Loss:  0.5548, Counter: 2: 100%|██████████| 704/704 [00:07<00:00, 100.36it/s]


Train Loss:  0.5447


V Loss:  0.6152, Avg Loss:  0.5374, Best Loss:  0.5548, Counter: 2: 100%|██████████| 235/235 [00:02<00:00, 107.21it/s]


Validation Loss:  0.5374

Epoch 10


T Loss:  0.5517, Avg Loss:  0.5435, Best Loss:  0.5374, Counter: 0: 100%|██████████| 704/704 [00:06<00:00, 103.91it/s]


Train Loss:  0.5435


V Loss:  0.5472, Avg Loss:  0.5322, Best Loss:  0.5374, Counter: 0: 100%|██████████| 235/235 [00:02<00:00, 113.56it/s]


Validation Loss:  0.5322

Epoch 11


T Loss:  0.4378, Avg Loss:  0.5396, Best Loss:  0.5322, Counter: 0: 100%|██████████| 704/704 [00:06<00:00, 104.99it/s]


Train Loss:  0.5396


V Loss:  0.4075, Avg Loss:  0.5280, Best Loss:  0.5322, Counter: 0: 100%|██████████| 235/235 [00:02<00:00, 115.14it/s]


Validation Loss:  0.5280

Epoch 12


T Loss:  0.2697, Avg Loss:  0.5370, Best Loss:  0.5280, Counter: 0: 100%|██████████| 704/704 [00:06<00:00, 106.14it/s]


Train Loss:  0.5370


V Loss:  0.3712, Avg Loss:  0.5775, Best Loss:  0.5280, Counter: 0: 100%|██████████| 235/235 [00:02<00:00, 110.46it/s]


Validation Loss:  0.5775

Epoch 13


T Loss:  0.5610, Avg Loss:  0.5337, Best Loss:  0.5280, Counter: 1: 100%|██████████| 704/704 [00:06<00:00, 105.42it/s]


Train Loss:  0.5337


V Loss:  0.4241, Avg Loss:  0.5320, Best Loss:  0.5280, Counter: 1: 100%|██████████| 235/235 [00:02<00:00, 116.07it/s]


Validation Loss:  0.5320

Epoch 14


T Loss:  0.6998, Avg Loss:  0.5312, Best Loss:  0.5280, Counter: 2: 100%|██████████| 704/704 [00:06<00:00, 105.78it/s]


Train Loss:  0.5312


V Loss:  0.2885, Avg Loss:  0.5201, Best Loss:  0.5280, Counter: 2: 100%|██████████| 235/235 [00:02<00:00, 115.64it/s]


Validation Loss:  0.5201

Epoch 15


T Loss:  0.3909, Avg Loss:  0.5274, Best Loss:  0.5201, Counter: 0: 100%|██████████| 704/704 [00:06<00:00, 104.12it/s]


Train Loss:  0.5274


V Loss:  0.5586, Avg Loss:  0.5156, Best Loss:  0.5201, Counter: 0: 100%|██████████| 235/235 [00:02<00:00, 106.89it/s]


Validation Loss:  0.5156

Epoch 16


T Loss:  0.2447, Avg Loss:  0.5250, Best Loss:  0.5156, Counter: 0: 100%|██████████| 704/704 [00:06<00:00, 100.98it/s]


Train Loss:  0.5250


V Loss:  0.4777, Avg Loss:  0.5149, Best Loss:  0.5156, Counter: 0: 100%|██████████| 235/235 [00:02<00:00, 116.45it/s]


Validation Loss:  0.5149

Epoch 17


T Loss:  1.1556, Avg Loss:  0.5232, Best Loss:  0.5156, Counter: 1: 100%|██████████| 704/704 [00:06<00:00, 104.54it/s]


Train Loss:  0.5232


V Loss:  0.4040, Avg Loss:  0.5760, Best Loss:  0.5156, Counter: 1: 100%|██████████| 235/235 [00:01<00:00, 119.22it/s]


Validation Loss:  0.5760

Epoch 18


T Loss:  0.8488, Avg Loss:  0.5252, Best Loss:  0.5156, Counter: 2: 100%|██████████| 704/704 [00:06<00:00, 103.01it/s]


Train Loss:  0.5252


V Loss:  0.2594, Avg Loss:  0.5327, Best Loss:  0.5156, Counter: 2: 100%|██████████| 235/235 [00:02<00:00, 111.00it/s]


Validation Loss:  0.5327

Epoch 19


T Loss:  0.1909, Avg Loss:  0.5171, Best Loss:  0.5156, Counter: 3: 100%|██████████| 704/704 [00:06<00:00, 102.03it/s]


Train Loss:  0.5171


V Loss:  0.4510, Avg Loss:  0.5879, Best Loss:  0.5156, Counter: 3: 100%|██████████| 235/235 [00:02<00:00, 117.24it/s]


Validation Loss:  0.5879


T Loss:  0.7342, Avg Loss:  0.5202: 100%|██████████| 119/119 [00:01<00:00, 101.48it/s]


Accuracy:  0.8145
[[1588   93  108  111]
 [  98 1707   31   64]
 [ 116   49 1417  318]
 [ 124   47  251 1478]]
              precision    recall  f1-score   support

           0     0.8245    0.8358    0.8301      1900
           1     0.9003    0.8984    0.8994      1900
           2     0.7842    0.7458    0.7645      1900
           3     0.7499    0.7779    0.7636      1900

    accuracy                         0.8145      7600
   macro avg     0.8147    0.8145    0.8144      7600
weighted avg     0.8147    0.8145    0.8144      7600

Test loss: 0.5201968483063353


Delta: tensor([[3., 0., 0.]], device='cuda:0')

Epoch 1


T Loss:  0.7613, Avg Loss:  0.9472, Best Loss:     inf, Counter: 0: 100%|██████████| 704/704 [00:06<00:00, 109.08it/s]


Train Loss:  0.9472


V Loss:  0.3383, Avg Loss:  0.7000, Best Loss:     inf, Counter: 0: 100%|██████████| 235/235 [00:02<00:00, 113.88it/s]


Validation Loss:  0.7000

Epoch 2


T Loss:  0.8442, Avg Loss:  0.6247, Best Loss:  0.7000, Counter: 0: 100%|██████████| 704/704 [00:06<00:00, 108.96it/s]


Train Loss:  0.6247


V Loss:  0.4368, Avg Loss:  0.5628, Best Loss:  0.7000, Counter: 0: 100%|██████████| 235/235 [00:02<00:00, 116.42it/s]


Validation Loss:  0.5628

Epoch 3


T Loss:  0.0559, Avg Loss:  0.5342, Best Loss:  0.5628, Counter: 0: 100%|██████████| 704/704 [00:06<00:00, 110.61it/s]


Train Loss:  0.5342


V Loss:  0.9376, Avg Loss:  0.5055, Best Loss:  0.5628, Counter: 0: 100%|██████████| 235/235 [00:02<00:00, 108.87it/s]


Validation Loss:  0.5055

Epoch 4


T Loss:  0.8322, Avg Loss:  0.4934, Best Loss:  0.5055, Counter: 0: 100%|██████████| 704/704 [00:06<00:00, 111.89it/s]


Train Loss:  0.4934


V Loss:  0.3223, Avg Loss:  0.4753, Best Loss:  0.5055, Counter: 0: 100%|██████████| 235/235 [00:02<00:00, 108.12it/s]


Validation Loss:  0.4753

Epoch 5


T Loss:  0.1599, Avg Loss:  0.4704, Best Loss:  0.4753, Counter: 0: 100%|██████████| 704/704 [00:06<00:00, 106.88it/s]


Train Loss:  0.4704


V Loss:  0.3761, Avg Loss:  0.4608, Best Loss:  0.4753, Counter: 0: 100%|██████████| 235/235 [00:02<00:00, 116.20it/s]


Validation Loss:  0.4608

Epoch 6


T Loss:  0.2595, Avg Loss:  0.4559, Best Loss:  0.4608, Counter: 0: 100%|██████████| 704/704 [00:06<00:00, 108.00it/s]


Train Loss:  0.4559


V Loss:  0.5569, Avg Loss:  0.4434, Best Loss:  0.4608, Counter: 0: 100%|██████████| 235/235 [00:02<00:00, 112.19it/s]


Validation Loss:  0.4434

Epoch 7


T Loss:  0.2214, Avg Loss:  0.4473, Best Loss:  0.4434, Counter: 0: 100%|██████████| 704/704 [00:06<00:00, 110.52it/s]


Train Loss:  0.4473


V Loss:  0.4786, Avg Loss:  0.4418, Best Loss:  0.4434, Counter: 0: 100%|██████████| 235/235 [00:02<00:00, 107.86it/s]


Validation Loss:  0.4418

Epoch 8


T Loss:  0.2109, Avg Loss:  0.4377, Best Loss:  0.4418, Counter: 0: 100%|██████████| 704/704 [00:06<00:00, 109.53it/s]


Train Loss:  0.4377


V Loss:  0.7661, Avg Loss:  0.4338, Best Loss:  0.4418, Counter: 0: 100%|██████████| 235/235 [00:02<00:00, 111.28it/s]


Validation Loss:  0.4338

Epoch 9


T Loss:  0.5354, Avg Loss:  0.4341, Best Loss:  0.4338, Counter: 0: 100%|██████████| 704/704 [00:06<00:00, 110.47it/s]


Train Loss:  0.4341


V Loss:  0.8101, Avg Loss:  0.4324, Best Loss:  0.4338, Counter: 0: 100%|██████████| 235/235 [00:02<00:00, 116.82it/s]


Validation Loss:  0.4324

Epoch 10


T Loss:  0.3207, Avg Loss:  0.4310, Best Loss:  0.4324, Counter: 0: 100%|██████████| 704/704 [00:06<00:00, 108.07it/s]


Train Loss:  0.4310


V Loss:  0.7677, Avg Loss:  0.4271, Best Loss:  0.4324, Counter: 0: 100%|██████████| 235/235 [00:02<00:00, 111.74it/s]


Validation Loss:  0.4271

Epoch 11


T Loss:  0.2960, Avg Loss:  0.4284, Best Loss:  0.4271, Counter: 0: 100%|██████████| 704/704 [00:06<00:00, 107.77it/s]


Train Loss:  0.4284


V Loss:  0.6674, Avg Loss:  0.4245, Best Loss:  0.4271, Counter: 0: 100%|██████████| 235/235 [00:02<00:00, 111.35it/s]


Validation Loss:  0.4245

Epoch 12


T Loss:  1.1925, Avg Loss:  0.4270, Best Loss:  0.4245, Counter: 0: 100%|██████████| 704/704 [00:06<00:00, 107.32it/s]


Train Loss:  0.4270


V Loss:  0.2520, Avg Loss:  0.4230, Best Loss:  0.4245, Counter: 0: 100%|██████████| 235/235 [00:02<00:00, 114.15it/s]


Validation Loss:  0.4230

Epoch 13


T Loss:  0.1433, Avg Loss:  0.4227, Best Loss:  0.4230, Counter: 0: 100%|██████████| 704/704 [00:06<00:00, 108.85it/s]


Train Loss:  0.4227


V Loss:  0.2717, Avg Loss:  0.4224, Best Loss:  0.4230, Counter: 0: 100%|██████████| 235/235 [00:02<00:00, 106.57it/s]


Validation Loss:  0.4224

Epoch 14


T Loss:  0.0266, Avg Loss:  0.4224, Best Loss:  0.4230, Counter: 1: 100%|██████████| 704/704 [00:06<00:00, 106.96it/s]


Train Loss:  0.4224


V Loss:  0.3459, Avg Loss:  0.4197, Best Loss:  0.4230, Counter: 1: 100%|██████████| 235/235 [00:02<00:00, 114.15it/s]


Validation Loss:  0.4197

Epoch 15


T Loss:  0.5819, Avg Loss:  0.4201, Best Loss:  0.4197, Counter: 0: 100%|██████████| 704/704 [00:06<00:00, 111.81it/s]


Train Loss:  0.4201


V Loss:  0.3195, Avg Loss:  0.4200, Best Loss:  0.4197, Counter: 0: 100%|██████████| 235/235 [00:02<00:00, 113.17it/s]


Validation Loss:  0.4200

Epoch 16


T Loss:  0.0102, Avg Loss:  0.4192, Best Loss:  0.4197, Counter: 1: 100%|██████████| 704/704 [00:06<00:00, 110.93it/s]


Train Loss:  0.4192


V Loss:  0.4820, Avg Loss:  0.4202, Best Loss:  0.4197, Counter: 1: 100%|██████████| 235/235 [00:01<00:00, 119.02it/s]


Validation Loss:  0.4202

Epoch 17


T Loss:  0.3467, Avg Loss:  0.4200, Best Loss:  0.4197, Counter: 2: 100%|██████████| 704/704 [00:06<00:00, 106.03it/s]


Train Loss:  0.4200


V Loss:  0.3858, Avg Loss:  0.4194, Best Loss:  0.4197, Counter: 2: 100%|██████████| 235/235 [00:02<00:00, 112.58it/s]


Validation Loss:  0.4194

Epoch 18


T Loss:  0.3486, Avg Loss:  0.4178, Best Loss:  0.4197, Counter: 3: 100%|██████████| 704/704 [00:06<00:00, 113.78it/s]


Train Loss:  0.4178


V Loss:  0.2958, Avg Loss:  0.4136, Best Loss:  0.4197, Counter: 3: 100%|██████████| 235/235 [00:02<00:00, 106.35it/s]


Validation Loss:  0.4136

Epoch 19


T Loss:  0.2199, Avg Loss:  0.4170, Best Loss:  0.4136, Counter: 0: 100%|██████████| 704/704 [00:06<00:00, 108.37it/s]


Train Loss:  0.4170


V Loss:  0.1475, Avg Loss:  0.4155, Best Loss:  0.4136, Counter: 0: 100%|██████████| 235/235 [00:02<00:00, 114.78it/s]


Validation Loss:  0.4155

Epoch 20


T Loss:  0.0865, Avg Loss:  0.4173, Best Loss:  0.4136, Counter: 1: 100%|██████████| 704/704 [00:06<00:00, 107.57it/s]


Train Loss:  0.4173


V Loss:  0.4292, Avg Loss:  0.4167, Best Loss:  0.4136, Counter: 1: 100%|██████████| 235/235 [00:02<00:00, 112.59it/s]


Validation Loss:  0.4167

Epoch 21


T Loss:  0.2537, Avg Loss:  0.4173, Best Loss:  0.4136, Counter: 2: 100%|██████████| 704/704 [00:06<00:00, 112.56it/s]


Train Loss:  0.4173


V Loss:  0.2683, Avg Loss:  0.4148, Best Loss:  0.4136, Counter: 2: 100%|██████████| 235/235 [00:02<00:00, 110.42it/s]


Validation Loss:  0.4148

Epoch 22


T Loss:  0.5986, Avg Loss:  0.4165, Best Loss:  0.4136, Counter: 3: 100%|██████████| 704/704 [00:06<00:00, 112.71it/s]


Train Loss:  0.4165


V Loss:  0.5086, Avg Loss:  0.4140, Best Loss:  0.4136, Counter: 3: 100%|██████████| 235/235 [00:02<00:00, 111.56it/s]


Validation Loss:  0.4140


T Loss:  0.7200, Avg Loss:  0.4371: 100%|██████████| 119/119 [00:01<00:00, 102.27it/s]


Accuracy:  0.8613
[[1636   68  115   81]
 [  37 1806   22   35]
 [ 108   17 1537  238]
 [  89   20  224 1567]]
              precision    recall  f1-score   support

           0     0.8749    0.8611    0.8679      1900
           1     0.9451    0.9505    0.9478      1900
           2     0.8098    0.8089    0.8094      1900
           3     0.8157    0.8247    0.8202      1900

    accuracy                         0.8613      7600
   macro avg     0.8614    0.8613    0.8613      7600
weighted avg     0.8614    0.8613    0.8613      7600

Test loss: 0.4370632403287567


Delta: tensor([[1., 2., 3.]], device='cuda:0')

Epoch 1


T Loss:  0.5435, Avg Loss:  1.0378, Best Loss:     inf, Counter: 0: 100%|██████████| 704/704 [00:06<00:00, 110.58it/s]


Train Loss:  1.0378


V Loss:  0.7507, Avg Loss:  0.8276, Best Loss:     inf, Counter: 0: 100%|██████████| 235/235 [00:02<00:00, 117.12it/s]


Validation Loss:  0.8276

Epoch 2


T Loss:  0.3898, Avg Loss:  0.7774, Best Loss:  0.8276, Counter: 0: 100%|██████████| 704/704 [00:06<00:00, 109.34it/s]


Train Loss:  0.7774


V Loss:  0.6978, Avg Loss:  0.7434, Best Loss:  0.8276, Counter: 0: 100%|██████████| 235/235 [00:02<00:00, 105.96it/s]


Validation Loss:  0.7434

Epoch 3


T Loss:  0.5786, Avg Loss:  0.7137, Best Loss:  0.7434, Counter: 0: 100%|██████████| 704/704 [00:06<00:00, 110.80it/s]


Train Loss:  0.7137


V Loss:  0.4863, Avg Loss:  0.6942, Best Loss:  0.7434, Counter: 0: 100%|██████████| 235/235 [00:02<00:00, 110.62it/s]


Validation Loss:  0.6942

Epoch 4


T Loss:  0.5547, Avg Loss:  0.6772, Best Loss:  0.6942, Counter: 0: 100%|██████████| 704/704 [00:06<00:00, 107.34it/s]


Train Loss:  0.6772


V Loss:  0.4017, Avg Loss:  0.6769, Best Loss:  0.6942, Counter: 0: 100%|██████████| 235/235 [00:02<00:00, 111.86it/s]


Validation Loss:  0.6769

Epoch 5


T Loss:  0.4474, Avg Loss:  0.6572, Best Loss:  0.6769, Counter: 0: 100%|██████████| 704/704 [00:06<00:00, 107.74it/s]


Train Loss:  0.6572


V Loss:  0.9674, Avg Loss:  0.6627, Best Loss:  0.6769, Counter: 0: 100%|██████████| 235/235 [00:02<00:00, 111.13it/s]


Validation Loss:  0.6627

Epoch 6


T Loss:  1.6741, Avg Loss:  0.6485, Best Loss:  0.6627, Counter: 0: 100%|██████████| 704/704 [00:06<00:00, 106.94it/s]


Train Loss:  0.6485


V Loss:  0.7506, Avg Loss:  0.6493, Best Loss:  0.6627, Counter: 0: 100%|██████████| 235/235 [00:01<00:00, 119.22it/s]


Validation Loss:  0.6493

Epoch 7


T Loss:  0.7077, Avg Loss:  0.6356, Best Loss:  0.6493, Counter: 0: 100%|██████████| 704/704 [00:06<00:00, 106.60it/s]


Train Loss:  0.6356


V Loss:  0.7553, Avg Loss:  0.6246, Best Loss:  0.6493, Counter: 0: 100%|██████████| 235/235 [00:02<00:00, 108.48it/s]


Validation Loss:  0.6246

Epoch 8


T Loss:  0.8174, Avg Loss:  0.6296, Best Loss:  0.6246, Counter: 0: 100%|██████████| 704/704 [00:06<00:00, 111.44it/s]


Train Loss:  0.6296


V Loss:  0.6669, Avg Loss:  0.6373, Best Loss:  0.6246, Counter: 0: 100%|██████████| 235/235 [00:02<00:00, 111.71it/s]


Validation Loss:  0.6373

Epoch 9


T Loss:  0.1707, Avg Loss:  0.6221, Best Loss:  0.6246, Counter: 1: 100%|██████████| 704/704 [00:06<00:00, 115.76it/s]


Train Loss:  0.6221


V Loss:  0.5436, Avg Loss:  0.6270, Best Loss:  0.6246, Counter: 1: 100%|██████████| 235/235 [00:01<00:00, 122.45it/s]


Validation Loss:  0.6270

Epoch 10


T Loss:  0.7099, Avg Loss:  0.6175, Best Loss:  0.6246, Counter: 2: 100%|██████████| 704/704 [00:06<00:00, 106.93it/s]


Train Loss:  0.6175


V Loss:  0.6457, Avg Loss:  0.6552, Best Loss:  0.6246, Counter: 2: 100%|██████████| 235/235 [00:02<00:00, 108.82it/s]


Validation Loss:  0.6552

Epoch 11


T Loss:  0.7420, Avg Loss:  0.6141, Best Loss:  0.6246, Counter: 3: 100%|██████████| 704/704 [00:06<00:00, 108.09it/s]


Train Loss:  0.6141


V Loss:  0.5823, Avg Loss:  0.6051, Best Loss:  0.6246, Counter: 3: 100%|██████████| 235/235 [00:02<00:00, 113.45it/s]


Validation Loss:  0.6051

Epoch 12


T Loss:  0.8810, Avg Loss:  0.6057, Best Loss:  0.6051, Counter: 0: 100%|██████████| 704/704 [00:06<00:00, 109.18it/s]


Train Loss:  0.6057


V Loss:  0.7678, Avg Loss:  0.6013, Best Loss:  0.6051, Counter: 0: 100%|██████████| 235/235 [00:02<00:00, 111.30it/s]


Validation Loss:  0.6013

Epoch 13


T Loss:  0.7535, Avg Loss:  0.6057, Best Loss:  0.6013, Counter: 0: 100%|██████████| 704/704 [00:06<00:00, 105.94it/s]


Train Loss:  0.6057


V Loss:  0.7501, Avg Loss:  0.6227, Best Loss:  0.6013, Counter: 0: 100%|██████████| 235/235 [00:02<00:00, 106.90it/s]


Validation Loss:  0.6227

Epoch 14


T Loss:  0.4458, Avg Loss:  0.6047, Best Loss:  0.6013, Counter: 1: 100%|██████████| 704/704 [00:06<00:00, 108.27it/s]


Train Loss:  0.6047


V Loss:  0.3753, Avg Loss:  0.6234, Best Loss:  0.6013, Counter: 1: 100%|██████████| 235/235 [00:02<00:00, 112.44it/s]


Validation Loss:  0.6234

Epoch 15


T Loss:  0.1999, Avg Loss:  0.5999, Best Loss:  0.6013, Counter: 2: 100%|██████████| 704/704 [00:06<00:00, 116.95it/s]


Train Loss:  0.5999


V Loss:  0.3381, Avg Loss:  0.6040, Best Loss:  0.6013, Counter: 2: 100%|██████████| 235/235 [00:02<00:00, 115.22it/s]


Validation Loss:  0.6040

Epoch 16


T Loss:  0.6461, Avg Loss:  0.5991, Best Loss:  0.6013, Counter: 3: 100%|██████████| 704/704 [00:06<00:00, 104.02it/s]


Train Loss:  0.5991


V Loss:  0.6194, Avg Loss:  0.5976, Best Loss:  0.6013, Counter: 3: 100%|██████████| 235/235 [00:02<00:00, 111.84it/s]


Validation Loss:  0.5976

Epoch 17


T Loss:  0.5420, Avg Loss:  0.5939, Best Loss:  0.5976, Counter: 0: 100%|██████████| 704/704 [00:06<00:00, 105.89it/s]


Train Loss:  0.5939


V Loss:  0.6214, Avg Loss:  0.6181, Best Loss:  0.5976, Counter: 0: 100%|██████████| 235/235 [00:02<00:00, 106.93it/s]


Validation Loss:  0.6181

Epoch 18


T Loss:  0.6494, Avg Loss:  0.5943, Best Loss:  0.5976, Counter: 1: 100%|██████████| 704/704 [00:06<00:00, 111.84it/s]


Train Loss:  0.5943


V Loss:  0.7217, Avg Loss:  0.5919, Best Loss:  0.5976, Counter: 1: 100%|██████████| 235/235 [00:02<00:00, 113.05it/s]


Validation Loss:  0.5919

Epoch 19


T Loss:  0.5497, Avg Loss:  0.5894, Best Loss:  0.5919, Counter: 0: 100%|██████████| 704/704 [00:06<00:00, 105.57it/s]


Train Loss:  0.5894


V Loss:  0.5919, Avg Loss:  0.5955, Best Loss:  0.5919, Counter: 0: 100%|██████████| 235/235 [00:02<00:00, 103.29it/s]


Validation Loss:  0.5955

Epoch 20


T Loss:  0.9286, Avg Loss:  0.5893, Best Loss:  0.5919, Counter: 1: 100%|██████████| 704/704 [00:06<00:00, 112.30it/s]


Train Loss:  0.5893


V Loss:  0.6634, Avg Loss:  0.5833, Best Loss:  0.5919, Counter: 1: 100%|██████████| 235/235 [00:02<00:00, 111.04it/s]


Validation Loss:  0.5833

Epoch 21


T Loss:  0.2469, Avg Loss:  0.5880, Best Loss:  0.5833, Counter: 0: 100%|██████████| 704/704 [00:06<00:00, 111.67it/s]


Train Loss:  0.5880


V Loss:  0.5431, Avg Loss:  0.6038, Best Loss:  0.5833, Counter: 0: 100%|██████████| 235/235 [00:02<00:00, 112.45it/s]


Validation Loss:  0.6038

Epoch 22


T Loss:  0.4190, Avg Loss:  0.5868, Best Loss:  0.5833, Counter: 1: 100%|██████████| 704/704 [00:06<00:00, 108.25it/s]


Train Loss:  0.5868


V Loss:  0.6987, Avg Loss:  0.5885, Best Loss:  0.5833, Counter: 1: 100%|██████████| 235/235 [00:02<00:00, 114.27it/s]


Validation Loss:  0.5885

Epoch 23


T Loss:  0.1456, Avg Loss:  0.5846, Best Loss:  0.5833, Counter: 2: 100%|██████████| 704/704 [00:06<00:00, 107.36it/s]


Train Loss:  0.5846


V Loss:  1.4023, Avg Loss:  0.5964, Best Loss:  0.5833, Counter: 2: 100%|██████████| 235/235 [00:02<00:00, 113.39it/s]


Validation Loss:  0.5964

Epoch 24


T Loss:  0.5444, Avg Loss:  0.5871, Best Loss:  0.5833, Counter: 3: 100%|██████████| 704/704 [00:06<00:00, 106.09it/s]


Train Loss:  0.5871


V Loss:  0.6530, Avg Loss:  0.5905, Best Loss:  0.5833, Counter: 3: 100%|██████████| 235/235 [00:02<00:00, 110.90it/s]


Validation Loss:  0.5905


T Loss:  0.2983, Avg Loss:  0.5832: 100%|██████████| 119/119 [00:01<00:00, 100.70it/s]


Accuracy:  0.7822
[[1485  165  145  105]
 [  76 1699   43   82]
 [  88   52 1402  358]
 [ 128  101  312 1359]]
              precision    recall  f1-score   support

           0     0.8357    0.7816    0.8077      1900
           1     0.8423    0.8942    0.8675      1900
           2     0.7371    0.7379    0.7375      1900
           3     0.7138    0.7153    0.7145      1900

    accuracy                         0.7822      7600
   macro avg     0.7822    0.7822    0.7818      7600
weighted avg     0.7822    0.7822    0.7818      7600

Test loss: 0.5831848365419051


Delta: tensor([[0., 0., 3.]], device='cuda:0')

Epoch 1


T Loss:  0.7073, Avg Loss:  1.1169, Best Loss:     inf, Counter: 0: 100%|██████████| 704/704 [00:06<00:00, 110.46it/s]


Train Loss:  1.1169


V Loss:  0.8766, Avg Loss:  0.9771, Best Loss:     inf, Counter: 0: 100%|██████████| 235/235 [00:02<00:00, 112.80it/s]


Validation Loss:  0.9771

Epoch 2


T Loss:  0.6882, Avg Loss:  0.9078, Best Loss:  0.9771, Counter: 0: 100%|██████████| 704/704 [00:06<00:00, 106.87it/s]


Train Loss:  0.9078


V Loss:  0.7066, Avg Loss:  0.8818, Best Loss:  0.9771, Counter: 0: 100%|██████████| 235/235 [00:02<00:00, 113.11it/s]


Validation Loss:  0.8818

Epoch 3


T Loss:  0.6899, Avg Loss:  0.8715, Best Loss:  0.8818, Counter: 0: 100%|██████████| 704/704 [00:06<00:00, 112.60it/s]


Train Loss:  0.8715


V Loss:  1.0041, Avg Loss:  0.8652, Best Loss:  0.8818, Counter: 0: 100%|██████████| 235/235 [00:02<00:00, 110.92it/s]


Validation Loss:  0.8652

Epoch 4


T Loss:  0.7157, Avg Loss:  0.8546, Best Loss:  0.8652, Counter: 0: 100%|██████████| 704/704 [00:06<00:00, 107.60it/s]


Train Loss:  0.8546


V Loss:  0.7285, Avg Loss:  0.8368, Best Loss:  0.8652, Counter: 0: 100%|██████████| 235/235 [00:02<00:00, 112.67it/s]


Validation Loss:  0.8368

Epoch 5


T Loss:  0.5853, Avg Loss:  0.8502, Best Loss:  0.8368, Counter: 0: 100%|██████████| 704/704 [00:06<00:00, 108.63it/s]


Train Loss:  0.8502


V Loss:  0.6897, Avg Loss:  0.9692, Best Loss:  0.8368, Counter: 0: 100%|██████████| 235/235 [00:02<00:00, 112.88it/s]


Validation Loss:  0.9692

Epoch 6


T Loss:  1.5027, Avg Loss:  0.8540, Best Loss:  0.8368, Counter: 1: 100%|██████████| 704/704 [00:06<00:00, 109.66it/s]


Train Loss:  0.8540


V Loss:  1.1694, Avg Loss:  0.8574, Best Loss:  0.8368, Counter: 1: 100%|██████████| 235/235 [00:02<00:00, 108.98it/s]


Validation Loss:  0.8574

Epoch 7


T Loss:  1.2905, Avg Loss:  0.8389, Best Loss:  0.8368, Counter: 2: 100%|██████████| 704/704 [00:06<00:00, 111.82it/s]


Train Loss:  0.8389


V Loss:  0.8417, Avg Loss:  0.8411, Best Loss:  0.8368, Counter: 2: 100%|██████████| 235/235 [00:02<00:00, 109.55it/s]


Validation Loss:  0.8411

Epoch 8


T Loss:  0.9630, Avg Loss:  0.8362, Best Loss:  0.8368, Counter: 3: 100%|██████████| 704/704 [00:06<00:00, 109.95it/s]


Train Loss:  0.8362


V Loss:  0.8088, Avg Loss:  0.8183, Best Loss:  0.8368, Counter: 3: 100%|██████████| 235/235 [00:02<00:00, 112.61it/s]


Validation Loss:  0.8183

Epoch 9


T Loss:  0.6048, Avg Loss:  0.8415, Best Loss:  0.8183, Counter: 0: 100%|██████████| 704/704 [00:06<00:00, 110.99it/s]


Train Loss:  0.8415


V Loss:  0.6912, Avg Loss:  0.8506, Best Loss:  0.8183, Counter: 0: 100%|██████████| 235/235 [00:02<00:00, 109.02it/s]


Validation Loss:  0.8506

Epoch 10


T Loss:  1.0435, Avg Loss:  0.8396, Best Loss:  0.8183, Counter: 1: 100%|██████████| 704/704 [00:06<00:00, 108.78it/s]


Train Loss:  0.8396


V Loss:  0.7804, Avg Loss:  0.8585, Best Loss:  0.8183, Counter: 1: 100%|██████████| 235/235 [00:02<00:00, 113.43it/s]


Validation Loss:  0.8585

Epoch 11


T Loss:  1.4876, Avg Loss:  0.8402, Best Loss:  0.8183, Counter: 2: 100%|██████████| 704/704 [00:06<00:00, 111.77it/s]


Train Loss:  0.8402


V Loss:  0.6570, Avg Loss:  0.8165, Best Loss:  0.8183, Counter: 2: 100%|██████████| 235/235 [00:02<00:00, 107.87it/s]


Validation Loss:  0.8165

Epoch 12


T Loss:  0.6562, Avg Loss:  0.8315, Best Loss:  0.8165, Counter: 0: 100%|██████████| 704/704 [00:06<00:00, 105.51it/s]


Train Loss:  0.8315


V Loss:  1.1478, Avg Loss:  0.8444, Best Loss:  0.8165, Counter: 0: 100%|██████████| 235/235 [00:02<00:00, 112.87it/s]


Validation Loss:  0.8444

Epoch 13


T Loss:  1.0691, Avg Loss:  0.8365, Best Loss:  0.8165, Counter: 1: 100%|██████████| 704/704 [00:06<00:00, 112.48it/s]


Train Loss:  0.8365


V Loss:  1.1656, Avg Loss:  0.8165, Best Loss:  0.8165, Counter: 1: 100%|██████████| 235/235 [00:02<00:00, 108.99it/s]


Validation Loss:  0.8165

Epoch 14


T Loss:  0.5583, Avg Loss:  0.8337, Best Loss:  0.8165, Counter: 2: 100%|██████████| 704/704 [00:06<00:00, 109.90it/s]


Train Loss:  0.8337


V Loss:  1.0299, Avg Loss:  0.8109, Best Loss:  0.8165, Counter: 2: 100%|██████████| 235/235 [00:01<00:00, 122.51it/s]


Validation Loss:  0.8109

Epoch 15


T Loss:  0.6504, Avg Loss:  0.8273, Best Loss:  0.8109, Counter: 0: 100%|██████████| 704/704 [00:06<00:00, 110.30it/s]


Train Loss:  0.8273


V Loss:  1.0382, Avg Loss:  0.8341, Best Loss:  0.8109, Counter: 0: 100%|██████████| 235/235 [00:02<00:00, 108.76it/s]


Validation Loss:  0.8341

Epoch 16


T Loss:  0.5509, Avg Loss:  0.8298, Best Loss:  0.8109, Counter: 1: 100%|██████████| 704/704 [00:06<00:00, 107.65it/s]


Train Loss:  0.8298


V Loss:  0.7070, Avg Loss:  0.8569, Best Loss:  0.8109, Counter: 1: 100%|██████████| 235/235 [00:02<00:00, 108.36it/s]


Validation Loss:  0.8569

Epoch 17


T Loss:  0.7233, Avg Loss:  0.8297, Best Loss:  0.8109, Counter: 2: 100%|██████████| 704/704 [00:06<00:00, 110.41it/s]


Train Loss:  0.8297


V Loss:  0.8316, Avg Loss:  0.8353, Best Loss:  0.8109, Counter: 2: 100%|██████████| 235/235 [00:02<00:00, 106.79it/s]


Validation Loss:  0.8353

Epoch 18


T Loss:  0.8926, Avg Loss:  0.8330, Best Loss:  0.8109, Counter: 3: 100%|██████████| 704/704 [00:06<00:00, 110.14it/s]


Train Loss:  0.8330


V Loss:  1.1273, Avg Loss:  0.8100, Best Loss:  0.8109, Counter: 3: 100%|██████████| 235/235 [00:02<00:00, 112.78it/s]


Validation Loss:  0.8100


T Loss:  0.7303, Avg Loss:  0.7946: 100%|██████████| 119/119 [00:01<00:00, 97.16it/s] 


Accuracy:  0.6732
[[1359  232  132  177]
 [ 216 1535   32  117]
 [ 288   73  973  566]
 [ 248  135  268 1249]]
              precision    recall  f1-score   support

           0     0.6438    0.7153    0.6776      1900
           1     0.7772    0.8079    0.7923      1900
           2     0.6925    0.5121    0.5888      1900
           3     0.5922    0.6574    0.6231      1900

    accuracy                         0.6732      7600
   macro avg     0.6764    0.6732    0.6704      7600
weighted avg     0.6764    0.6732    0.6704      7600

Test loss: 0.7946344824398265


Delta: tensor([[1., 1., 1.]], device='cuda:0')

Epoch 1


T Loss:  0.7226, Avg Loss:  1.0826, Best Loss:     inf, Counter: 0: 100%|██████████| 704/704 [00:06<00:00, 109.55it/s]


Train Loss:  1.0826


V Loss:  0.9366, Avg Loss:  0.8589, Best Loss:     inf, Counter: 0: 100%|██████████| 235/235 [00:02<00:00, 106.65it/s]


Validation Loss:  0.8589

Epoch 2


T Loss:  0.7548, Avg Loss:  0.7647, Best Loss:  0.8589, Counter: 0: 100%|██████████| 704/704 [00:06<00:00, 105.33it/s]


Train Loss:  0.7647


V Loss:  0.5910, Avg Loss:  0.7092, Best Loss:  0.8589, Counter: 0: 100%|██████████| 235/235 [00:02<00:00, 113.57it/s]


Validation Loss:  0.7092

Epoch 3


T Loss:  1.2352, Avg Loss:  0.6839, Best Loss:  0.7092, Counter: 0: 100%|██████████| 704/704 [00:06<00:00, 106.94it/s]


Train Loss:  0.6839


V Loss:  0.4986, Avg Loss:  0.6670, Best Loss:  0.7092, Counter: 0: 100%|██████████| 235/235 [00:02<00:00, 108.59it/s]


Validation Loss:  0.6670

Epoch 4


T Loss:  0.5883, Avg Loss:  0.6473, Best Loss:  0.6670, Counter: 0: 100%|██████████| 704/704 [00:06<00:00, 107.71it/s]


Train Loss:  0.6473


V Loss:  0.6097, Avg Loss:  0.6309, Best Loss:  0.6670, Counter: 0: 100%|██████████| 235/235 [00:02<00:00, 112.49it/s]


Validation Loss:  0.6309

Epoch 5


T Loss:  0.5833, Avg Loss:  0.6261, Best Loss:  0.6309, Counter: 0: 100%|██████████| 704/704 [00:06<00:00, 106.64it/s]


Train Loss:  0.6261


V Loss:  0.6902, Avg Loss:  0.6242, Best Loss:  0.6309, Counter: 0: 100%|██████████| 235/235 [00:02<00:00, 112.48it/s]


Validation Loss:  0.6242

Epoch 6


T Loss:  0.7572, Avg Loss:  0.6128, Best Loss:  0.6242, Counter: 0: 100%|██████████| 704/704 [00:06<00:00, 109.28it/s]


Train Loss:  0.6128


V Loss:  0.5637, Avg Loss:  0.6071, Best Loss:  0.6242, Counter: 0: 100%|██████████| 235/235 [00:01<00:00, 119.42it/s]


Validation Loss:  0.6071

Epoch 7


T Loss:  0.5653, Avg Loss:  0.6008, Best Loss:  0.6071, Counter: 0: 100%|██████████| 704/704 [00:06<00:00, 109.83it/s]


Train Loss:  0.6008


V Loss:  0.4938, Avg Loss:  0.6176, Best Loss:  0.6071, Counter: 0: 100%|██████████| 235/235 [00:02<00:00, 108.62it/s]


Validation Loss:  0.6176

Epoch 8


T Loss:  0.5625, Avg Loss:  0.5941, Best Loss:  0.6071, Counter: 1: 100%|██████████| 704/704 [00:06<00:00, 110.14it/s]


Train Loss:  0.5941


V Loss:  0.5386, Avg Loss:  0.5888, Best Loss:  0.6071, Counter: 1: 100%|██████████| 235/235 [00:02<00:00, 105.99it/s]


Validation Loss:  0.5888

Epoch 9


T Loss:  0.3909, Avg Loss:  0.5862, Best Loss:  0.5888, Counter: 0: 100%|██████████| 704/704 [00:06<00:00, 109.90it/s]


Train Loss:  0.5862


V Loss:  0.5578, Avg Loss:  0.5913, Best Loss:  0.5888, Counter: 0: 100%|██████████| 235/235 [00:02<00:00, 111.40it/s]


Validation Loss:  0.5913

Epoch 10


T Loss:  0.3103, Avg Loss:  0.5817, Best Loss:  0.5888, Counter: 1: 100%|██████████| 704/704 [00:06<00:00, 116.77it/s]


Train Loss:  0.5817


V Loss:  0.6361, Avg Loss:  0.5941, Best Loss:  0.5888, Counter: 1: 100%|██████████| 235/235 [00:02<00:00, 110.76it/s]


Validation Loss:  0.5941

Epoch 11


T Loss:  1.0211, Avg Loss:  0.5773, Best Loss:  0.5888, Counter: 2: 100%|██████████| 704/704 [00:06<00:00, 107.44it/s]


Train Loss:  0.5773


V Loss:  0.7116, Avg Loss:  0.5770, Best Loss:  0.5888, Counter: 2: 100%|██████████| 235/235 [00:02<00:00, 110.43it/s]


Validation Loss:  0.5770

Epoch 12


T Loss:  0.2970, Avg Loss:  0.5747, Best Loss:  0.5770, Counter: 0: 100%|██████████| 704/704 [00:06<00:00, 109.45it/s]


Train Loss:  0.5747


V Loss:  0.7824, Avg Loss:  0.5790, Best Loss:  0.5770, Counter: 0: 100%|██████████| 235/235 [00:02<00:00, 117.40it/s]


Validation Loss:  0.5790

Epoch 13


T Loss:  0.1304, Avg Loss:  0.5677, Best Loss:  0.5770, Counter: 1: 100%|██████████| 704/704 [00:06<00:00, 112.52it/s]


Train Loss:  0.5677


V Loss:  0.8232, Avg Loss:  0.5731, Best Loss:  0.5770, Counter: 1: 100%|██████████| 235/235 [00:02<00:00, 106.13it/s]


Validation Loss:  0.5731

Epoch 14


T Loss:  0.9528, Avg Loss:  0.5667, Best Loss:  0.5731, Counter: 0: 100%|██████████| 704/704 [00:06<00:00, 106.13it/s]


Train Loss:  0.5667


V Loss:  0.4070, Avg Loss:  0.5874, Best Loss:  0.5731, Counter: 0: 100%|██████████| 235/235 [00:02<00:00, 116.88it/s]


Validation Loss:  0.5874

Epoch 15


T Loss:  0.3646, Avg Loss:  0.5643, Best Loss:  0.5731, Counter: 1: 100%|██████████| 704/704 [00:06<00:00, 109.80it/s]


Train Loss:  0.5643


V Loss:  0.8712, Avg Loss:  0.5991, Best Loss:  0.5731, Counter: 1: 100%|██████████| 235/235 [00:02<00:00, 109.92it/s]


Validation Loss:  0.5991

Epoch 16


T Loss:  0.4977, Avg Loss:  0.5614, Best Loss:  0.5731, Counter: 2: 100%|██████████| 704/704 [00:06<00:00, 111.95it/s]


Train Loss:  0.5614


V Loss:  0.4972, Avg Loss:  0.5648, Best Loss:  0.5731, Counter: 2: 100%|██████████| 235/235 [00:02<00:00, 107.74it/s]


Validation Loss:  0.5648

Epoch 17


T Loss:  0.2774, Avg Loss:  0.5599, Best Loss:  0.5648, Counter: 0: 100%|██████████| 704/704 [00:06<00:00, 112.16it/s]


Train Loss:  0.5599


V Loss:  0.4819, Avg Loss:  0.5696, Best Loss:  0.5648, Counter: 0: 100%|██████████| 235/235 [00:02<00:00, 107.66it/s]


Validation Loss:  0.5696

Epoch 18


T Loss:  1.1200, Avg Loss:  0.5580, Best Loss:  0.5648, Counter: 1: 100%|██████████| 704/704 [00:06<00:00, 114.38it/s]


Train Loss:  0.5580


V Loss:  0.4582, Avg Loss:  0.5579, Best Loss:  0.5648, Counter: 1: 100%|██████████| 235/235 [00:02<00:00, 109.25it/s]


Validation Loss:  0.5579

Epoch 19


T Loss:  0.9004, Avg Loss:  0.5556, Best Loss:  0.5579, Counter: 0: 100%|██████████| 704/704 [00:06<00:00, 107.09it/s]


Train Loss:  0.5556


V Loss:  0.5306, Avg Loss:  0.5649, Best Loss:  0.5579, Counter: 0: 100%|██████████| 235/235 [00:02<00:00, 108.58it/s]


Validation Loss:  0.5649

Epoch 20


T Loss:  0.4219, Avg Loss:  0.5543, Best Loss:  0.5579, Counter: 1: 100%|██████████| 704/704 [00:06<00:00, 108.84it/s]


Train Loss:  0.5543


V Loss:  0.4129, Avg Loss:  0.5527, Best Loss:  0.5579, Counter: 1: 100%|██████████| 235/235 [00:02<00:00, 108.76it/s]


Validation Loss:  0.5527

Epoch 21


T Loss:  0.3750, Avg Loss:  0.5533, Best Loss:  0.5527, Counter: 0: 100%|██████████| 704/704 [00:06<00:00, 106.31it/s]


Train Loss:  0.5533


V Loss:  0.9548, Avg Loss:  0.5575, Best Loss:  0.5527, Counter: 0: 100%|██████████| 235/235 [00:02<00:00, 109.93it/s]


Validation Loss:  0.5575

Epoch 22


T Loss:  0.5509, Avg Loss:  0.5512, Best Loss:  0.5527, Counter: 1: 100%|██████████| 704/704 [00:06<00:00, 105.66it/s]


Train Loss:  0.5512


V Loss:  0.6166, Avg Loss:  0.5665, Best Loss:  0.5527, Counter: 1: 100%|██████████| 235/235 [00:02<00:00, 108.29it/s]


Validation Loss:  0.5665

Epoch 23


T Loss:  0.3722, Avg Loss:  0.5515, Best Loss:  0.5527, Counter: 2: 100%|██████████| 704/704 [00:06<00:00, 110.06it/s]


Train Loss:  0.5515


V Loss:  0.2746, Avg Loss:  0.5612, Best Loss:  0.5527, Counter: 2: 100%|██████████| 235/235 [00:02<00:00, 112.01it/s]


Validation Loss:  0.5612

Epoch 24


T Loss:  0.3955, Avg Loss:  0.5471, Best Loss:  0.5527, Counter: 3: 100%|██████████| 704/704 [00:06<00:00, 112.74it/s]


Train Loss:  0.5471


V Loss:  0.4170, Avg Loss:  0.5544, Best Loss:  0.5527, Counter: 3: 100%|██████████| 235/235 [00:02<00:00, 109.27it/s]


Validation Loss:  0.5544


T Loss:  0.5831, Avg Loss:  0.5560: 100%|██████████| 119/119 [00:01<00:00, 94.85it/s] 


Accuracy:  0.7953
[[1574  126  125   75]
 [ 103 1705   42   50]
 [ 131   55 1459  255]
 [ 180   90  324 1306]]
              precision    recall  f1-score   support

           0     0.7918    0.8284    0.8097      1900
           1     0.8629    0.8974    0.8798      1900
           2     0.7482    0.7679    0.7579      1900
           3     0.7746    0.6874    0.7284      1900

    accuracy                         0.7953      7600
   macro avg     0.7944    0.7953    0.7939      7600
weighted avg     0.7944    0.7953    0.7939      7600

Test loss: 0.5559738839373869


In [2]:
import torch
from torch import nn
from torch.nn.utils.rnn import pack_padded_sequence, pad_packed_sequence

class CharCNNELMO(nn.Module):
    def __init__(self, Emb, hidden_dim, dropout, num_layers):
        super().__init__()

        self.hidden_dim = hidden_dim
        self.num_layers = num_layers

        self.char_embedding = nn.Embedding.from_pretrained(torch.FloatTensor(Emb.vectors), padding_idx=Emb.key_to_index['<pad>'], freeze=False)

        self.conv1 = nn.Conv1d(Emb.vectors.shape[1], hidden_dim, kernel_size=3)

        self.lstm = nn.LSTM(Emb.vectors.shape[1], hidden_dim, batch_first=True, bidirectional=True, num_layers=num_layers, dropout=dropout)

    def forward(self, X, X_lengths):
        X = self.char_embedding(X)
        X = pack_padded_sequence(X, X_lengths, batch_first=True, enforce_sorted=False)
        X, (h_n, c_n) = self.lstm(X, None)
        X, _ = pad_packed_sequence(X, batch_first=True)
        return X, h_n, c_n

In [None]:
# creat co