In [1]:
cfg = {
    'dev_train_len': 25*10**3,
    'dev_validation_len': 5*10**3,
    'learning_rate': 0.001,
    'epochs': 100,
    'embedding_dim': 100,
    'batch_size': 32,
    'dropout': 0.2,
    'optimizer': 'Adam',
    'num_layers': 2
}

cfg['hidden_dim'] = cfg['embedding_dim']

In [2]:
DEV_TRAIN_LEN = cfg['dev_train_len']
DEV_VALIDATION_LEN = cfg['dev_validation_len']
LEARNING_RATE = cfg['learning_rate']
EPOCHS = cfg['epochs']
BATCH_SIZE = cfg['batch_size']
DROPOUT = cfg['dropout']
OPTIMIZER = cfg['optimizer']
NUM_LAYERS = cfg['num_layers']
HIDDEN_DIM = cfg['hidden_dim']
EMBEDDING_DIM = cfg['embedding_dim']

DIR = '/scratch/shu7bh/RES/Word/1'

In [3]:
import wandb
wandb.init(project="ELMo", name="WordEmb", config=cfg)

Failed to detect the name of this notebook, you can set it manually with the WANDB_NOTEBOOK_NAME environment variable to enable code saving.
[34m[1mwandb[0m: Currently logged in as: [33mshu7bh[0m. Use [1m`wandb login --relogin`[0m to force relogin


Create Dir

In [4]:
import os
if not os.path.exists(DIR):
    os.makedirs(DIR)

Set Device

In [5]:
import torch

if torch.cuda.is_available():
    DEVICE = torch.device('cuda')
else:
    DEVICE = torch.device('cpu')
print(DEVICE)

cuda


Prepare Data

In [6]:
from gensim.downloader import load

glove_dict = {
    '50': 'glove-wiki-gigaword-50',
    '100': 'glove-wiki-gigaword-100',
    '200': 'glove-wiki-gigaword-200'
}

glove_dict[str(EMBEDDING_DIM)] = load(glove_dict[str(EMBEDDING_DIM)])
glove = glove_dict[str(EMBEDDING_DIM)]

In [7]:
from nltk.tokenize import word_tokenize
from nltk.stem import WordNetLemmatizer
import pandas as pd
import unicodedata
import re

def normalize_unicode(text: str) -> str:
    return unicodedata.normalize('NFD', text)

def clean_data(text: str) -> str:
    text = normalize_unicode(text.lower().strip())
    text = re.sub(r"([.!?])", r" \1", text)
    text = re.sub(r"[^a-zA-Z.!?]+", r" ", text)
    return text

lemmatizer = WordNetLemmatizer()
freq_words = dict()

def tokenize_data(text: str, create_unique_words: bool) -> list:
    tokens = [lemmatizer.lemmatize(token) for token in word_tokenize(text)]
    tokens = [token if token in glove else '<unk>' for token in tokens]

    if '<unk>' in tokens:
        return tokens

    if create_unique_words:
        for token in tokens:
            if token not in freq_words:
                freq_words[token] = 1
            else:
                freq_words[token] += 1
    return tokens

def replace_words(tokens: list, filter_rare_words: bool) -> list:
    tokens = [token if token in freq_words else '<unk>' for token in tokens]
    if filter_rare_words:
        tokens = [token if freq_words[token] >= 4 else '<unk>' for token in tokens]
    return tokens

def read_data(path: str, create_unique_words, filter_rare_words) -> pd.DataFrame:
    df = pd.read_csv(path)
    df = df.sample(frac=1, random_state=0).reset_index(drop=True)
    df['Description'] = df['Description'].apply(clean_data)
    df['Description'] = df['Description'].apply(tokenize_data, create_unique_words=create_unique_words)

    df = df[df['Description'].apply(lambda x: '<unk>' not in x)]
    df = df.reset_index(drop=True)

    df['Class Index'] = df['Class Index'].apply(lambda x: x-1)
    ydf = df.copy(deep=True)
    ydf['Description'] = ydf['Description'].apply(replace_words, filter_rare_words=filter_rare_words)
    return df, ydf

In [8]:
freq_words = dict()
xdf, ydf = read_data(
    'data/train.csv', 
    create_unique_words=True, 
    filter_rare_words=True
)

unique_words = set()
for tokens in ydf['Description']:
    unique_words.update(tokens)
print(len(unique_words))

freq_words = dict()
for token in unique_words:
    freq_words[token] = 0

21489


In [9]:
xdf

Unnamed: 0,Class Index,Description
0,3,"[london, british, airline, magnate, richard, b..."
1,3,"[regardless, space, competition, are, poised, ..."
2,0,"[cbs, million, of, folded, paper, crane, flutt..."
3,3,"[in, the, time, it, take, you, to, read, this,..."
4,0,"[washington, a, highly, classified, u, intelli..."
...,...,...
101981,1,"[toronto, reuters, national, hockey, league, t..."
101982,3,"[com, september, am, pt, ., there, s, no, doub..."
101983,0,"[pakistani, security, force, have, arrested, m..."
101984,3,"[palmsource, finally, unveiled, it, new, o, ve..."


In [10]:
NUM_CLASSES = len(set(xdf['Class Index']))
NUM_CLASSES

4

In [11]:
# Create a dictionary of all words
word_to_idx = {word: idx + 1 for idx, word in enumerate(unique_words)}

# Add special tokens
word_to_idx['<pad>'] = 0
word_to_idx['<sos>'] = len(word_to_idx)
word_to_idx['<eos>'] = len(word_to_idx)

# Create a dictionary of all words
idx_to_word = {idx: word for word, idx in word_to_idx.items()}

# print the length of the word to index mapping
print(len(word_to_idx))

21492


In [12]:
idx_to_vec = glove.vectors

In [13]:
glove_key_to_idx = glove.key_to_index

In [14]:
glove_key_to_idx['<pad>'] = len(glove_key_to_idx)
glove_key_to_idx['<sos>'] = len(glove_key_to_idx)
glove_key_to_idx['<eos>'] = len(glove_key_to_idx)
glove_key_to_idx['<unk>'] = len(glove_key_to_idx)

In [15]:
glove_idx_to_key = {idx: key for key, idx in glove_key_to_idx.items()}

In [16]:
import numpy as np

glove_idx_to_vec = glove.vectors

pad_vec = np.zeros((1, EMBEDDING_DIM))
sos_vec = np.random.rand(1, EMBEDDING_DIM)
eos_vec = np.random.rand(1, EMBEDDING_DIM)
unk_vec = np.mean(glove_idx_to_vec, axis=0, keepdims=True)

glove_idx_to_vec = np.concatenate((glove_idx_to_vec, pad_vec, sos_vec, eos_vec, unk_vec), axis=0)

In [17]:
dev_train_raw_x = xdf[:DEV_TRAIN_LEN]
dev_train_raw_y = ydf[:DEV_TRAIN_LEN]

dev_validation_raw_x = xdf[DEV_TRAIN_LEN:DEV_TRAIN_LEN+DEV_VALIDATION_LEN]
dev_validation_raw_y = ydf[DEV_TRAIN_LEN:DEV_TRAIN_LEN+DEV_VALIDATION_LEN]

Dataset

In [18]:
from torch.utils.data import Dataset

class Sentences(Dataset):
    def __init__(
            self, 
            adf: pd.DataFrame, 
            pdf: pd.DataFrame, 
            word_to_idx: dict,
            glove_key_to_idx: dict
        ) -> None:

        self.Xf = []
        self.Xb = []
        self.Yf = []
        self.Yb = []
        self.L = []

        for sentence in adf['Description']:
            self.Xf += [torch.tensor(
                [glove_key_to_idx[w] for w in sentence] + 
                [glove_key_to_idx['<eos>']]
            )]
            self.Xb += [torch.tensor(
                [glove_key_to_idx['<eos>']] + 
                [glove_key_to_idx[w] for w in reversed(sentence)]
            )]

            self.L += [torch.tensor(len(sentence) + 1)]

        for sentence in pdf['Description']:
            self.Yf += [torch.tensor(
                [word_to_idx[w] for w in sentence] + 
                [word_to_idx['<eos>']] + 
                [word_to_idx['<pad>']]
            )]

            self.Yf[-1] = self.Yf[-1][1:]

            self.Yb += [torch.tensor(
                [word_to_idx[w] for w in reversed(sentence)] +
                [word_to_idx['<sos>']]
            )]

        self.Y = torch.tensor(adf['Class Index'].tolist())

    def __len__(self) -> int:
        return len(self.Xf)

    def __getitem__(self, idx: int) -> tuple:
        return self.Xf[idx], self.Xb[idx], self.Y[idx], self.L[idx], self.Yf[idx], self.Yb[idx]

Create Dataset

In [19]:
dev_train_dataset = Sentences(dev_train_raw_x, dev_train_raw_y, word_to_idx, glove_key_to_idx)
dev_validation_dataset = Sentences(dev_validation_raw_x, dev_validation_raw_y, word_to_idx, glove_key_to_idx)

In [20]:
def collate_fn(batch: list) -> tuple:
    xf, xb, y, l, yf, yb = zip(*batch)

    xf = torch.nn.utils.rnn.pad_sequence(xf, padding_value=glove_key_to_idx['<pad>'], batch_first=True)
    xb = torch.nn.utils.rnn.pad_sequence(xb, padding_value=glove_key_to_idx['<pad>'], batch_first=True)
    yf = torch.nn.utils.rnn.pad_sequence(yf, padding_value=word_to_idx['<pad>'], batch_first=True)
    yb = torch.nn.utils.rnn.pad_sequence(yb, padding_value=word_to_idx['<pad>'], batch_first=True)
    return xf, xb, torch.stack(y), torch.stack(l), yf, yb

In [21]:
from torch.utils.data import DataLoader

dev_train_loader = DataLoader(dev_train_dataset, batch_size=BATCH_SIZE, shuffle=True, collate_fn=collate_fn)
dev_validation_loader = DataLoader(dev_validation_dataset, batch_size=BATCH_SIZE, shuffle=True, collate_fn=collate_fn)

In [22]:
type(idx_to_word)

dict

ELMo

In [23]:
from typing import Any, Mapping
from torch import nn
from torch.nn.utils.rnn import pack_padded_sequence, pad_packed_sequence

class ELMo(nn.Module):
    def __init__(
            self, 
            glove_idx_to_vec: np.ndarray,
            idx_to_word: dict,
            dropout: float, 
            num_layers: int, 
            hidden_dim: int, 
            word_embed_dim: int,
            filename: str = None
        ) -> None:

        super(ELMo, self).__init__()

        self.word_embed = nn.Embedding.from_pretrained(torch.from_numpy(glove_idx_to_vec).float(), padding_idx=glove_key_to_idx['<pad>'])

        self.lstmf = nn.LSTM(
            input_size=word_embed_dim, 
            hidden_size=hidden_dim, 
            num_layers=num_layers, 
            dropout=dropout,
            batch_first=True
        )

        self.lstmb = nn.LSTM(
            input_size=word_embed_dim, 
            hidden_size=hidden_dim, 
            num_layers=num_layers, 
            dropout=dropout,
            batch_first=True
        )

        self.dropout = nn.Dropout(dropout)
        self.num_layers = num_layers
        self.hidden_dim = hidden_dim

        del self.state_dict()['word_embed.weight']

        if filename:
            self.load_state_dict(torch.load(filename), strict=False)

    def forward(self, xf: torch.Tensor, xb: torch.Tensor, l: torch.Tensor) -> torch.Tensor:
        xf = self.word_embed(xf)
        xb = self.word_embed(xb)

        input = xf.detach().clone()

        xf = pack_padded_sequence(xf, lengths=l, batch_first=True, enforce_sorted=False)

        xb = pack_padded_sequence(xb, lengths=l, batch_first=True, enforce_sorted=False)

        xf, (hsf, csf) = self.lstmf(xf)
        xb, (hsb, csb) = self.lstmb(xb)

        xf, _ = pad_packed_sequence(xf, batch_first=True)
        xb, _ = pad_packed_sequence(xb, batch_first=True)

        xf = self.dropout(xf)
        xb = self.dropout(xb)

        return xf, xb, input, (hsf, csf), (hsb, csb)

Early Stopping

In [24]:
import numpy as np

class EarlyStopping:
    def __init__(self, patience:int = 3, delta:float = 0.001):
        self.patience = patience
        self.counter = 0
        self.best_loss:float = np.inf
        self.best_model_pth = 0
        self.delta = delta

    def __call__(self, loss, epoch: int):
        should_stop = False

        if loss >= self.best_loss - self.delta:
            self.counter += 1
            if self.counter > self.patience:
                should_stop = True
        else:
            self.best_loss = loss
            self.counter = 0
            self.best_model_pth = epoch
        return should_stop

LM

In [26]:
from tqdm import tqdm

class LM(nn.Module):
    def __init__(self, 
            hidden_dim: int, 
            vocab_size: int, 
            filename: str = None
        ) -> None:

        super(LM, self).__init__()
        self.hidden_dim = hidden_dim
        self.vocab_size = vocab_size
        self.elmo = ELMo(
            glove_idx_to_vec=glove_idx_to_vec,
            idx_to_word=idx_to_word,
            dropout=DROPOUT, 
            num_layers=NUM_LAYERS, 
            hidden_dim=HIDDEN_DIM,
            word_embed_dim=EMBEDDING_DIM
        )
        self.linear_forward = nn.Linear(hidden_dim, vocab_size)
        self.linear_backward = nn.Linear(hidden_dim, vocab_size)

        if filename:
            self.load_state_dict(torch.load(filename))

    def forward(self, xf: torch.Tensor, xb: torch.Tensor, l: torch.Tensor) -> torch.Tensor:
        xf, xb, _, _, _ = self.elmo(xf, xb, l)
        yf = self.linear_forward(xf)
        yb = self.linear_backward(xb)
        return yf, yb

    def fit(self, train_loader: DataLoader, validation_loader: DataLoader, epochs: int, learning_rate: float, filename: str) -> None:
        self.es = EarlyStopping()
        self.optimizer = torch.optim.Adam(self.parameters(), lr=learning_rate)
        self.criterion = nn.CrossEntropyLoss(ignore_index=word_to_idx['<pad>'])

        for epoch in range(epochs):
            print('----------------------------------------')
            self._train(train_loader)
            loss = self._evaluate(validation_loader)
            print(f'Epoch: {epoch + 1} | Loss: {loss:7.4f}')
            if self.es(loss, epoch):
                break
            if self.es.counter == 0:
                torch.save(self.elmo.state_dict(), os.path.join(DIR, f'{filename}_elmo.pth'))

    def _call(self, xf: torch.Tensor, xb: torch.Tensor, y: torch.Tensor, l: torch.Tensor, yf: torch.Tensor, yb: torch.Tensor) -> torch.Tensor:
        xf, xb, y, yf, yb = xf.to(DEVICE), xb.to(DEVICE), y.to(DEVICE), yf.to(DEVICE), yb.to(DEVICE)

        yf_hat, yb_hat = self(xf, xb, l)

        yf_hat = yf_hat.view(-1, self.vocab_size)
        yb_hat = yb_hat.view(-1, self.vocab_size)

        yf = yf.view(-1)
        yb = yb.view(-1)

        loss1 = self.criterion(yf_hat, yf)
        loss2 = self.criterion(yb_hat, yb)

        loss = (loss1 + loss2) / 2

        return loss, loss1, loss2

    def _train(self, train_loader: DataLoader) -> None:
        self.train()
        epoch_loss = []
        epoch_loss1 = []
        epoch_loss2 = []

        pbar = tqdm(train_loader)
        for xf, xb, y, l, yf, yb in pbar:

            loss, loss1, loss2 = self._call(xf, xb, y, l, yf, yb)
            epoch_loss.append(loss.item())
            epoch_loss1.append(loss1.item())
            epoch_loss2.append(loss2.item())
            self.optimizer.zero_grad()
            loss.backward()
            self.optimizer.step()

            pbar.set_description(f'T Loss: {loss.item():7.4f}, Avg Loss: {np.mean(epoch_loss):7.4f}, Avg Loss1: {np.mean(epoch_loss1):7.4f}, Avg Loss2: {np.mean(epoch_loss2):7.4f}')

        wandb.log({'upstream_train_loss': np.mean(epoch_loss), 'upstream_train_lossf': np.mean(epoch_loss1), 'upstream_train_lossb': np.mean(epoch_loss2)})

    def _evaluate(self, validation_loader: DataLoader) -> float:
        self.eval()
        epoch_loss = []
        epoch_loss1 = []
        epoch_loss2 = []
        pbar = tqdm(validation_loader)
        with torch.no_grad():
            for xf, xb, y, l, yf, yb in pbar:
                loss, loss1, loss2 = self._call(xf, xb, y, l, yf, yb)
                epoch_loss.append(loss.item())
                epoch_loss1.append(loss1.item())
                epoch_loss2.append(loss2.item())
                pbar.set_description(f'V Loss: {epoch_loss[-1]:7.4f}, Avg Loss: {np.mean(epoch_loss):7.4f}, Avg Loss1: {np.mean(epoch_loss1):7.4f}, Avg Loss2: {np.mean(epoch_loss2):7.4f}, Counter: {self.es.counter}, Best Loss: {self.es.best_loss:7.4f}')

        wandb.log({'upstream_validation_loss': np.mean(epoch_loss), 'upstream_validation_lossf': np.mean(epoch_loss1), 'upstream_validation_lossb': np.mean(epoch_loss2)})
        return np.mean(epoch_loss)

Initialize Model

In [27]:
lm = LM(
    hidden_dim=HIDDEN_DIM, 
    vocab_size=len(word_to_idx), 
    filename=None
).to(DEVICE)
print(lm)

LM(
  (elmo): ELMo(
    (word_embed): Embedding(400004, 100, padding_idx=400000)
    (lstmf): LSTM(100, 100, num_layers=2, batch_first=True, dropout=0.2)
    (lstmb): LSTM(100, 100, num_layers=2, batch_first=True, dropout=0.2)
    (dropout): Dropout(p=0.2, inplace=False)
  )
  (linear_forward): Linear(in_features=100, out_features=21492, bias=True)
  (linear_backward): Linear(in_features=100, out_features=21492, bias=True)
)


In [28]:
from torchinfo import summary

summary(lm, device=DEVICE)

Layer (type:depth-idx)                   Param #
LM                                       --
├─ELMo: 1-1                              --
│    └─Embedding: 2-1                    (40,000,400)
│    └─LSTM: 2-2                         161,600
│    └─LSTM: 2-3                         161,600
│    └─Dropout: 2-4                      --
├─Linear: 1-2                            2,170,692
├─Linear: 1-3                            2,170,692
Total params: 44,664,984
Trainable params: 4,664,584
Non-trainable params: 40,000,400

In [29]:
lm.fit(dev_train_loader, dev_validation_loader, epochs=EPOCHS, learning_rate=LEARNING_RATE, filename='best')

----------------------------------------


T Loss:  6.7043, Avg Loss:  7.0560, Avg Loss1:  7.0875, Avg Loss2:  7.0245: 100%|██████████| 782/782 [00:24<00:00, 31.37it/s]
V Loss:  6.7269, Avg Loss:  6.7380, Avg Loss1:  6.7929, Avg Loss2:  6.6832, Counter: 0, Best Loss:     inf: 100%|██████████| 157/157 [00:01<00:00, 87.35it/s]


Epoch: 1 | Loss:  6.7380
----------------------------------------


T Loss:  6.3778, Avg Loss:  6.5488, Avg Loss1:  6.6124, Avg Loss2:  6.4851: 100%|██████████| 782/782 [00:24<00:00, 31.39it/s]
V Loss:  6.5450, Avg Loss:  6.3578, Avg Loss1:  6.4315, Avg Loss2:  6.2841, Counter: 0, Best Loss:  6.7380: 100%|██████████| 157/157 [00:01<00:00, 93.18it/s]


Epoch: 2 | Loss:  6.3578
----------------------------------------


T Loss:  6.3790, Avg Loss:  6.2302, Avg Loss1:  6.3075, Avg Loss2:  6.1529: 100%|██████████| 782/782 [00:24<00:00, 31.81it/s]
V Loss:  6.3081, Avg Loss:  6.0838, Avg Loss1:  6.1755, Avg Loss2:  5.9922, Counter: 0, Best Loss:  6.3578: 100%|██████████| 157/157 [00:01<00:00, 90.84it/s]


Epoch: 3 | Loss:  6.0838
----------------------------------------


T Loss:  6.0170, Avg Loss:  5.9797, Avg Loss1:  6.0539, Avg Loss2:  5.9056: 100%|██████████| 782/782 [00:24<00:00, 32.14it/s]
V Loss:  5.7735, Avg Loss:  5.8731, Avg Loss1:  5.9345, Avg Loss2:  5.8118, Counter: 0, Best Loss:  6.0838: 100%|██████████| 157/157 [00:01<00:00, 87.21it/s]


Epoch: 4 | Loss:  5.8731
----------------------------------------


T Loss:  5.8665, Avg Loss:  5.7933, Avg Loss1:  5.8425, Avg Loss2:  5.7440: 100%|██████████| 782/782 [00:24<00:00, 31.66it/s]
V Loss:  5.6645, Avg Loss:  5.7347, Avg Loss1:  5.7757, Avg Loss2:  5.6938, Counter: 0, Best Loss:  5.8731: 100%|██████████| 157/157 [00:01<00:00, 89.27it/s]


Epoch: 5 | Loss:  5.7347
----------------------------------------


T Loss:  5.6988, Avg Loss:  5.6558, Avg Loss1:  5.6912, Avg Loss2:  5.6205: 100%|██████████| 782/782 [00:24<00:00, 32.10it/s]
V Loss:  5.8578, Avg Loss:  5.6301, Avg Loss1:  5.6601, Avg Loss2:  5.6001, Counter: 0, Best Loss:  5.7347: 100%|██████████| 157/157 [00:01<00:00, 93.11it/s]


Epoch: 6 | Loss:  5.6301
----------------------------------------


T Loss:  5.6286, Avg Loss:  5.5482, Avg Loss1:  5.5759, Avg Loss2:  5.5205: 100%|██████████| 782/782 [00:24<00:00, 31.94it/s]
V Loss:  5.3921, Avg Loss:  5.5481, Avg Loss1:  5.5719, Avg Loss2:  5.5242, Counter: 0, Best Loss:  5.6301: 100%|██████████| 157/157 [00:01<00:00, 94.01it/s] 


Epoch: 7 | Loss:  5.5481
----------------------------------------


T Loss:  5.4339, Avg Loss:  5.4596, Avg Loss1:  5.4820, Avg Loss2:  5.4372: 100%|██████████| 782/782 [00:24<00:00, 31.56it/s]
V Loss:  5.4422, Avg Loss:  5.4882, Avg Loss1:  5.5076, Avg Loss2:  5.4688, Counter: 0, Best Loss:  5.5481: 100%|██████████| 157/157 [00:01<00:00, 94.46it/s]


Epoch: 8 | Loss:  5.4882
----------------------------------------


T Loss:  5.6024, Avg Loss:  5.3843, Avg Loss1:  5.4033, Avg Loss2:  5.3652: 100%|██████████| 782/782 [00:24<00:00, 32.01it/s]
V Loss:  5.1706, Avg Loss:  5.4362, Avg Loss1:  5.4516, Avg Loss2:  5.4208, Counter: 0, Best Loss:  5.4882: 100%|██████████| 157/157 [00:01<00:00, 89.30it/s]


Epoch: 9 | Loss:  5.4362
----------------------------------------


T Loss:  5.6579, Avg Loss:  5.3187, Avg Loss1:  5.3346, Avg Loss2:  5.3029: 100%|██████████| 782/782 [00:24<00:00, 31.92it/s]
V Loss:  5.8318, Avg Loss:  5.3961, Avg Loss1:  5.4090, Avg Loss2:  5.3833, Counter: 0, Best Loss:  5.4362: 100%|██████████| 157/157 [00:01<00:00, 89.30it/s]


Epoch: 10 | Loss:  5.3961
----------------------------------------


T Loss:  4.9186, Avg Loss:  5.2613, Avg Loss1:  5.2752, Avg Loss2:  5.2474: 100%|██████████| 782/782 [00:24<00:00, 31.88it/s]
V Loss:  5.0762, Avg Loss:  5.3592, Avg Loss1:  5.3695, Avg Loss2:  5.3489, Counter: 0, Best Loss:  5.3961: 100%|██████████| 157/157 [00:01<00:00, 89.96it/s]


Epoch: 11 | Loss:  5.3592
----------------------------------------


T Loss:  5.3650, Avg Loss:  5.2103, Avg Loss1:  5.2214, Avg Loss2:  5.1991: 100%|██████████| 782/782 [00:24<00:00, 31.85it/s]
V Loss:  5.2506, Avg Loss:  5.3349, Avg Loss1:  5.3430, Avg Loss2:  5.3268, Counter: 0, Best Loss:  5.3592: 100%|██████████| 157/157 [00:01<00:00, 88.20it/s]


Epoch: 12 | Loss:  5.3349
----------------------------------------


T Loss:  5.1501, Avg Loss:  5.1650, Avg Loss1:  5.1747, Avg Loss2:  5.1553: 100%|██████████| 782/782 [00:24<00:00, 32.09it/s]
V Loss:  4.7909, Avg Loss:  5.3072, Avg Loss1:  5.3133, Avg Loss2:  5.3010, Counter: 0, Best Loss:  5.3349: 100%|██████████| 157/157 [00:01<00:00, 87.59it/s]


Epoch: 13 | Loss:  5.3072
----------------------------------------


T Loss:  5.1352, Avg Loss:  5.1241, Avg Loss1:  5.1325, Avg Loss2:  5.1156: 100%|██████████| 782/782 [00:24<00:00, 31.78it/s]
V Loss:  5.5345, Avg Loss:  5.2907, Avg Loss1:  5.2946, Avg Loss2:  5.2869, Counter: 0, Best Loss:  5.3072: 100%|██████████| 157/157 [00:01<00:00, 85.68it/s]


Epoch: 14 | Loss:  5.2907
----------------------------------------


T Loss:  4.9601, Avg Loss:  5.0871, Avg Loss1:  5.0939, Avg Loss2:  5.0802: 100%|██████████| 782/782 [00:24<00:00, 31.44it/s]
V Loss:  4.8806, Avg Loss:  5.2722, Avg Loss1:  5.2779, Avg Loss2:  5.2665, Counter: 0, Best Loss:  5.2907: 100%|██████████| 157/157 [00:01<00:00, 93.99it/s]


Epoch: 15 | Loss:  5.2722
----------------------------------------


T Loss:  5.0032, Avg Loss:  5.0535, Avg Loss1:  5.0608, Avg Loss2:  5.0463: 100%|██████████| 782/782 [00:24<00:00, 31.87it/s]
V Loss:  5.6417, Avg Loss:  5.2615, Avg Loss1:  5.2632, Avg Loss2:  5.2597, Counter: 0, Best Loss:  5.2722: 100%|██████████| 157/157 [00:01<00:00, 90.82it/s]


Epoch: 16 | Loss:  5.2615
----------------------------------------


T Loss:  5.3579, Avg Loss:  5.0232, Avg Loss1:  5.0291, Avg Loss2:  5.0172: 100%|██████████| 782/782 [00:24<00:00, 31.68it/s]
V Loss:  5.4697, Avg Loss:  5.2516, Avg Loss1:  5.2535, Avg Loss2:  5.2496, Counter: 0, Best Loss:  5.2615: 100%|██████████| 157/157 [00:01<00:00, 88.34it/s]


Epoch: 17 | Loss:  5.2516
----------------------------------------


T Loss:  4.9876, Avg Loss:  4.9937, Avg Loss1:  4.9988, Avg Loss2:  4.9885: 100%|██████████| 782/782 [00:24<00:00, 31.96it/s]
V Loss:  5.8891, Avg Loss:  5.2420, Avg Loss1:  5.2438, Avg Loss2:  5.2401, Counter: 0, Best Loss:  5.2516: 100%|██████████| 157/157 [00:01<00:00, 89.53it/s]


Epoch: 18 | Loss:  5.2420
----------------------------------------


T Loss:  5.0850, Avg Loss:  4.9682, Avg Loss1:  4.9727, Avg Loss2:  4.9636: 100%|██████████| 782/782 [00:24<00:00, 31.61it/s]
V Loss:  5.2533, Avg Loss:  5.2332, Avg Loss1:  5.2351, Avg Loss2:  5.2313, Counter: 0, Best Loss:  5.2420: 100%|██████████| 157/157 [00:01<00:00, 89.73it/s]


Epoch: 19 | Loss:  5.2332
----------------------------------------


T Loss:  4.6605, Avg Loss:  4.9436, Avg Loss1:  4.9475, Avg Loss2:  4.9396: 100%|██████████| 782/782 [00:24<00:00, 32.30it/s]
V Loss:  4.7124, Avg Loss:  5.2233, Avg Loss1:  5.2233, Avg Loss2:  5.2233, Counter: 0, Best Loss:  5.2332: 100%|██████████| 157/157 [00:01<00:00, 89.66it/s]


Epoch: 20 | Loss:  5.2233
----------------------------------------


T Loss:  4.9845, Avg Loss:  4.9216, Avg Loss1:  4.9244, Avg Loss2:  4.9189: 100%|██████████| 782/782 [00:24<00:00, 31.74it/s]
V Loss:  5.5309, Avg Loss:  5.2220, Avg Loss1:  5.2218, Avg Loss2:  5.2222, Counter: 0, Best Loss:  5.2233: 100%|██████████| 157/157 [00:01<00:00, 87.67it/s]


Epoch: 21 | Loss:  5.2220
----------------------------------------


T Loss:  4.8437, Avg Loss:  4.9001, Avg Loss1:  4.9026, Avg Loss2:  4.8976: 100%|██████████| 782/782 [00:24<00:00, 31.85it/s]
V Loss:  5.2809, Avg Loss:  5.2154, Avg Loss1:  5.2137, Avg Loss2:  5.2172, Counter: 0, Best Loss:  5.2220: 100%|██████████| 157/157 [00:01<00:00, 84.42it/s]


Epoch: 22 | Loss:  5.2154
----------------------------------------


T Loss:  4.8001, Avg Loss:  4.8809, Avg Loss1:  4.8835, Avg Loss2:  4.8783: 100%|██████████| 782/782 [00:24<00:00, 31.56it/s]
V Loss:  4.7511, Avg Loss:  5.2089, Avg Loss1:  5.2078, Avg Loss2:  5.2100, Counter: 0, Best Loss:  5.2154: 100%|██████████| 157/157 [00:01<00:00, 89.49it/s]


Epoch: 23 | Loss:  5.2089
----------------------------------------


T Loss:  5.0675, Avg Loss:  4.8619, Avg Loss1:  4.8642, Avg Loss2:  4.8596: 100%|██████████| 782/782 [00:24<00:00, 31.73it/s]
V Loss:  5.0860, Avg Loss:  5.2092, Avg Loss1:  5.2074, Avg Loss2:  5.2109, Counter: 0, Best Loss:  5.2089: 100%|██████████| 157/157 [00:01<00:00, 89.17it/s]


Epoch: 24 | Loss:  5.2092
----------------------------------------


T Loss:  5.0145, Avg Loss:  4.8450, Avg Loss1:  4.8466, Avg Loss2:  4.8434: 100%|██████████| 782/782 [00:24<00:00, 31.46it/s]
V Loss:  5.2289, Avg Loss:  5.2036, Avg Loss1:  5.2004, Avg Loss2:  5.2067, Counter: 1, Best Loss:  5.2089: 100%|██████████| 157/157 [00:01<00:00, 88.32it/s]


Epoch: 25 | Loss:  5.2036
----------------------------------------


T Loss:  5.0047, Avg Loss:  4.8289, Avg Loss1:  4.8287, Avg Loss2:  4.8291: 100%|██████████| 782/782 [00:24<00:00, 32.53it/s]
V Loss:  5.0256, Avg Loss:  5.2031, Avg Loss1:  5.2001, Avg Loss2:  5.2060, Counter: 0, Best Loss:  5.2036: 100%|██████████| 157/157 [00:01<00:00, 90.79it/s]


Epoch: 26 | Loss:  5.2031
----------------------------------------


T Loss:  4.9389, Avg Loss:  4.8123, Avg Loss1:  4.8130, Avg Loss2:  4.8115: 100%|██████████| 782/782 [00:24<00:00, 32.20it/s]
V Loss:  5.1432, Avg Loss:  5.2020, Avg Loss1:  5.1972, Avg Loss2:  5.2068, Counter: 1, Best Loss:  5.2036: 100%|██████████| 157/157 [00:01<00:00, 86.95it/s]


Epoch: 27 | Loss:  5.2020
----------------------------------------


T Loss:  4.9711, Avg Loss:  4.7987, Avg Loss1:  4.7979, Avg Loss2:  4.7995: 100%|██████████| 782/782 [00:24<00:00, 31.62it/s]
V Loss:  5.0468, Avg Loss:  5.1983, Avg Loss1:  5.1931, Avg Loss2:  5.2035, Counter: 0, Best Loss:  5.2020: 100%|██████████| 157/157 [00:01<00:00, 90.54it/s]


Epoch: 28 | Loss:  5.1983
----------------------------------------


T Loss:  4.8068, Avg Loss:  4.7847, Avg Loss1:  4.7837, Avg Loss2:  4.7857: 100%|██████████| 782/782 [00:24<00:00, 32.07it/s]
V Loss:  5.2312, Avg Loss:  5.2004, Avg Loss1:  5.1932, Avg Loss2:  5.2075, Counter: 0, Best Loss:  5.1983: 100%|██████████| 157/157 [00:01<00:00, 92.19it/s]


Epoch: 29 | Loss:  5.2004
----------------------------------------


T Loss:  4.9903, Avg Loss:  4.7713, Avg Loss1:  4.7712, Avg Loss2:  4.7715: 100%|██████████| 782/782 [00:24<00:00, 32.38it/s]
V Loss:  4.9253, Avg Loss:  5.1973, Avg Loss1:  5.1887, Avg Loss2:  5.2059, Counter: 1, Best Loss:  5.1983: 100%|██████████| 157/157 [00:01<00:00, 90.65it/s]


Epoch: 30 | Loss:  5.1973
----------------------------------------


T Loss:  4.5905, Avg Loss:  4.7581, Avg Loss1:  4.7560, Avg Loss2:  4.7603: 100%|██████████| 782/782 [00:24<00:00, 31.77it/s]
V Loss:  4.7175, Avg Loss:  5.1964, Avg Loss1:  5.1910, Avg Loss2:  5.2017, Counter: 2, Best Loss:  5.1983: 100%|██████████| 157/157 [00:01<00:00, 89.35it/s]


Epoch: 31 | Loss:  5.1964
----------------------------------------


T Loss:  4.9124, Avg Loss:  4.7478, Avg Loss1:  4.7456, Avg Loss2:  4.7500: 100%|██████████| 782/782 [00:24<00:00, 31.75it/s]
V Loss:  4.9001, Avg Loss:  5.1957, Avg Loss1:  5.1865, Avg Loss2:  5.2050, Counter: 0, Best Loss:  5.1964: 100%|██████████| 157/157 [00:01<00:00, 91.25it/s]


Epoch: 32 | Loss:  5.1957
----------------------------------------


T Loss:  4.7021, Avg Loss:  4.7350, Avg Loss1:  4.7323, Avg Loss2:  4.7377: 100%|██████████| 782/782 [00:24<00:00, 31.79it/s]
V Loss:  5.7860, Avg Loss:  5.1996, Avg Loss1:  5.1890, Avg Loss2:  5.2101, Counter: 1, Best Loss:  5.1964: 100%|██████████| 157/157 [00:01<00:00, 91.50it/s]


Epoch: 33 | Loss:  5.1996
----------------------------------------


T Loss:  4.8940, Avg Loss:  4.7244, Avg Loss1:  4.7213, Avg Loss2:  4.7274: 100%|██████████| 782/782 [00:24<00:00, 31.82it/s]
V Loss:  4.8236, Avg Loss:  5.1930, Avg Loss1:  5.1843, Avg Loss2:  5.2017, Counter: 2, Best Loss:  5.1964: 100%|██████████| 157/157 [00:01<00:00, 88.34it/s]


Epoch: 34 | Loss:  5.1930
----------------------------------------


T Loss:  4.6980, Avg Loss:  4.7146, Avg Loss1:  4.7114, Avg Loss2:  4.7179: 100%|██████████| 782/782 [00:23<00:00, 32.66it/s]
V Loss:  5.0347, Avg Loss:  5.1950, Avg Loss1:  5.1855, Avg Loss2:  5.2045, Counter: 0, Best Loss:  5.1930: 100%|██████████| 157/157 [00:01<00:00, 91.60it/s]


Epoch: 35 | Loss:  5.1950
----------------------------------------


T Loss:  4.8793, Avg Loss:  4.7043, Avg Loss1:  4.6995, Avg Loss2:  4.7091: 100%|██████████| 782/782 [00:23<00:00, 33.97it/s]
V Loss:  5.0984, Avg Loss:  5.2000, Avg Loss1:  5.1903, Avg Loss2:  5.2097, Counter: 1, Best Loss:  5.1930: 100%|██████████| 157/157 [00:01<00:00, 89.80it/s]


Epoch: 36 | Loss:  5.2000
----------------------------------------


T Loss:  4.7168, Avg Loss:  4.6955, Avg Loss1:  4.6907, Avg Loss2:  4.7004: 100%|██████████| 782/782 [00:24<00:00, 31.53it/s]
V Loss:  5.0962, Avg Loss:  5.2007, Avg Loss1:  5.1932, Avg Loss2:  5.2082, Counter: 2, Best Loss:  5.1930: 100%|██████████| 157/157 [00:01<00:00, 87.68it/s]


Epoch: 37 | Loss:  5.2007
----------------------------------------


T Loss:  4.9639, Avg Loss:  4.6855, Avg Loss1:  4.6796, Avg Loss2:  4.6915: 100%|██████████| 782/782 [00:24<00:00, 32.56it/s]
V Loss:  5.3314, Avg Loss:  5.2033, Avg Loss1:  5.1921, Avg Loss2:  5.2145, Counter: 3, Best Loss:  5.1930: 100%|██████████| 157/157 [00:01<00:00, 88.26it/s]

Epoch: 38 | Loss:  5.2033





## Downstream Task

In [30]:
downstream_train_raw_x = xdf[DEV_TRAIN_LEN+DEV_VALIDATION_LEN:]
downstream_train_raw_y = ydf[DEV_TRAIN_LEN+DEV_VALIDATION_LEN:]

downstream_validation_raw_x = xdf[:DEV_TRAIN_LEN+DEV_VALIDATION_LEN]
downstream_validation_raw_y = ydf[:DEV_TRAIN_LEN+DEV_VALIDATION_LEN]

In [31]:
print(len(downstream_train_raw_x))
print(len(downstream_validation_raw_x))

71986
30000


In [32]:
downstream_train_dataset = Sentences(downstream_train_raw_x, downstream_train_raw_y, word_to_idx, glove_key_to_idx)
downstream_validation_dataset = Sentences(downstream_validation_raw_x, downstream_validation_raw_y, word_to_idx, glove_key_to_idx)

downstream_train_loader = DataLoader(downstream_train_dataset, batch_size=BATCH_SIZE, shuffle=True, collate_fn=collate_fn)
downstream_validation_loader = DataLoader(downstream_validation_dataset, batch_size=BATCH_SIZE, shuffle=True, collate_fn=collate_fn)

In [33]:
from sklearn.metrics import classification_report, confusion_matrix

class NewsClassification(nn.Module):
    def __init__(self, 
            hidden_dim: int, 
            vocab_size: int, 
            num_classes: int,
            filename: str = None
        ) -> None:

        super(NewsClassification, self).__init__()
        self.hidden_dim = hidden_dim
        self.vocab_size = vocab_size
        self.num_classes = num_classes
        self.elmo = ELMo(
            glove_idx_to_vec=glove_idx_to_vec,
            idx_to_word=idx_to_word,
            dropout=DROPOUT, 
            num_layers=NUM_LAYERS, 
            hidden_dim=HIDDEN_DIM,
            word_embed_dim=EMBEDDING_DIM,
            filename=filename
        )

        for param in self.elmo.parameters():
            param.requires_grad = False

        self.delta = nn.Parameter(torch.randn(1, 3))
        self.linear = nn.Sequential(
            nn.Linear(2 * hidden_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, num_classes)
        )

    def forward(self, xf: torch.Tensor, xb: torch.Tensor, l: torch.Tensor) -> torch.Tensor:
        _, _, input, (hsf, csf), (hsb, csb) = self.elmo(xf, xb, l)
        hsf = hsf.permute(1, 0, 2)
        csf = csf.permute(1, 0, 2)
        hsb = hsb.permute(1, 0, 2)
        csb = csb.permute(1, 0, 2)

        hs = torch.cat([hsf, hsb], dim=2)
        cs = torch.cat([csf, csb], dim=2)

        val = (hs + cs) / 2

        input = torch.mean(input, dim=1)
        input = torch.cat([input] * val.shape[1], dim=1).unsqueeze(1)

        val = torch.cat([input, val], dim=1)
        val = (self.delta / (torch.sum(self.delta))) @ val
        val = val.squeeze()

        x = self.linear(val)
        return x

    def fit(self, 
            train_loader: DataLoader, 
            validation_loader: DataLoader, 
            epochs: int, 
            learning_rate: float
        ) -> None:

        self.es = EarlyStopping()
        self.optimizer = torch.optim.Adam(self.parameters(), lr=learning_rate)
        self.criterion = nn.CrossEntropyLoss()

        for epoch in range(epochs):
            print('----------------------------------------')
            self._train(train_loader)
            loss = self._evaluate(validation_loader)
            print(f'Epoch: {epoch + 1} | Loss: {loss:7.4f}')
            if self.es(loss, epoch):
                break
            if self.es.counter == 0:
                torch.save(self.state_dict(), os.path.join(DIR, 'best.pth'))

    def _call(self, xf: torch.Tensor, xb: torch.Tensor, y: torch.Tensor, l: torch.Tensor, yf: torch.Tensor, yb: torch.Tensor) -> torch.Tensor:
        xf, xb, y, yf, yb = xf.to(DEVICE), xb.to(DEVICE), y.to(DEVICE), yf.to(DEVICE), yb.to(DEVICE)

        y_hat = self(xf, xb, l)
        y_hat = y_hat.view(-1, self.num_classes)
        y = y.view(-1)
        loss = self.criterion(y_hat, y)

        return loss

    def _train(self, train_loader: DataLoader) -> None:
        self.train()
        epoch_loss = []
        pbar = tqdm(train_loader)
        for xf, xb, y, l, yf, yb in pbar:
            loss = self._call(xf, xb, y, l, yf, yb)
            epoch_loss.append(loss.item())
            self.optimizer.zero_grad()
            loss.backward()
            self.optimizer.step()

            pbar.set_description(f'T Loss: {loss.item():7.4f}, Avg Loss: {np.mean(epoch_loss):7.4f}')

        wandb.log({'downstream_train_loss': np.mean(epoch_loss)})

    def _evaluate(self, validation_loader: DataLoader) -> float:
        self.eval()
        epoch_loss = []
        pbar = tqdm(validation_loader)
        with torch.no_grad():
            for xf, xb, y, l, yf, yb in pbar:
                loss = self._call(xf, xb, y, l, yf, yb)
                epoch_loss.append(loss.item())
                pbar.set_description(f'V Loss: {epoch_loss[-1]:7.4f}, Avg Loss: {np.mean(epoch_loss):7.4f}, Counter: {self.es.counter}, Best Loss: {self.es.best_loss:7.4f}')

        wandb.log({'downstream_validation_loss': np.mean(epoch_loss)})
        return np.mean(epoch_loss)

    def _metrics(self, test_loader: DataLoader) -> None:
        self.eval()
        self.criterion = nn.CrossEntropyLoss()
        pbar = tqdm(test_loader)
        y_pred = []
        y_true = []
        epoch_loss = []

        with torch.no_grad():
            for xf, xb, y, l, yf, yb in pbar:
                xf, xb, y, yf, yb = xf.to(DEVICE), xb.to(DEVICE), y.to(DEVICE), yf.to(DEVICE), yb.to(DEVICE)

                y_hat = self(xf, xb, l)
                y_hat = y_hat.view(-1, self.num_classes)
                y = y.view(-1)
                loss = self.criterion(y_hat, y)

                epoch_loss.append(loss.item())
                y_hat = torch.argmax(y_hat, dim=1)
                y_pred += y_hat.tolist()
                y_true += y.tolist()

        wandb.log({'downstrea_delta': self.delta.tolist()})

        wandb.log({'downstream_test_loss': np.mean(epoch_loss)})
        print(f'Test Loss: {np.mean(epoch_loss):7.4f}')

        cr = classification_report(y_true, y_pred, digits=4)
        wandb.log({'classification_report': cr})
        print('Classification Report:', cr)

        cm = confusion_matrix(y_true, y_pred)
        wandb.log({'confusion_matrix': cm})
        print('Confusion Matrix:', cm)


In [34]:
test_xdf, test_ydf = read_data('data/test.csv', create_unique_words=False, filter_rare_words=False)

In [35]:
downstream_test_dataset = Sentences(test_xdf, test_ydf, word_to_idx, glove_key_to_idx)
downstream_test_loader = DataLoader(downstream_test_dataset, batch_size=128, shuffle=True, collate_fn=collate_fn)

In [36]:
delta = [None, [3, 0, 0], [1, 2, 3], [0, 0, 3], [1, 1, 1]] 
for i in range(1, len(delta)):
    delta[i] = torch.tensor(delta[i], dtype=torch.float32).reshape(1, -1).to(DEVICE)
    print(delta[i])

for d in delta:
    nc = NewsClassification(hidden_dim=HIDDEN_DIM, 
                            vocab_size=len(word_to_idx), 
                            num_classes=NUM_CLASSES,
                            filename=os.path.join(DIR, 'best_elmo.pth')
                        ).to(DEVICE)

    if d is not None:
        print(d)
        nc.delta = nn.Parameter(d, requires_grad=False)

    nc.fit(downstream_train_loader, downstream_validation_loader, epochs=EPOCHS, learning_rate=LEARNING_RATE)

    nc.load_state_dict(torch.load(os.path.join(DIR, 'best.pth')))
    nc._metrics(downstream_test_loader)

tensor([[3., 0., 0.]], device='cuda:0')
tensor([[1., 2., 3.]], device='cuda:0')
tensor([[0., 0., 3.]], device='cuda:0')
tensor([[1., 1., 1.]], device='cuda:0')
----------------------------------------


T Loss:  0.2665, Avg Loss:  0.3733: 100%|██████████| 2250/2250 [00:20<00:00, 108.92it/s]
V Loss:  0.0803, Avg Loss:  0.3333, Counter: 0, Best Loss:     inf: 100%|██████████| 938/938 [00:06<00:00, 137.83it/s]


Epoch: 1 | Loss:  0.3333
----------------------------------------


T Loss:  0.1081, Avg Loss:  0.3335: 100%|██████████| 2250/2250 [00:20<00:00, 110.69it/s]
V Loss:  0.2866, Avg Loss:  0.3120, Counter: 0, Best Loss:  0.3333: 100%|██████████| 938/938 [00:06<00:00, 136.38it/s]


Epoch: 2 | Loss:  0.3120
----------------------------------------


T Loss:  0.4290, Avg Loss:  0.3209: 100%|██████████| 2250/2250 [00:20<00:00, 110.45it/s]
V Loss:  0.1119, Avg Loss:  0.3059, Counter: 0, Best Loss:  0.3120: 100%|██████████| 938/938 [00:06<00:00, 138.14it/s]


Epoch: 3 | Loss:  0.3059
----------------------------------------


T Loss:  0.2507, Avg Loss:  0.3111: 100%|██████████| 2250/2250 [00:20<00:00, 110.35it/s]
V Loss:  0.1609, Avg Loss:  0.2999, Counter: 0, Best Loss:  0.3059: 100%|██████████| 938/938 [00:06<00:00, 136.78it/s]


Epoch: 4 | Loss:  0.2999
----------------------------------------


T Loss:  0.4171, Avg Loss:  0.3038: 100%|██████████| 2250/2250 [00:19<00:00, 113.79it/s]
V Loss:  0.2269, Avg Loss:  0.2983, Counter: 0, Best Loss:  0.2999: 100%|██████████| 938/938 [00:06<00:00, 137.96it/s]


Epoch: 5 | Loss:  0.2983
----------------------------------------


T Loss:  0.1024, Avg Loss:  0.2979: 100%|██████████| 2250/2250 [00:19<00:00, 117.54it/s]
V Loss:  0.2450, Avg Loss:  0.3019, Counter: 0, Best Loss:  0.2983: 100%|██████████| 938/938 [00:07<00:00, 133.17it/s]


Epoch: 6 | Loss:  0.3019
----------------------------------------


T Loss:  0.3455, Avg Loss:  0.2936: 100%|██████████| 2250/2250 [00:20<00:00, 112.24it/s]
V Loss:  0.1383, Avg Loss:  0.2965, Counter: 1, Best Loss:  0.2983: 100%|██████████| 938/938 [00:06<00:00, 143.83it/s]


Epoch: 7 | Loss:  0.2965
----------------------------------------


T Loss:  0.2269, Avg Loss:  0.2907: 100%|██████████| 2250/2250 [00:19<00:00, 113.97it/s]
V Loss:  0.3272, Avg Loss:  0.2944, Counter: 0, Best Loss:  0.2965: 100%|██████████| 938/938 [00:06<00:00, 141.97it/s]


Epoch: 8 | Loss:  0.2944
----------------------------------------


T Loss:  0.2767, Avg Loss:  0.2875: 100%|██████████| 2250/2250 [00:19<00:00, 114.66it/s]
V Loss:  0.1233, Avg Loss:  0.2973, Counter: 0, Best Loss:  0.2944: 100%|██████████| 938/938 [00:06<00:00, 137.78it/s]


Epoch: 9 | Loss:  0.2973
----------------------------------------


T Loss:  0.0713, Avg Loss:  0.2827: 100%|██████████| 2250/2250 [00:19<00:00, 115.81it/s]
V Loss:  0.1741, Avg Loss:  0.2904, Counter: 1, Best Loss:  0.2944: 100%|██████████| 938/938 [00:06<00:00, 140.29it/s]


Epoch: 10 | Loss:  0.2904
----------------------------------------


T Loss:  0.4371, Avg Loss:  0.2819: 100%|██████████| 2250/2250 [00:19<00:00, 113.37it/s]
V Loss:  0.1353, Avg Loss:  0.3022, Counter: 0, Best Loss:  0.2904: 100%|██████████| 938/938 [00:06<00:00, 142.95it/s]


Epoch: 11 | Loss:  0.3022
----------------------------------------


T Loss:  0.5335, Avg Loss:  0.2788: 100%|██████████| 2250/2250 [00:20<00:00, 111.31it/s]
V Loss:  0.8151, Avg Loss:  0.2884, Counter: 1, Best Loss:  0.2904: 100%|██████████| 938/938 [00:06<00:00, 140.32it/s]


Epoch: 12 | Loss:  0.2884
----------------------------------------


T Loss:  0.3163, Avg Loss:  0.2774: 100%|██████████| 2250/2250 [00:20<00:00, 112.08it/s]
V Loss:  0.3092, Avg Loss:  0.2934, Counter: 0, Best Loss:  0.2884: 100%|██████████| 938/938 [00:06<00:00, 138.12it/s]


Epoch: 13 | Loss:  0.2934
----------------------------------------


T Loss:  0.0863, Avg Loss:  0.2754: 100%|██████████| 2250/2250 [00:19<00:00, 112.78it/s]
V Loss:  0.6298, Avg Loss:  0.2891, Counter: 1, Best Loss:  0.2884: 100%|██████████| 938/938 [00:06<00:00, 134.73it/s]


Epoch: 14 | Loss:  0.2891
----------------------------------------


T Loss:  0.1686, Avg Loss:  0.2730: 100%|██████████| 2250/2250 [00:19<00:00, 114.32it/s]
V Loss:  0.0478, Avg Loss:  0.2909, Counter: 2, Best Loss:  0.2884: 100%|██████████| 938/938 [00:06<00:00, 134.44it/s]


Epoch: 15 | Loss:  0.2909
----------------------------------------


T Loss:  0.2169, Avg Loss:  0.2712: 100%|██████████| 2250/2250 [00:19<00:00, 112.70it/s]
V Loss:  0.2064, Avg Loss:  0.2899, Counter: 3, Best Loss:  0.2884: 100%|██████████| 938/938 [00:06<00:00, 146.00it/s]


Epoch: 16 | Loss:  0.2899


100%|██████████| 51/51 [00:00<00:00, 111.76it/s]


Test Loss:  0.2954
Classification Report:               precision    recall  f1-score   support

           0     0.9319    0.8893    0.9101      1708
           1     0.9539    0.9722    0.9630      1765
           2     0.8504    0.8130    0.8313      1503
           3     0.8235    0.8837    0.8526      1505

    accuracy                         0.8929      6481
   macro avg     0.8899    0.8896    0.8892      6481
weighted avg     0.8938    0.8929    0.8929      6481

Confusion Matrix: [[1519   41   78   70]
 [  13 1716   24   12]
 [  54   24 1222  203]
 [  44   18  113 1330]]
tensor([[3., 0., 0.]], device='cuda:0')
----------------------------------------


T Loss:  0.1856, Avg Loss:  0.4428: 100%|██████████| 2250/2250 [00:18<00:00, 120.02it/s]
V Loss:  0.1813, Avg Loss:  0.3699, Counter: 0, Best Loss:     inf: 100%|██████████| 938/938 [00:06<00:00, 138.81it/s]


Epoch: 1 | Loss:  0.3699
----------------------------------------


T Loss:  0.0723, Avg Loss:  0.3649: 100%|██████████| 2250/2250 [00:18<00:00, 120.44it/s]
V Loss:  0.1982, Avg Loss:  0.3530, Counter: 0, Best Loss:  0.3699: 100%|██████████| 938/938 [00:06<00:00, 137.26it/s]


Epoch: 2 | Loss:  0.3530
----------------------------------------


T Loss:  0.1020, Avg Loss:  0.3479: 100%|██████████| 2250/2250 [00:19<00:00, 116.65it/s]
V Loss:  0.2015, Avg Loss:  0.3392, Counter: 0, Best Loss:  0.3530: 100%|██████████| 938/938 [00:06<00:00, 137.78it/s]


Epoch: 3 | Loss:  0.3392
----------------------------------------


T Loss:  0.2695, Avg Loss:  0.3364: 100%|██████████| 2250/2250 [00:19<00:00, 117.93it/s]
V Loss:  0.0889, Avg Loss:  0.3458, Counter: 0, Best Loss:  0.3392: 100%|██████████| 938/938 [00:06<00:00, 135.57it/s]


Epoch: 4 | Loss:  0.3458
----------------------------------------


T Loss:  0.3373, Avg Loss:  0.3304: 100%|██████████| 2250/2250 [00:18<00:00, 120.35it/s]
V Loss:  0.1782, Avg Loss:  0.3345, Counter: 1, Best Loss:  0.3392: 100%|██████████| 938/938 [00:06<00:00, 138.77it/s]


Epoch: 5 | Loss:  0.3345
----------------------------------------


T Loss:  0.3592, Avg Loss:  0.3255: 100%|██████████| 2250/2250 [00:19<00:00, 115.38it/s]
V Loss:  0.2995, Avg Loss:  0.3304, Counter: 0, Best Loss:  0.3345: 100%|██████████| 938/938 [00:06<00:00, 140.64it/s]


Epoch: 6 | Loss:  0.3304
----------------------------------------


T Loss:  0.8005, Avg Loss:  0.3200: 100%|██████████| 2250/2250 [00:18<00:00, 119.48it/s]
V Loss:  0.4166, Avg Loss:  0.3214, Counter: 0, Best Loss:  0.3304: 100%|██████████| 938/938 [00:06<00:00, 142.86it/s]


Epoch: 7 | Loss:  0.3214
----------------------------------------


T Loss:  0.3390, Avg Loss:  0.3154: 100%|██████████| 2250/2250 [00:19<00:00, 117.98it/s]
V Loss:  0.6730, Avg Loss:  0.3228, Counter: 0, Best Loss:  0.3214: 100%|██████████| 938/938 [00:06<00:00, 138.87it/s]


Epoch: 8 | Loss:  0.3228
----------------------------------------


T Loss:  0.5072, Avg Loss:  0.3119: 100%|██████████| 2250/2250 [00:18<00:00, 118.45it/s]
V Loss:  0.3352, Avg Loss:  0.3208, Counter: 1, Best Loss:  0.3214: 100%|██████████| 938/938 [00:06<00:00, 138.51it/s]


Epoch: 9 | Loss:  0.3208
----------------------------------------


T Loss:  0.3299, Avg Loss:  0.3080: 100%|██████████| 2250/2250 [00:18<00:00, 119.59it/s]
V Loss:  0.5071, Avg Loss:  0.3139, Counter: 2, Best Loss:  0.3214: 100%|██████████| 938/938 [00:06<00:00, 141.04it/s]


Epoch: 10 | Loss:  0.3139
----------------------------------------


T Loss:  0.1393, Avg Loss:  0.3052: 100%|██████████| 2250/2250 [00:18<00:00, 120.69it/s]
V Loss:  0.2308, Avg Loss:  0.3208, Counter: 0, Best Loss:  0.3139: 100%|██████████| 938/938 [00:06<00:00, 139.97it/s]


Epoch: 11 | Loss:  0.3208
----------------------------------------


T Loss:  0.3751, Avg Loss:  0.3017: 100%|██████████| 2250/2250 [00:19<00:00, 114.24it/s]
V Loss:  0.1337, Avg Loss:  0.3183, Counter: 1, Best Loss:  0.3139: 100%|██████████| 938/938 [00:06<00:00, 142.59it/s]


Epoch: 12 | Loss:  0.3183
----------------------------------------


T Loss:  0.0723, Avg Loss:  0.2987: 100%|██████████| 2250/2250 [00:19<00:00, 117.11it/s]
V Loss:  0.0956, Avg Loss:  0.3134, Counter: 2, Best Loss:  0.3139: 100%|██████████| 938/938 [00:06<00:00, 134.67it/s]


Epoch: 13 | Loss:  0.3134
----------------------------------------


T Loss:  0.1435, Avg Loss:  0.2960: 100%|██████████| 2250/2250 [00:18<00:00, 120.38it/s]
V Loss:  0.2016, Avg Loss:  0.3126, Counter: 3, Best Loss:  0.3139: 100%|██████████| 938/938 [00:06<00:00, 136.09it/s]


Epoch: 14 | Loss:  0.3126
----------------------------------------


T Loss:  0.1448, Avg Loss:  0.2937: 100%|██████████| 2250/2250 [00:19<00:00, 117.91it/s]
V Loss:  0.4174, Avg Loss:  0.3176, Counter: 0, Best Loss:  0.3126: 100%|██████████| 938/938 [00:06<00:00, 136.59it/s]


Epoch: 15 | Loss:  0.3176
----------------------------------------


T Loss:  0.2320, Avg Loss:  0.2916: 100%|██████████| 2250/2250 [00:18<00:00, 119.60it/s]
V Loss:  0.0583, Avg Loss:  0.3099, Counter: 1, Best Loss:  0.3126: 100%|██████████| 938/938 [00:06<00:00, 144.82it/s]


Epoch: 16 | Loss:  0.3099
----------------------------------------


T Loss:  0.2387, Avg Loss:  0.2887: 100%|██████████| 2250/2250 [00:19<00:00, 118.26it/s]
V Loss:  0.6730, Avg Loss:  0.3095, Counter: 0, Best Loss:  0.3099: 100%|██████████| 938/938 [00:06<00:00, 136.86it/s]


Epoch: 17 | Loss:  0.3095
----------------------------------------


T Loss:  0.2238, Avg Loss:  0.2863: 100%|██████████| 2250/2250 [00:18<00:00, 119.29it/s]
V Loss:  0.7894, Avg Loss:  0.3141, Counter: 1, Best Loss:  0.3099: 100%|██████████| 938/938 [00:06<00:00, 141.99it/s]


Epoch: 18 | Loss:  0.3141
----------------------------------------


T Loss:  0.2673, Avg Loss:  0.2848: 100%|██████████| 2250/2250 [00:19<00:00, 116.21it/s]
V Loss:  0.1500, Avg Loss:  0.3171, Counter: 2, Best Loss:  0.3099: 100%|██████████| 938/938 [00:06<00:00, 138.81it/s]


Epoch: 19 | Loss:  0.3171
----------------------------------------


T Loss:  0.1701, Avg Loss:  0.2834: 100%|██████████| 2250/2250 [00:19<00:00, 117.27it/s]
V Loss:  0.0774, Avg Loss:  0.3160, Counter: 3, Best Loss:  0.3099: 100%|██████████| 938/938 [00:06<00:00, 138.00it/s]


Epoch: 20 | Loss:  0.3160


100%|██████████| 51/51 [00:00<00:00, 110.86it/s]


Test Loss:  0.3215
Classification Report:               precision    recall  f1-score   support

           0     0.9203    0.8653    0.8920      1708
           1     0.9403    0.9734    0.9566      1765
           2     0.8460    0.8150    0.8302      1503
           3     0.8269    0.8791    0.8522      1505

    accuracy                         0.8863      6481
   macro avg     0.8834    0.8832    0.8827      6481
weighted avg     0.8868    0.8863    0.8860      6481

Confusion Matrix: [[1478   66   97   67]
 [  21 1718   13   13]
 [  58   23 1225  197]
 [  49   20  113 1323]]
tensor([[1., 2., 3.]], device='cuda:0')
----------------------------------------


T Loss:  0.2193, Avg Loss:  0.3761: 100%|██████████| 2250/2250 [00:19<00:00, 116.88it/s]
V Loss:  0.4140, Avg Loss:  0.3307, Counter: 0, Best Loss:     inf: 100%|██████████| 938/938 [00:06<00:00, 138.20it/s]


Epoch: 1 | Loss:  0.3307
----------------------------------------


T Loss:  0.1121, Avg Loss:  0.3367: 100%|██████████| 2250/2250 [00:19<00:00, 117.79it/s]
V Loss:  0.0443, Avg Loss:  0.3162, Counter: 0, Best Loss:  0.3307: 100%|██████████| 938/938 [00:07<00:00, 133.83it/s]


Epoch: 2 | Loss:  0.3162
----------------------------------------


T Loss:  0.5165, Avg Loss:  0.3227: 100%|██████████| 2250/2250 [00:19<00:00, 116.97it/s]
V Loss:  0.4048, Avg Loss:  0.3149, Counter: 0, Best Loss:  0.3162: 100%|██████████| 938/938 [00:06<00:00, 137.81it/s]


Epoch: 3 | Loss:  0.3149
----------------------------------------


T Loss:  0.4816, Avg Loss:  0.3140: 100%|██████████| 2250/2250 [00:19<00:00, 114.35it/s]
V Loss:  0.5013, Avg Loss:  0.3122, Counter: 0, Best Loss:  0.3149: 100%|██████████| 938/938 [00:06<00:00, 144.77it/s]


Epoch: 4 | Loss:  0.3122
----------------------------------------


T Loss:  0.1545, Avg Loss:  0.3064: 100%|██████████| 2250/2250 [00:19<00:00, 117.18it/s]
V Loss:  0.6626, Avg Loss:  0.3007, Counter: 0, Best Loss:  0.3122: 100%|██████████| 938/938 [00:06<00:00, 138.81it/s]


Epoch: 5 | Loss:  0.3007
----------------------------------------


T Loss:  0.1476, Avg Loss:  0.3019: 100%|██████████| 2250/2250 [00:19<00:00, 117.30it/s]
V Loss:  0.1424, Avg Loss:  0.2973, Counter: 0, Best Loss:  0.3007: 100%|██████████| 938/938 [00:06<00:00, 140.96it/s]


Epoch: 6 | Loss:  0.2973
----------------------------------------


T Loss:  0.2887, Avg Loss:  0.2971: 100%|██████████| 2250/2250 [00:18<00:00, 118.47it/s]
V Loss:  0.3895, Avg Loss:  0.2947, Counter: 0, Best Loss:  0.2973: 100%|██████████| 938/938 [00:06<00:00, 138.19it/s]


Epoch: 7 | Loss:  0.2947
----------------------------------------


T Loss:  0.2406, Avg Loss:  0.2946: 100%|██████████| 2250/2250 [00:19<00:00, 116.82it/s]
V Loss:  0.3381, Avg Loss:  0.2949, Counter: 0, Best Loss:  0.2947: 100%|██████████| 938/938 [00:07<00:00, 132.59it/s]


Epoch: 8 | Loss:  0.2949
----------------------------------------


T Loss:  0.3200, Avg Loss:  0.2914: 100%|██████████| 2250/2250 [00:19<00:00, 117.11it/s]
V Loss:  0.1083, Avg Loss:  0.2892, Counter: 1, Best Loss:  0.2947: 100%|██████████| 938/938 [00:06<00:00, 137.75it/s]


Epoch: 9 | Loss:  0.2892
----------------------------------------


T Loss:  0.2122, Avg Loss:  0.2896: 100%|██████████| 2250/2250 [00:19<00:00, 115.46it/s]
V Loss:  0.2403, Avg Loss:  0.2943, Counter: 0, Best Loss:  0.2892: 100%|██████████| 938/938 [00:06<00:00, 139.56it/s]


Epoch: 10 | Loss:  0.2943
----------------------------------------


T Loss:  0.3508, Avg Loss:  0.2852: 100%|██████████| 2250/2250 [00:19<00:00, 116.75it/s]
V Loss:  0.1084, Avg Loss:  0.2946, Counter: 1, Best Loss:  0.2892: 100%|██████████| 938/938 [00:06<00:00, 135.51it/s]


Epoch: 11 | Loss:  0.2946
----------------------------------------


T Loss:  0.2750, Avg Loss:  0.2839: 100%|██████████| 2250/2250 [00:19<00:00, 117.10it/s]
V Loss:  0.1251, Avg Loss:  0.2887, Counter: 2, Best Loss:  0.2892: 100%|██████████| 938/938 [00:06<00:00, 138.62it/s]


Epoch: 12 | Loss:  0.2887
----------------------------------------


T Loss:  0.2158, Avg Loss:  0.2817: 100%|██████████| 2250/2250 [00:19<00:00, 115.14it/s]
V Loss:  0.4684, Avg Loss:  0.2984, Counter: 3, Best Loss:  0.2892: 100%|██████████| 938/938 [00:06<00:00, 139.85it/s]


Epoch: 13 | Loss:  0.2984


100%|██████████| 51/51 [00:00<00:00, 110.49it/s]


Test Loss:  0.2973
Classification Report:               precision    recall  f1-score   support

           0     0.9214    0.8993    0.9102      1708
           1     0.9697    0.9598    0.9647      1765
           2     0.8547    0.8257    0.8399      1503
           3     0.8260    0.8864    0.8551      1505

    accuracy                         0.8957      6481
   macro avg     0.8929    0.8928    0.8925      6481
weighted avg     0.8969    0.8957    0.8960      6481

Confusion Matrix: [[1536   27   77   68]
 [  27 1694   20   24]
 [  56   17 1241  189]
 [  48    9  114 1334]]
tensor([[0., 0., 3.]], device='cuda:0')
----------------------------------------


T Loss:  0.3086, Avg Loss:  0.3742: 100%|██████████| 2250/2250 [00:19<00:00, 116.69it/s]
V Loss:  0.5008, Avg Loss:  0.3241, Counter: 0, Best Loss:     inf: 100%|██████████| 938/938 [00:06<00:00, 140.43it/s]


Epoch: 1 | Loss:  0.3241
----------------------------------------


T Loss:  0.5946, Avg Loss:  0.3344: 100%|██████████| 2250/2250 [00:19<00:00, 117.03it/s]
V Loss:  0.3442, Avg Loss:  0.3324, Counter: 0, Best Loss:  0.3241: 100%|██████████| 938/938 [00:06<00:00, 139.19it/s]


Epoch: 2 | Loss:  0.3324
----------------------------------------


T Loss:  0.0698, Avg Loss:  0.3243: 100%|██████████| 2250/2250 [00:19<00:00, 116.47it/s]
V Loss:  0.2839, Avg Loss:  0.3167, Counter: 1, Best Loss:  0.3241: 100%|██████████| 938/938 [00:06<00:00, 134.21it/s]


Epoch: 3 | Loss:  0.3167
----------------------------------------


T Loss:  0.1590, Avg Loss:  0.3168: 100%|██████████| 2250/2250 [00:19<00:00, 114.73it/s]
V Loss:  0.0597, Avg Loss:  0.3041, Counter: 0, Best Loss:  0.3167: 100%|██████████| 938/938 [00:06<00:00, 138.02it/s]


Epoch: 4 | Loss:  0.3041
----------------------------------------


T Loss:  0.4286, Avg Loss:  0.3114: 100%|██████████| 2250/2250 [00:19<00:00, 114.10it/s]
V Loss:  0.2825, Avg Loss:  0.3008, Counter: 0, Best Loss:  0.3041: 100%|██████████| 938/938 [00:07<00:00, 132.43it/s]


Epoch: 5 | Loss:  0.3008
----------------------------------------


T Loss:  0.1859, Avg Loss:  0.3064: 100%|██████████| 2250/2250 [00:19<00:00, 117.95it/s]
V Loss:  0.3542, Avg Loss:  0.3058, Counter: 0, Best Loss:  0.3008: 100%|██████████| 938/938 [00:06<00:00, 148.13it/s]


Epoch: 6 | Loss:  0.3058
----------------------------------------


T Loss:  0.5076, Avg Loss:  0.3048: 100%|██████████| 2250/2250 [00:19<00:00, 117.91it/s]
V Loss:  0.7846, Avg Loss:  0.3002, Counter: 1, Best Loss:  0.3008: 100%|██████████| 938/938 [00:06<00:00, 143.38it/s]


Epoch: 7 | Loss:  0.3002
----------------------------------------


T Loss:  0.1447, Avg Loss:  0.3022: 100%|██████████| 2250/2250 [00:19<00:00, 116.21it/s]
V Loss:  0.1105, Avg Loss:  0.3046, Counter: 2, Best Loss:  0.3008: 100%|██████████| 938/938 [00:06<00:00, 138.62it/s]


Epoch: 8 | Loss:  0.3046
----------------------------------------


T Loss:  0.8399, Avg Loss:  0.2980: 100%|██████████| 2250/2250 [00:19<00:00, 116.95it/s]
V Loss:  0.2355, Avg Loss:  0.3049, Counter: 3, Best Loss:  0.3008: 100%|██████████| 938/938 [00:06<00:00, 142.12it/s]


Epoch: 9 | Loss:  0.3049


100%|██████████| 51/51 [00:00<00:00, 105.68it/s]


Test Loss:  0.3090
Classification Report:               precision    recall  f1-score   support

           0     0.9145    0.8888    0.9014      1708
           1     0.9535    0.9637    0.9586      1765
           2     0.8118    0.8523    0.8315      1503
           3     0.8615    0.8352    0.8482      1505

    accuracy                         0.8883      6481
   macro avg     0.8853    0.8850    0.8849      6481
weighted avg     0.8890    0.8883    0.8884      6481

Confusion Matrix: [[1518   49   93   48]
 [  21 1701   30   13]
 [  62   19 1281  141]
 [  59   15  174 1257]]
tensor([[1., 1., 1.]], device='cuda:0')
----------------------------------------


T Loss:  0.3168, Avg Loss:  0.3756: 100%|██████████| 2250/2250 [00:18<00:00, 119.38it/s]
V Loss:  0.1661, Avg Loss:  0.3420, Counter: 0, Best Loss:     inf: 100%|██████████| 938/938 [00:07<00:00, 133.27it/s]


Epoch: 1 | Loss:  0.3420
----------------------------------------


T Loss:  1.0549, Avg Loss:  0.3326: 100%|██████████| 2250/2250 [00:19<00:00, 114.32it/s]
V Loss:  0.2407, Avg Loss:  0.3168, Counter: 0, Best Loss:  0.3420: 100%|██████████| 938/938 [00:06<00:00, 139.60it/s]


Epoch: 2 | Loss:  0.3168
----------------------------------------


T Loss:  0.2883, Avg Loss:  0.3209: 100%|██████████| 2250/2250 [00:19<00:00, 116.16it/s]
V Loss:  0.1241, Avg Loss:  0.3067, Counter: 0, Best Loss:  0.3168: 100%|██████████| 938/938 [00:06<00:00, 152.11it/s]


Epoch: 3 | Loss:  0.3067
----------------------------------------


T Loss:  0.2214, Avg Loss:  0.3104: 100%|██████████| 2250/2250 [00:18<00:00, 118.57it/s]
V Loss:  0.7316, Avg Loss:  0.3078, Counter: 0, Best Loss:  0.3067: 100%|██████████| 938/938 [00:06<00:00, 137.27it/s]


Epoch: 4 | Loss:  0.3078
----------------------------------------


T Loss:  0.3674, Avg Loss:  0.3044: 100%|██████████| 2250/2250 [00:18<00:00, 119.14it/s]
V Loss:  0.3861, Avg Loss:  0.3363, Counter: 1, Best Loss:  0.3067: 100%|██████████| 938/938 [00:06<00:00, 143.27it/s]


Epoch: 5 | Loss:  0.3363
----------------------------------------


T Loss:  0.1560, Avg Loss:  0.3008: 100%|██████████| 2250/2250 [00:18<00:00, 121.77it/s]
V Loss:  0.0931, Avg Loss:  0.2913, Counter: 2, Best Loss:  0.3067: 100%|██████████| 938/938 [00:06<00:00, 138.38it/s]


Epoch: 6 | Loss:  0.2913
----------------------------------------


T Loss:  0.6670, Avg Loss:  0.2955: 100%|██████████| 2250/2250 [00:19<00:00, 114.91it/s]
V Loss:  0.2253, Avg Loss:  0.3020, Counter: 0, Best Loss:  0.2913: 100%|██████████| 938/938 [00:06<00:00, 137.89it/s]


Epoch: 7 | Loss:  0.3020
----------------------------------------


T Loss:  0.2431, Avg Loss:  0.2931: 100%|██████████| 2250/2250 [00:18<00:00, 123.17it/s]
V Loss:  0.1303, Avg Loss:  0.3029, Counter: 1, Best Loss:  0.2913: 100%|██████████| 938/938 [00:06<00:00, 136.80it/s]


Epoch: 8 | Loss:  0.3029
----------------------------------------


T Loss:  0.1699, Avg Loss:  0.2888: 100%|██████████| 2250/2250 [00:19<00:00, 117.36it/s]
V Loss:  0.1065, Avg Loss:  0.2875, Counter: 2, Best Loss:  0.2913: 100%|██████████| 938/938 [00:06<00:00, 134.95it/s]


Epoch: 9 | Loss:  0.2875
----------------------------------------


T Loss:  0.0615, Avg Loss:  0.2876: 100%|██████████| 2250/2250 [00:19<00:00, 116.73it/s]
V Loss:  0.2089, Avg Loss:  0.2918, Counter: 0, Best Loss:  0.2875: 100%|██████████| 938/938 [00:06<00:00, 137.02it/s]


Epoch: 10 | Loss:  0.2918
----------------------------------------


T Loss:  0.3332, Avg Loss:  0.2853: 100%|██████████| 2250/2250 [00:18<00:00, 120.60it/s]
V Loss:  0.1832, Avg Loss:  0.2846, Counter: 1, Best Loss:  0.2875: 100%|██████████| 938/938 [00:06<00:00, 134.91it/s]


Epoch: 11 | Loss:  0.2846
----------------------------------------


T Loss:  0.2762, Avg Loss:  0.2811: 100%|██████████| 2250/2250 [00:19<00:00, 113.65it/s]
V Loss:  0.0933, Avg Loss:  0.2963, Counter: 0, Best Loss:  0.2846: 100%|██████████| 938/938 [00:06<00:00, 147.06it/s]


Epoch: 12 | Loss:  0.2963
----------------------------------------


T Loss:  0.0774, Avg Loss:  0.2802: 100%|██████████| 2250/2250 [00:19<00:00, 118.21it/s]
V Loss:  0.4503, Avg Loss:  0.2975, Counter: 1, Best Loss:  0.2846: 100%|██████████| 938/938 [00:06<00:00, 141.42it/s]


Epoch: 13 | Loss:  0.2975
----------------------------------------


T Loss:  0.0741, Avg Loss:  0.2776: 100%|██████████| 2250/2250 [00:19<00:00, 118.40it/s]
V Loss:  0.0856, Avg Loss:  0.2862, Counter: 2, Best Loss:  0.2846: 100%|██████████| 938/938 [00:06<00:00, 136.47it/s]


Epoch: 14 | Loss:  0.2862
----------------------------------------


T Loss:  0.1716, Avg Loss:  0.2759: 100%|██████████| 2250/2250 [00:19<00:00, 117.91it/s]
V Loss:  0.2279, Avg Loss:  0.2960, Counter: 3, Best Loss:  0.2846: 100%|██████████| 938/938 [00:06<00:00, 140.35it/s]


Epoch: 15 | Loss:  0.2960


100%|██████████| 51/51 [00:00<00:00, 109.13it/s]

Test Loss:  0.2945
Classification Report:               precision    recall  f1-score   support

           0     0.9171    0.9063    0.9117      1708
           1     0.9649    0.9660    0.9655      1765
           2     0.8515    0.8277    0.8394      1503
           3     0.8415    0.8751    0.8580      1505

    accuracy                         0.8971      6481
   macro avg     0.8937    0.8938    0.8936      6481
weighted avg     0.8973    0.8971    0.8971      6481

Confusion Matrix: [[1548   28   81   51]
 [  25 1705   20   15]
 [  57   20 1244  182]
 [  58   14  116 1317]]





In [37]:
wandb.finish()



0,1
downstream_test_loss,▁█▂▅▁
downstream_train_loss,▅▃▃▂▂▁▁▁▁█▄▃▃▃▂▂▂▂▂▁▅▃▂▂▂▂▁▅▃▃▂▂▃▃▂▂▂▂▁▁
downstream_validation_loss,▅▃▂▂▂▁▁▂▂█▅▅▄▄▃▄▃▃▃▄▅▃▂▂▁▂▂▄▄▂▂▃▄▃▅▂▁▁▂▂
upstream_train_loss,█▇▆▅▄▄▄▃▃▃▃▃▂▂▂▂▂▂▂▂▂▂▂▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁
upstream_train_lossb,█▆▅▅▄▄▃▃▃▃▃▃▂▂▂▂▂▂▂▂▂▂▂▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁
upstream_train_lossf,█▇▆▅▄▄▄▃▃▃▃▃▂▂▂▂▂▂▂▂▂▂▂▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁
upstream_validation_loss,█▆▅▄▃▃▃▂▂▂▂▂▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
upstream_validation_lossb,█▆▅▄▃▃▃▂▂▂▂▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
upstream_validation_lossf,█▆▅▄▄▃▃▂▂▂▂▂▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁

0,1
classification_report,precis...
downstream_test_loss,0.29448
downstream_train_loss,0.27594
downstream_validation_loss,0.29602
upstream_train_loss,4.68555
upstream_train_lossb,4.69145
upstream_train_lossf,4.67964
upstream_validation_loss,5.20333
upstream_validation_lossb,5.21451
upstream_validation_lossf,5.19214
