In [None]:
%load_ext autoreload
%autoreload 2
%matplotlib inline
%config InlineBackend.figure_format = 'retina'

In [None]:
import numpy as np
import torch
import pandas as pd
from tqdm import tqdm_notebook
import matplotlib.pyplot as plt
import seaborn as sns
np.set_printoptions(suppress=True)
from Utils import *

In [None]:
import sys
sys.path.insert(0, 'preprocess/')
import vectorizer

import pickle

In [None]:
vec = pickle.load(open('preprocess/IMDB/imdb_data.p', 'rb'))
add_frequencies(vec, vec.seq_text['train'])

In [None]:
import model.Attn_Word_Pert as M
Model = M.Model

In [None]:
X, Xt, Xr = vec.seq_text['train'], vec.seq_text['test'], vec.seq_text['rem']
y, yt, yr = vec.label['train'], vec.label['test'], vec.label['rem']

In [None]:
X, y = filterbylength(X, y, min_length=6)
Xt, yt = filterbylength(Xt, yt, min_length=6)
Xr, yr = filterbylength(Xr, yr, min_length=6)

Xt, yt = sortbylength(Xt, yt)
Xr, yr = sortbylength(Xr, yr)

In [None]:
pos_weight = 1 #len(y)/sum(y) - 1

In [None]:
from sklearn.metrics import classification_report, f1_score

def train(name='') :
    model = Model(vec.vocab_size, vec.word_dim, 64, dirname='imdb', hidden_size=128, pre_embed=vec.embeddings)
    best_f1 = 0.0
    for i in tqdm_notebook(range(5)) :
        loss = model.train(X, y)
        o, he = model.evaluate(Xt)
        o = np.array(o)
        rep = classification_report(yt, (o > 0.5))
        f1 = f1_score(yt, (o > 0.5), pos_label=1)
        print(rep)
        stmt = '%s, %s' % (i, loss)
        if True : #f1 > best_f1 :
            best_f1 = f1
            dirname = model.save_values(add_name=name, save_model=True)
            print("Model Saved", f1)
        else :
            dirname = model.save_values(add_name=name, save_model=False)
            print("Model not saved", f1)
        f = open(dirname + '/epoch.txt', 'a')
        f.write(stmt + '\n')
        f.write(rep + '\n')
        f.close()
    
    return model

In [None]:
train(name='TEST_imdb')

# **EVALUATION**

In [None]:
def load_model(dirname) :
    model = Model(vec.vocab_size, vec.word_dim, 32, dirname='imdb', hidden_size=128, pre_embed=vec.embeddings)
    model.dirname = dirname
    model.load_values(dirname)
    return model

In [None]:
model = load_model('outputs/attn_word_imdb/MonOct1510:24:342018_TEST_imdb/')

In [None]:
yt_hat, attn_hat = evaluate_and_print(model, Xt, yt)

In [None]:
plot_entropy(Xt, attn_hat)

# __SAMPLING__

In [None]:
model.vec = vec
sampled_output = model.sampling_top(Xt, sample_vocab=100)

In [None]:
import pickle
pickle.dump(sampled_output, open(model.dirname + '/sampled.p', 'wb'))

In [None]:
sampled_output = pickle.load(open(model.dirname + '/sampled.p', 'rb'))

In [None]:
generate_medians_from_sampling_top(sampled_output, attn_hat)

In [None]:
get_distractors(sampled_output, attn_hat)

In [None]:
n = 31
w = 0
pos = 0
plot_diff(vec.map2words(Xt[n]), 
          best_attn_idxs[n][pos], 
          vec.idx2word[int(words_sampled[n][w])], 
          attn_hat[n][:], 
          perts_attn[n][:, w, pos])

**Gradients**
=============

In [None]:
grads = model.gradient_mem(Xt)
process_grads(grads)

In [None]:
plot_grads(Xt, attn_hat, grads)

**Permutation**
===========

In [None]:
perms = model.permute_attn(Xt)

In [None]:
plot_permutations(perms, Xt, yt_hat, attn_hat)

**Adversarial**
===========

In [None]:
adversarial_outputs = model.adversarial(Xt, _type='uniform')
ad_y, ad_attn = adversarial_outputs

In [None]:
jds = plot_adversarial(X, yt_hat, attn_hat, adversarial_outputs)

In [None]:
idx = list(np.where(np.logical_and(np.array(jds) > 0.1, yt_hat[:, 0] > 0.9))[0])[:30]
idx

In [None]:
n = 248
print_adversarial_example(vec.map2words(X[n]), attn_hat[n], ad_attn[n])
print(yt_hat[n], ad_y[n])

In [None]:
remove_outputs = model.remove_and_run(Xt)

In [None]:
plot_y_diff(Xt, attn_hat, yt_hat, remove_outputs)

In [None]:
X, Y = np.meshgrid(np.arange(100)/100, np.arange(100)/100)

In [None]:
Z = np.zeros((100, 100))
W = np.zeros((100, 100))

for i in range(100) :
    for j in range(100) :
        Z[i, j] = jsd_bern(i/100, j/100)
        W[i, j] = abs(i-j)

In [None]:
plt.matshow(Z)
plt.matshow(W)