In [None]:
import importlib
import numpy as np
import os
import pickle
import sys
import torch

from transformers import (DistilBertForSequenceClassification, 
                          DistilBertTokenizer)

# Our code imports
sys.path.insert(0, os.path.join(os.getcwd(), '..', 'src'))
import train_eval
import synonym

importlib.reload(synonym)
importlib.reload(train_eval)

## Set Up Data for Training

In [None]:
DATASET = "imdb"

In [None]:
# Load data
imdb = train_eval.ReviewDataset(source=DATASET)
train_sentences, train_labels = imdb.reviewsAndLabels(test_train="train")

# Set up model
pretrained_weights = 'distilbert-base-cased'
tokenizer = DistilBertTokenizer.from_pretrained(pretrained_weights)
model_class = DistilBertForSequenceClassification
model = model_class.from_pretrained(pretrained_weights, num_labels=2,
                                   output_attentions = True,
                                   output_hidden_states = False)
train_data, validation_data = train_eval.ReviewDataset.setUpData(train_sentences, 
                                                           train_labels, 
                                                           tokenizer, 256, 0.2)

# Print first and last example!
for label, sent in zip(train_labels[[0, -1]], train_sentences[[0, -1]]):
    print("label: {}".format(label))
    print(sent, '\n')

## Training Loop

In [None]:
losses, model = train_eval.train(model, 
                                 train_data, 
                                 validation_data, 
                                 batch_size=8, 
                                 epochs=1, 
                                 lr=3e-5, # from https://github.com/nshepperd/gpt-2/blob/finetuning/train.py
                                 adam_eps=1e-8)

# torch.save(model, '{}_distil.model'.format(DATASET))

## Evaluation Pipeline

In [None]:
datasets = ['imdb', 'yelp']

In [None]:
accuracies = []
pretrained_weights = 'distilbert-base-cased'
tokenizer = DistilBertTokenizer.from_pretrained(pretrained_weights)

for d in datasets:

    # Set up test data
    imdb = train_eval.ReviewDataset(source=d)
    test_sentences, test_labels = imdb.reviewsAndLabels(test_train="test")

    evaluation_data, _ = train_eval.ReviewDataset.setUpData(test_sentences, 
                                                               test_labels, 
                                                               tokenizer, 256)
    model = torch.load('{}_distil.model'.format(d))
    
    # evaluate
    acc = train_eval.evaluate(model, evaluation_data, 128)
    print("{} accuracy: {}".format(d, np.mean(acc[0])))
    accuracies.append(np.mean(acc[0]))
    

In [None]:
accuracies