In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
import pandas as pd
import numpy as np
import torch
import pdb

from pathlib import Path
from torch import nn
from torch.nn import functional as F
from torch import optim
from torch.utils.data import DataLoader
from torch.utils.data import Dataset

In [None]:
from ignite.metrics import Accuracy, Loss
from ignite.contrib.handlers import ProgressBar

In [None]:
from yelp.dataset import ProjectDataset
from yelp.trainer import YelpTrainer
from yelp.model import Classifier
from yelp.args import args

In [None]:
args

In [None]:
path = Path('./data/yelp')
review_csv = path/args.sample_file
scratch = path/args.workdir_name
vectorizer_path = scratch/args.vectorizer_fname
args.save_dir = scratch

df = pd.read_csv(review_csv)

In [None]:
# train_ds = ProjectDataset.load_data_and_create_vectorizer(df.loc[df['split'] == 'train'])
# train_ds.save_vectorizer(vectorizer_path)

In [None]:
train_df = df.loc[df['split'] == 'train']
train_ds = ProjectDataset.load_data_and_vectorizer(train_df, vectorizer_path)
vectorizer = train_ds.get_vectorizer()
train_dl = DataLoader(train_ds, batch_size=128, shuffle=True, drop_last=True)

val_df = df.loc[df['split'] == 'val']
val_ds = ProjectDataset.load_data_and_vectorizer(val_df, vectorizer_path)
val_dl = DataLoader(val_ds, batch_size=128, shuffle=True, drop_last=True)

In [None]:
classifier = Classifier(num_features=len((vectorizer).review_vocab))
optimizer = optim.Adam(classifier.parameters(), lr=args.learning_rate)
scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer=optimizer, mode='min', factor=0.5, patience=1)
loss_func = nn.BCEWithLogitsLoss()

In [None]:
def bce_logits_wrapper(output):
    y_pred, y = output
    y_pred = (torch.sigmoid(y_pred) > 0.5).long()
    return y_pred, y

In [None]:
pbar = ProgressBar(persist=True)
metrics = {'accuracy': Accuracy(bce_logits_wrapper), 'loss': Loss(loss_func)}

In [None]:
yelp_trainer = YelpTrainer(classifier, optimizer, loss_func, train_dl, val_dl, args, pbar, metrics)

In [None]:
yelp_trainer.run()