## Imports

In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import os
import sys
module_path = os.path.abspath(os.path.join('..'))
if module_path not in sys.path:
    sys.path.append(module_path)

In [3]:
from path_explain import utils
utils.set_up_environment(visible_devices='0')

In [4]:
import tensorflow as tf
import tensorflow_datasets
import numpy as np
import pandas as pd
import altair as alt
import scipy
from bert_explainer import BertExplainerTF
from path_explain.path_explainer_tf import PathExplainerTF
from transformers import *
from plot.text import text_plot
import transformers
from tqdm import tqdm
from functools import reduce

## Data and Model Loading

In [5]:
task = 'sst-2'
num_labels = len(glue_processors[task]().get_labels())

In [6]:
config = BertConfig.from_pretrained('.', num_labels=num_labels)
tokenizer = BertTokenizer.from_pretrained('bert-base-cased')
model = TFBertForSequenceClassification.from_pretrained('.', config=config)

In [7]:
data, info = tensorflow_datasets.load('glue/sst2', with_info=True)

INFO:absl:Overwrite dataset info from restored data version.
INFO:absl:Reusing dataset glue (/homes/gws/psturm/tensorflow_datasets/glue/sst2/0.0.2)
INFO:absl:Constructing tf.data.Dataset for split None, from /homes/gws/psturm/tensorflow_datasets/glue/sst2/0.0.2


In [8]:
train_dataset = glue_convert_examples_to_features(data['train'], tokenizer, max_length=128, task=task)
valid_dataset = glue_convert_examples_to_features(data['validation'], tokenizer, max_length=128, task=task)
valid_dataset = valid_dataset.batch(16)

## Model Evaluation

In [9]:
valid_pred = model.predict(valid_dataset)

In [10]:
valid_input = []
valid_labels = []
for batch in valid_dataset:
    valid_input.append(batch[0])
    valid_labels.append(batch[1].numpy())
valid_labels_np = np.concatenate(valid_labels, axis=0)

In [11]:
valid_pred_max = np.argmax(valid_pred, axis=-1)
accuracy = np.sum(valid_pred_max == valid_labels_np) / len(valid_labels_np)

positive_mask = valid_labels_np == 1
positive_accuracy = np.sum(valid_pred_max[positive_mask] == valid_labels_np[positive_mask]) / np.sum(positive_mask)

negative_mask = valid_labels_np == 0
negative_accuracy = np.sum(valid_pred_max[negative_mask] == valid_labels_np[negative_mask]) / np.sum(negative_mask)

print('Validation Accuracy: {:.4f}'.format(accuracy))
print('Positive Sentiment Accuracy: {:.4f}'.format(positive_accuracy))
print('Negative Sentiment Accuracy: {:.4f}'.format(negative_accuracy))

Validation Accuracy: 0.8956
Positive Sentiment Accuracy: 0.9527
Negative Sentiment Accuracy: 0.8364


In [12]:
batch_conf = valid_pred[:16]
batch_conf = scipy.special.softmax(batch_conf, axis=-1)
batch_pred = np.argmax(batch_conf, axis=-1)
batch_labels = valid_labels[0]
batch_input = valid_input[0]
batch_ids = batch_input['input_ids']
batch_baseline = np.zeros((1, 128))

## Model Interpretation

In [13]:
explainer = BertExplainerTF(model)

In [14]:
try:
    attributions = np.load('attributions.npy')
except FileNotFoundError as e:
    attributions = explainer.attributions(inputs=batch_ids,
                                          baseline=batch_baseline,
                                          batch_size=30,
                                          num_samples=1000,
                                          use_expectation=False,
                                          output_indices=1,
                                          verbose=True)
    np.save('attributions.npy', attributions)

In [15]:
def check_completeness(index):
    current_input = {
        'input_ids': batch_input['input_ids'][index:index+1],
        'attention_mask': batch_input['attention_mask'][index:index+1],
        'token_type_ids': batch_input['token_type_ids'][index:index+1],
    }

    current_baseline = {
        'input_ids': np.zeros((1, 128)).astype(int),
        'attention_mask': batch_input['attention_mask'][index:index+1],
        'token_type_ids': batch_input['token_type_ids'][index:index+1],
    }

    current_output = model(current_input)[0]
    baseline_output = model(current_baseline)[0]
    output_difference = current_output - baseline_output
    output_difference = output_difference[0, 1]
    sum_attr = np.sum(attributions[index, :])

    encoded_sentence = batch_input['input_ids'].numpy()[index]
    encoded_sentence = encoded_sentence[encoded_sentence != 0]
    label = batch_labels[index]
    print(tokenizer.decode(encoded_sentence))
    print('This sentence is {} (predicted confidence: {:.4f})'.format('positive' if label == 1 else 'negative', batch_conf[index, label]))
    print('Output difference:\t{:.4f} ({:.4f} - {:.4})'.format(output_difference,
                                                                current_output[0, 1],
                                                                baseline_output[0, 1]))
    print('Sum of attributions:\t{:.4f}'.format(sum_attr))
    print('-------------------------')

In [16]:
for i in [1,3,5,7,10]:
    check_completeness(i)

[CLS] too much of it feels unfocused and underdeveloped. [SEP]
This sentence is negative (predicted confidence: 0.9847)
Output difference:	-2.6409 (-2.3507 - 0.2902)
Sum of attributions:	-2.6379
-------------------------
[CLS] prurient playthings aside, there's little to love about this english trifle. [SEP]
This sentence is negative (predicted confidence: 0.9527)
Output difference:	-1.8265 (-1.8019 - 0.02459)
Sum of attributions:	-1.8654
-------------------------
[CLS] it proves quite compelling as an intense, brooding character study. [SEP]
This sentence is positive (predicted confidence: 0.9992)
Output difference:	3.0870 (3.4095 - 0.3225)
Sum of attributions:	3.2732
-------------------------
[CLS] looks and feels like a project better suited for the small screen. [SEP]
This sentence is negative (predicted confidence: 0.9552)
Output difference:	-2.1291 (-1.7871 - 0.342)
Sum of attributions:	-2.1694
-------------------------
[CLS] a painfully funny ode to bad behavior. [SEP]
This sent

In [17]:
batch_ids = batch_input['input_ids']

In [18]:
i = 1
text_plot(tokens=tokenizer.convert_ids_to_tokens(batch_ids[i].numpy()),
          attributions=attributions[i],
          non_zero_mask=batch_ids[i].numpy()!=0,
          include_legend=True).properties(title='Sentence {}'.format(i))

In [19]:
i = 3
text_plot(tokens=tokenizer.convert_ids_to_tokens(batch_ids[i].numpy()),
          non_zero_mask=batch_ids[i].numpy()!=0,
          attributions=attributions[i]).properties(title='Sentence {}'.format(i))

In [20]:
i = 5
text_plot(tokens=tokenizer.convert_ids_to_tokens(batch_ids[i].numpy()),
          non_zero_mask=batch_ids[i].numpy()!=0,
          attributions=attributions[i]).properties(title='Sentence {}'.format(i))

In [21]:
i = 7
text_plot(tokens=tokenizer.convert_ids_to_tokens(batch_ids[i].numpy()),
          non_zero_mask=batch_ids[i].numpy()!=0,
          attributions=attributions[i]).properties(title='Sentence {}'.format(i))

In [22]:
i = 10
text_plot(tokens=tokenizer.convert_ids_to_tokens(batch_ids[i].numpy()),
          non_zero_mask=batch_ids[i].numpy()!=0,
          attributions=attributions[i]).properties(title='Sentence {}'.format(i))

In [23]:
def sum_interactions_over_tokens(interactions, tokens):
    start_index = 1
    end_index = tokens.index('.')
    tokens = tokens[start_index:end_index]

    interaction_array = []
    token_array = []
    zero_indices = []
    found_special_token = False
    for i in range(len(interactions)):
        token = tokens[i]
        interaction = interactions[i]
        if token.startswith('##'):
            interaction_array[-1] += interaction
            token_array[-1] += token[2:]
            zero_indices[-1] += [i]
            found_special_token = False
        elif found_special_token:
            interaction_array[-1] += interaction
            token_array[-1] += token
            zero_indices[-1] += [i]
            found_special_token = False
        elif token == '-' or token == "'":
            interaction_array[-1] += interaction
            token_array[-1] += token
            zero_indices[-1] += [i]
            found_special_token = True
        else:
            interaction_array.append(interaction)
            token_array.append(token)
            zero_indices.append([i])
    summed_interactions = np.stack(interaction_array, axis=0)
    summed_interactions = summed_interactions[:, start_index:end_index]
    summed_interactions = np.stack([np.sum(summed_interactions[:, zero_index], axis=1) for zero_index in zero_indices], axis=1)

    return summed_interactions, token_array

In [35]:
def interaction_plot(summed_interactions, summed_tokens):
    plot_array = []
    for j in range(len(summed_interactions)):
        zeroed_interactions = summed_interactions[j].copy()
        zeroed_interactions[j] = 0.0
        j_token = summed_tokens[j]

        plot = text_plot(tokens=summed_tokens,
                  attributions=zeroed_interactions,
                  strip_special=False,
                  include_legend=False,
                  include_grid=True).properties(title='Interactions with {}'.format(j_token))
        plot_array.append(plot)
    plot = reduce(lambda x, y: x & y, plot_array).configure_view(opacity=0.0).resolve_scale(color='independent')
    return plot

In [44]:
def get_interactions_for_index(i):
    start_index = 1
    batch_tokens = tokenizer.convert_ids_to_tokens(batch_ids[i].numpy())
    end_index = batch_tokens.index('.')
    interaction_array = []
    for index_to_explain in tqdm(range(start_index, end_index)):
        interactions = explainer.interactions(inputs=batch_ids[i:i+1],
                                              baseline=batch_baseline,
                                              batch_size=10,
                                              num_samples=1000,
                                              use_expectation=False,
                                              output_indices=1,
                                              verbose=False,
                                              interaction_index=int(index_to_explain))
        interaction_array.append(interactions)
    interactions = np.concatenate(interaction_array, axis=0)
    return interactions

## Attributions of Attributions

In [26]:
batch_baseline = np.zeros((1, 128))

In [None]:
interactions_array = []
for i in range(len(attributions)):
    try:
        interactions = np.load('interactions_{}.npy'.format(i))
    except FileNotFoundError as e:
        interactions = get_interactions_for_index(i)
        np.save('interactions_{}.npy'.format(i), interactions)
    interactions_array.append(interactions)

100%|██████████| 25/25 [33:03<00:00, 79.35s/it]
 44%|████▍     | 8/18 [10:33<13:11, 79.19s/it]

In [36]:
i = 1
batch_tokens = tokenizer.convert_ids_to_tokens(batch_ids[i].numpy())
summed_interactions, summed_tokens = sum_interactions_over_tokens(interactions[i], batch_tokens)
interaction_plot(summed_interactions, summed_tokens)

In [37]:
i = 3
try:
    interactions = np.load('interactions_{}.npy'.format(i))
except FileNotFoundError as e:
    interactions = get_interactions_for_index(i)
    np.save('interactions_{}.npy'.format(i), interactions)
batch_tokens = tokenizer.convert_ids_to_tokens(batch_ids[i].numpy())
summed_interactions, summed_tokens = sum_interactions_over_tokens(interactions, batch_tokens)
interaction_plot(summed_interactions, summed_tokens)

In [38]:
i = 5
try:
    interactions = np.load('interactions_{}.npy'.format(i))
except FileNotFoundError as e:
    interactions = get_interactions_for_index(i)
    np.save('interactions_{}.npy'.format(i), interactions)
batch_tokens = tokenizer.convert_ids_to_tokens(batch_ids[i].numpy())
summed_interactions, summed_tokens = sum_interactions_over_tokens(interactions, batch_tokens)
interaction_plot(summed_interactions, summed_tokens)

In [39]:
i = 7
try:
    interactions = np.load('interactions_{}.npy'.format(i))
except FileNotFoundError as e:
    interactions = get_interactions_for_index(i)
    np.save('interactions_{}.npy'.format(i), interactions)
batch_tokens = tokenizer.convert_ids_to_tokens(batch_ids[i].numpy())
summed_interactions, summed_tokens = sum_interactions_over_tokens(interactions, batch_tokens)
interaction_plot(summed_interactions, summed_tokens)

In [40]:
i = 10
try:
    interactions = np.load('interactions_{}.npy'.format(i))
except FileNotFoundError as e:
    interactions = get_interactions_for_index(i)
    np.save('interactions_{}.npy'.format(i), interactions)
batch_tokens = tokenizer.convert_ids_to_tokens(batch_ids[i].numpy())
summed_interactions, summed_tokens = sum_interactions_over_tokens(interactions, batch_tokens)
interaction_plot(summed_interactions, summed_tokens)