In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import pandas as pd
import numpy as np
import torch
import pdb

from pathlib import Path
from torch import nn
from torch.nn import functional as F
from torch import optim
from torch.utils.data import DataLoader
from torch.utils.data import Dataset

In [3]:
from ignite.metrics import Accuracy, Loss
from ignite.contrib.handlers import ProgressBar

In [4]:
from yelp.dataset import ProjectDataset
from yelp.trainer import YelpTrainer
from yelp.model import Classifier
from yelp.args import args

In [5]:
path = Path('./data/yelp')
review_csv = path/args.sample_file
scratch = path/args.workdir_name
vectorizer_path = scratch/args.vectorizer_fname
args.save_dir = scratch

df = pd.read_csv(review_csv)

In [6]:
args

Namespace(batch_size=1024, checkpointer_name='classifier', checkpointer_prefix='yelp', device='cuda:3', early_stopping_criteria=5, frequency_cutoff=25, learning_rate=0.001, num_epochs=100, sample_file='reviews_with_splits_lite.csv', save_dir=PosixPath('data/yelp/scratch'), save_every=2, save_total=5, vectorizer_fname='vectorizer.json', workdir_name='scratch')

In [7]:
# train_ds = ProjectDataset.load_data_and_create_vectorizer(df.loc[df['split'] == 'train'])
# train_ds.save_vectorizer(vectorizer_path)

In [8]:
train_df = df.loc[df['split'] == 'train']
train_ds = ProjectDataset.load_data_and_vectorizer(train_df, vectorizer_path)
vectorizer = train_ds.get_vectorizer()
train_dl = DataLoader(train_ds, args.batch_size, shuffle=True, drop_last=True)

val_df = df.loc[df['split'] == 'val']
val_ds = ProjectDataset.load_data_and_vectorizer(val_df, vectorizer_path)
val_dl = DataLoader(val_ds, args.batch_size, shuffle=True, drop_last=True)

In [9]:
classifier = Classifier(num_features=len((vectorizer).review_vocab))
optimizer = optim.Adam(classifier.parameters(), lr=args.learning_rate)
scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer=optimizer, mode='min', factor=0.5, patience=1)
loss_func = nn.BCEWithLogitsLoss()

In [10]:
def bce_logits_wrapper(output):
    y_pred, y = output
    y_pred = (torch.sigmoid(y_pred) > 0.5).long()
    return y_pred, y

In [11]:
pbar = ProgressBar(persist=True)
metrics = {'accuracy': Accuracy(bce_logits_wrapper), 'loss': Loss(loss_func)}

In [12]:
yelp_trainer = YelpTrainer(classifier, optimizer, loss_func, train_dl, val_dl, args, pbar, metrics)

In [13]:
yelp_trainer.run()

Epoch [1/100]: [38/38] 100%|██████████, loss=6.48e-01 [00:06<00:00]
Epoch [2/100]: [0/38]   0%|          , loss=5.57e-01 [00:00<?]

Epoch: 1
Training - Loss: 0.552, Accuracy: 0.839
Validation - Loss: 0.559, Accuracy: 0.829


Epoch [2/100]: [38/38] 100%|██████████, loss=5.31e-01 [00:06<00:00]
Epoch [3/100]: [0/38]   0%|          , loss=4.74e-01 [00:00<?]

Epoch: 2
Training - Loss: 0.477, Accuracy: 0.857
Validation - Loss: 0.487, Accuracy: 0.846


Epoch [3/100]: [38/38] 100%|██████████, loss=4.62e-01 [00:06<00:00]
Epoch [4/100]: [0/38]   0%|          , loss=4.34e-01 [00:00<?]

Epoch: 3
Training - Loss: 0.429, Accuracy: 0.874
Validation - Loss: 0.441, Accuracy: 0.860


Epoch [4/100]: [38/38] 100%|██████████, loss=4.21e-01 [00:06<00:00]
Epoch [5/100]: [0/38]   0%|          , loss=3.92e-01 [00:00<?]

Epoch: 4
Training - Loss: 0.394, Accuracy: 0.885
Validation - Loss: 0.408, Accuracy: 0.874


Epoch [5/100]: [38/38] 100%|██████████, loss=3.86e-01 [00:06<00:00]
Epoch [6/100]: [0/38]   0%|          , loss=3.55e-01 [00:00<?]

Epoch: 5
Training - Loss: 0.367, Accuracy: 0.892
Validation - Loss: 0.382, Accuracy: 0.883


Epoch [6/100]: [38/38] 100%|██████████, loss=3.55e-01 [00:06<00:00]


Epoch: 6
Training - Loss: 0.345, Accuracy: 0.901
Validation - Loss: 0.361, Accuracy: 0.890


Epoch [7/100]: [38/38] 100%|██████████, loss=3.39e-01 [00:06<00:00]
Epoch [8/100]: [0/38]   0%|          , loss=3.36e-01 [00:00<?]

Epoch: 7
Training - Loss: 0.327, Accuracy: 0.905
Validation - Loss: 0.345, Accuracy: 0.893


Epoch [8/100]: [38/38] 100%|██████████, loss=3.27e-01 [00:06<00:00]
Epoch [9/100]: [0/38]   0%|          , loss=3.23e-01 [00:00<?]

Epoch: 8
Training - Loss: 0.312, Accuracy: 0.910
Validation - Loss: 0.331, Accuracy: 0.898


Epoch [9/100]: [38/38] 100%|██████████, loss=3.14e-01 [00:06<00:00]
Epoch [10/100]: [0/38]   0%|          , loss=3.23e-01 [00:00<?]

Epoch: 9
Training - Loss: 0.299, Accuracy: 0.913
Validation - Loss: 0.319, Accuracy: 0.900


Epoch [10/100]: [38/38] 100%|██████████, loss=3.07e-01 [00:06<00:00]
Epoch [11/100]: [0/38]   0%|          , loss=2.88e-01 [00:00<?]

Epoch: 10
Training - Loss: 0.288, Accuracy: 0.916
Validation - Loss: 0.309, Accuracy: 0.902


Epoch [11/100]: [38/38] 100%|██████████, loss=2.85e-01 [00:06<00:00]
Epoch [12/100]: [0/38]   0%|          , loss=2.84e-01 [00:00<?]

Epoch: 11
Training - Loss: 0.278, Accuracy: 0.918
Validation - Loss: 0.301, Accuracy: 0.902


Epoch [12/100]: [38/38] 100%|██████████, loss=2.78e-01 [00:06<00:00]
Epoch [13/100]: [0/38]   0%|          , loss=2.72e-01 [00:00<?]

Epoch: 12
Training - Loss: 0.269, Accuracy: 0.921
Validation - Loss: 0.293, Accuracy: 0.906


Epoch [13/100]: [38/38] 100%|██████████, loss=2.69e-01 [00:06<00:00]
Epoch [14/100]: [0/38]   0%|          , loss=2.70e-01 [00:00<?]

Epoch: 13
Training - Loss: 0.261, Accuracy: 0.922
Validation - Loss: 0.285, Accuracy: 0.907


Epoch [14/100]: [38/38] 100%|██████████, loss=2.64e-01 [00:06<00:00]
Epoch [15/100]: [0/38]   0%|          , loss=2.73e-01 [00:00<?]

Epoch: 14
Training - Loss: 0.254, Accuracy: 0.924
Validation - Loss: 0.280, Accuracy: 0.908


Epoch [15/100]: [38/38] 100%|██████████, loss=2.62e-01 [00:06<00:00]
Epoch [16/100]: [0/38]   0%|          , loss=2.46e-01 [00:00<?]

Epoch: 15
Training - Loss: 0.248, Accuracy: 0.925
Validation - Loss: 0.275, Accuracy: 0.908


Epoch [16/100]: [38/38] 100%|██████████, loss=2.46e-01 [00:06<00:00]


Epoch: 16
Training - Loss: 0.242, Accuracy: 0.927
Validation - Loss: 0.270, Accuracy: 0.909


Epoch [17/100]: [38/38] 100%|██████████, loss=2.45e-01 [00:06<00:00]
Epoch [18/100]: [0/38]   0%|          , loss=2.42e-01 [00:00<?]

Epoch: 17
Training - Loss: 0.237, Accuracy: 0.928
Validation - Loss: 0.264, Accuracy: 0.910


Epoch [18/100]: [38/38] 100%|██████████, loss=2.38e-01 [00:06<00:00]
Epoch [19/100]: [0/38]   0%|          , loss=2.21e-01 [00:00<?]

Epoch: 18
Training - Loss: 0.232, Accuracy: 0.929
Validation - Loss: 0.261, Accuracy: 0.911


Epoch [19/100]: [38/38] 100%|██████████, loss=2.26e-01 [00:06<00:00]
Epoch [20/100]: [0/38]   0%|          , loss=2.34e-01 [00:00<?]

Epoch: 19
Training - Loss: 0.227, Accuracy: 0.931
Validation - Loss: 0.256, Accuracy: 0.912


Epoch [20/100]: [38/38] 100%|██████████, loss=2.29e-01 [00:06<00:00]
Epoch [21/100]: [0/38]   0%|          , loss=2.10e-01 [00:00<?]

Epoch: 20
Training - Loss: 0.223, Accuracy: 0.932
Validation - Loss: 0.254, Accuracy: 0.912


Epoch [21/100]: [38/38] 100%|██████████, loss=2.16e-01 [00:06<00:00]
Epoch [22/100]: [0/38]   0%|          , loss=2.17e-01 [00:00<?]

Epoch: 21
Training - Loss: 0.218, Accuracy: 0.933
Validation - Loss: 0.251, Accuracy: 0.912


Epoch [22/100]: [38/38] 100%|██████████, loss=2.17e-01 [00:06<00:00]
Epoch [23/100]: [0/38]   0%|          , loss=2.11e-01 [00:00<?]

Epoch: 22
Training - Loss: 0.215, Accuracy: 0.934
Validation - Loss: 0.247, Accuracy: 0.913


Epoch [23/100]: [38/38] 100%|██████████, loss=2.12e-01 [00:06<00:00]
Epoch [24/100]: [0/38]   0%|          , loss=1.97e-01 [00:00<?]

Epoch: 23
Training - Loss: 0.211, Accuracy: 0.935
Validation - Loss: 0.246, Accuracy: 0.912


Epoch [24/100]: [38/38] 100%|██████████, loss=2.04e-01 [00:06<00:00]
Epoch [25/100]: [0/38]   0%|          , loss=2.01e-01 [00:00<?]

Epoch: 24
Training - Loss: 0.208, Accuracy: 0.936
Validation - Loss: 0.243, Accuracy: 0.914


Epoch [25/100]: [38/38] 100%|██████████, loss=2.04e-01 [00:06<00:00]
Epoch [26/100]: [0/38]   0%|          , loss=2.00e-01 [00:00<?]

Epoch: 25
Training - Loss: 0.204, Accuracy: 0.937
Validation - Loss: 0.241, Accuracy: 0.915


Epoch [26/100]: [38/38] 100%|██████████, loss=2.02e-01 [00:06<00:00]
Epoch [27/100]: [0/38]   0%|          , loss=2.14e-01 [00:00<?]

Epoch: 26
Training - Loss: 0.201, Accuracy: 0.938
Validation - Loss: 0.238, Accuracy: 0.915


Epoch [27/100]: [38/38] 100%|██████████, loss=2.07e-01 [00:06<00:00]
Epoch [28/100]: [0/38]   0%|          , loss=1.91e-01 [00:00<?]

Epoch: 27
Training - Loss: 0.198, Accuracy: 0.939
Validation - Loss: 0.236, Accuracy: 0.916


Epoch [28/100]: [38/38] 100%|██████████, loss=1.94e-01 [00:06<00:00]
Epoch [29/100]: [0/38]   0%|          , loss=1.90e-01 [00:00<?]

Epoch: 28
Training - Loss: 0.196, Accuracy: 0.939
Validation - Loss: 0.235, Accuracy: 0.916


Epoch [29/100]: [38/38] 100%|██████████, loss=1.93e-01 [00:06<00:00]
Epoch [30/100]: [0/38]   0%|          , loss=1.92e-01 [00:00<?]

Epoch: 29
Training - Loss: 0.193, Accuracy: 0.940
Validation - Loss: 0.232, Accuracy: 0.917


Epoch [30/100]: [38/38] 100%|██████████, loss=1.92e-01 [00:06<00:00]
Epoch [31/100]: [0/38]   0%|          , loss=1.84e-01 [00:00<?]

Epoch: 30
Training - Loss: 0.190, Accuracy: 0.941
Validation - Loss: 0.229, Accuracy: 0.918


Epoch [31/100]: [38/38] 100%|██████████, loss=1.87e-01 [00:06<00:00]
Epoch [32/100]: [0/38]   0%|          , loss=1.83e-01 [00:00<?]

Epoch: 31
Training - Loss: 0.188, Accuracy: 0.941
Validation - Loss: 0.228, Accuracy: 0.919


Epoch [32/100]: [38/38] 100%|██████████, loss=1.86e-01 [00:06<00:00]
Epoch [33/100]: [0/38]   0%|          , loss=1.95e-01 [00:00<?]

Epoch: 32
Training - Loss: 0.186, Accuracy: 0.942
Validation - Loss: 0.226, Accuracy: 0.919


Epoch [33/100]: [38/38] 100%|██████████, loss=1.90e-01 [00:06<00:00]
Epoch [34/100]: [0/38]   0%|          , loss=1.70e-01 [00:00<?]

Epoch: 33
Training - Loss: 0.184, Accuracy: 0.942
Validation - Loss: 0.226, Accuracy: 0.920


Epoch [34/100]: [38/38] 100%|██████████, loss=1.77e-01 [00:06<00:00]
Epoch [35/100]: [0/38]   0%|          , loss=1.98e-01 [00:00<?]

Epoch: 34
Training - Loss: 0.182, Accuracy: 0.943
Validation - Loss: 0.225, Accuracy: 0.918


Epoch [35/100]: [38/38] 100%|██████████, loss=1.89e-01 [00:06<00:00]
Epoch [36/100]: [0/38]   0%|          , loss=1.95e-01 [00:00<?]

Epoch: 35
Training - Loss: 0.180, Accuracy: 0.944
Validation - Loss: 0.223, Accuracy: 0.920


Epoch [36/100]: [38/38] 100%|██████████, loss=1.87e-01 [00:06<00:00]
Epoch [37/100]: [0/38]   0%|          , loss=1.71e-01 [00:00<?]

Epoch: 36
Training - Loss: 0.178, Accuracy: 0.944
Validation - Loss: 0.223, Accuracy: 0.920


Epoch [37/100]: [38/38] 100%|██████████, loss=1.74e-01 [00:06<00:00]
Epoch [38/100]: [0/38]   0%|          , loss=1.78e-01 [00:00<?]

Epoch: 37
Training - Loss: 0.176, Accuracy: 0.945
Validation - Loss: 0.221, Accuracy: 0.920


Epoch [38/100]: [38/38] 100%|██████████, loss=1.76e-01 [00:06<00:00]
Epoch [39/100]: [0/38]   0%|          , loss=1.72e-01 [00:00<?]

Epoch: 38
Training - Loss: 0.174, Accuracy: 0.945
Validation - Loss: 0.220, Accuracy: 0.920


Epoch [39/100]: [38/38] 100%|██████████, loss=1.73e-01 [00:06<00:00]
Epoch [40/100]: [0/38]   0%|          , loss=1.67e-01 [00:00<?]

Epoch: 39
Training - Loss: 0.172, Accuracy: 0.945
Validation - Loss: 0.220, Accuracy: 0.920


Epoch [40/100]: [38/38] 100%|██████████, loss=1.70e-01 [00:06<00:00]
Epoch [41/100]: [0/38]   0%|          , loss=1.72e-01 [00:00<?]

Epoch: 40
Training - Loss: 0.170, Accuracy: 0.946
Validation - Loss: 0.219, Accuracy: 0.921


Epoch [41/100]: [38/38] 100%|██████████, loss=1.71e-01 [00:06<00:00]
Epoch [42/100]: [0/38]   0%|          , loss=1.64e-01 [00:00<?]

Epoch: 41
Training - Loss: 0.169, Accuracy: 0.946
Validation - Loss: 0.219, Accuracy: 0.921


Epoch [42/100]: [38/38] 100%|██████████, loss=1.66e-01 [00:06<00:00]
Epoch [43/100]: [0/38]   0%|          , loss=1.63e-01 [00:00<?]

Epoch: 42
Training - Loss: 0.167, Accuracy: 0.947
Validation - Loss: 0.216, Accuracy: 0.922


Epoch [43/100]: [38/38] 100%|██████████, loss=1.66e-01 [00:06<00:00]
Epoch [44/100]: [0/38]   0%|          , loss=1.61e-01 [00:00<?]

Epoch: 43
Training - Loss: 0.166, Accuracy: 0.947
Validation - Loss: 0.216, Accuracy: 0.920


Epoch [44/100]: [38/38] 100%|██████████, loss=1.63e-01 [00:06<00:00]
Epoch [45/100]: [0/38]   0%|          , loss=1.74e-01 [00:00<?]

Epoch: 44
Training - Loss: 0.164, Accuracy: 0.947
Validation - Loss: 0.216, Accuracy: 0.922


Epoch [45/100]: [38/38] 100%|██████████, loss=1.69e-01 [00:06<00:00]
Epoch [46/100]: [0/38]   0%|          , loss=1.56e-01 [00:00<?]

Epoch: 45
Training - Loss: 0.163, Accuracy: 0.948
Validation - Loss: 0.215, Accuracy: 0.921


Epoch [46/100]: [38/38] 100%|██████████, loss=1.60e-01 [00:06<00:00]
Epoch [47/100]: [0/38]   0%|          , loss=1.79e-01 [00:00<?]

Epoch: 46
Training - Loss: 0.161, Accuracy: 0.948
Validation - Loss: 0.213, Accuracy: 0.922


Epoch [47/100]: [38/38] 100%|██████████, loss=1.70e-01 [00:06<00:00]
Epoch [48/100]: [0/38]   0%|          , loss=1.49e-01 [00:00<?]

Epoch: 47
Training - Loss: 0.160, Accuracy: 0.949
Validation - Loss: 0.214, Accuracy: 0.922


Epoch [48/100]: [38/38] 100%|██████████, loss=1.55e-01 [00:06<00:00]
Epoch [49/100]: [0/38]   0%|          , loss=1.40e-01 [00:00<?]

Epoch: 48
Training - Loss: 0.159, Accuracy: 0.949
Validation - Loss: 0.213, Accuracy: 0.922


Epoch [49/100]: [38/38] 100%|██████████, loss=1.50e-01 [00:06<00:00]
Epoch [50/100]: [0/38]   0%|          , loss=1.59e-01 [00:00<?]

Epoch: 49
Training - Loss: 0.157, Accuracy: 0.949
Validation - Loss: 0.213, Accuracy: 0.921


Epoch [50/100]: [38/38] 100%|██████████, loss=1.58e-01 [00:06<00:00]
Epoch [51/100]: [0/38]   0%|          , loss=1.58e-01 [00:00<?]

Epoch: 50
Training - Loss: 0.156, Accuracy: 0.950
Validation - Loss: 0.210, Accuracy: 0.924


Epoch [51/100]: [38/38] 100%|██████████, loss=1.57e-01 [00:06<00:00]
Epoch [52/100]: [0/38]   0%|          , loss=1.83e-01 [00:00<?]

Epoch: 51
Training - Loss: 0.155, Accuracy: 0.950
Validation - Loss: 0.213, Accuracy: 0.921


Epoch [52/100]: [38/38] 100%|██████████, loss=1.68e-01 [00:06<00:00]
Epoch [53/100]: [0/38]   0%|          , loss=1.52e-01 [00:00<?]

Epoch: 52
Training - Loss: 0.154, Accuracy: 0.950
Validation - Loss: 0.210, Accuracy: 0.922


Epoch [53/100]: [38/38] 100%|██████████, loss=1.53e-01 [00:06<00:00]
Epoch [54/100]: [0/38]   0%|          , loss=1.57e-01 [00:00<?]

Epoch: 53
Training - Loss: 0.153, Accuracy: 0.951
Validation - Loss: 0.211, Accuracy: 0.921


Epoch [54/100]: [38/38] 100%|██████████, loss=1.55e-01 [00:06<00:00]
Epoch [55/100]: [0/38]   0%|          , loss=1.47e-01 [00:00<?]

Epoch: 54
Training - Loss: 0.152, Accuracy: 0.951
Validation - Loss: 0.211, Accuracy: 0.921


Epoch [55/100]: [38/38] 100%|██████████, loss=1.50e-01 [00:06<00:00]


Epoch: 55
Training - Loss: 0.150, Accuracy: 0.951
Validation - Loss: 0.210, Accuracy: 0.921
