In [2]:
cfg = {
    'dev_train_len': 25*10**3,
    'dev_validation_len': 5*10**3,
    'learning_rate': 0.001,
    'epochs': 100,
    'char_embedding_dim': 32,
    'batch_size': 32,
    'dropout': 0.1,
    'optimizer': 'Adam',
    'num_layers': 2,
    'word_emb_dim': 200,
    'max_word_len': 20,
    'hidden_dim': 100,
    'char_out_channels': 64,
}

Hyperparameters

In [3]:
DEV_TRAIN_LEN = cfg['dev_train_len']
DEV_VALIDATION_LEN = cfg['dev_validation_len']
LEARNING_RATE = cfg['learning_rate']
EPOCHS = cfg['epochs']
CHAR_EMBEDDING_DIM = cfg['char_embedding_dim']
BATCH_SIZE = cfg['batch_size']
DROPOUT = cfg['dropout']
OPTIMIZER = cfg['optimizer']
NUM_LAYERS = cfg['num_layers']
HIDDEN_DIM = cfg['hidden_dim']
WORD_EMB_DIM = cfg['word_emb_dim']
MAX_WORD_LEN = cfg['max_word_len']
CHAR_OUT_CHANNELS = cfg['char_out_channels']

DIR = '/scratch/shu7bh/RES/4'

Create Dir

In [4]:
import os
if not os.path.exists(DIR):
    os.makedirs(DIR)

Set Device

In [5]:
import torch
import os

if torch.cuda.is_available():
    DEVICE = torch.device('cuda')
    torch.backends.cudnn.benchmark = True
else:
    DEVICE = torch.device('cpu')
print(DEVICE)

cuda


Prepare Data

In [6]:
from nltk.tokenize import word_tokenize
from nltk.stem import WordNetLemmatizer
import pandas as pd
import unicodedata
import re

def normalize_unicode(text: str) -> str:
    return unicodedata.normalize('NFD', text)

unique_chars = set()

def clean_data(text: str) -> str:
    text = normalize_unicode(text.lower().strip())
    text = re.sub(r"([.!?])", r" \1", text)
    text = re.sub(r"[^a-zA-Z.!?]+", r" ", text)
    for char in text:
        unique_chars.add(char)
    return text

lemmatizer = WordNetLemmatizer()
freq_words = dict()
def tokenize_data(text: str, create_unique_words: bool) -> list:
    global freq_words
    tokens = [lemmatizer.lemmatize(token) for token in word_tokenize(text)]

    if create_unique_words:
        for token in tokens:
            if token not in freq_words:
                freq_words[token] = 1
            else:
                freq_words[token] += 1
    return tokens

def replace_words(tokens: list, filter_rare_words: bool) -> list:
    new_tokens = []
    for i in range(len(tokens)):
        if tokens[i] not in freq_words:
            new_tokens.append('<unk>')
        else:
            if filter_rare_words:
                if freq_words[tokens[i]] < 4:
                    new_tokens.append('<unk>')
                else:
                    new_tokens.append(tokens[i])
    return new_tokens

def read_data(path: str, create_unique_words, filter_rare_words) -> pd.DataFrame:
    df = pd.read_csv(path)
    df = df.sample(frac=1, random_state=0).reset_index(drop=True)
    df['Description'] = df['Description'].apply(clean_data)
    df['Description'] = df['Description'].apply(tokenize_data, create_unique_words=create_unique_words)
    df['Class Index'] = df['Class Index'].apply(lambda x: x-1)
    pred_df = df.copy(deep=True)
    pred_df['Description'] = pred_df['Description'].apply(replace_words, filter_rare_words=filter_rare_words)
    return df, pred_df

In [7]:
freq_words = dict()
actual_df, pred_df = read_data(
    'data/train.csv', 
    create_unique_words=True, 
    filter_rare_words=True
)

unique_words = set()
for tokens in pred_df['Description']:
    unique_words.update(tokens)
print(len(unique_words))

23765


In [8]:
NUM_CLASSES = len(set(actual_df['Class Index']))
NUM_CLASSES

4

Char -> Idx and Word -> Idx

In [9]:
# Create a dictionary of all characters
char_to_idx = {char: idx + 1 for idx, char in enumerate(unique_chars)}

# Add special tokens
char_to_idx['<pad>'] = 0
char_to_idx['<sos>'] = len(char_to_idx)
char_to_idx['<eos>'] = len(char_to_idx)

# Create a dictionary of all characters
idx_to_char = {idx: char for char, idx in char_to_idx.items()}

# print the character to index mapping
print(char_to_idx)
print(idx_to_char)

# Create a dictionary of all words
word_to_idx = {word: idx + 1 for idx, word in enumerate(unique_words)}

# Add special tokens
word_to_idx['<pad>'] = 0
word_to_idx['<sos>'] = len(word_to_idx)
word_to_idx['<eos>'] = len(word_to_idx)

# Create a dictionary of all words
idx_to_word = {idx: word for word, idx in word_to_idx.items()}

# print the length of the word to index mapping
print(len(word_to_idx))

{'b': 1, 'e': 2, '!': 3, 'q': 4, 'c': 5, '.': 6, '?': 7, 'w': 8, ' ': 9, 'm': 10, 'y': 11, 'i': 12, 'a': 13, 'l': 14, 'p': 15, 'j': 16, 'v': 17, 'r': 18, 'g': 19, 'z': 20, 'u': 21, 'o': 22, 'k': 23, 'n': 24, 't': 25, 's': 26, 'x': 27, 'h': 28, 'f': 29, 'd': 30, '<pad>': 0, '<sos>': 31, '<eos>': 32}
{1: 'b', 2: 'e', 3: '!', 4: 'q', 5: 'c', 6: '.', 7: '?', 8: 'w', 9: ' ', 10: 'm', 11: 'y', 12: 'i', 13: 'a', 14: 'l', 15: 'p', 16: 'j', 17: 'v', 18: 'r', 19: 'g', 20: 'z', 21: 'u', 22: 'o', 23: 'k', 24: 'n', 25: 't', 26: 's', 27: 'x', 28: 'h', 29: 'f', 30: 'd', 0: '<pad>', 31: '<sos>', 32: '<eos>'}
23768


In [10]:
dev_train_raw_a = actual_df[:DEV_TRAIN_LEN]
dev_train_raw_p = pred_df[:DEV_TRAIN_LEN]

dev_validation_raw_a = actual_df[DEV_TRAIN_LEN:DEV_TRAIN_LEN+DEV_VALIDATION_LEN]
dev_validation_raw_p = pred_df[DEV_TRAIN_LEN:DEV_TRAIN_LEN+DEV_VALIDATION_LEN]

Dataset

In [11]:
from torch.utils.data import Dataset

class Sentences(Dataset):
    def __init__(
            self, 
            adf: pd.DataFrame, 
            pdf: pd.DataFrame, 
            char_to_idx: dict, 
            word_to_idx: dict
        ) -> None:

        self.X = []
        self.Y_ = []

        for sentence in adf['Description']:
            sent = []
            for w in sentence:
                sent += [[char_to_idx[w[i]] for i in range(min(MAX_WORD_LEN, len(w)))] + [char_to_idx['<pad>']] * (MAX_WORD_LEN - len(w))]

            sent += [[char_to_idx['<eos>']] * MAX_WORD_LEN]
            sent = torch.cat([torch.tensor(word) for word in sent])
            self.X += [sent]

        for sentence in pdf['Description']:
            self.Y_ += [torch.tensor([word_to_idx['<sos>']] + [word_to_idx[w] for w in sentence] + [word_to_idx['<eos>']])]

        self.Y = torch.tensor(adf['Class Index'].tolist())

    def __len__(self) -> int:
        return len(self.X)

    def __getitem__(self, idx: int) -> tuple:
        return self.X[idx], self.Y[idx], torch.tensor(len(self.X[idx])), self.Y_[idx]

Create Dataset

In [12]:
dev_train_dataset = Sentences(dev_train_raw_a, dev_train_raw_p, char_to_idx, word_to_idx)
dev_validation_dataset = Sentences(dev_validation_raw_a, dev_validation_raw_p, char_to_idx, word_to_idx)

Collate

In [13]:
def collate_fn(batch: list) -> tuple:
    x, y, l, y_ = zip(*batch)

    x = torch.nn.utils.rnn.pad_sequence(x, padding_value=char_to_idx['<pad>'], batch_first=True)
    y_ = torch.nn.utils.rnn.pad_sequence(y_, padding_value=word_to_idx['<pad>'], batch_first=True)
    y_ = torch.cat([y_, torch.zeros(y_.shape[0], 1, dtype=torch.long)], dim=1)
    l = [i/20 for i in l]
    return x, torch.stack(y), torch.stack(l), y_[..., 2:], y_[..., :-2]

Create DataLoader

In [14]:
from torch.utils.data import DataLoader

dev_train_loader = DataLoader(dev_train_dataset, batch_size=BATCH_SIZE, shuffle=True, collate_fn=collate_fn)
dev_validation_loader = DataLoader(dev_validation_dataset, batch_size=BATCH_SIZE, shuffle=True, collate_fn=collate_fn)

CharCNN

In [15]:
import torch
import torch.nn as nn
import torch.nn.functional as F

class CharCNN(nn.Module):
    def __init__(
            self, 
            char_vocab: int,
            char_embed_dim: int,
            char_out_channels: list,
            char_kernel_sizes: list,
            dropout: float,
            word_embed_dim: int
        ) -> None:

        super(CharCNN, self).__init__()

        self.char_embed = nn.Embedding(char_vocab, char_embed_dim)
        self.dropout = nn.Dropout(dropout)

        self.conv_max_pools = nn.ModuleList([
            nn.Sequential(
                nn.Conv1d(char_embed_dim, char_out_channels[i], char_kernel_sizes[i]),
                nn.ReLU(),
                nn.AdaptiveAvgPool1d(1),
                nn.Flatten()
            )
            for i in range(len(char_out_channels))
        ])

        self.fc = nn.Linear(sum(char_out_channels), word_embed_dim) # the fully connected layer

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        x = self.char_embed(x)
        x = x.transpose(1, 2)
        x = [conv_max_pool(x) for conv_max_pool in self.conv_max_pools]
        x = torch.cat(x, dim=1)
        x = self.dropout(x)
        x = self.fc(x)
        return x

ELMo

In [16]:
class ELMo(nn.Module):
    def __init__(
            self, 
            char_vocab: int, 
            char_embed_dim: int, 
            char_out_channels: list, 
            char_kernel_sizes: list, 
            dropout: float, 
            num_layers: int, 
            hidden_dim: int, 
            word_embed_dim: int,
            filename: str = None
        ) -> None:

        super(ELMo, self).__init__()

        self.char_cnn = CharCNN(
            char_vocab=char_vocab, 
            char_embed_dim=char_embed_dim, 
            char_out_channels=char_out_channels, 
            char_kernel_sizes=char_kernel_sizes, 
            dropout=dropout,
            word_embed_dim=word_embed_dim
        )

        self.lstmf = nn.LSTM(
            input_size=word_embed_dim, 
            hidden_size=hidden_dim, 
            num_layers=num_layers, 
            dropout=dropout,
            batch_first=True
        )

        self.lstmb = nn.LSTM(
            input_size=word_embed_dim, 
            hidden_size=hidden_dim, 
            num_layers=num_layers, 
            dropout=dropout,
            batch_first=True
        )

        self.dropout = nn.Dropout(dropout)
        self.num_layers = num_layers
        self.hidden_dim = hidden_dim

        if filename:
            self.load_state_dict(torch.load(filename))

    def forward(self, x: torch.Tensor, l: torch.Tensor) -> torch.Tensor:
        bz = x.shape[0]
        x = x.view(-1, MAX_WORD_LEN)
        x = self.char_cnn(x)
        x = x.view(bz, -1, x.shape[1])
        xf = x
        xb = x.flip([1])
        input = x.detach().clone()
        xf, (hsf, csf) = self.lstmf(xf)
        xb, (hsb, csb) = self.lstmb(xb)
        xb = xb.flip([1])
        xf = self.dropout(xf)
        xb = self.dropout(xb)
        return xf, xb, input, (hsf, csf), (hsb, csb)

Early Stopping

In [17]:
import numpy as np

class EarlyStopping:
    def __init__(self, patience:int = 3, delta:float = 0.001):
        self.patience = patience
        self.counter = 0
        self.best_loss:float = np.inf
        self.best_model_pth = 0
        self.delta = delta

    def __call__(self, loss, epoch: int):
        should_stop = False

        if loss >= self.best_loss - self.delta:
            self.counter += 1
            if self.counter > self.patience:
                should_stop = True
        else:
            self.best_loss = loss
            self.counter = 0
            self.best_model_pth = epoch
        return should_stop

LM

In [19]:
from tqdm import tqdm

class LM(nn.Module):
    def __init__(self, 
            char_vocab: int,
            hidden_dim: int, 
            vocab_size: int, 
            filename: str = None
        ) -> None:

        super(LM, self).__init__()
        self.hidden_dim = hidden_dim
        self.vocab_size = vocab_size
        self.elmo = ELMo(
            char_vocab=char_vocab, 
            char_embed_dim=CHAR_EMBEDDING_DIM, 
            char_out_channels=[CHAR_OUT_CHANNELS] * 5,
            char_kernel_sizes=[2, 3, 4, 5, 6], 
            dropout=DROPOUT, 
            num_layers=NUM_LAYERS, 
            hidden_dim=HIDDEN_DIM,
            word_embed_dim=WORD_EMB_DIM
        )
        self.linear_forward = nn.Linear(hidden_dim, vocab_size)
        self.linear_backward = nn.Linear(hidden_dim, vocab_size)

        if filename:
            self.load_state_dict(torch.load(filename))

    def forward(self, x: torch.Tensor, l: torch.Tensor) -> torch.Tensor:
        xf, xb, _, _, _ = self.elmo(x, l)
        yf = self.linear_forward(xf)
        yb = self.linear_backward(xb)
        return yf, yb

    def fit(self, train_loader: DataLoader, validation_loader: DataLoader, epochs: int, learning_rate: float, filename: str) -> None:
        self.es = EarlyStopping()
        self.optimizer = torch.optim.Adam(self.parameters(), lr=learning_rate)
        self.criterion = nn.CrossEntropyLoss(ignore_index=word_to_idx['<pad>'])

        for epoch in range(epochs):
            print('----------------------------------------')
            self._train(train_loader)
            loss = self._evaluate(validation_loader)
            print(f'Epoch: {epoch + 1} | Loss: {loss:7.4f}')
            if self.es(loss, epoch):
                break
            if self.es.counter == 0:
                torch.save(self.state_dict(), os.path.join(DIR, f'{filename}_lm.pth'))
                torch.save(self.elmo.state_dict(), os.path.join(DIR, f'{filename}_elmo.pth'))
                torch.save(self.elmo.char_cnn.state_dict(), os.path.join(DIR, f'{filename}_char_cnn.pth'))

    def _call(self, x: torch.Tensor, y: torch.Tensor, l: torch.Tensor, yf: torch.Tensor, yb: torch.Tensor) -> torch.Tensor:
        x, y, yf, yb = x.to(DEVICE), y.to(DEVICE), yf.to(DEVICE), yb.to(DEVICE)
        # print(x.shape, y.shape, l.shape, yf.shape, yb.shape)
        yf_hat, yb_hat = self(x, l)

        yf_hat = yf_hat.view(-1, self.vocab_size)
        yb_hat = yb_hat.view(-1, self.vocab_size)

        yf = yf.view(-1)
        yb = yb.view(-1)

        # print(yf_hat.shape, yb_hat.shape, yf.shape, yb.shape)

        loss1 = self.criterion(yf_hat, yf)
        loss2 = self.criterion(yb_hat, yb)

        loss = (loss1 + loss2) / 2

        # print(loss1.item(), loss2.item(), loss.item())
        return loss, loss1, loss2

    def _train(self, train_loader: DataLoader) -> None:
        self.train()
        epoch_loss = []
        epoch_loss1 = []
        epoch_loss2 = []

        pbar = tqdm(train_loader)
        for x, y, l, yf, yb in pbar:

            loss, loss1, loss2 = self._call(x, y, l, yf, yb)
            epoch_loss.append(loss.item())
            epoch_loss1.append(loss1.item())
            epoch_loss2.append(loss2.item())
            self.optimizer.zero_grad()
            loss.backward()
            self.optimizer.step()

            pbar.set_description(f'T Loss: {loss.item():7.4f}, Avg Loss: {np.mean(epoch_loss):7.4f}, Avg Loss1: {np.mean(epoch_loss1):7.4f}, Avg Loss2: {np.mean(epoch_loss2):7.4f}')

        # run.log({'upstream_train_loss': np.mean(epoch_loss)})

    def _evaluate(self, validation_loader: DataLoader) -> float:
        self.eval()
        epoch_loss = []
        epoch_loss1 = []
        epoch_loss2 = []
        pbar = tqdm(validation_loader)
        with torch.no_grad():
            for x, y, l, yf, yb in pbar:
                loss, loss1, loss2 = self._call(x, y, l, yf, yb)
                epoch_loss.append(loss.item())
                epoch_loss1.append(loss1.item())
                epoch_loss2.append(loss2.item())
                pbar.set_description(f'V Loss: {epoch_loss[-1]:7.4f}, Avg Loss: {np.mean(epoch_loss):7.4f}, Avg Loss1: {np.mean(epoch_loss1):7.4f}, Avg Loss2: {np.mean(epoch_loss2):7.4f}, Counter: {self.es.counter}, Best Loss: {self.es.best_loss:7.4f}')

        # run.log({'upstream_validation_loss': np.mean(epoch_loss)})
        return np.mean(epoch_loss)

Initialize Model

In [20]:
lm = LM(char_vocab=len(char_to_idx), hidden_dim=HIDDEN_DIM, vocab_size=len(word_to_idx)).to(DEVICE)
print(lm)

LM(
  (elmo): ELMo(
    (char_cnn): CharCNN(
      (char_embed): Embedding(33, 32)
      (dropout): Dropout(p=0.1, inplace=False)
      (conv_max_pools): ModuleList(
        (0): Sequential(
          (0): Conv1d(32, 64, kernel_size=(2,), stride=(1,))
          (1): ReLU()
          (2): AdaptiveAvgPool1d(output_size=1)
          (3): Flatten(start_dim=1, end_dim=-1)
        )
        (1): Sequential(
          (0): Conv1d(32, 64, kernel_size=(3,), stride=(1,))
          (1): ReLU()
          (2): AdaptiveAvgPool1d(output_size=1)
          (3): Flatten(start_dim=1, end_dim=-1)
        )
        (2): Sequential(
          (0): Conv1d(32, 64, kernel_size=(4,), stride=(1,))
          (1): ReLU()
          (2): AdaptiveAvgPool1d(output_size=1)
          (3): Flatten(start_dim=1, end_dim=-1)
        )
        (3): Sequential(
          (0): Conv1d(32, 64, kernel_size=(5,), stride=(1,))
          (1): ReLU()
          (2): AdaptiveAvgPool1d(output_size=1)
          (3): Flatten(start_dim=1, 

In [21]:
from torchinfo import summary

summary(lm, device=DEVICE)

Layer (type:depth-idx)                             Param #
LM                                                 --
├─ELMo: 1-1                                        --
│    └─CharCNN: 2-1                                --
│    │    └─Embedding: 3-1                         1,056
│    │    └─Dropout: 3-2                           --
│    │    └─ModuleList: 3-3                        41,280
│    │    └─Linear: 3-4                            64,200
│    └─LSTM: 2-2                                   201,600
│    └─LSTM: 2-3                                   201,600
│    └─Dropout: 2-4                                --
├─Linear: 1-2                                      2,400,568
├─Linear: 1-3                                      2,400,568
Total params: 5,310,872
Trainable params: 5,310,872
Non-trainable params: 0

In [22]:
lm.fit(dev_train_loader, dev_validation_loader, epochs=EPOCHS, learning_rate=LEARNING_RATE, filename='best')

----------------------------------------


T Loss:  7.0533, Avg Loss:  7.0733, Avg Loss1:  7.1113, Avg Loss2:  7.0354: 100%|██████████| 782/782 [01:15<00:00, 10.31it/s]
V Loss:  6.8613, Avg Loss:  6.9205, Avg Loss1:  6.9612, Avg Loss2:  6.8797, Counter: 0, Best Loss:     inf: 100%|██████████| 157/157 [00:02<00:00, 60.19it/s]


Epoch: 1 | Loss:  6.9205
----------------------------------------


T Loss:  7.0013, Avg Loss:  6.8521, Avg Loss1:  6.9383, Avg Loss2:  6.7659: 100%|██████████| 782/782 [00:25<00:00, 30.44it/s]
V Loss:  6.9507, Avg Loss:  6.8327, Avg Loss1:  6.9441, Avg Loss2:  6.7212, Counter: 0, Best Loss:  6.9205: 100%|██████████| 157/157 [00:02<00:00, 68.81it/s]


Epoch: 2 | Loss:  6.8327
----------------------------------------


T Loss:  6.6718, Avg Loss:  6.8010, Avg Loss1:  6.9115, Avg Loss2:  6.6904: 100%|██████████| 782/782 [00:25<00:00, 30.62it/s]
V Loss:  6.7583, Avg Loss:  6.8078, Avg Loss1:  6.9297, Avg Loss2:  6.6859, Counter: 0, Best Loss:  6.8327: 100%|██████████| 157/157 [00:02<00:00, 66.74it/s]


Epoch: 3 | Loss:  6.8078
----------------------------------------


T Loss:  6.8191, Avg Loss:  6.7790, Avg Loss1:  6.8954, Avg Loss2:  6.6626: 100%|██████████| 782/782 [00:25<00:00, 30.19it/s]
V Loss:  6.9024, Avg Loss:  6.7941, Avg Loss1:  6.9150, Avg Loss2:  6.6733, Counter: 0, Best Loss:  6.8078: 100%|██████████| 157/157 [00:02<00:00, 65.91it/s]


Epoch: 4 | Loss:  6.7941
----------------------------------------


T Loss:  6.2997, Avg Loss:  6.5955, Avg Loss1:  6.6506, Avg Loss2:  6.5405: 100%|██████████| 782/782 [00:25<00:00, 30.12it/s]
V Loss:  6.2063, Avg Loss:  6.4784, Avg Loss1:  6.5227, Avg Loss2:  6.4340, Counter: 0, Best Loss:  6.7941: 100%|██████████| 157/157 [00:02<00:00, 67.79it/s]


Epoch: 5 | Loss:  6.4784
----------------------------------------


T Loss:  6.4998, Avg Loss:  6.3253, Avg Loss1:  6.3498, Avg Loss2:  6.3009: 100%|██████████| 782/782 [00:25<00:00, 30.43it/s]
V Loss:  6.3635, Avg Loss:  6.2568, Avg Loss1:  6.2710, Avg Loss2:  6.2425, Counter: 0, Best Loss:  6.4784: 100%|██████████| 157/157 [00:02<00:00, 66.21it/s]


Epoch: 6 | Loss:  6.2568
----------------------------------------


T Loss:  6.3647, Avg Loss:  6.1317, Avg Loss1:  6.1296, Avg Loss2:  6.1338: 100%|██████████| 782/782 [00:25<00:00, 30.90it/s]
V Loss:  6.1479, Avg Loss:  6.1018, Avg Loss1:  6.1047, Avg Loss2:  6.0990, Counter: 0, Best Loss:  6.2568: 100%|██████████| 157/157 [00:02<00:00, 66.24it/s]


Epoch: 7 | Loss:  6.1018
----------------------------------------


T Loss:  5.6913, Avg Loss:  5.9859, Avg Loss1:  5.9737, Avg Loss2:  5.9981: 100%|██████████| 782/782 [00:25<00:00, 30.58it/s]
V Loss:  6.2949, Avg Loss:  5.9927, Avg Loss1:  5.9894, Avg Loss2:  5.9959, Counter: 0, Best Loss:  6.1018: 100%|██████████| 157/157 [00:02<00:00, 66.47it/s]


Epoch: 8 | Loss:  5.9927
----------------------------------------


T Loss:  6.0377, Avg Loss:  5.8671, Avg Loss1:  5.8487, Avg Loss2:  5.8856: 100%|██████████| 782/782 [00:25<00:00, 30.88it/s]
V Loss:  6.0362, Avg Loss:  5.8980, Avg Loss1:  5.8934, Avg Loss2:  5.9025, Counter: 0, Best Loss:  5.9927: 100%|██████████| 157/157 [00:02<00:00, 66.84it/s]


Epoch: 9 | Loss:  5.8980
----------------------------------------


T Loss:  5.8567, Avg Loss:  5.7675, Avg Loss1:  5.7434, Avg Loss2:  5.7917: 100%|██████████| 782/782 [00:25<00:00, 30.76it/s]
V Loss:  6.1588, Avg Loss:  5.8260, Avg Loss1:  5.8190, Avg Loss2:  5.8329, Counter: 0, Best Loss:  5.8980: 100%|██████████| 157/157 [00:02<00:00, 66.88it/s]


Epoch: 10 | Loss:  5.8260
----------------------------------------


T Loss:  5.5194, Avg Loss:  5.6797, Avg Loss1:  5.6497, Avg Loss2:  5.7096: 100%|██████████| 782/782 [00:25<00:00, 30.70it/s]
V Loss:  5.9024, Avg Loss:  5.7685, Avg Loss1:  5.7582, Avg Loss2:  5.7788, Counter: 0, Best Loss:  5.8260: 100%|██████████| 157/157 [00:02<00:00, 67.62it/s]


Epoch: 11 | Loss:  5.7685
----------------------------------------


T Loss:  5.7706, Avg Loss:  5.6027, Avg Loss1:  5.5705, Avg Loss2:  5.6350: 100%|██████████| 782/782 [00:25<00:00, 30.73it/s]
V Loss:  5.5917, Avg Loss:  5.7080, Avg Loss1:  5.7015, Avg Loss2:  5.7144, Counter: 0, Best Loss:  5.7685: 100%|██████████| 157/157 [00:02<00:00, 69.45it/s]


Epoch: 12 | Loss:  5.7080
----------------------------------------


T Loss:  4.6026, Avg Loss:  5.5298, Avg Loss1:  5.4983, Avg Loss2:  5.5612: 100%|██████████| 782/782 [00:25<00:00, 30.63it/s]
V Loss:  5.9076, Avg Loss:  5.6662, Avg Loss1:  5.6613, Avg Loss2:  5.6712, Counter: 0, Best Loss:  5.7080: 100%|██████████| 157/157 [00:02<00:00, 70.01it/s]


Epoch: 13 | Loss:  5.6662
----------------------------------------


T Loss:  5.4133, Avg Loss:  5.4656, Avg Loss1:  5.4360, Avg Loss2:  5.4952: 100%|██████████| 782/782 [00:25<00:00, 30.90it/s]
V Loss:  5.7122, Avg Loss:  5.6254, Avg Loss1:  5.6246, Avg Loss2:  5.6261, Counter: 0, Best Loss:  5.6662: 100%|██████████| 157/157 [00:02<00:00, 67.69it/s]


Epoch: 14 | Loss:  5.6254
----------------------------------------


T Loss:  5.3186, Avg Loss:  5.4053, Avg Loss1:  5.3764, Avg Loss2:  5.4342: 100%|██████████| 782/782 [00:25<00:00, 30.55it/s]
V Loss:  5.5492, Avg Loss:  5.5975, Avg Loss1:  5.5948, Avg Loss2:  5.6002, Counter: 0, Best Loss:  5.6254: 100%|██████████| 157/157 [00:02<00:00, 67.29it/s]


Epoch: 15 | Loss:  5.5975
----------------------------------------


T Loss:  5.2540, Avg Loss:  5.3521, Avg Loss1:  5.3254, Avg Loss2:  5.3788: 100%|██████████| 782/782 [00:25<00:00, 30.44it/s]
V Loss:  6.1230, Avg Loss:  5.5672, Avg Loss1:  5.5706, Avg Loss2:  5.5637, Counter: 0, Best Loss:  5.5975: 100%|██████████| 157/157 [00:02<00:00, 67.06it/s]


Epoch: 16 | Loss:  5.5672
----------------------------------------


T Loss:  5.7274, Avg Loss:  5.3041, Avg Loss1:  5.2789, Avg Loss2:  5.3293: 100%|██████████| 782/782 [00:25<00:00, 30.56it/s]
V Loss:  5.4733, Avg Loss:  5.5407, Avg Loss1:  5.5481, Avg Loss2:  5.5332, Counter: 0, Best Loss:  5.5672: 100%|██████████| 157/157 [00:02<00:00, 69.93it/s]


Epoch: 17 | Loss:  5.5407
----------------------------------------


T Loss:  5.1112, Avg Loss:  5.2598, Avg Loss1:  5.2351, Avg Loss2:  5.2846: 100%|██████████| 782/782 [00:25<00:00, 30.41it/s]
V Loss:  5.7267, Avg Loss:  5.5178, Avg Loss1:  5.5275, Avg Loss2:  5.5081, Counter: 0, Best Loss:  5.5407: 100%|██████████| 157/157 [00:02<00:00, 69.35it/s]


Epoch: 18 | Loss:  5.5178
----------------------------------------


T Loss:  5.2569, Avg Loss:  5.2179, Avg Loss1:  5.1946, Avg Loss2:  5.2411: 100%|██████████| 782/782 [00:24<00:00, 31.47it/s]
V Loss:  5.2099, Avg Loss:  5.4997, Avg Loss1:  5.5114, Avg Loss2:  5.4880, Counter: 0, Best Loss:  5.5178: 100%|██████████| 157/157 [00:02<00:00, 69.06it/s]


Epoch: 19 | Loss:  5.4997
----------------------------------------


T Loss:  5.1919, Avg Loss:  5.1809, Avg Loss1:  5.1589, Avg Loss2:  5.2030: 100%|██████████| 782/782 [00:25<00:00, 31.12it/s]
V Loss:  5.4726, Avg Loss:  5.4887, Avg Loss1:  5.5015, Avg Loss2:  5.4759, Counter: 0, Best Loss:  5.4997: 100%|██████████| 157/157 [00:02<00:00, 73.10it/s]


Epoch: 20 | Loss:  5.4887
----------------------------------------


T Loss:  4.9272, Avg Loss:  5.1445, Avg Loss1:  5.1230, Avg Loss2:  5.1659: 100%|██████████| 782/782 [00:25<00:00, 31.25it/s]
V Loss:  6.0672, Avg Loss:  5.4880, Avg Loss1:  5.5017, Avg Loss2:  5.4742, Counter: 0, Best Loss:  5.4887: 100%|██████████| 157/157 [00:02<00:00, 70.50it/s]


Epoch: 21 | Loss:  5.4880
----------------------------------------


T Loss:  5.3632, Avg Loss:  5.1144, Avg Loss1:  5.0937, Avg Loss2:  5.1350: 100%|██████████| 782/782 [00:25<00:00, 30.74it/s]
V Loss:  5.9468, Avg Loss:  5.4712, Avg Loss1:  5.4865, Avg Loss2:  5.4558, Counter: 1, Best Loss:  5.4887: 100%|██████████| 157/157 [00:02<00:00, 68.78it/s]


Epoch: 22 | Loss:  5.4712
----------------------------------------


T Loss:  5.2343, Avg Loss:  5.0835, Avg Loss1:  5.0633, Avg Loss2:  5.1037: 100%|██████████| 782/782 [00:25<00:00, 30.75it/s]
V Loss:  4.1086, Avg Loss:  5.4506, Avg Loss1:  5.4678, Avg Loss2:  5.4334, Counter: 0, Best Loss:  5.4712: 100%|██████████| 157/157 [00:02<00:00, 67.98it/s]


Epoch: 23 | Loss:  5.4506
----------------------------------------


T Loss:  5.5144, Avg Loss:  5.0557, Avg Loss1:  5.0358, Avg Loss2:  5.0757: 100%|██████████| 782/782 [00:25<00:00, 30.48it/s]
V Loss:  5.8351, Avg Loss:  5.4512, Avg Loss1:  5.4670, Avg Loss2:  5.4353, Counter: 0, Best Loss:  5.4506: 100%|██████████| 157/157 [00:02<00:00, 66.42it/s]


Epoch: 24 | Loss:  5.4512
----------------------------------------


T Loss:  5.2531, Avg Loss:  5.0294, Avg Loss1:  5.0112, Avg Loss2:  5.0476: 100%|██████████| 782/782 [00:25<00:00, 30.54it/s]
V Loss:  5.3601, Avg Loss:  5.4414, Avg Loss1:  5.4581, Avg Loss2:  5.4247, Counter: 1, Best Loss:  5.4506: 100%|██████████| 157/157 [00:02<00:00, 68.66it/s]


Epoch: 25 | Loss:  5.4414
----------------------------------------


T Loss:  4.9070, Avg Loss:  5.0043, Avg Loss1:  4.9855, Avg Loss2:  5.0231: 100%|██████████| 782/782 [00:25<00:00, 30.78it/s]
V Loss:  5.6066, Avg Loss:  5.4422, Avg Loss1:  5.4571, Avg Loss2:  5.4273, Counter: 0, Best Loss:  5.4414: 100%|██████████| 157/157 [00:02<00:00, 69.08it/s]


Epoch: 26 | Loss:  5.4422
----------------------------------------


T Loss:  5.2434, Avg Loss:  4.9814, Avg Loss1:  4.9645, Avg Loss2:  4.9983: 100%|██████████| 782/782 [00:25<00:00, 30.39it/s]
V Loss:  4.7730, Avg Loss:  5.4338, Avg Loss1:  5.4504, Avg Loss2:  5.4172, Counter: 1, Best Loss:  5.4414: 100%|██████████| 157/157 [00:02<00:00, 68.60it/s]


Epoch: 27 | Loss:  5.4338
----------------------------------------


T Loss:  4.5018, Avg Loss:  4.9601, Avg Loss1:  4.9423, Avg Loss2:  4.9780: 100%|██████████| 782/782 [00:25<00:00, 30.94it/s]
V Loss:  6.3275, Avg Loss:  5.4379, Avg Loss1:  5.4560, Avg Loss2:  5.4199, Counter: 0, Best Loss:  5.4338: 100%|██████████| 157/157 [00:02<00:00, 69.52it/s]


Epoch: 28 | Loss:  5.4379
----------------------------------------


T Loss:  5.1042, Avg Loss:  4.9399, Avg Loss1:  4.9228, Avg Loss2:  4.9570: 100%|██████████| 782/782 [00:25<00:00, 30.69it/s]
V Loss:  4.8384, Avg Loss:  5.4302, Avg Loss1:  5.4461, Avg Loss2:  5.4143, Counter: 1, Best Loss:  5.4338: 100%|██████████| 157/157 [00:02<00:00, 66.62it/s]


Epoch: 29 | Loss:  5.4302
----------------------------------------


T Loss:  5.3182, Avg Loss:  4.9207, Avg Loss1:  4.9037, Avg Loss2:  4.9377: 100%|██████████| 782/782 [00:25<00:00, 30.77it/s]
V Loss:  5.7474, Avg Loss:  5.4283, Avg Loss1:  5.4457, Avg Loss2:  5.4109, Counter: 0, Best Loss:  5.4302: 100%|██████████| 157/157 [00:02<00:00, 68.43it/s]


Epoch: 30 | Loss:  5.4283
----------------------------------------


T Loss:  5.0406, Avg Loss:  4.9010, Avg Loss1:  4.8839, Avg Loss2:  4.9180: 100%|██████████| 782/782 [00:25<00:00, 31.06it/s]
V Loss:  5.1248, Avg Loss:  5.4241, Avg Loss1:  5.4423, Avg Loss2:  5.4060, Counter: 0, Best Loss:  5.4283: 100%|██████████| 157/157 [00:02<00:00, 70.58it/s]


Epoch: 31 | Loss:  5.4241
----------------------------------------


T Loss:  5.0918, Avg Loss:  4.8858, Avg Loss1:  4.8694, Avg Loss2:  4.9022: 100%|██████████| 782/782 [00:25<00:00, 30.58it/s]
V Loss:  5.4332, Avg Loss:  5.4263, Avg Loss1:  5.4450, Avg Loss2:  5.4076, Counter: 0, Best Loss:  5.4241: 100%|██████████| 157/157 [00:02<00:00, 69.08it/s]


Epoch: 32 | Loss:  5.4263
----------------------------------------


T Loss:  5.0259, Avg Loss:  4.8684, Avg Loss1:  4.8511, Avg Loss2:  4.8856: 100%|██████████| 782/782 [00:25<00:00, 30.75it/s]
V Loss:  5.4509, Avg Loss:  5.4252, Avg Loss1:  5.4440, Avg Loss2:  5.4064, Counter: 1, Best Loss:  5.4241: 100%|██████████| 157/157 [00:02<00:00, 67.47it/s]


Epoch: 33 | Loss:  5.4252
----------------------------------------


T Loss:  4.6874, Avg Loss:  4.8523, Avg Loss1:  4.8363, Avg Loss2:  4.8682: 100%|██████████| 782/782 [00:25<00:00, 30.53it/s]
V Loss:  5.0181, Avg Loss:  5.4255, Avg Loss1:  5.4438, Avg Loss2:  5.4072, Counter: 2, Best Loss:  5.4241: 100%|██████████| 157/157 [00:02<00:00, 69.70it/s]


Epoch: 34 | Loss:  5.4255
----------------------------------------


T Loss:  5.0684, Avg Loss:  4.8379, Avg Loss1:  4.8228, Avg Loss2:  4.8530: 100%|██████████| 782/782 [00:25<00:00, 30.67it/s]
V Loss:  5.5330, Avg Loss:  5.4266, Avg Loss1:  5.4429, Avg Loss2:  5.4103, Counter: 3, Best Loss:  5.4241: 100%|██████████| 157/157 [00:02<00:00, 65.58it/s]

Epoch: 35 | Loss:  5.4266





In [23]:
downstream_train_raw_a = actual_df[DEV_TRAIN_LEN+DEV_VALIDATION_LEN:]
downstream_train_raw_p = pred_df[DEV_TRAIN_LEN+DEV_VALIDATION_LEN:]

downstream_validation_raw_a = actual_df[:DEV_TRAIN_LEN+DEV_VALIDATION_LEN]
downstream_validation_raw_p = pred_df[:DEV_TRAIN_LEN+DEV_VALIDATION_LEN]

In [24]:
print(len(downstream_train_raw_a))
print(len(downstream_validation_raw_a))

90000
30000


In [25]:
downstream_train_dataset = Sentences(downstream_train_raw_a, downstream_train_raw_p, char_to_idx, word_to_idx)
downstream_validation_dataset = Sentences(downstream_validation_raw_a, downstream_validation_raw_p, char_to_idx, word_to_idx)

downstream_train_loader = DataLoader(downstream_train_dataset, batch_size=BATCH_SIZE, shuffle=True, collate_fn=collate_fn)
downstream_validation_loader = DataLoader(downstream_validation_dataset, batch_size=BATCH_SIZE, shuffle=True, collate_fn=collate_fn)

## Downstream Task

In [41]:
from sklearn.metrics import classification_report, confusion_matrix

class NewsClassification(nn.Module):
    def __init__(self, 
            char_vocab: int,
            hidden_dim: int, 
            vocab_size: int, 
            num_classes: int,
            filename: str = None
        ) -> None:

        super(NewsClassification, self).__init__()
        self.hidden_dim = hidden_dim
        self.vocab_size = vocab_size
        self.num_classes = num_classes
        self.elmo = ELMo(
            char_vocab=char_vocab, 
            char_embed_dim=CHAR_EMBEDDING_DIM, 
            char_out_channels=[CHAR_OUT_CHANNELS] * 5,
            char_kernel_sizes=[2, 3, 4, 5, 6], 
            dropout=DROPOUT, 
            num_layers=NUM_LAYERS, 
            hidden_dim=HIDDEN_DIM,
            word_embed_dim=WORD_EMB_DIM,
            filename=filename
        )

        for param in self.elmo.parameters():
            param.requires_grad = False

        self.delta = nn.Parameter(torch.randn(1, 3))
        self.linear = nn.Sequential(
            nn.Linear(2 * hidden_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, num_classes)
        )

    def forward(self, x: torch.Tensor, l: torch.Tensor) -> torch.Tensor:
        _, _, input, (hsf, csf), (hsb, csb) = self.elmo(x, l)
        hsf = hsf.permute(1, 0, 2)
        csf = csf.permute(1, 0, 2)
        hsb = hsb.permute(1, 0, 2)
        csb = csb.permute(1, 0, 2)

        hs = torch.cat([hsf, hsb], dim=2)
        cs = torch.cat([csf, csb], dim=2)

        val = (hs + cs) / 2

        input = torch.mean(input, dim=1).unsqueeze(1)

        val = torch.cat([input, val], dim=1)
        val = (self.delta / (torch.sum(self.delta))) @ val
        val = val.squeeze()

        x = self.linear(val)
        return x

    def fit(self, 
            train_loader: DataLoader, 
            validation_loader: DataLoader, 
            epochs: int, 
            learning_rate: float
        ) -> None:

        self.es = EarlyStopping()
        self.optimizer = torch.optim.Adam(self.parameters(), lr=learning_rate)
        self.criterion = nn.CrossEntropyLoss()

        for epoch in range(epochs):
            print('----------------------------------------')
            self._train(train_loader)
            loss = self._evaluate(validation_loader)
            print(f'Epoch: {epoch + 1} | Loss: {loss:7.4f}')
            if self.es(loss, epoch):
                break
            if self.es.counter == 0:
                torch.save(self.state_dict(), os.path.join(DIR, 'best.pth'))

    def _call(self, x: torch.Tensor, y: torch.Tensor, l: torch.Tensor, yf: torch.Tensor, yb: torch.Tensor) -> torch.Tensor:
        x, y, yf, yb = x.to(DEVICE), y.to(DEVICE), yf.to(DEVICE), yb.to(DEVICE)

        y_hat = self(x, l)
        y_hat = y_hat.view(-1, self.num_classes)
        y = y.view(-1)
        loss = self.criterion(y_hat, y)
        return loss

    def _train(self, train_loader: DataLoader) -> None:
        self.train()
        epoch_loss = []
        pbar = tqdm(train_loader)
        for x, y, l, yf, yb in pbar:
            loss = self._call(x, y, l, yf, yb)
            epoch_loss.append(loss.item())
            self.optimizer.zero_grad()
            loss.backward()
            self.optimizer.step()

            pbar.set_description(f'T Loss: {loss.item():7.4f}, Avg Loss: {np.mean(epoch_loss):7.4f}')

    def _evaluate(self, validation_loader: DataLoader) -> float:
        self.eval()
        epoch_loss = []
        pbar = tqdm(validation_loader)
        with torch.no_grad():
            for x, y, l, yf, yb in pbar:
                loss = self._call(x, y, l, yf, yb)
                epoch_loss.append(loss.item())
                pbar.set_description(f'V Loss: {epoch_loss[-1]:7.4f}, Avg Loss: {np.mean(epoch_loss):7.4f}, Counter: {self.es.counter}, Best Loss: {self.es.best_loss:7.4f}')
        return np.mean(epoch_loss)

    def _metrics(self, test_loader: DataLoader) -> None:
        self.eval()
        self.criterion = nn.CrossEntropyLoss()
        pbar = tqdm(test_loader)
        y_pred = []
        y_true = []
        epoch_loss = []

        with torch.no_grad():
            for x, y, l, yf, yb in pbar:
                x, y, yf, yb = x.to(DEVICE), y.to(DEVICE), yf.to(DEVICE), yb.to(DEVICE)
                y_hat = self(x, l)
                y_hat = y_hat.view(-1, self.num_classes)
                y = y.view(-1)
                loss = self.criterion(y_hat, y)

                epoch_loss.append(loss.item())
                y_hat = torch.argmax(y_hat, dim=1)
                y_pred += y_hat.tolist()
                y_true += y.tolist()

        print(f'Test Loss: {np.mean(epoch_loss):7.4f}')

        cr = classification_report(y_true, y_pred, digits=4)
        print('Classification Report:', cr)

        cm = confusion_matrix(y_true, y_pred)
        print('Confusion Matrix:', cm)

In [42]:
nc = NewsClassification(char_vocab=len(char_to_idx), 
                        hidden_dim=HIDDEN_DIM, 
                        vocab_size=len(word_to_idx), 
                        num_classes=NUM_CLASSES,
                        filename=os.path.join(DIR, 'best_elmo.pth')
                    ).to(DEVICE)

In [29]:
summary(nc, device=DEVICE)

Layer (type:depth-idx)                             Param #
NewsClassification                                 3
├─ELMo: 1-1                                        --
│    └─CharCNN: 2-1                                --
│    │    └─Embedding: 3-1                         (1,056)
│    │    └─Dropout: 3-2                           --
│    │    └─ModuleList: 3-3                        (41,280)
│    │    └─Linear: 3-4                            (64,200)
│    └─LSTM: 2-2                                   (201,600)
│    └─LSTM: 2-3                                   (201,600)
│    └─Dropout: 2-4                                --
├─Sequential: 1-2                                  --
│    └─Linear: 2-5                                 20,100
│    └─ReLU: 2-6                                   --
│    └─Linear: 2-7                                 404
Total params: 530,243
Trainable params: 20,507
Non-trainable params: 509,736

In [30]:
nc.fit(downstream_train_loader, downstream_validation_loader, epochs=EPOCHS, learning_rate=LEARNING_RATE)

----------------------------------------


T Loss:  0.7373, Avg Loss:  0.6860: 100%|██████████| 2813/2813 [00:27<00:00, 101.47it/s]
V Loss:  0.8952, Avg Loss:  0.6811, Counter: 0, Best Loss:     inf: 100%|██████████| 938/938 [00:06<00:00, 138.48it/s]


Epoch: 1 | Loss:  0.6811
----------------------------------------


T Loss:  0.4789, Avg Loss:  0.6124: 100%|██████████| 2813/2813 [00:25<00:00, 109.14it/s]
V Loss:  0.5435, Avg Loss:  0.6229, Counter: 0, Best Loss:  0.6811: 100%|██████████| 938/938 [00:06<00:00, 134.75it/s]


Epoch: 2 | Loss:  0.6229
----------------------------------------


T Loss:  0.3010, Avg Loss:  0.6001: 100%|██████████| 2813/2813 [00:25<00:00, 110.08it/s]
V Loss:  0.6057, Avg Loss:  0.5675, Counter: 0, Best Loss:  0.6229: 100%|██████████| 938/938 [00:06<00:00, 137.02it/s]


Epoch: 3 | Loss:  0.5675
----------------------------------------


T Loss:  0.4059, Avg Loss:  0.5863: 100%|██████████| 2813/2813 [00:25<00:00, 108.64it/s]
V Loss:  0.7303, Avg Loss:  0.5405, Counter: 0, Best Loss:  0.5675: 100%|██████████| 938/938 [00:06<00:00, 136.29it/s]


Epoch: 4 | Loss:  0.5405
----------------------------------------


T Loss:  0.3019, Avg Loss:  0.5801: 100%|██████████| 2813/2813 [00:25<00:00, 110.14it/s]
V Loss:  0.2502, Avg Loss:  0.5444, Counter: 0, Best Loss:  0.5405: 100%|██████████| 938/938 [00:06<00:00, 136.28it/s]


Epoch: 5 | Loss:  0.5444
----------------------------------------


T Loss:  0.5700, Avg Loss:  0.5726: 100%|██████████| 2813/2813 [00:25<00:00, 109.87it/s]
V Loss:  0.7550, Avg Loss:  0.5699, Counter: 1, Best Loss:  0.5405: 100%|██████████| 938/938 [00:06<00:00, 137.27it/s]


Epoch: 6 | Loss:  0.5699
----------------------------------------


T Loss:  0.5998, Avg Loss:  0.5726: 100%|██████████| 2813/2813 [00:25<00:00, 110.17it/s]
V Loss:  0.5259, Avg Loss:  0.5297, Counter: 2, Best Loss:  0.5405: 100%|██████████| 938/938 [00:06<00:00, 135.23it/s]


Epoch: 7 | Loss:  0.5297
----------------------------------------


T Loss:  0.7394, Avg Loss:  0.5678: 100%|██████████| 2813/2813 [00:25<00:00, 109.29it/s]
V Loss:  0.7083, Avg Loss:  0.5312, Counter: 0, Best Loss:  0.5297: 100%|██████████| 938/938 [00:07<00:00, 133.43it/s]


Epoch: 8 | Loss:  0.5312
----------------------------------------


T Loss:  0.5228, Avg Loss:  0.5638: 100%|██████████| 2813/2813 [00:25<00:00, 110.50it/s]
V Loss:  0.5998, Avg Loss:  0.5604, Counter: 1, Best Loss:  0.5297: 100%|██████████| 938/938 [00:06<00:00, 136.20it/s]


Epoch: 9 | Loss:  0.5604
----------------------------------------


T Loss:  0.5188, Avg Loss:  0.5640: 100%|██████████| 2813/2813 [00:25<00:00, 110.06it/s]
V Loss:  0.7865, Avg Loss:  0.5128, Counter: 2, Best Loss:  0.5297: 100%|██████████| 938/938 [00:06<00:00, 136.16it/s]


Epoch: 10 | Loss:  0.5128
----------------------------------------


T Loss:  0.5591, Avg Loss:  0.5598: 100%|██████████| 2813/2813 [00:25<00:00, 109.48it/s]
V Loss:  0.4082, Avg Loss:  0.5271, Counter: 0, Best Loss:  0.5128: 100%|██████████| 938/938 [00:06<00:00, 136.40it/s]


Epoch: 11 | Loss:  0.5271
----------------------------------------


T Loss:  0.3950, Avg Loss:  0.5586: 100%|██████████| 2813/2813 [00:25<00:00, 112.10it/s]
V Loss:  0.7758, Avg Loss:  0.5335, Counter: 1, Best Loss:  0.5128: 100%|██████████| 938/938 [00:06<00:00, 137.42it/s]


Epoch: 12 | Loss:  0.5335
----------------------------------------


T Loss:  0.5361, Avg Loss:  0.5553: 100%|██████████| 2813/2813 [00:25<00:00, 110.61it/s]
V Loss:  0.3414, Avg Loss:  0.5346, Counter: 2, Best Loss:  0.5128: 100%|██████████| 938/938 [00:06<00:00, 135.77it/s]


Epoch: 13 | Loss:  0.5346
----------------------------------------


T Loss:  0.7613, Avg Loss:  0.5536: 100%|██████████| 2813/2813 [00:25<00:00, 110.67it/s]
V Loss:  0.3369, Avg Loss:  0.5249, Counter: 3, Best Loss:  0.5128: 100%|██████████| 938/938 [00:06<00:00, 145.80it/s]

Epoch: 14 | Loss:  0.5249





In [43]:
nc.load_state_dict(torch.load(os.path.join(DIR, 'best.pth')))

<All keys matched successfully>

In [31]:
test_adf, test_pdf = read_data('data/test.csv', create_unique_words=False, filter_rare_words=False)

In [32]:
downstream_test_dataset = Sentences(test_adf, test_pdf, char_to_idx, word_to_idx)
downstream_test_loader = DataLoader(downstream_test_dataset, batch_size=BATCH_SIZE, shuffle=True, collate_fn=collate_fn)

In [44]:
nc._metrics(downstream_test_loader)

100%|██████████| 238/238 [00:01<00:00, 213.47it/s]


Test Loss:  0.5309
Classification Report:               precision    recall  f1-score   support

           0     0.8253    0.8253    0.8253      1900
           1     0.8318    0.9289    0.8777      1900
           2     0.7963    0.6953    0.7423      1900
           3     0.7275    0.7347    0.7311      1900

    accuracy                         0.7961      7600
   macro avg     0.7952    0.7961    0.7941      7600
weighted avg     0.7952    0.7961    0.7941      7600

Confusion Matrix: [[1568  140   94   98]
 [  64 1765   23   48]
 [ 127   75 1321  377]
 [ 141  142  221 1396]]
