In [1]:
from config import config as cfg
from preprocess import *
import numpy as np
import wandb
import torch
import pandas as pd
import os

In [2]:
if torch.cuda.is_available():
    DEVICE = torch.device('cuda')
else:
    DEVICE = torch.device('cpu')
print(DEVICE)

cuda


In [3]:
DEV_TRAIN_LEN = cfg['parameters']['dev_train_len']['value']
DEV_VALIDATION_LEN = cfg['parameters']['dev_validation_len']['value']

DIR = '/scratch/shu7bh/RES/PRE'

In [4]:
import os
if not os.path.exists(DIR):
    os.makedirs(DIR)

In [5]:
print(DEV_TRAIN_LEN)
print(DEV_VALIDATION_LEN)

5000
1000


In [6]:
df = pd.read_csv('data/train.csv')
df = df.sample(frac=1, random_state=0).reset_index(drop=True)
df['Description'] = df['Description'].apply(tokenize_corpus)
df['Description'] = df['Description'].apply(get_word_tokenized_corpus)

In [7]:
dev_train = df[:DEV_TRAIN_LEN]['Description']
dev_validation = df[DEV_TRAIN_LEN:DEV_TRAIN_LEN + DEV_VALIDATION_LEN]['Description']

In [8]:
dev_validation

5000    [video, game, publisher, electronic, arts, on,...
5001    [by, amanda, gardner, ,, healthday, reporter, ...
5002    [percival, became, the, tigers, 39, ;, new, cl...
5003    [it, 's, getting, harder, to, shrink, chips, ,...
5004    [canadian, press, halifax, cp, they, have, bec...
                              ...                        
5995    [massive, database, holds, info, on, millions,...
5996    [supreme, court, justices, on, tuesday, uncork...
5997    [the, yankees, should, soon, clinch, their, se...
5998    [the, un, nuclear, agency, agreed, yesterday, ...
5999    [ap, how, do, you, explain, a, quarterback, sn...
Name: Description, Length: 1000, dtype: object

In [9]:
from torch.nn.utils.rnn import pad_sequence
from torch.utils.data import DataLoader
from dataset import SentencesDataset
from elmo import ELMO
from torch import nn

In [10]:
class Collator:
    def __init__(self, Emb):
        self.pad_index = Emb.key_to_index['<pad>']

    def __call__(self, batch):
        X, X_lengths = zip(*batch)
        X = pad_sequence(X, batch_first=True, padding_value=self.pad_index)
        return X[:, :-1], X[:, 1:], torch.stack(X_lengths) - 1

In [11]:
import tqdm

def fit(model, dataloader, train, es, loss_fn, optimizer):
    model.train() if train else model.eval()
    epoch_loss = []

    pbar = tqdm.tqdm(dataloader)

    for X, Y, X_lengths in pbar:
        X = X.to(DEVICE)
        Y = Y.to(DEVICE)

        Y_pred = model(X, X_lengths)
        Y_pred = Y_pred.reshape(-1, Y_pred.shape[2])

        Y = Y.reshape(-1)

        loss = loss_fn(Y_pred, Y)
        epoch_loss.append(loss.item())

        X.detach()
        Y_pred.detach()
        Y.detach()

        if train:
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

        pbar.set_description(f'{"T" if train else "V"} Loss: {loss.item():7.4f}, Avg Loss: {np.mean(epoch_loss):7.4f}, Best Loss: {es.best_loss:7.4f}, Counter: {es.counter}')

    return np.mean(epoch_loss)

In [12]:
import EarlyStopping as ES

def train(EPOCHS, elmo, training_dataloader, validation_dataloader, loss_fn, optimizer):
    es = ES.EarlyStopping(patience=2, delta=0.01)

    for epoch in range(EPOCHS):
        print(f'\nEpoch {epoch+1}')

        epoch_loss = fit(elmo, training_dataloader, True, es, loss_fn, optimizer)
        wandb.log({'train_loss': epoch_loss})

        with torch.no_grad():
            epoch_loss = fit(elmo, validation_dataloader, False, es, loss_fn, optimizer)
            wandb.log({'validation_loss': epoch_loss})
            if es(epoch_loss, epoch):
                break
            if es.counter == 0:
                torch.save(elmo.state_dict(), os.path.join(DIR, f'best_model.pth'))
    
        # torch.save(elmo.state_dict(), os.path.join(DIR, f'elmo_{epoch + 1}.pth'))

    wandb.log({'loss': es.best_loss})
    # os.rename(os.path.join(DIR, f'elmo_{es.best_model_pth + 1}.pth'), os.path.join(DIR, 'best_model.pth'))

    best_model = wandb.Artifact(f'best_model_{wandb.run.id}', type='model')
    best_model.add_file(os.path.join(DIR, 'best_model.pth'))
    wandb.run.log_artifact(best_model)
    # wandb.save(os.path.join(DIR, 'best_model.pth'))

In [13]:
def run(cfg=None):
    with wandb.init(config=cfg):
        config = wandb.config
        BATCH_SIZE = 16
        if config.hidden_dim in [300, 500]:
            BATCH_SIZE = 32

        wandb.log({'batch_size': BATCH_SIZE})
        HIDDEN_DIM = config['hidden_dim']
        DROP_OUT = config['dropout']
        OPTIMIZER = config['optimizer']
        LEARNING_RATE = config['learning_rate']
        EPOCHS = config['epochs']
        EMBEDDITNG_DIM = config['embedding_dim']
        NUM_LAYERS = config['num_layers']

        Emb = create_vocab(df['Description'], EMBEDDITNG_DIM)

        dev_train_dataset = SentencesDataset(dev_train, Emb)
        dev_validation_dataset = SentencesDataset(dev_validation, Emb)

        collate_fn = Collator(Emb)

        training_dataloader = DataLoader(dev_train_dataset, batch_size=BATCH_SIZE, shuffle=True, collate_fn=collate_fn, pin_memory=True, num_workers=4)
        validation_dataloader = DataLoader(dev_validation_dataset, batch_size=BATCH_SIZE, shuffle=True, collate_fn=collate_fn, pin_memory=True, num_workers=4)

        torch.cuda.empty_cache()

        elmo = ELMO(Emb, HIDDEN_DIM, DROP_OUT, NUM_LAYERS).to(DEVICE)

        optimizer = getattr(torch.optim, OPTIMIZER)(elmo.parameters(), lr=LEARNING_RATE)
        loss_fn = nn.CrossEntropyLoss(ignore_index=Emb.key_to_index['<pad>'])

        train(EPOCHS, elmo, training_dataloader, validation_dataloader, loss_fn, optimizer)

sweep_id = wandb.sweep(cfg, project='ELMO')
wandb.agent(sweep_id, run, count=1)
# wandb.agent(sweep_id, run_sweep, count=20)

Failed to detect the name of this notebook, you can set it manually with the WANDB_NOTEBOOK_NAME environment variable to enable code saving.


Create sweep with ID: vtpegz9o
Sweep URL: https://wandb.ai/shu7bh/ELMO/sweeps/vtpegz9o


[34m[1mwandb[0m: Agent Starting Run: q9tp7d56 with config:
[34m[1mwandb[0m: 	dev_train_len: 5000
[34m[1mwandb[0m: 	dev_validation_len: 1000
[34m[1mwandb[0m: 	dropout: 0
[34m[1mwandb[0m: 	embedding_dim: 50
[34m[1mwandb[0m: 	epochs: 100
[34m[1mwandb[0m: 	hidden_dim: 100
[34m[1mwandb[0m: 	learning_rate: 0.001
[34m[1mwandb[0m: 	num_layers: 2
[34m[1mwandb[0m: 	optimizer: Adam
Failed to detect the name of this notebook, you can set it manually with the WANDB_NOTEBOOK_NAME environment variable to enable code saving.
[34m[1mwandb[0m: Currently logged in as: [33mshu7bh[0m. Use [1m`wandb login --relogin`[0m to force relogin



Epoch 1


T Loss:  6.7581, Avg Loss:  7.4066, Best Loss:     inf, Counter: 0: 100%|██████████| 313/313 [00:07<00:00, 39.36it/s]
V Loss:  6.9564, Avg Loss:  7.0115, Best Loss:     inf, Counter: 0: 100%|██████████| 63/63 [00:00<00:00, 68.21it/s]



Epoch 2


T Loss:  5.7395, Avg Loss:  6.2551, Best Loss:  7.0115, Counter: 0: 100%|██████████| 313/313 [00:07<00:00, 41.37it/s]
V Loss:  5.9191, Avg Loss:  5.9212, Best Loss:  7.0115, Counter: 0: 100%|██████████| 63/63 [00:01<00:00, 60.68it/s]



Epoch 3


T Loss:  5.1992, Avg Loss:  5.2450, Best Loss:  5.9212, Counter: 0: 100%|██████████| 313/313 [00:07<00:00, 41.78it/s]
V Loss:  5.0033, Avg Loss:  5.0460, Best Loss:  5.9212, Counter: 0: 100%|██████████| 63/63 [00:00<00:00, 68.50it/s]



Epoch 4


T Loss:  4.2835, Avg Loss:  4.3918, Best Loss:  5.0460, Counter: 0: 100%|██████████| 313/313 [00:07<00:00, 41.88it/s]
V Loss:  4.7000, Avg Loss:  4.3501, Best Loss:  5.0460, Counter: 0: 100%|██████████| 63/63 [00:00<00:00, 67.05it/s]



Epoch 5


T Loss:  3.6056, Avg Loss:  3.6868, Best Loss:  4.3501, Counter: 0: 100%|██████████| 313/313 [00:07<00:00, 41.51it/s]
V Loss:  4.3410, Avg Loss:  3.8124, Best Loss:  4.3501, Counter: 0: 100%|██████████| 63/63 [00:00<00:00, 67.82it/s]



Epoch 6


T Loss:  3.0700, Avg Loss:  3.1155, Best Loss:  3.8124, Counter: 0: 100%|██████████| 313/313 [00:07<00:00, 40.98it/s]
V Loss:  3.2996, Avg Loss:  3.3724, Best Loss:  3.8124, Counter: 0: 100%|██████████| 63/63 [00:00<00:00, 65.76it/s]



Epoch 7


T Loss:  2.5856, Avg Loss:  2.6525, Best Loss:  3.3724, Counter: 0: 100%|██████████| 313/313 [00:07<00:00, 41.90it/s]
V Loss:  2.9962, Avg Loss:  3.0383, Best Loss:  3.3724, Counter: 0: 100%|██████████| 63/63 [00:00<00:00, 64.67it/s]



Epoch 8


T Loss:  1.9532, Avg Loss:  2.2590, Best Loss:  3.0383, Counter: 0: 100%|██████████| 313/313 [00:07<00:00, 41.49it/s]
V Loss:  2.8662, Avg Loss:  2.7613, Best Loss:  3.0383, Counter: 0: 100%|██████████| 63/63 [00:00<00:00, 65.59it/s]



Epoch 9


T Loss:  1.8429, Avg Loss:  1.9211, Best Loss:  2.7613, Counter: 0: 100%|██████████| 313/313 [00:07<00:00, 41.78it/s]
V Loss:  2.4534, Avg Loss:  2.5285, Best Loss:  2.7613, Counter: 0: 100%|██████████| 63/63 [00:00<00:00, 69.95it/s] 



Epoch 10


T Loss:  1.6146, Avg Loss:  1.6253, Best Loss:  2.5285, Counter: 0: 100%|██████████| 313/313 [00:07<00:00, 42.14it/s]
V Loss:  2.4612, Avg Loss:  2.3457, Best Loss:  2.5285, Counter: 0: 100%|██████████| 63/63 [00:00<00:00, 70.93it/s]



Epoch 11


T Loss:  1.4413, Avg Loss:  1.3690, Best Loss:  2.3457, Counter: 0: 100%|██████████| 313/313 [00:07<00:00, 41.76it/s]
V Loss:  3.1584, Avg Loss:  2.1916, Best Loss:  2.3457, Counter: 0: 100%|██████████| 63/63 [00:00<00:00, 64.86it/s]



Epoch 12


T Loss:  1.1570, Avg Loss:  1.1463, Best Loss:  2.1916, Counter: 0: 100%|██████████| 313/313 [00:07<00:00, 41.17it/s]
V Loss:  2.1242, Avg Loss:  2.0472, Best Loss:  2.1916, Counter: 0: 100%|██████████| 63/63 [00:00<00:00, 68.52it/s]



Epoch 13


T Loss:  0.9717, Avg Loss:  0.9520, Best Loss:  2.0472, Counter: 0: 100%|██████████| 313/313 [00:07<00:00, 41.28it/s]
V Loss:  1.4896, Avg Loss:  1.9357, Best Loss:  2.0472, Counter: 0: 100%|██████████| 63/63 [00:00<00:00, 68.28it/s]



Epoch 14


T Loss:  0.9510, Avg Loss:  0.7834, Best Loss:  1.9357, Counter: 0: 100%|██████████| 313/313 [00:07<00:00, 41.36it/s]
V Loss:  1.6285, Avg Loss:  1.8426, Best Loss:  1.9357, Counter: 0: 100%|██████████| 63/63 [00:00<00:00, 64.49it/s]



Epoch 15


T Loss:  0.6702, Avg Loss:  0.6373, Best Loss:  1.8426, Counter: 0: 100%|██████████| 313/313 [00:07<00:00, 41.38it/s]
V Loss:  1.6279, Avg Loss:  1.7734, Best Loss:  1.8426, Counter: 0: 100%|██████████| 63/63 [00:01<00:00, 60.83it/s]



Epoch 16


T Loss:  0.3987, Avg Loss:  0.5140, Best Loss:  1.7734, Counter: 0: 100%|██████████| 313/313 [00:07<00:00, 41.85it/s]
V Loss:  1.9499, Avg Loss:  1.7059, Best Loss:  1.7734, Counter: 0: 100%|██████████| 63/63 [00:00<00:00, 63.84it/s]



Epoch 17


T Loss:  0.4260, Avg Loss:  0.4107, Best Loss:  1.7059, Counter: 0: 100%|██████████| 313/313 [00:07<00:00, 41.51it/s]
V Loss:  1.7573, Avg Loss:  1.6552, Best Loss:  1.7059, Counter: 0: 100%|██████████| 63/63 [00:00<00:00, 63.53it/s]



Epoch 18


T Loss:  0.2938, Avg Loss:  0.3272, Best Loss:  1.6552, Counter: 0: 100%|██████████| 313/313 [00:07<00:00, 41.91it/s]
V Loss:  1.5524, Avg Loss:  1.6092, Best Loss:  1.6552, Counter: 0: 100%|██████████| 63/63 [00:00<00:00, 67.67it/s]



Epoch 19


T Loss:  0.2490, Avg Loss:  0.2601, Best Loss:  1.6092, Counter: 0: 100%|██████████| 313/313 [00:07<00:00, 41.60it/s]
V Loss:  1.6494, Avg Loss:  1.5827, Best Loss:  1.6092, Counter: 0: 100%|██████████| 63/63 [00:00<00:00, 66.81it/s]



Epoch 20


T Loss:  0.1680, Avg Loss:  0.2085, Best Loss:  1.5827, Counter: 0: 100%|██████████| 313/313 [00:07<00:00, 40.95it/s]
V Loss:  1.3169, Avg Loss:  1.5682, Best Loss:  1.5827, Counter: 0: 100%|██████████| 63/63 [00:00<00:00, 66.64it/s]



Epoch 21


T Loss:  0.1940, Avg Loss:  0.1669, Best Loss:  1.5682, Counter: 0: 100%|██████████| 313/313 [00:07<00:00, 41.17it/s]
V Loss:  1.9826, Avg Loss:  1.5447, Best Loss:  1.5682, Counter: 0: 100%|██████████| 63/63 [00:00<00:00, 65.40it/s]



Epoch 22


T Loss:  0.1381, Avg Loss:  0.1356, Best Loss:  1.5447, Counter: 0: 100%|██████████| 313/313 [00:07<00:00, 41.12it/s]
V Loss:  1.3955, Avg Loss:  1.5189, Best Loss:  1.5447, Counter: 0: 100%|██████████| 63/63 [00:01<00:00, 62.33it/s]



Epoch 23


T Loss:  0.1120, Avg Loss:  0.1110, Best Loss:  1.5189, Counter: 0: 100%|██████████| 313/313 [00:07<00:00, 41.32it/s]
V Loss:  1.0447, Avg Loss:  1.5013, Best Loss:  1.5189, Counter: 0: 100%|██████████| 63/63 [00:00<00:00, 68.79it/s]



Epoch 24


T Loss:  0.0880, Avg Loss:  0.0915, Best Loss:  1.5013, Counter: 0: 100%|██████████| 313/313 [00:07<00:00, 41.96it/s]
V Loss:  1.1320, Avg Loss:  1.4900, Best Loss:  1.5013, Counter: 0: 100%|██████████| 63/63 [00:00<00:00, 65.33it/s]



Epoch 25


T Loss:  0.0942, Avg Loss:  0.0760, Best Loss:  1.4900, Counter: 0: 100%|██████████| 313/313 [00:07<00:00, 41.69it/s]
V Loss:  1.9246, Avg Loss:  1.4921, Best Loss:  1.4900, Counter: 0: 100%|██████████| 63/63 [00:00<00:00, 65.85it/s]



Epoch 26


T Loss:  0.0583, Avg Loss:  0.0636, Best Loss:  1.4900, Counter: 1: 100%|██████████| 313/313 [00:07<00:00, 41.17it/s]
V Loss:  1.4201, Avg Loss:  1.4821, Best Loss:  1.4900, Counter: 1: 100%|██████████| 63/63 [00:00<00:00, 68.47it/s]



Epoch 27


T Loss:  0.0432, Avg Loss:  0.0535, Best Loss:  1.4821, Counter: 0: 100%|██████████| 313/313 [00:07<00:00, 42.07it/s]
V Loss:  2.0941, Avg Loss:  1.4833, Best Loss:  1.4821, Counter: 0: 100%|██████████| 63/63 [00:00<00:00, 64.44it/s]



Epoch 28


T Loss:  0.0554, Avg Loss:  0.0456, Best Loss:  1.4821, Counter: 1: 100%|██████████| 313/313 [00:07<00:00, 41.82it/s]
V Loss:  1.7835, Avg Loss:  1.4741, Best Loss:  1.4821, Counter: 1: 100%|██████████| 63/63 [00:00<00:00, 66.33it/s]



Epoch 29


T Loss:  0.0421, Avg Loss:  0.0382, Best Loss:  1.4741, Counter: 0: 100%|██████████| 313/313 [00:07<00:00, 41.16it/s]
V Loss:  1.6333, Avg Loss:  1.4712, Best Loss:  1.4741, Counter: 0: 100%|██████████| 63/63 [00:00<00:00, 65.84it/s]



Epoch 30


T Loss:  0.0275, Avg Loss:  0.0325, Best Loss:  1.4712, Counter: 0: 100%|██████████| 313/313 [00:07<00:00, 42.10it/s]
V Loss:  2.2682, Avg Loss:  1.4728, Best Loss:  1.4712, Counter: 0: 100%|██████████| 63/63 [00:00<00:00, 66.41it/s]



Epoch 31


T Loss:  0.0223, Avg Loss:  0.0278, Best Loss:  1.4712, Counter: 1: 100%|██████████| 313/313 [00:07<00:00, 41.09it/s]
V Loss:  1.7743, Avg Loss:  1.4673, Best Loss:  1.4712, Counter: 1: 100%|██████████| 63/63 [00:00<00:00, 66.23it/s]



Epoch 32


T Loss:  0.0241, Avg Loss:  0.0238, Best Loss:  1.4673, Counter: 0: 100%|██████████| 313/313 [00:07<00:00, 41.32it/s]
V Loss:  1.4427, Avg Loss:  1.4636, Best Loss:  1.4673, Counter: 0: 100%|██████████| 63/63 [00:00<00:00, 65.84it/s]



Epoch 33


T Loss:  0.0180, Avg Loss:  0.0212, Best Loss:  1.4636, Counter: 0: 100%|██████████| 313/313 [00:07<00:00, 41.48it/s]
V Loss:  1.1268, Avg Loss:  1.4585, Best Loss:  1.4636, Counter: 0: 100%|██████████| 63/63 [00:00<00:00, 66.52it/s]



Epoch 34


T Loss:  0.0179, Avg Loss:  0.0178, Best Loss:  1.4585, Counter: 0: 100%|██████████| 313/313 [00:07<00:00, 41.48it/s]
V Loss:  1.5209, Avg Loss:  1.4550, Best Loss:  1.4585, Counter: 0: 100%|██████████| 63/63 [00:00<00:00, 64.23it/s]



Epoch 35


T Loss:  0.0162, Avg Loss:  0.0153, Best Loss:  1.4550, Counter: 0: 100%|██████████| 313/313 [00:07<00:00, 41.18it/s]
V Loss:  1.6929, Avg Loss:  1.4596, Best Loss:  1.4550, Counter: 0: 100%|██████████| 63/63 [00:00<00:00, 65.13it/s]



Epoch 36


T Loss:  0.0170, Avg Loss:  0.0134, Best Loss:  1.4550, Counter: 1: 100%|██████████| 313/313 [00:07<00:00, 41.45it/s]
V Loss:  1.9855, Avg Loss:  1.4529, Best Loss:  1.4550, Counter: 1: 100%|██████████| 63/63 [00:00<00:00, 66.61it/s]



Epoch 37


T Loss:  0.0112, Avg Loss:  0.0115, Best Loss:  1.4529, Counter: 0: 100%|██████████| 313/313 [00:07<00:00, 40.92it/s]
V Loss:  1.4470, Avg Loss:  1.4547, Best Loss:  1.4529, Counter: 0: 100%|██████████| 63/63 [00:01<00:00, 62.73it/s]



Epoch 38


T Loss:  0.0223, Avg Loss:  0.0104, Best Loss:  1.4529, Counter: 1: 100%|██████████| 313/313 [00:07<00:00, 40.96it/s]
V Loss:  1.8850, Avg Loss:  1.4748, Best Loss:  1.4529, Counter: 1: 100%|██████████| 63/63 [00:00<00:00, 66.55it/s]



Epoch 39


T Loss:  0.0101, Avg Loss:  0.0118, Best Loss:  1.4529, Counter: 2: 100%|██████████| 313/313 [00:07<00:00, 41.02it/s]
V Loss:  1.6061, Avg Loss:  1.4560, Best Loss:  1.4529, Counter: 2: 100%|██████████| 63/63 [00:00<00:00, 68.64it/s]



Epoch 40


T Loss:  0.0059, Avg Loss:  0.0074, Best Loss:  1.4529, Counter: 3: 100%|██████████| 313/313 [00:07<00:00, 41.42it/s]
V Loss:  0.4766, Avg Loss:  1.4452, Best Loss:  1.4529, Counter: 3: 100%|██████████| 63/63 [00:00<00:00, 65.16it/s]



Epoch 41


T Loss:  0.0082, Avg Loss:  0.0064, Best Loss:  1.4452, Counter: 0: 100%|██████████| 313/313 [00:07<00:00, 41.74it/s]
V Loss:  0.9831, Avg Loss:  1.4531, Best Loss:  1.4452, Counter: 0: 100%|██████████| 63/63 [00:00<00:00, 65.45it/s]



Epoch 42


T Loss:  0.0051, Avg Loss:  0.0057, Best Loss:  1.4452, Counter: 1: 100%|██████████| 313/313 [00:07<00:00, 41.33it/s]
V Loss:  1.5235, Avg Loss:  1.4584, Best Loss:  1.4452, Counter: 1: 100%|██████████| 63/63 [00:00<00:00, 64.39it/s]



Epoch 43


T Loss:  0.0068, Avg Loss:  0.0053, Best Loss:  1.4452, Counter: 2: 100%|██████████| 313/313 [00:07<00:00, 41.19it/s]
V Loss:  1.5720, Avg Loss:  1.4635, Best Loss:  1.4452, Counter: 2: 100%|██████████| 63/63 [00:00<00:00, 67.17it/s]



Epoch 44


T Loss:  0.0065, Avg Loss:  0.0101, Best Loss:  1.4452, Counter: 3: 100%|██████████| 313/313 [00:07<00:00, 41.37it/s]
V Loss:  1.1863, Avg Loss:  1.4571, Best Loss:  1.4452, Counter: 3: 100%|██████████| 63/63 [00:00<00:00, 68.57it/s]


0,1
batch_size,▁
loss,▁
train_loss,█▇▆▅▄▄▄▃▃▃▂▂▂▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
validation_loss,█▇▆▅▄▃▃▃▂▂▂▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁

0,1
batch_size,16.0
loss,1.44521
train_loss,0.01009
validation_loss,1.45711


In [22]:
# import wandb
# wrun = wandb.init()
# artifact = wrun.use_artifact('shu7bh/ELMO/best_model_3y4z0vh5:v0', type='model')
# cfg = wrun.config
# artifact_dir = artifact.download()

wrun = wandb.Api().artifact('shu7bh/ELMO/best_model_q9tp7d56:v0', type='model')
artifact_dir = wrun.download()
# shu7bh/ELMO/runs/q9tp7d56
cfg = wandb.Api().run('shu7bh/ELMO/q9tp7d56').config

[34m[1mwandb[0m: Downloading large artifact best_model_q9tp7d56:v0, 54.45MB. 1 files... 
[34m[1mwandb[0m:   1 of 1 files downloaded.  
Done. 0:0:1.2


In [23]:
import tqdm

def run_epoch(model, dataloader, loss_fn):
    epoch_loss = []

    pbar = tqdm.tqdm(dataloader)

    for X, Y, X_lengths in pbar:
        X = X.to(DEVICE)
        Y = Y.to(DEVICE)

        Y_pred = model(X, X_lengths)
        Y_pred = Y_pred.reshape(-1, Y_pred.shape[2])

        Y = Y.reshape(-1)

        loss = loss_fn(Y_pred, Y)
        epoch_loss.append(loss.item())

        pbar.set_description(f'Loss: {loss.item():7.4f}, Avg Loss: {np.mean(epoch_loss):7.4f}')

    return np.mean(epoch_loss)

In [24]:
def validate(elmo, validation_dataloader, loss_fn):
    with torch.no_grad():
        elmo.eval()
        epoch_loss = run_epoch(elmo, validation_dataloader, loss_fn)
        print(f'Validation Loss: {epoch_loss:7.4f}')

In [25]:
Emb = create_vocab(df['Description'], cfg['embedding_dim'])

# dev_train_dataset = SentencesDataset(dev_train, Emb)
dev_validation_dataset = SentencesDataset(dev_validation, Emb)

collate_fn = Collator(Emb)

# training_dataloader = DataLoader(dev_train_dataset, batch_size=BATCH_SIZE, shuffle=True, collate_fn=collate_fn, pin_memory=True, num_workers=4)
validation_dataloader = DataLoader(dev_validation_dataset, batch_size=16, shuffle=True, collate_fn=collate_fn, pin_memory=True, num_workers=4)
elmo = ELMO(Emb, cfg['hidden_dim'], cfg['dropout'], cfg['num_layers']).to(DEVICE)

elmo.load_state_dict(torch.load(os.path.join(artifact_dir, 'best_model.pth')))

validate(elmo, validation_dataloader, nn.CrossEntropyLoss(ignore_index=Emb.key_to_index['<pad>']))

Loss:  1.8123, Avg Loss:  1.4573: 100%|██████████| 63/63 [00:00<00:00, 70.66it/s]

Validation Loss:  1.4573



