In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import pandas as pd
import numpy as np
import torch
import pdb

from pathlib import Path
from torch import nn
from torch.nn import functional as F
from torch import optim
from torch.utils.data import DataLoader
from torch.utils.data import Dataset

In [3]:
from ignite.engine import Events, create_supervised_trainer, create_supervised_evaluator
from ignite.metrics import Accuracy, Loss
from ignite.contrib.handlers import ProgressBar

In [4]:
from yelp.dataset import ProjectDataset

In [8]:
def set_all_seed(seed, cuda):
  np.random.seed(seed)
  torch.manual_seed(seed)
  if cuda:
    torch.cuda.manual_seed(seed)

In [5]:
path = Path('./data/yelp')
review_csv = path/'reviews_with_splits_lite.csv'
scratch = path/'scratch'
vectorizer_path = scratch/'vectorizer.json'

df = pd.read_csv(review_csv)

In [6]:
# train_ds = ProjectDataset.load_data_and_create_vectorizer(df.loc[df['split'] == 'train'])
# train_ds.save_vectorizer(vectorizer_path)

In [7]:
train_df = df.loc[df['split'] == 'train']
train_ds = ProjectDataset.load_data_and_vectorizer(train_df, vectorizer_path)
vectorizer = train_ds.get_vectorizer()
train_dl = DataLoader(train_ds, batch_size=128, shuffle=True, drop_last=True)

val_df = df.loc[df['split'] == 'val']
val_ds = ProjectDataset.load_data_and_vectorizer(val_df, vectorizer_path)
val_dl = DataLoader(val_ds, batch_size=128, shuffle=True, drop_last=True)

In [31]:
class ReviewClassifier(nn.Module):
  def __init__(self, num_features):
    super(ReviewClassifier, self).__init__()
    self.fc1 = nn.Linear(in_features=num_features, out_features=1)
    
  def forward(self, x_in, apply_sigmoid=False):
    y_out = self.fc1(x_in).squeeze(1)
    if apply_sigmoid:
      y_out = torch.sigmoid(y_out)
    return y_out

In [39]:
classifier = ReviewClassifier(num_features=len((vectorizer).review_vocab))
optimizer = optim.Adam(classifier.parameters(), lr=0.001)
scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer=optimizer, mode='min', factor=0.5, patience=1)
loss_func = nn.BCEWithLogitsLoss()

In [40]:
def bce_logits_wrapper(output):
    y_pred, y = output
    y_pred = (torch.sigmoid(y_pred) > 0.5).long()
    return y_pred, y

In [41]:
trainer = create_supervised_trainer(classifier, optimizer, loss_func, device='cuda:3')
evaluator = create_supervised_evaluator(classifier, metrics=\
                                        {'accuracy':Accuracy(bce_logits_wrapper),\
                                         'bce': Loss(loss_func)}, device='cuda:3')

In [42]:
pbar = ProgressBar(persist=True)
pbar.attach(trainer, output_transform=lambda x: {'loss': x})

In [43]:
@trainer.on(Events.EPOCH_COMPLETED)
def log_training_results(engine):
  evaluator.run(train_dl)
  metrics = evaluator.state.metrics
  pbar.log_message(f"Training Results - Epoch: {engine.state.epoch}\
                    Avg accuracy: {metrics['accuracy']:0.2f}\
                    Avg loss: {metrics['bce']:0.2f}")
                   
@trainer.on(Events.EPOCH_COMPLETED)
def log_validation_results(engine):
  evaluator.run(val_dl)
  metrics = evaluator.state.metrics
  pbar.log_message(f"Validation Results - Epoch: {engine.state.epoch}\
                    Avg accuracy: {metrics['accuracy']:0.2f}\
                    Avg loss: {metrics['bce']:0.2f}")

In [44]:
@evaluator.on(Events.COMPLETED)
def test_eval(engine):
  print("evaluator.on()")

In [45]:
trainer.run(train_dl, max_epochs=2)



[0/306]   0%|           [00:00<?][A[A

Epoch [1/2]: [0/306]   0%|           [00:00<?][A[A

Epoch [1/2]: [0/306]   0%|          , loss=6.88e-01 [00:00<?][A[A

Epoch [1/2]: [1/306]   0%|          , loss=6.88e-01 [00:00<00:07][A[A

Epoch [1/2]: [1/306]   0%|          , loss=6.88e-01 [00:00<00:08][A[A

Epoch [1/2]: [2/306]   1%|          , loss=6.88e-01 [00:00<00:07][A[A

Epoch [1/2]: [2/306]   1%|          , loss=6.85e-01 [00:00<00:07][A[A

Epoch [1/2]: [3/306]   1%|          , loss=6.85e-01 [00:00<00:07][A[A

Epoch [1/2]: [3/306]   1%|          , loss=6.86e-01 [00:00<00:07][A[A

Epoch [1/2]: [4/306]   1%|▏         , loss=6.86e-01 [00:00<00:07][A[A

Epoch [1/2]: [4/306]   1%|▏         , loss=6.79e-01 [00:00<00:07][A[A

Epoch [1/2]: [5/306]   2%|▏         , loss=6.79e-01 [00:00<00:07][A[A

Epoch [1/2]: [5/306]   2%|▏         , loss=6.71e-01 [00:00<00:07][A[A

Epoch [1/2]: [6/306]   2%|▏         , loss=6.71e-01 [00:00<00:06][A[A

Epoch [1/2]: [6/306]   2%|▏    

Epoch [1/2]: [50/306]  16%|█▋        , loss=5.73e-01 [00:01<00:06][A[A

Epoch [1/2]: [51/306]  17%|█▋        , loss=5.73e-01 [00:01<00:06][A[A

Epoch [1/2]: [51/306]  17%|█▋        , loss=5.73e-01 [00:01<00:06][A[A

Epoch [1/2]: [51/306]  17%|█▋        , loss=5.67e-01 [00:01<00:06][A[A

Epoch [1/2]: [52/306]  17%|█▋        , loss=5.67e-01 [00:01<00:06][A[A

Epoch [1/2]: [52/306]  17%|█▋        , loss=5.88e-01 [00:01<00:06][A[A

Epoch [1/2]: [53/306]  17%|█▋        , loss=5.88e-01 [00:01<00:06][A[A

Epoch [1/2]: [53/306]  17%|█▋        , loss=5.77e-01 [00:01<00:06][A[A

Epoch [1/2]: [54/306]  18%|█▊        , loss=5.77e-01 [00:01<00:06][A[A

Epoch [1/2]: [54/306]  18%|█▊        , loss=5.69e-01 [00:01<00:06][A[A

Epoch [1/2]: [55/306]  18%|█▊        , loss=5.69e-01 [00:01<00:06][A[A

Epoch [1/2]: [55/306]  18%|█▊        , loss=5.67e-01 [00:01<00:06][A[A

Epoch [1/2]: [56/306]  18%|█▊        , loss=5.67e-01 [00:01<00:05][A[A

Epoch [1/2]: [56/306]  18%|█▊        ,

Epoch [1/2]: [100/306]  33%|███▎      , loss=5.00e-01 [00:02<00:05][A[A

Epoch [1/2]: [101/306]  33%|███▎      , loss=5.00e-01 [00:02<00:04][A[A

Epoch [1/2]: [101/306]  33%|███▎      , loss=5.00e-01 [00:02<00:04][A[A

Epoch [1/2]: [101/306]  33%|███▎      , loss=4.82e-01 [00:02<00:04][A[A

Epoch [1/2]: [102/306]  33%|███▎      , loss=4.82e-01 [00:02<00:04][A[A

Epoch [1/2]: [102/306]  33%|███▎      , loss=5.32e-01 [00:02<00:04][A[A

Epoch [1/2]: [103/306]  34%|███▎      , loss=5.32e-01 [00:02<00:04][A[A

Epoch [1/2]: [103/306]  34%|███▎      , loss=4.90e-01 [00:02<00:04][A[A

Epoch [1/2]: [104/306]  34%|███▍      , loss=4.90e-01 [00:02<00:04][A[A

Epoch [1/2]: [104/306]  34%|███▍      , loss=4.77e-01 [00:02<00:04][A[A

Epoch [1/2]: [105/306]  34%|███▍      , loss=4.77e-01 [00:02<00:04][A[A

Epoch [1/2]: [105/306]  34%|███▍      , loss=4.85e-01 [00:02<00:04][A[A

Epoch [1/2]: [106/306]  35%|███▍      , loss=4.85e-01 [00:02<00:04][A[A

Epoch [1/2]: [106/306]  3

Epoch [1/2]: [150/306]  49%|████▉     , loss=4.20e-01 [00:03<00:03][A[A

Epoch [1/2]: [150/306]  49%|████▉     , loss=4.67e-01 [00:03<00:03][A[A

Epoch [1/2]: [151/306]  49%|████▉     , loss=4.67e-01 [00:03<00:03][A[A

Epoch [1/2]: [151/306]  49%|████▉     , loss=4.67e-01 [00:03<00:03][A[A

Epoch [1/2]: [151/306]  49%|████▉     , loss=4.62e-01 [00:03<00:03][A[A

Epoch [1/2]: [152/306]  50%|████▉     , loss=4.62e-01 [00:03<00:03][A[A

Epoch [1/2]: [152/306]  50%|████▉     , loss=4.96e-01 [00:03<00:03][A[A

Epoch [1/2]: [153/306]  50%|█████     , loss=4.96e-01 [00:03<00:03][A[A

Epoch [1/2]: [153/306]  50%|█████     , loss=4.28e-01 [00:03<00:03][A[A

Epoch [1/2]: [154/306]  50%|█████     , loss=4.28e-01 [00:03<00:03][A[A

Epoch [1/2]: [154/306]  50%|█████     , loss=4.50e-01 [00:03<00:03][A[A

Epoch [1/2]: [155/306]  51%|█████     , loss=4.50e-01 [00:03<00:03][A[A

Epoch [1/2]: [155/306]  51%|█████     , loss=4.28e-01 [00:03<00:03][A[A

Epoch [1/2]: [156/306]  5

Epoch [1/2]: [199/306]  65%|██████▌   , loss=4.51e-01 [00:04<00:02][A[A

Epoch [1/2]: [200/306]  65%|██████▌   , loss=4.51e-01 [00:04<00:02][A[A

Epoch [1/2]: [200/306]  65%|██████▌   , loss=4.18e-01 [00:04<00:02][A[A

Epoch [1/2]: [201/306]  66%|██████▌   , loss=4.18e-01 [00:04<00:02][A[A

Epoch [1/2]: [201/306]  66%|██████▌   , loss=4.18e-01 [00:04<00:02][A[A

Epoch [1/2]: [201/306]  66%|██████▌   , loss=4.37e-01 [00:04<00:02][A[A

Epoch [1/2]: [202/306]  66%|██████▌   , loss=4.37e-01 [00:04<00:02][A[A

Epoch [1/2]: [202/306]  66%|██████▌   , loss=4.35e-01 [00:04<00:02][A[A

Epoch [1/2]: [203/306]  66%|██████▋   , loss=4.35e-01 [00:04<00:02][A[A

Epoch [1/2]: [203/306]  66%|██████▋   , loss=4.18e-01 [00:04<00:02][A[A

Epoch [1/2]: [204/306]  67%|██████▋   , loss=4.18e-01 [00:04<00:02][A[A

Epoch [1/2]: [204/306]  67%|██████▋   , loss=3.87e-01 [00:04<00:02][A[A

Epoch [1/2]: [205/306]  67%|██████▋   , loss=3.87e-01 [00:04<00:02][A[A

Epoch [1/2]: [205/306]  6

Epoch [1/2]: [249/306]  81%|████████▏ , loss=4.38e-01 [00:06<00:01][A[A

Epoch [1/2]: [249/306]  81%|████████▏ , loss=4.38e-01 [00:06<00:01][A[A

Epoch [1/2]: [250/306]  82%|████████▏ , loss=4.38e-01 [00:06<00:01][A[A

Epoch [1/2]: [250/306]  82%|████████▏ , loss=4.10e-01 [00:06<00:01][A[A

Epoch [1/2]: [251/306]  82%|████████▏ , loss=4.10e-01 [00:06<00:01][A[A

Epoch [1/2]: [251/306]  82%|████████▏ , loss=4.10e-01 [00:06<00:01][A[A

Epoch [1/2]: [251/306]  82%|████████▏ , loss=4.19e-01 [00:06<00:01][A[A

Epoch [1/2]: [252/306]  82%|████████▏ , loss=4.19e-01 [00:06<00:01][A[A

Epoch [1/2]: [252/306]  82%|████████▏ , loss=3.79e-01 [00:06<00:01][A[A

Epoch [1/2]: [253/306]  83%|████████▎ , loss=3.79e-01 [00:06<00:01][A[A

Epoch [1/2]: [253/306]  83%|████████▎ , loss=4.08e-01 [00:06<00:01][A[A

Epoch [1/2]: [254/306]  83%|████████▎ , loss=4.08e-01 [00:06<00:01][A[A

Epoch [1/2]: [254/306]  83%|████████▎ , loss=3.74e-01 [00:06<00:01][A[A

Epoch [1/2]: [255/306]  8

Epoch [1/2]: [298/306]  97%|█████████▋, loss=3.99e-01 [00:07<00:00][A[A

Epoch [1/2]: [299/306]  98%|█████████▊, loss=3.99e-01 [00:07<00:00][A[A

Epoch [1/2]: [299/306]  98%|█████████▊, loss=4.02e-01 [00:07<00:00][A[A

Epoch [1/2]: [300/306]  98%|█████████▊, loss=4.02e-01 [00:07<00:00][A[A

Epoch [1/2]: [300/306]  98%|█████████▊, loss=3.85e-01 [00:07<00:00][A[A

Epoch [1/2]: [301/306]  98%|█████████▊, loss=3.85e-01 [00:07<00:00][A[A

Epoch [1/2]: [301/306]  98%|█████████▊, loss=3.85e-01 [00:07<00:00][A[A

Epoch [1/2]: [301/306]  98%|█████████▊, loss=3.79e-01 [00:07<00:00][A[A

Epoch [1/2]: [302/306]  99%|█████████▊, loss=3.79e-01 [00:07<00:00][A[A

Epoch [1/2]: [302/306]  99%|█████████▊, loss=3.49e-01 [00:07<00:00][A[A

Epoch [1/2]: [303/306]  99%|█████████▉, loss=3.49e-01 [00:07<00:00][A[A

Epoch [1/2]: [303/306]  99%|█████████▉, loss=3.77e-01 [00:07<00:00][A[A

Epoch [1/2]: [304/306]  99%|█████████▉, loss=3.77e-01 [00:07<00:00][A[A

Epoch [1/2]: [304/306]  9

KeyError: 'metrics not found in engine.state.metrics'