In [None]:
import torch
from torch.utils.data import DataLoader
from review_dataset import ReviewDataset
from review_classifier import ReviewClassifier

In [None]:
def generate_batches(dataset, batch_size, shuffle=True, drop_last=True, device='cuda:0'):
    """
    A generator function which wraps the PyTorch DataLoader. It will
    ensure each tensor is on the write device location.
    """
    dataloader = DataLoader(dataset=dataset, batch_size=batch_size, shuffle=shuffle, drop_last=drop_last)
    for data_dict in dataloader:
        out_data_dict = {}
        for name, tensor in data_dict.items():
            out_data_dict[name] = data_dict[name].to(device)
        yield out_data_dict

In [None]:
from argparse import Namespace

args = Namespace(
    # Data and path inpformation
    frequency_cutoff = 25,
    model_state_file = 'yelp_clf.pth',
    review_csv='../Data/reviews_with_splits_lite.csv',
    save_dir = './model/',
    vectorizer_file = 'vectorizer,json',
    # No model hyoerparameters
    #Training hyperparameters 
    batch_size = 18,
    early_stopping_criteria = 5,
    learning_rate = 0.01,
    seed = 1337,
    num_epochs = 50, 
    cuda = True
)

In [None]:
def make_train_state(args):
    return {'epoch_index': 0,
    'train_loss': [],
    'train_acc': [],
    'val_loss': [],
    'val_acc': [],
    'test_loss': -1,
    'test_acc': -1
    }

In [None]:
def compute_accuracy(y_true, y_pred):
    correct = torch.eq(y_true, y_pred).sum().item()
    acc = (correct/len(y_pred)) * 100
    return acc

In [None]:
train_state = make_train_state(args)

In [None]:
if not torch.cuda.is_available():
    args.cuda = False
args.device = torch.device("cuda:0" if args.cuda else "cpu")

In [None]:
# dataset and vectorizer
dataset = ReviewDataset.load_dataset_and_make_vectorizer(args.review_csv)
vectorizer = dataset.get_vectorizer()

In [None]:
len(vectorizer.review_vocab)

In [None]:
classifier = ReviewClassifier(num_features=len(vectorizer.review_vocab)).to(args.device)
print(classifier)

classifier.state_dict()

In [None]:
loss_func = torch.nn.BCEWithLogitsLoss()
optimizer = torch.optim.Adam(classifier.parameters(), lr=args.learning_rate)

In [None]:
for epoch in range(args.num_epochs):
    train_state['epoch_index'] = epoch
    
    # batch generator, set loss and acc to 0, set train mode on
    dataset.set_split('train')
    batch_generator = generate_batches(dataset, batch_size=args.batch_size, device=args.device)
    
    running_loss = 0.0
    running_acc = 0.0
    
    classifier.train()
    
    for batch_index, batch_dict in enumerate(batch_generator):
        # setout the zero grad 
        optimizer.zero_grad()
        
        # compute logits 
        y_logits = classifier(x=batch_dict['x_data'].float())
        
        # compute the loss
        loss = loss_func(y_logits, batch_dict['y_target'].float())
        loss_batch = loss.item()
        running_loss += (loss_batch - running_loss) / (batch_index +1)
        
        y_preds = torch.round(torch.sigmoid(y_logits))
        
        # use loss to compute the gradients 
        loss.backward()
        
        # optimzer to take gradient step
        optimizer.step()
        
        # batch accuracy
        batch_acc = compute_accuracy(y_true=batch_dict['y_target'], y_pred=y_preds)
        running_acc += (batch_acc -running_acc) / (batch_index + 1)
        
    train_state['train_loss'].append(running_loss)
    train_state['train_acc'].append(running_acc)
    
    # vaildation iteration
    dataset.set_split('val')
    batch_generator = generate_batches(dataset, batch_size=args.batch_size, device=args.device)

    running_loss = 0.
    running_acc = 0.
    classifier.eval()
    
    for batch_index, batch_dict in enumerate(batch_generator):
        # compute the logits 
        y_logits = classifier(x=batch_dict['x_data'].float())
        
        loss = loss_func(y_logits, batch_dict['y_target'].float())
        loss_batch = loss.item()
        
        y_preds = torch.round(torch.sigmoid(y_logits))
        
        running_loss += (loss_batch -running_loss) / (batch_index +1 )
        
        # compute acc
        batch_acc = compute_accuracy(y_true=batch_dict['y_target'], y_pred=y_preds)
        running_acc += (batch_acc - running_acc) / (batch_index + 1)
    
    train_state['val_loss'].append(running_loss)
    train_state['val_acc'].append(running_acc)
        
    if epoch % 10 == 0: 
        print(f"Epoch: {epoch} \n Train Loss:{train_state['train_loss'][-1]:.3f} | Train acc: {train_state['train_acc'][-1]:.3f} | Val loss: {train_state['val_loss'][-1]:.3f} | Val acc: {train_state['val_acc'][-1]:.3f}")
    

In [None]:
dataset.set_split('test')
batch_generator = generate_batches(dataset, batch_size=args.batch_size, device=args.device)
running_loss = 0.
running_acc = 0.
classifier.eval()

for batch_index, batch_dict in enumerate(batch_generator):
    # compute the output
    y_logits = classifier(x=batch_dict['x_data'].float())
    # compute the loss
    loss = loss_func(y_logits, batch_dict['y_target'].float())
    loss_batch = loss.item()
    
    y_preds = torch.round(torch.sigmoid(y_logits))
    
    running_loss += (loss_batch - running_loss) / (batch_index + 1)
    # compute the accuracy
    acc_batch = compute_accuracy(y_true=batch_dict['y_target'], y_pred=y_preds)
    running_acc += (acc_batch - running_acc) / (batch_index + 1)

train_state['test_loss'] = running_loss
train_state['test_acc'] = running_acc

print(f"Test loss: {train_state['test_loss']:.3f} | Test acc: {train_state['test_acc']:.3f}")

In [None]:
# Sort weights
fc1_weights = classifier.fc1.weight.detach()[0]
_, indices = torch.sort(fc1_weights.to('cpu'), dim=0, descending=True)
indices = indices.numpy().tolist()
# Top 20 words
print("Influential words in Positive Reviews:")
print("--------------------------------------")
for i in range(20):
    print(vectorizer.review_vocab.lookup_index(indices[i]))

print("Influential words in Negative Reviews:")
print("--------------------------------------")
indices.reverse()
for i in range(20):
    print(vectorizer.review_vocab.lookup_index(indices[i]))