In [1]:
cfg = {
    'dev_train_len': 5*10**3,
    'dev_validation_len': 1*10**3,
    'learning_rate': 0.001,
    'epochs': 100,
    'embedding_dim': 50,
    'hidden_dim': 50,
    'dropout': 0,
    'optimizer': 'Adam',
    'num_layers': 2
}

In [2]:
from nltk import word_tokenize
from gensim.models import KeyedVectors
import gensim.downloader as api
import unicodedata
import random
import torch
import re

def normalize_unicode(text: str) -> str:
    return unicodedata.normalize('NFD', text)

def tokenize_corpus(s: str) -> str:
    s = normalize_unicode(s)
    s = s.lower()
    s = re.sub(r"""[^a-zA-Z0-9?.,;'"]+""", " ", s)
    s = re.sub(r'(.)\1{3,}',r'\1', s)
    s = s.rstrip().strip()
    return s

def get_word_tokenized_corpus(s: str) -> list:
    return word_tokenize(s)

glove_dict = {
    '50': 'glove-wiki-gigaword-50',
    '100': 'glove-wiki-gigaword-100',
    '200': 'glove-wiki-gigaword-200'
}

glove_dict['50'] = api.load(glove_dict['50'])
# glove_dict['100'] = api.load(glove_dict['100'])
# glove_dict['200'] = api.load(glove_dict['200'])

def create_vocab(sentences: list, embedding_dim: int):
    glove = glove_dict[str(embedding_dim)]

    Emb = KeyedVectors(vector_size=glove.vector_size)
    vocab = []

    for sentence in sentences:
        vocab.extend(sentence)

    vocab = set(vocab)

    vectors, keys = [], []
    for token in vocab:
        if token in glove:
            vectors.append(torch.tensor(glove[token]))
            keys.append(token)

    keys.extend(['<unk>', '<pad>', '<sos>', '<eos>'])
    vectors.append(torch.mean(torch.stack(vectors), dim=0).numpy())
    vectors.append([0 for _ in range(embedding_dim)])
    vectors.append([random.random() for _ in range(embedding_dim)])
    vectors.append([random.random() for _ in range(embedding_dim)])
    Emb.add_vectors(keys, vectors)

    return Emb

def get_sentence_index(sentence: list, Emb: KeyedVectors):
    word_vec = []

    word_vec.append(Emb.key_to_index['<sos>'])
    for word in sentence:
        word_vec.append(get_vocab_index(word, Emb))
    word_vec.append(Emb.key_to_index['<eos>'])

    return torch.tensor(word_vec)

def get_vocab_index(word: str, Emb: KeyedVectors):
    if word in Emb:
        return Emb.key_to_index[word]
    return Emb.key_to_index['<unk>']


In [5]:
import numpy as np
import torch
import pandas as pd

In [6]:
if torch.cuda.is_available():
    DEVICE = torch.device('cuda')
else:
    DEVICE = torch.device('cpu')
print(DEVICE)

cuda


In [7]:
# DEV_TRAIN_LEN = cfg['parameters']['dev_train_len']['value']
# DEV_VALIDATION_LEN = cfg['parameters']['dev_validation_len']['value']

DEV_TRAIN_LEN = cfg['dev_train_len']
DEV_VALIDATION_LEN = cfg['dev_validation_len']

DIR = '/scratch/shu7bh/RES/PRE'

In [8]:
import os
if not os.path.exists(DIR):
    os.makedirs(DIR)

In [9]:
print(DEV_TRAIN_LEN)
print(DEV_VALIDATION_LEN)

5000
1000


In [10]:
df = pd.read_csv('data/train.csv')
df = df.sample(frac=1, random_state=0).reset_index(drop=True)
df['Description'] = df['Description'].apply(tokenize_corpus)
df['Description'] = df['Description'].apply(get_word_tokenized_corpus)

In [11]:
df['Class Index']

0         4
1         4
2         4
3         1
4         3
         ..
119995    4
119996    3
119997    1
119998    4
119999    4
Name: Class Index, Length: 120000, dtype: int64

In [12]:
dev_train = df[:DEV_TRAIN_LEN]['Description']
dev_validation = df[DEV_TRAIN_LEN:DEV_TRAIN_LEN + DEV_VALIDATION_LEN]['Description']

In [13]:
dev_validation

5000    [video, game, publisher, electronic, arts, on,...
5001    [by, amanda, gardner, ,, healthday, reporter, ...
5002    [percival, became, the, tigers, 39, ;, new, cl...
5003    [it, 's, getting, harder, to, shrink, chips, ,...
5004    [canadian, press, halifax, cp, they, have, bec...
                              ...                        
5995    [massive, database, holds, info, on, millions,...
5996    [supreme, court, justices, on, tuesday, uncork...
5997    [the, yankees, should, soon, clinch, their, se...
5998    [the, un, nuclear, agency, agreed, yesterday, ...
5999    [ap, how, do, you, explain, a, quarterback, sn...
Name: Description, Length: 1000, dtype: object

In [17]:
from torch.nn.utils.rnn import pad_sequence
from torch.utils.data import DataLoader

In [16]:
import torch
from torch import nn
from torch.nn.utils.rnn import pack_padded_sequence, pad_packed_sequence

class ELMO(nn.Module):
    def __init__(self, Emb, hidden_dim, dropout, num_layers):
        super().__init__()

        self.hidden_dim = hidden_dim
        self.num_layers = num_layers

        self.embedding = nn.Embedding.from_pretrained(torch.FloatTensor(Emb.vectors), padding_idx=Emb.key_to_index['<pad>'])
        self.lstm = nn.LSTM(Emb.vectors.shape[1], hidden_dim, batch_first=True, bidirectional=True, num_layers=num_layers, dropout=dropout)

    def forward(self, X, X_lengths):
        X = self.embedding(X)
        X = pack_padded_sequence(X, X_lengths, batch_first=True, enforce_sorted=False)
        # X, _ = self.lstm(X, None)
        X, (h_n, _) = self.lstm(X, None)
        X, _ = pad_packed_sequence(X, batch_first=True)
        return X, h_n

In [22]:
from torch import nn

class LM(nn.Module):
    def __init__(self, Emb, hidden_dim, dropout, num_layers):
        super().__init__()

        self.elmo = ELMO(Emb, hidden_dim, dropout, num_layers)
        self.fc = nn.Linear(hidden_dim * 2, Emb.vectors.shape[0])

    def forward(self, X, X_lengths):
        X, _ = self.elmo(X, X_lengths)
        X = self.fc(X)
        return X

In [15]:
from torch.utils.data import Dataset
import torch

class SentencesDataset(Dataset):
    def __init__(self, sentences: list, Emb):
        super().__init__()

        self.data = []
        for sentence in sentences:
            self.data.append(get_sentence_index(sentence, Emb))

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        return self.data[idx], torch.tensor(len(self.data[idx]))

In [18]:
class Collator:
    def __init__(self, Emb):
        self.pad_index = Emb.key_to_index['<pad>']

    def __call__(self, batch):
        X, X_lengths = zip(*batch)
        X = pad_sequence(X, batch_first=True, padding_value=self.pad_index)
        return X[:, :-1], X[:, 1:], torch.stack(X_lengths) - 1

In [19]:
import tqdm

def fit(model, dataloader, train, es, loss_fn, optimizer):
    model.train() if train else model.eval()
    epoch_loss = []

    pbar = tqdm.tqdm(dataloader)

    for X, Y, X_lengths in pbar:
        X = X.to(DEVICE)
        Y = Y.to(DEVICE)

        Y_pred = model(X, X_lengths)
        Y_pred = Y_pred.reshape(-1, Y_pred.shape[2])

        Y = Y.reshape(-1)

        loss = loss_fn(Y_pred, Y)
        epoch_loss.append(loss.item())

        X.detach()
        Y_pred.detach()
        Y.detach()

        if train:
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

        pbar.set_description(f'{"T" if train else "V"} Loss: {loss.item():7.4f}, Avg Loss: {np.mean(epoch_loss):7.4f}, Best Loss: {es.best_loss:7.4f}, Counter: {es.counter}')

    return np.mean(epoch_loss)

In [23]:
import numpy as np

class EarlyStopping:
    def __init__(self, patience:int = 3, delta:float = 0.001):
        self.patience = patience
        self.counter = 0
        self.best_loss:float = np.inf
        self.best_model_pth = 0
        self.delta = delta

    def __call__(self, loss, epoch: int):
        should_stop = False

        if loss >= self.best_loss - self.delta:
            self.counter += 1
            if self.counter > self.patience:
                should_stop = True
        else:
            self.best_loss = loss
            self.counter = 0
            self.best_model_pth = epoch
        return should_stop

In [20]:
def train(EPOCHS, model, training_dataloader, validation_dataloader, loss_fn, optimizer):
    es = EarlyStopping(patience=1, delta=0.1)

    for epoch in range(EPOCHS):
        print(f'\nEpoch {epoch+1}')

        epoch_loss = fit(model, training_dataloader, True, es, loss_fn, optimizer)
        # wandb.log({'train_loss': epoch_loss})

        with torch.no_grad():
            epoch_loss = fit(model, validation_dataloader, False, es, loss_fn, optimizer)
            # wandb.log({'validation_loss': epoch_loss})
            if es(epoch_loss, epoch):
                break
            if es.counter == 0:
                torch.save(model.state_dict(), os.path.join(DIR, f'best_model.pth'))
                torch.save(model.elmo.state_dict(), os.path.join(DIR, f'best_model_elmo.pth'))

In [21]:
def run(config=None):
    # with wandb.init(config=cfg):
        # config = wandb.config
    BATCH_SIZE = 16
    # if config.hidden_dim in [300, 500]:
    #     BATCH_SIZE = 32
    if config['hidden_dim'] in [300, 500]:
        BATCH_SIZE = 32

    # wandb.log({'batch_size': BATCH_SIZE})
    HIDDEN_DIM = config['hidden_dim']
    DROP_OUT = config['dropout']
    OPTIMIZER = config['optimizer']
    LEARNING_RATE = config['learning_rate']
    EPOCHS = config['epochs']
    EMBEDDITNG_DIM = config['embedding_dim']
    NUM_LAYERS = config['num_layers']

    Emb = create_vocab(df['Description'], EMBEDDITNG_DIM)

    dev_train_dataset = SentencesDataset(dev_train, Emb)
    dev_validation_dataset = SentencesDataset(dev_validation, Emb)

    collate_fn = Collator(Emb)

    training_dataloader = DataLoader(dev_train_dataset, batch_size=BATCH_SIZE, shuffle=True, collate_fn=collate_fn, pin_memory=True, num_workers=4)
    validation_dataloader = DataLoader(dev_validation_dataset, batch_size=BATCH_SIZE, shuffle=True, collate_fn=collate_fn, pin_memory=True, num_workers=4)

    torch.cuda.empty_cache()

    model = LM(Emb, HIDDEN_DIM, DROP_OUT, NUM_LAYERS).to(DEVICE)

    optimizer = getattr(torch.optim, OPTIMIZER)(model.parameters(), lr=LEARNING_RATE)
    loss_fn = nn.CrossEntropyLoss(ignore_index=Emb.key_to_index['<pad>'])

    train(EPOCHS, model, training_dataloader, validation_dataloader, loss_fn, optimizer)

run(cfg)


Epoch 1


T Loss:  6.9441, Avg Loss:  7.5440, Best Loss:     inf, Counter: 0: 100%|██████████| 313/313 [00:06<00:00, 45.14it/s]
V Loss:  7.0764, Avg Loss:  7.1183, Best Loss:     inf, Counter: 0: 100%|██████████| 63/63 [00:00<00:00, 70.86it/s]



Epoch 2


T Loss:  6.7168, Avg Loss:  6.8660, Best Loss:  7.1183, Counter: 0: 100%|██████████| 313/313 [00:06<00:00, 48.23it/s]
V Loss:  6.7011, Avg Loss:  6.8075, Best Loss:  7.1183, Counter: 0: 100%|██████████| 63/63 [00:00<00:00, 79.66it/s] 



Epoch 3


T Loss:  5.8657, Avg Loss:  6.1721, Best Loss:  6.8075, Counter: 0: 100%|██████████| 313/313 [00:06<00:00, 48.54it/s]
V Loss:  6.2183, Avg Loss:  5.9884, Best Loss:  6.8075, Counter: 0: 100%|██████████| 63/63 [00:00<00:00, 71.29it/s]



Epoch 4


T Loss:  4.6929, Avg Loss:  5.4382, Best Loss:  5.9884, Counter: 0: 100%|██████████| 313/313 [00:06<00:00, 48.83it/s]
V Loss:  5.2009, Avg Loss:  5.3672, Best Loss:  5.9884, Counter: 0: 100%|██████████| 63/63 [00:00<00:00, 72.70it/s]



Epoch 5


T Loss:  4.5941, Avg Loss:  4.8246, Best Loss:  5.3672, Counter: 0: 100%|██████████| 313/313 [00:06<00:00, 48.66it/s]
V Loss:  5.3075, Avg Loss:  4.8565, Best Loss:  5.3672, Counter: 0: 100%|██████████| 63/63 [00:00<00:00, 79.24it/s] 



Epoch 6


T Loss:  4.4263, Avg Loss:  4.3121, Best Loss:  4.8565, Counter: 0: 100%|██████████| 313/313 [00:06<00:00, 48.62it/s]
V Loss:  4.4044, Avg Loss:  4.4156, Best Loss:  4.8565, Counter: 0: 100%|██████████| 63/63 [00:00<00:00, 80.05it/s] 



Epoch 7


T Loss:  3.7449, Avg Loss:  3.8598, Best Loss:  4.4156, Counter: 0: 100%|██████████| 313/313 [00:06<00:00, 48.32it/s]
V Loss:  3.8946, Avg Loss:  4.0462, Best Loss:  4.4156, Counter: 0: 100%|██████████| 63/63 [00:00<00:00, 74.30it/s]



Epoch 8


T Loss:  2.9055, Avg Loss:  3.4715, Best Loss:  4.0462, Counter: 0: 100%|██████████| 313/313 [00:06<00:00, 48.98it/s]
V Loss:  3.2890, Avg Loss:  3.7275, Best Loss:  4.0462, Counter: 0: 100%|██████████| 63/63 [00:00<00:00, 78.42it/s]



Epoch 9


T Loss:  2.7592, Avg Loss:  3.1337, Best Loss:  3.7275, Counter: 0: 100%|██████████| 313/313 [00:06<00:00, 49.33it/s]
V Loss:  3.4785, Avg Loss:  3.4716, Best Loss:  3.7275, Counter: 0: 100%|██████████| 63/63 [00:00<00:00, 78.67it/s] 



Epoch 10


T Loss:  2.5439, Avg Loss:  2.8317, Best Loss:  3.4716, Counter: 0: 100%|██████████| 313/313 [00:06<00:00, 48.70it/s]
V Loss:  3.3278, Avg Loss:  3.2395, Best Loss:  3.4716, Counter: 0: 100%|██████████| 63/63 [00:00<00:00, 74.84it/s]



Epoch 11


T Loss:  2.6278, Avg Loss:  2.5573, Best Loss:  3.2395, Counter: 0: 100%|██████████| 313/313 [00:06<00:00, 48.09it/s]
V Loss:  3.0383, Avg Loss:  3.0320, Best Loss:  3.2395, Counter: 0: 100%|██████████| 63/63 [00:00<00:00, 73.35it/s]



Epoch 12


T Loss:  2.4314, Avg Loss:  2.3056, Best Loss:  3.0320, Counter: 0: 100%|██████████| 313/313 [00:06<00:00, 49.25it/s]
V Loss:  2.7488, Avg Loss:  2.8391, Best Loss:  3.0320, Counter: 0: 100%|██████████| 63/63 [00:00<00:00, 74.16it/s] 



Epoch 13


T Loss:  2.3073, Avg Loss:  2.0767, Best Loss:  2.8391, Counter: 0: 100%|██████████| 313/313 [00:06<00:00, 47.98it/s]
V Loss:  2.6691, Avg Loss:  2.6829, Best Loss:  2.8391, Counter: 0: 100%|██████████| 63/63 [00:00<00:00, 74.74it/s] 



Epoch 14


T Loss:  1.7840, Avg Loss:  1.8694, Best Loss:  2.6829, Counter: 0: 100%|██████████| 313/313 [00:06<00:00, 49.49it/s]
V Loss:  2.7172, Avg Loss:  2.5355, Best Loss:  2.6829, Counter: 0: 100%|██████████| 63/63 [00:00<00:00, 73.90it/s]



Epoch 15


T Loss:  1.6992, Avg Loss:  1.6790, Best Loss:  2.5355, Counter: 0: 100%|██████████| 313/313 [00:06<00:00, 48.22it/s]
V Loss:  1.9621, Avg Loss:  2.4023, Best Loss:  2.5355, Counter: 0: 100%|██████████| 63/63 [00:00<00:00, 75.06it/s] 



Epoch 16


T Loss:  1.6972, Avg Loss:  1.5075, Best Loss:  2.4023, Counter: 0: 100%|██████████| 313/313 [00:06<00:00, 49.47it/s]
V Loss:  2.1930, Avg Loss:  2.2933, Best Loss:  2.4023, Counter: 0: 100%|██████████| 63/63 [00:00<00:00, 73.16it/s]



Epoch 17


T Loss:  1.2564, Avg Loss:  1.3498, Best Loss:  2.2933, Counter: 0: 100%|██████████| 313/313 [00:06<00:00, 49.53it/s]
V Loss:  2.0299, Avg Loss:  2.1947, Best Loss:  2.2933, Counter: 0: 100%|██████████| 63/63 [00:00<00:00, 74.34it/s]



Epoch 18


T Loss:  1.2489, Avg Loss:  1.2071, Best Loss:  2.2933, Counter: 1: 100%|██████████| 313/313 [00:06<00:00, 48.57it/s]
V Loss:  1.8103, Avg Loss:  2.1081, Best Loss:  2.2933, Counter: 1: 100%|██████████| 63/63 [00:00<00:00, 76.89it/s] 



Epoch 19


T Loss:  1.0416, Avg Loss:  1.0746, Best Loss:  2.1081, Counter: 0: 100%|██████████| 313/313 [00:06<00:00, 48.59it/s]
V Loss:  1.8913, Avg Loss:  2.0236, Best Loss:  2.1081, Counter: 0: 100%|██████████| 63/63 [00:00<00:00, 70.21it/s]



Epoch 20


T Loss:  0.8096, Avg Loss:  0.9540, Best Loss:  2.1081, Counter: 1: 100%|██████████| 313/313 [00:06<00:00, 49.22it/s]
V Loss:  1.8867, Avg Loss:  1.9514, Best Loss:  2.1081, Counter: 1: 100%|██████████| 63/63 [00:00<00:00, 72.42it/s]



Epoch 21


T Loss:  0.7585, Avg Loss:  0.8428, Best Loss:  1.9514, Counter: 0: 100%|██████████| 313/313 [00:06<00:00, 49.10it/s]
V Loss:  1.9638, Avg Loss:  1.8796, Best Loss:  1.9514, Counter: 0: 100%|██████████| 63/63 [00:00<00:00, 75.20it/s] 



Epoch 22


T Loss:  0.7479, Avg Loss:  0.7405, Best Loss:  1.9514, Counter: 1: 100%|██████████| 313/313 [00:06<00:00, 49.00it/s]
V Loss:  1.9668, Avg Loss:  1.8235, Best Loss:  1.9514, Counter: 1: 100%|██████████| 63/63 [00:00<00:00, 75.43it/s] 



Epoch 23


T Loss:  0.7127, Avg Loss:  0.6486, Best Loss:  1.8235, Counter: 0: 100%|██████████| 313/313 [00:06<00:00, 48.45it/s]
V Loss:  1.9967, Avg Loss:  1.7759, Best Loss:  1.8235, Counter: 0: 100%|██████████| 63/63 [00:00<00:00, 74.69it/s]



Epoch 24


T Loss:  0.5664, Avg Loss:  0.5654, Best Loss:  1.8235, Counter: 1: 100%|██████████| 313/313 [00:06<00:00, 49.95it/s]
V Loss:  1.5640, Avg Loss:  1.7245, Best Loss:  1.8235, Counter: 1: 100%|██████████| 63/63 [00:00<00:00, 73.70it/s]


In [24]:
import tqdm

def run_epoch(model, dataloader, loss_fn):
    epoch_loss = []

    pbar = tqdm.tqdm(dataloader)

    for X, Y, X_lengths in pbar:
        X = X.to(DEVICE)
        Y = Y.to(DEVICE)

        Y_pred = model(X, X_lengths)
        Y_pred = Y_pred.reshape(-1, Y_pred.shape[2])

        Y = Y.reshape(-1)

        loss = loss_fn(Y_pred, Y)
        epoch_loss.append(loss.item())

        pbar.set_description(f'Loss: {loss.item():7.4f}, Avg Loss: {np.mean(epoch_loss):7.4f}')

    return np.mean(epoch_loss)

In [25]:
def validate(elmo, validation_dataloader, loss_fn):
    with torch.no_grad():
        elmo.eval()
        epoch_loss = run_epoch(elmo, validation_dataloader, loss_fn)
        print(f'Validation Loss: {epoch_loss:7.4f}')

In [26]:
Emb = create_vocab(df['Description'], cfg['embedding_dim'])

# dev_train_dataset = SentencesDataset(dev_train, Emb)
dev_validation_dataset = SentencesDataset(dev_validation, Emb)

collate_fn = Collator(Emb)

# training_dataloader = DataLoader(dev_train_dataset, batch_size=BATCH_SIZE, shuffle=True, collate_fn=collate_fn, pin_memory=True, num_workers=4)
validation_dataloader = DataLoader(dev_validation_dataset, batch_size=16, shuffle=True, collate_fn=collate_fn, pin_memory=True, num_workers=4)
model = LM(Emb, cfg['hidden_dim'], cfg['dropout'], cfg['num_layers']).to(DEVICE)

model.load_state_dict(torch.load(os.path.join(DIR, 'best_model.pth')))

validate(model, validation_dataloader, nn.CrossEntropyLoss(ignore_index=Emb.key_to_index['<pad>']))

Loss:  2.0416, Avg Loss:  1.8229: 100%|██████████| 63/63 [00:00<00:00, 80.82it/s] 

Validation Loss:  1.8229





In [27]:
elmo = ELMO(Emb, cfg['hidden_dim'], cfg['dropout'], cfg['num_layers']).to(DEVICE)

# load only elmo from the best model
elmo.load_state_dict(torch.load(os.path.join(DIR, 'best_model_elmo.pth')), strict=False)

<All keys matched successfully>

In [28]:
class DownStream(nn.Module):
    def __init__(self, elmo, dropout):
        super().__init__()
        self.elmo = elmo
        # freeze the ELMO parameters
        for param in self.elmo.parameters():
            param.requires_grad = False

        self.hidden_dim = self.elmo.hidden_dim
        self.num_layers = self.elmo.num_layers
        self.dropout = dropout

        self.delta = nn.Parameter(torch.randn(1, self.num_layers + 1))
        self.linear = nn.Linear(self.hidden_dim * 2, 4)

    def forward(self, X, X_lengths):
        _, Y = self.elmo(X, X_lengths)
        # get the first hidden layer batch_first=True

        print(Y.shape)
        Y = Y.permute(1, 0, 2).reshape(Y.shape[0], self.num_layers, self.hidden_dim * 2)

        X = torch.mean(X, dim=1)

        # Y = Y.reshape(Y.shape[0], Y.shape[1], self.num_layers, 2 * self.hidden_dim)
        Y = torch.stack([X, Y], dim=1)

        # mean dimension 1 to get Y of shape (batch_size, 1, num_layers + 1, 2 * hidden_dim)
        # Y = torch.mean(Y, dim=1)

        # multiply by delta 
        Y = Y * (self.delta / torch.sum(self.delta))

        # sum over the num_layers dimension
        Y = torch.sum(Y, dim=1)

        # pass through linear layer
        Y = self.linear(Y)

        return Y

In [29]:
from torch.utils.data import DataLoader, Dataset

class DownStreamDataset(Dataset):
    def __init__(self, df: pd.DataFrame, Emb: KeyedVectors):
        self.descriptions = df['Description']
        self.descriptions = [get_sentence_index(description, Emb) for description in self.descriptions]
        # self.df.loc[: 'Description'] = df['Description'].apply(get_sentence_index, Emb=Emb)
        self.class_index = df['Class Index']
        self.Emb = Emb

    def __len__(self):
        return len(self.class_index)

    def __getitem__(self, idx):
        l = torch.tensor(len(self.descriptions[idx]))
        return self.descriptions[idx], l, torch.tensor(self.class_index[idx])

In [30]:
class DownStreamCollator:
    def __init__(self, Emb):
        self.pad_index = Emb.key_to_index['<pad>']

    def __call__(self, batch):
        X, X_lengths, Y = zip(*batch)
        X = pad_sequence(X, batch_first=True, padding_value=self.pad_index)
        print(len(X_lengths))
        print(len(Y))
        print('hello')
        return X, torch.stack(X_lengths), torch.stack(Y)

In [31]:
downstream_train = df[DEV_TRAIN_LEN + DEV_VALIDATION_LEN:]
downstream_validation = df[:DEV_TRAIN_LEN + DEV_VALIDATION_LEN]

downstream_train_dataset = DownStreamDataset(downstream_train, Emb)
downstream_validation_dataset = DownStreamDataset(downstream_validation, Emb)

downstream_collate_fn = DownStreamCollator(Emb)

downstream_training_dataloader = DataLoader(downstream_train_dataset, batch_size=16, shuffle=True, collate_fn=downstream_collate_fn, pin_memory=True)
downstream_validation_dataloader = DataLoader(downstream_validation_dataset, batch_size=16, shuffle=True, collate_fn=downstream_collate_fn, pin_memory=True)

In [32]:
def downstream_fit(model, dataloader, train, loss_fn, optimizer):
    model.train() if train else model.eval()
    epoch_loss = []

    pbar = tqdm.tqdm(dataloader)

    for X, X_lengths, Y in pbar:
        print('hello')
        print(X.shape)
        print(X_lengths.shape)
        print(Y.shape)
        X = X.to(DEVICE)
        Y = Y.to(DEVICE)

        Y_pred = model(X, X_lengths)
        Y_pred = Y_pred.reshape(-1, Y_pred.shape[2])

        Y = Y.reshape(-1)

        loss = loss_fn(Y_pred, Y)
        epoch_loss.append(loss.item())

        X.detach()
        Y_pred.detach()
        Y.detach()

        if train:
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

        pbar.set_description(f'{"T" if train else "V"} Loss: {loss.item():7.4f}, Avg Loss: {np.mean(epoch_loss):7.4f}')

    return np.mean(epoch_loss)

In [33]:
dmodel = DownStream(elmo, cfg['dropout']).to(DEVICE)
doptimizer = getattr(torch.optim, cfg['optimizer'])(dmodel.parameters(), lr=cfg['learning_rate'])
dloss_fn = nn.CrossEntropyLoss()

def downstream_train_fn(EPOCHS, model, training_dataloader, validation_dataloader, loss_fn, optimizer):
    es = EarlyStopping(patience=1, delta=0.1)
    for epoch in range(EPOCHS):
        print(f'\nEpoch {epoch+1}')

        epoch_loss = downstream_fit(model, training_dataloader, True, loss_fn, optimizer)
        print(f'Train Loss: {epoch_loss:7.4f}')
        # wandb.log({'downstream_train_loss': epoch_loss})

        with torch.no_grad():
            epoch_loss = downstream_fit(model, validation_dataloader, False, loss_fn, optimizer)
            print(f'Validation Loss: {epoch_loss:7.4f}')
            if es(epoch_loss, epoch):
                break
            if es.counter == 0:
                torch.save(model.state_dict(), os.path.join(DIR, f'downstream_best_model.pth'))
            # wandb.log({'downstream_validation_loss': epoch_loss})

In [34]:
downstream_train_fn(cfg['epochs'], dmodel, downstream_training_dataloader, downstream_validation_dataloader, dloss_fn, doptimizer)


Epoch 1


  0%|          | 0/7125 [00:00<?, ?it/s]


KeyError: 1920