In [15]:
import os
import pickle
import torch
import numpy as np
from model import SentimentGRUWithGlove, TransformerModel
from dataset import SentimentAnalysisDataset, LABEL_MAP

data_path = 'data'

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print('device:', device)

device: cpu


In [2]:
print('-> Loading datasets')
test_dataset = SentimentAnalysisDataset(os.path.join(data_path, 'testEmotions.csv'))
train_size = len(test_dataset)

-> Loading datasets
-> Loading word embeddings


In [11]:
def run_test(model, dataset, device):
    y_true = []
    y_predict = []

    model.eval()
    with torch.no_grad():
        for tokens, label in dataset:
            tokens = tokens.to(device).float()
            label = label.to(device)
            y_true.append(int(label.argmax(dim=0)))

            if len(tokens) == 0:
                # Predict neutral if no token after processing 
                # e.g., only stopwords in the original text
                y_predict.append(int(LABEL_MAP['neutral']))
                continue

            # Forward pass
            y_prob = model(tokens)
            y_predict.append(int(torch.softmax(y_prob, dim=0).argmax(dim=0)))

    y_true = np.array(y_true)
    y_predict = np.array(y_predict)

    accuracy = sum(y_true == y_predict) / len(y_true)

    return accuracy

In [13]:
print('-> Initalizing GRU model')
model = SentimentGRUWithGlove(100)
model.to(device)

print('-> Loading model state')
model.load_state_dict(torch.load('gru_A_06_03_16_34.state'))

print('-> Running accuracy test')
acc = run_test(model, test_dataset, device)
print('Test accuracy:', acc)

-> Initalizing GRU model
Test accuracy: 0.5209840810419681


In [17]:
print('-> Initalizing transformer model')
model = TransformerModel(100)
model.to(device)

print('-> Loading model state')
model.load_state_dict(torch.load('transformer_classifier_C_06_03_19_55.state'))

print('-> Running accuracy test')
acc = run_test(model, test_dataset, device)
print('Test accuracy:', acc)

-> Initalizing transformer model
-> Loading model state
-> Running accuracy test
