In [None]:
import glob
import importlib
import numpy as np
import os
import pickle
import sys
import time
import torch

from transformers import (DistilBertForSequenceClassification, 
                          DistilBertTokenizer)

# Our code imports
sys.path.insert(0, os.path.join(os.getcwd(), '..', 'src'))
import train_eval
import synonym

importlib.reload(synonym)
importlib.reload(train_eval)

## Model Inputs 

In [None]:
model_path = "imdb_distil.model"
pretrained_weights = 'distilbert-base-cased'
tokenizer = DistilBertTokenizer.from_pretrained(pretrained_weights)
max_seq = 256

## Set up data for model

In [None]:
imdb = train_eval.ReviewDataset(source="imdb")
test_sentences, test_labels = imdb.reviewsAndLabels(test_train="test")

## Generate Hotwords

In [None]:
s_attacks = synonym.SynonymAttacks(model_path, tokenizer, 256)

hws = []
softmax_changes = []

for hw, sm in s_attacks.generateHotWords(test_sentences, test_labels, train_no=25000, 
                                         method="blank"):
    hws.append(hw)
    softmax_changes.append(sm)


In [None]:
# # Save output
# pickle.dump([hws, softmax_changes], open("distil_hw_sm.p", "wb"))

# Load output
hw_sm = pickle.load(open('distil_hw_sm.p', 'rb'))
hws = hw_sm[0]
softmax_changes = hw_sm[1]

In [None]:
hot_words = np.array([item for hw in hws for item in hw])
diffs = [item for sm in softmax_changes for item in sm]
vocab = list(set(hot_words))

In [None]:
word_weights = []
hot_words_encoding = np.array(tokenizer.encode(list(hot_words)))
vocab_encoding = tokenizer.encode(vocab)

for word in vocab_encoding[1:-1]:
    word_weights.append(np.nanmean(np.where(hot_words_encoding[1:-1] == word, diffs, np.nan)))

# Positive weight means it was made more positive - negative weight means it was made more negative
word_weights = np.array(word_weights)
word_weights_z = (word_weights - np.mean(word_weights)) / np.nanstd(word_weights)

In [None]:
word_sorts = np.argsort(word_weights_z)
ranked_words = np.array(vocab)[word_sorts]

In [None]:
# Save output
# pickle.dump((vocab, word_weights_z, word_sorts), open("distil-vcb_wt_sort.p", "wb"))

## Hot Words - Gamma Search

In [None]:
# From our previously calculated hot words
vocab, word_weights, word_sorts = pickle.load(open("distil-vcb_wt_sort.p", "rb"))
word_dist = dict(zip(vocab, word_weights))

In [None]:
datasets = ['imdb']
gammas = [0, 10, 20, 30, 40, 50, 60, 80, 100, 120]
overall_acc = dict()
    
for dataset in datasets:
    model_path = "{}_distil.model".format(dataset)
    pretrained_weights = 'distilbert-base-cased'
    tokenizer = DistilBertTokenizer.from_pretrained(pretrained_weights)

    model = torch.load(model_path)

    imdb = train_eval.ReviewDataset(source=dataset)
    test_sentences, test_labels = imdb.reviewsAndLabels(test_train="test")

    # Instantiate attack class
    s_attacks = synonym.SynonymAttacks(model_path, tokenizer, 256)
    
    accuracies = []
    # Create adversarial examples
    for gamma in gammas:
        adv_data, adv_label = s_attacks.generateSynonymReviews(test_sentences, test_labels,
                                                              replacements=gamma, 
                                                               hot_word_distribution=word_dist,
                                                              method="random")

        evaluation_data, _ = train_eval.ReviewDataset.setUpData(adv_data, 
                                                               adv_label, 
                                                               tokenizer, 256)

        acc = train_eval.evaluate(model, evaluation_data, 128)
        acc = np.mean(acc[0])
        print("for gamma = {}, accuracy: {}".format(gamma, acc))
        accuracies.append(acc)

    overall_acc[dataset] = accuracies

In [None]:
accuracies