## Imports

In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import os
import sys
module_path = os.path.abspath(os.path.join('..'))
if module_path not in sys.path:
    sys.path.append(module_path)

In [3]:
from path_explain import utils
utils.set_up_environment(visible_devices='2')

In [4]:
import tensorflow as tf
import tensorflow_datasets
import numpy as np
import pandas as pd
import altair as alt
import scipy
from bert_explainer import BertExplainerTF
from transformers import *
import transformers

## Data and Model Loading

In [5]:
task = 'sst-2'
num_labels = len(glue_processors[task]().get_labels())

In [6]:
config = BertConfig.from_pretrained('.', num_labels=num_labels)
tokenizer = BertTokenizer.from_pretrained('bert-base-cased')
model = TFBertForSequenceClassification.from_pretrained('.', config=config)

In [7]:
data, info = tensorflow_datasets.load('glue/sst2', with_info=True)

INFO:absl:Overwrite dataset info from restored data version.
INFO:absl:Reusing dataset glue (/homes/gws/psturm/tensorflow_datasets/glue/sst2/0.0.2)
INFO:absl:Constructing tf.data.Dataset for split None, from /homes/gws/psturm/tensorflow_datasets/glue/sst2/0.0.2


In [8]:
train_dataset = glue_convert_examples_to_features(data['train'], tokenizer, max_length=128, task=task)
valid_dataset = glue_convert_examples_to_features(data['validation'], tokenizer, max_length=128, task=task)
valid_dataset = valid_dataset.batch(16)

## Model Evaluation

In [9]:
valid_pred = model.predict(valid_dataset)

In [10]:
valid_labels = []
for batch in valid_dataset:
    valid_labels.append(batch[1].numpy())
valid_labels = np.concatenate(valid_labels, axis=0)

In [11]:
valid_pred_max = np.argmax(valid_pred, axis=-1)
accuracy = np.sum(valid_pred_max == valid_labels) / len(valid_labels)

positive_mask = valid_labels == 1
positive_accuracy = np.sum(valid_pred_max[positive_mask] == valid_labels[positive_mask]) / np.sum(positive_mask)

negative_mask = valid_labels == 0
negative_accuracy = np.sum(valid_pred_max[negative_mask] == valid_labels[negative_mask]) / np.sum(negative_mask)

print('Validation Accuracy: {:.4f}'.format(accuracy))
print('Positive Sentiment Accuracy: {:.4f}'.format(positive_accuracy))
print('Negative Sentiment Accuracy: {:.4f}'.format(negative_accuracy))

Validation Accuracy: 0.8956
Positive Sentiment Accuracy: 0.9527
Negative Sentiment Accuracy: 0.8364


## Displaying Sentences

In [12]:
batch_conf = valid_pred[-8:]
batch_conf = scipy.special.softmax(batch_conf, axis=-1)
for i in range(batch[0]['input_ids'].shape[0]):
    encoded_sentence = batch[0]['input_ids'].numpy()[i]
    encoded_sentence = encoded_sentence[encoded_sentence != 0]
    label = batch[1][i].numpy()
    print(tokenizer.decode(encoded_sentence))
    print('This sentence is {} (predicted confidence: {:.4f})'.format('positive' if label == 1 else 'negative', batch_conf[i, label]))
    print('-------------------------')

[CLS] pumpkin takes an admirable look at the hypocrisy of political correctness, but it does so with such an uneven tone that you never know when humor ends and tragedy begins. [SEP]
This sentence is negative (predicted confidence: 0.5426)
-------------------------
[CLS] is the time really ripe for a warmed - over james bond adventure, with a village idiot as the 007 clone? [SEP]
This sentence is negative (predicted confidence: 0.8609)
-------------------------
[CLS] exquisitely nuanced in mood tics and dialogue, this chamber drama is superbly acted by the deeply appealing veteran bouquet and the chilling but quite human berling. [SEP]
This sentence is positive (predicted confidence: 0.9992)
-------------------------
[CLS] the movie's relatively simple plot and uncomplicated morality play well with the affable cast. [SEP]
This sentence is positive (predicted confidence: 0.9983)
-------------------------
[CLS] dense with characters and contains some thrilling moments. [SEP]
This sentenc

## Model Interpretation

In [13]:
explainer = BertExplainerTF(model)

In [52]:
batch_input_ids = tf.cast(batch[0]['input_ids'], tf.float64)
batch_labels   = batch[1]
batch_baseline = np.zeros((1, 128)).astype(np.float64)

In [51]:
index = 4
mask = batch_input_ids[index] > 0

In [50]:
for a in attributions[index][mask]:
    print('{:.2f}'.format(a), end=' ')

-140.07 -288.91 72.79 -62.19 76.10 -60.76 87.47 -163.04 50.96 -27.59 -0.30 7.17 

In [49]:
np.sum(attributions[index])

-448.3723403776483

In [45]:
attributions = explainer.attributions(inputs=batch_input_ids,
                                      baseline=batch_baseline,
                                      batch_size=10,
                                      num_samples=100,
                                      use_expectation=False,
                                      output_indices=batch_labels,
                                      verbose=True)




  0%|          | 0/8 [00:00<?, ?it/s][A[A[A


 12%|█▎        | 1/8 [00:04<00:28,  4.07s/it][A[A[A


 25%|██▌       | 2/8 [00:07<00:23,  3.87s/it][A[A[A


 38%|███▊      | 3/8 [00:10<00:18,  3.71s/it][A[A[A


 50%|█████     | 4/8 [00:14<00:14,  3.60s/it][A[A[A


 62%|██████▎   | 5/8 [00:17<00:10,  3.52s/it][A[A[A


 75%|███████▌  | 6/8 [00:20<00:06,  3.47s/it][A[A[A


 88%|████████▊ | 7/8 [00:24<00:03,  3.43s/it][A[A[A


100%|██████████| 8/8 [00:27<00:00,  3.44s/it][A[A[A


## Prediction on Custom Sentences

In [16]:
sentences = [(0, 'This movie was too long, but overall a good movie nonetheless'),
             (1, 'This movie was not bad')]
input_sen = [transformers.data.InputExample(guid=x[0],
                                            text_a=x[1],
                                            label='0') for x in sentences]
examples = glue_convert_examples_to_features(input_sen,
                                             tokenizer,
                                             max_length=128,
                                             task=task)
input_ids = np.array([example.input_ids for example in examples])
predicted_logits = model(input_ids)[0]
predicted_confidence = scipy.special.softmax(predicted_logits, axis=-1)
predicted_labels = np.argmax(predicted_confidence, axis=1)
for i in range(len(sentences)):
    print(sentences[i][1])
    print('The above sentence is {} with confidence: {:.4f}'.format(
        'positive' if predicted_logits[i, 1] > predicted_logits[i, 0] else 'negative',
        predicted_confidence[i, predicted_labels[i]]
    ))
    print('-------------------')

This movie was too long, but overall a good movie nonetheless
The above sentence is positive with confidence: 0.9545
-------------------
This movie was not bad
The above sentence is positive with confidence: 0.9485
-------------------


## Getting model output
Note that there is a difference between specifying and not specifying the attention mask.
The explanation code does specify the attention mask as the nonzero entries, but
I need to confirm that doing so makes sense

In [73]:
model(batch[0])

(<tf.Tensor: id=241221, shape=(8, 2), dtype=float32, numpy=
 array([[-0.08151681, -0.25236744],
        [ 0.7073199 , -1.1151597 ],
        [-3.7333772 ,  3.39218   ],
        [-3.397447  ,  2.9782803 ],
        [-3.4761655 ,  3.0213773 ],
        [-3.454523  ,  3.0430386 ],
        [-3.043965  ,  2.4792337 ],
        [-1.3244063 ,  0.5451266 ]], dtype=float32)>,)

In [67]:
model(batch_input_ids)[0]

<tf.Tensor: id=239074, shape=(8, 2), dtype=float32, numpy=
array([[ 0.24352185, -0.59544015],
       [ 0.70060503, -1.0981779 ],
       [-3.30811   ,  2.8154087 ],
       [-2.380083  ,  1.7778169 ],
       [-1.4645485 ,  0.79111874],
       [-2.7737343 ,  2.2790208 ],
       [-1.4757023 ,  0.7624265 ],
       [-0.8581508 ,  0.23136781]], dtype=float32)>

In [74]:
batch_ids = tf.cast(batch_input_ids, tf.int32)
batch_masks = tf.cast(tf.cast(input_ids, tf.bool), tf.int32)
batch_token_types = tf.zeros(batch_ids.shape)

In [76]:
model([batch_ids, batch_masks, batch_token_types])[0]

<tf.Tensor: id=245520, shape=(8, 2), dtype=float32, numpy=
array([[-0.08151681, -0.25236744],
       [ 0.7073199 , -1.1151597 ],
       [-3.7333772 ,  3.39218   ],
       [-3.397447  ,  2.9782803 ],
       [-3.4761655 ,  3.0213773 ],
       [-3.454523  ,  3.0430386 ],
       [-3.043965  ,  2.4792337 ],
       [-1.3244063 ,  0.5451266 ]], dtype=float32)>