# DisSent Books 5 Error Analysis

We analyze the full model because it performs the best.

In [2]:
%load_ext autoreload
%autoreload 2

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [4]:
import os
import torch
dis_net = torch.load(os.path.join('./exp/books5_words_4096_sgd_01_d0_fcd0', 'dis-model-11'  + ".pickle"))
dis_net.cpu()
dis_net.s1_only = False
dis_net.s2_only = False

In [5]:
"""
DATA
"""
import numpy as np
from data import get_dis, get_batch, build_vocab, get_filtered_dis

train, valid, test = get_dis('/home/anie/DisExtract/data/books', 
                                      'discourse_EN_FIVE_and_but_because_if_when_2017dec12')
word_vec = build_vocab(train['s1'] + train['s2'] +
                       valid['s1'] + valid['s2'] +
                       test['s1'] + test['s2'], '/home/anie/glove/glove.840B.300d.txt')

# unknown words instead of map to <unk>, this directly takes them out
for split in ['s1', 's2']:
    for data_type in ['train', 'valid', 'test']:
        eval(data_type)[split] = np.array([['<s>'] +
                                           [word for word in sent.split() if word in word_vec] +
                                           ['</s>'] for sent in eval(data_type)[split]])

We can use the following function to get an example of a given marker

In [6]:
from util import get_labels
dis_labels = get_labels('books_5')

In [7]:
def get_example_by_idx(marker, idx, data_type='valid'):
    marker_idx = dis_labels.index(marker)
    cnt = idx
    for i in range(len(eval(data_type)['label'])):
        if eval(data_type)['label'][i] == marker_idx and cnt == 0:
            return eval(data_type)['s1'][i], eval(data_type)['s2'][i], eval(data_type)['label'][i]
        elif eval(data_type)['label'][i] == marker_idx:
            cnt -= 1
        else:
            continue

In [13]:
res = get_example_by_idx('but', 0)

In [14]:
res

(['<s>',
  'They',
  'were',
  'less',
  'strict',
  'theyre',
  'good',
  'parents',
  '.',
  '</s>'],
 ['<s>', 'Thats', 'what', 'really', 'matters', '.', '</s>'],
 2)

## What is the interpretation method?

We use the [contextual decomposition](https://arxiv.org/pdf/1801.05453.pdf) technique. The core idea of this method is to:

$h_t = \mathrm{rel}_t + \mathrm{irrel}_t$ for all t

We break down each hidden state into a "relevant" vector and a "irrelevant" vector. We are working with a special case where we only look at each word's contribution to the final prediction, ignoring the interactions between words.

After breaking each hidden state down, we then use $rel_t \times \frac{\nabla \hat y_i}{\nabla h_t}$, the gradient to the hidden state as it's influence to the final prediction.


In [None]:
from viz import MaxPoolingCDBiLSTM

bilstm = MaxPoolingCDBiLSTM(model=dis_net, glove_path="/home/anie/glove/glove.840B.300d.txt", bilstm=True)
bilstm.word_vec = word_vec
bilstm.model.encoder.word_vec = word_vec

We first take a look at some more successful markers like `but`, `and`, and `if`

In [15]:
for i in range(10):
    res = get_example_by_idx('but', i)
    display(*bilstm.visualize_example(res[0], res[1], res[2], dis_labels))

In [17]:
for i in range(19, 20):
    res = get_example_by_idx('but', i)
    display(*bilstm.visualize_example(res[0], res[1], res[2], dis_labels))

In [18]:
for i in range(20, 30):
    res = get_example_by_idx('but', i)
    display(*bilstm.visualize_example(res[0], res[1], res[2], dis_labels))

In [19]:
for i in range(10):
    res = get_example_by_idx('if', i)
    display(*bilstm.visualize_example(res[0], res[1], res[2], dis_labels))

In [20]:
for i in range(10, 20):
    res = get_example_by_idx('if', i)
    display(*bilstm.visualize_example(res[0], res[1], res[2], dis_labels))

We can zoom in on `because` marker

In [8]:
import json
type_one = json.load(open('./type_one_error_list.json', 'rb'))
type_two = json.load(open('./type_two_error_list.json', 'rb'))
correct = json.load(open('./correct_list.json', 'rb'))

In [11]:
for i in range(10):
    display(*bilstm.visualize_example(correct[i][0], correct[i][1], correct[i][2], dis_labels))

In [12]:
for i in range(10, 20):
    display(*bilstm.visualize_example(correct[i][0], correct[i][1], correct[i][2], dis_labels))

In [None]:
for i in range(20, 30):
    display(*bilstm.visualize_example(correct[i][0], correct[i][1], correct[i][2], dis_labels))