# Set-up<br>
This set-up assumes that the working directory (`os.curdir`) is where the notebook is.

In [125]:
import os
import sys
this_notebook_dir = os.curdir
# all imports will be relative to the root directory of the project
# we thus add it here:
project_root_dir = os.path.relpath(os.path.join('..', '..'), this_notebook_dir)
if project_root_dir not in sys.path:
    sys.path += [project_root_dir]

# Loading data and model

We will now load a dataset

In [126]:
from src.data.dataload import *
data = load_sst()
#data = load_agnews()
print(f'loaded dataset {data.NAME}')
train, dev, test = data.train_val_test
dev

loaded dataset sst


Unnamed: 0,sentence,label
0,It 's a lovely film with lovely performances b...,3
1,"No one goes unindicted here , which is probabl...",2
2,And if you 're not nearly moved to tears by a ...,3
3,"A warm , funny , engaging film .",4
4,Uses sharp humor and insight into human nature...,4
...,...,...
1096,it seems to me the film is about the art of ri...,1
1097,It 's just disappointingly superficial -- a mo...,1
1098,The title not only describes its main characte...,1
1099,Sometimes it feels as if it might have been ma...,2


Loading a model for the dataset

In [127]:
from src.models.bcnmodel import *
from src.models.bertmodel import *
#model = BertModel()
model = BCNModel(device=torch.device('cpu'))
print(f'expecting location for the model file at '
      f'"{model._get_model_filepath_for_dataset(data)}"')
model.load_model(data)
print(f'loaded model {type(model)} of kind {model.MODELTYPE} for {data.NAME}')

expecting location for the model file at "../../models/bcn-sst_output/model.tar.gz"
loaded model <class 'src.models.bcnmodel.BCNModel'> of kind allennlp for sst


# Perturbations

We will now load all the necessary tools to perturb data

In [128]:
from src.data.perturbations import add_perturbations
from src.data.perturbations import \
    remove_commas, \
    remove_all_punctuation, \
    switch_gender, \
    strip_trailing_punct, \
    add_typo, \
    change_first_name, \
    change_last_name, \
    change_location, \
    contraction, \
    swap_adjectives

perturbation_reduction = ['strip_punct', 'remove_commas', 'remove_all_punct']
perturbations_named = ['change_first_name','change_last_name', 'change_location']
perturbations_other = ['contraction', 'add_typo', 'switch_gender', 'swap_adj']
perturbations_list = perturbation_reduction + perturbations_named + ['contraction', 'add_typo', 'switch_gender', 'swap_adj']

perturbations_named_f = [change_first_name, change_last_name, change_location]
perturbations_unnamed_f = [remove_commas, remove_all_punctuation, switch_gender,
                           strip_trailing_punct, add_typo, contraction, swap_adjectives]
perturbations_all_f = perturbations_named_f + perturbations_unnamed_f

def bert_detokenize(token_list):
    """
    Reverse BERT tokenization accounting for ##s representing subwords. Based on https://github.com/huggingface/transformers/issues/36
    """
    text = ' '.join(token_list)
    text = text.replace(' ##', '')
    text = text.replace("` `", "``")
    text = text.replace("' '", "''")
    return text

def run_detokenizer_on_perturbations(df):
    for perturbation in perturbations_list:
        if f'{perturbation}_concat' not in df.columns:
            continue
        df[f'{perturbation}_concat'] = df[f'{perturbation}_tokens'].apply(bert_detokenize)
        df[f'{perturbation}_tokens'] = df[f'{perturbation}_concat'].apply(lambda x: model.tokenizer.tokenize(x))
    return df

# merge cased and uncased perturbations
def merge_perturbations(df, df_cased):
    for perturbation in perturbations_named:
        for suffix in ['tokens', 'pert_ind', 'success', 'concat']:
            df[f'{perturbation}_{suffix}'] = df_cased[f'{perturbation}_{suffix}'].tolist()
    return df

Now we are ready to generate a dataframe with the perturbations

In [129]:
inds = [49, 114, 363]
if type(model) == BCNModel:
    df_perturbations = add_perturbations(
        df=dev.iloc[inds],
        tokenizer=model.tokenizer,
        sentence_col_name=data.SENTENCE,
        perturbation_functions=perturbations_all_f,
    )
elif type(model) == BertModel:
    # uncased perturbations
    df_perturbations = add_perturbations(
        df=dev.iloc[inds],
        tokenizer=model.tokenizer,
        sentence_col_name=data.SENTENCE,
        perturbation_functions=perturbations_unnamed_f,
    )
    # with cased tokenizer
    df_perturbations_named = add_perturbations(
        df=dev.iloc[inds],
        tokenizer=BertTokenizer.from_pretrained('bert-base-cased'),
        sentence_col_name=data.SENTENCE,
        perturbation_functions=perturbations_named_f,
    )
    # merge the two together
    df_perturbations = merge_perturbations(df_perturbations, df_perturbations_named)
    # remove ## in the reconstructed sentences
    df_perturbations = run_detokenizer_on_perturbations(df_perturbations)
df_perturbations.head()

Unnamed: 0,sentence,label,tokens_orig,change_first_name_concat,change_first_name_tokens,change_first_name_success,change_first_name_pert_ind,change_last_name_concat,change_last_name_tokens,change_last_name_success,...,add_typo_success,add_typo_pert_ind,contraction_concat,contraction_tokens,contraction_success,contraction_pert_ind,swap_adj_concat,swap_adj_tokens,swap_adj_success,swap_adj_pert_ind
49,"A smart , witty follow-up .",4,"[A, smart, ,, witty, follow, -, up, .]","A smart , witty follow-up .","[A, smart, ,, witty, follow, -, up, .]",0,,"A smart , witty follow-up .","[A, smart, ,, witty, follow, -, up, .]",0,...,1,"[6, 7]","A smart , witty follow-up .","[A, smart, ,, witty, follow, -, up, .]",0,,"A smart , witty follow-up .","[A, smart, ,, witty, follow, -, up, .]",0,
114,A warm but realistic meditation on friendship ...,4,"[A, warm, but, realistic, meditation, on, frie...",A warm but realistic meditation on friendship ...,"[A, warm, but, realistic, meditation, on, frie...",0,,A warm but realistic meditation on friendship ...,"[A, warm, but, realistic, meditation, on, frie...",0,...,1,"[8, 9]",A warm but realistic meditation on friendship ...,"[A, warm, but, realistic, meditation, on, frie...",0,,A warm but realistic meditation on friendship ...,"[A, warm, but, realistic, meditation, on, frie...",0,
363,"A gorgeous , high-spirited musical from India ...",4,"[A, gorgeous, ,, high, -, spirited, musical, f...","A gorgeous , high-spirited musical from India ...","[A, gorgeous, ,, high, -, spirited, musical, f...",0,,"A gorgeous , high-spirited musical from India ...","[A, gorgeous, ,, high, -, spirited, musical, f...",0,...,1,"[9, 10]","A gorgeous , high-spirited musical from India ...","[A, gorgeous, ,, high, -, spirited, musical, f...",0,,"A gorgeous , high-spirited musical from India ...","[A, gorgeous, ,, high, -, spirited, musical, f...",0,


# Explainers and Evaluation

Explainers are run in an evaluation loop, where metrics are applied to each suitable perturbation of each sentence:

In [133]:
from src.explainers.explainers import *

import tqdm
import scipy as sp

df_perturbations.reset_index(drop=True, inplace=True)
preds = model.predict_label_batch(df_perturbations.sentence)


# helper function
def get_sorted_tokens(scores, tokens):
    N = len(scores)
    sorted_inds = sorted(range(N), key=lambda i: abs(scores[i]), reverse=True)
    return [str(tokens[i]).lower() for i in sorted_inds]

def metric_top_k(k, original_scores, original_tokens, perturbed_scores, perturbed_tokens, **kwargs):
    N = len(original_tokens)
    k = min(k, N)
    original_top_k = get_sorted_tokens(scores=original_scores, tokens=original_tokens)[:k]
    perturbed_top_k = get_sorted_tokens(scores=perturbed_scores, tokens=perturbed_tokens)[:k]
    num_common = 0
    for word in set([str(t) for t in original_top_k]):
        num_common += min(original_top_k.count(word), perturbed_top_k.count(word))
    print(original_top_k)
    print(perturbed_top_k)
    return num_common / k

def metric_spearman(original_scores, perturbed_scores, **kwargs):
    return sp.stats.spearmanr(original_scores, perturbed_scores).correlation


def run_evaluation(explainer, show_f, metric_proxy):
    for i in tqdm.trange(len(df_perturbations)):
        prediction = preds[i]
        original_sentence = df_perturbations.sentence[i]
        original_tokens = model.tokenizer.tokenize(original_sentence)
        original_explanation = None
        for p in perturbations_list:
            perturbed_sentence = df_perturbations[f'{p}_concat'][i]
            perturbed_tokens = df_perturbations[f'{p}_tokens'][i]
            # perturbation wasn't successful:
            if df_perturbations[f'{p}_success'][i] == 0:
                # print('skipped:', p, 'unsuccessful')
                continue
            # different numbers of tokens
            elif len(original_tokens) != len(perturbed_tokens):
                # print('skipped:', p, 'length changed')
                continue
            # or didn't reach the same prediction
            elif model.predict_label(perturbed_sentence) != preds[i]:
                # print('skipped', p, 'prediction changed')
                continue
            # we will run the prediction on the perturbed sentence now
            if original_explanation is None:
                original_explanation = explainer.explain_instances([original_sentence])
                original_tokens, original_scores = metric_proxy(sentence=original_sentence,
                                                                explanation=original_explanation,
                                                                prediction=prediction)
                original = dict(
                    original_sentence=original_sentence,
                    original_tokens=original_tokens,
                    original_explanation=original_explanation,
                    original_scores=original_scores,
                    prediction=preds[i],
                )
            perturbed_explanation = explainer.explain_instances([perturbed_sentence])
            perturbed_tokens, perturbed_scores = metric_proxy(sentence=perturbed_sentence,
                                                              explanation=perturbed_explanation,
                                                              prediction=prediction)
            perturbed = dict(
                perturbed_sentence=perturbed_sentence,
                perturbed_tokens=perturbed_tokens,
                perturbed_explanation=perturbed_explanation,
                perturbed_scores=perturbed_scores,
            )
            if len(perturbed_tokens) != len(original_tokens):
                # print('skipped:', p, 'length changed')
                continue
            #print('running explainer with ', p)
            print('-'*100)
            # show metrics
            K = 5
            print('PERTURBATION', p)
            print(f'TOP-{K} score:', metric_top_k(k=K, **original, **perturbed))
            print('SPEARMAN:', metric_spearman(original_scores=original_scores, perturbed_scores=perturbed_scores))
            # show perturbed explanation
            print('ORIGINAL:', original_sentence)
            show_f(sentence=original_sentence,
                   explanation=original_explanation,
                   **original, **perturbed)
            print('PERTURBED:', perturbed_sentence)
            show_f(sentence=perturbed_sentence,
                   explanation=perturbed_explanation,
                   **original, **perturbed)
            print('-'*100)
        original_explanation = None

### LIME explainer

Construct LIME explainer:

In [134]:
lime_explainer = LimeExplainer(model, num_samples=1000)
print(f'using explainer {type(lime_explainer)} with model {lime_explainer.model} and dataset {lime_explainer.model.dataset_finetune.NAME}')

using explainer <class 'src.explainers.explainers.LimeExplainer'> with model <src.models.bcnmodel.BCNModel object at 0x7c7f38bb62b0> and dataset sst


Analyse:

In [135]:
def metric_proxy_lime(sentence, explanation, **kwargs):
    scores, pred, tokens, inds = explanation
    inds = np.argsort(inds[0])
    tokens = [tokens[0][i] for i in inds]
    scores = [scores[0][i] for i in inds]
    return tokens, scores

def show_lime_sentence(sentence, explanation, **kwargs):
    tokens, scores = metric_proxy_lime(sentence=sentence, explanation=explanation, **kwargs)
    print('tokens', tokens)
    print('scores', ['%.3f' % s for s in scores])

run_evaluation(explainer=lime_explainer,
               show_f=show_lime_sentence,
               metric_proxy=metric_proxy_lime)

 33%|███▎      | 1/3 [00:06<00:13,  6.61s/it]

----------------------------------------------------------------------------------------------------
PERTURBATION strip_punct
['witty', 'smart', 'follow', 'up', 'a']
['witty', 'smart', 'a', 'follow', 'up']
TOP-5 score: 1.0
SPEARMAN: 0.6
ORIGINAL: A smart , witty follow-up .
tokens ['A', 'smart', 'witty', 'follow', 'up']
scores ['0.026', '0.185', '0.212', '0.033', '0.028']
PERTURBED: A smart , witty follow - up 
tokens ['A', 'smart', 'witty', 'follow', 'up']
scores ['0.037', '0.188', '0.221', '-0.025', '-0.009']
----------------------------------------------------------------------------------------------------


 67%|██████▋   | 2/3 [00:13<00:06,  6.61s/it]

----------------------------------------------------------------------------------------------------
PERTURBATION strip_punct
['but', 'meditation', 'realistic', 'and', 'warm']
['but', 'meditation', 'realistic', 'warm', 'friendship']
TOP-5 score: 0.8
SPEARMAN: 0.9878787878787878
ORIGINAL: A warm but realistic meditation on friendship , family and affection .
tokens ['A', 'warm', 'but', 'realistic', 'meditation', 'on', 'friendship', 'family', 'and', 'affection']
scores ['-0.021', '-0.058', '0.132', '-0.065', '0.120', '-0.010', '0.050', '-0.023', '-0.061', '-0.027']
PERTURBED: A warm but realistic meditation on friendship , family and affection 
tokens ['A', 'warm', 'but', 'realistic', 'meditation', 'on', 'friendship', 'family', 'and', 'affection']
scores ['-0.020', '-0.060', '0.146', '-0.070', '0.132', '-0.013', '0.058', '-0.027', '-0.058', '-0.032']
----------------------------------------------------------------------------------------------------


100%|██████████| 3/3 [00:22<00:00,  7.44s/it]

----------------------------------------------------------------------------------------------------
PERTURBATION strip_punct
['gorgeous', 'exquisitely', 'blends', 'high', 'spirited']
['gorgeous', 'spirited', 'exquisitely', 'blends', 'a']
TOP-5 score: 0.8
SPEARMAN: 0.8749999999999999
ORIGINAL: A gorgeous , high-spirited musical from India that exquisitely blends music , dance , song , and high drama .
tokens ['A', 'gorgeous', 'high', 'spirited', 'musical', 'from', 'India', 'that', 'exquisitely', 'blends', 'music', 'dance', 'song', 'and', 'drama']
scores ['0.012', '0.083', '0.057', '0.045', '0.003', '-0.025', '-0.028', '-0.011', '0.070', '0.062', '0.007', '-0.015', '-0.004', '0.009', '0.009']
PERTURBED: A gorgeous , high - spirited musical from India that exquisitely blends music , dance , song , and high drama 
tokens ['A', 'gorgeous', 'high', 'spirited', 'musical', 'from', 'India', 'that', 'exquisitely', 'blends', 'music', 'dance', 'song', 'and', 'drama']
scores ['0.016', '0.061', '-0




### Alternative explainer

This depends on a model, so we will first prepare show_f and metric_proxy for each explainer:

**AllenNLP explainer helpers**

In [136]:
import html
from IPython.core.display import display, HTML

def visualise_weights(tokens, gradients, max_alpha=.3):
    max_alpha = max_alpha 
    highlighted_text = []
    for i in range(len(tokens)):
        weight = gradients[i]
        highlighted_text.append('<span style="background-color:rgba(135,206,250,' + str(weight / max_alpha) + ');">' + html.escape(tokens[i]) + '</span>')
    highlighted_text = ' '.join(highlighted_text)
    print(display(HTML(highlighted_text)))

def metric_proxy_allennlp(sentence, explanation, **kwargs):
    grads, labels = explanation
    tokenized = model.tokenizer.tokenize(sentence)
    return tokenized, grads[0]

def show_allennlp_sentence(sentence, explanation, **kwargs):
    tokens, scores = metric_proxy_allennlp(sentence=sentence, explanation=explanation, **kwargs)
    visualise_weights([str(t) for t in tokens], scores, max_alpha=.5)

**SHAP explainer helpers**

In [137]:
def metric_proxy_shap(sentence, explanation, prediction, **kwargs):
    tokens = explanation.data[0, 1:-1]
    scores = explanation.values[0][1:-1]
    return tokens, scores[:, prediction]

def show_shap_sentence(sentence, explanation, **kwargs):
    tokens, scores = metric_proxy_shap(sentence=sentence, explanation=explanation, **kwargs)
    print('tokens:', tokens)
    print('scores:', ['%.3f' % s for s in scores])

We will now run evaluation with the alternative explainer:

In [138]:
alternative_explainer = None
metric_proxy_func = None
show_func = None
if type(model) == BCNModel:
    alternative_explainer = AllenNLPExplainer(model)
    show_func = show_allennlp_sentence
    metric_proxy_func = metric_proxy_allennlp
elif type(model) == BertModel:
    alternative_explainer = SHAPExplainer(model)
    show_func = show_shap_sentence
    metric_proxy_func = metric_proxy_shap

print('alternative explainer', type(alternative_explainer))
run_evaluation(explainer=alternative_explainer,
               show_f=show_func,
               metric_proxy=metric_proxy_func)

  0%|          | 0/3 [00:00<?, ?it/s]

alternative explainer <class 'src.explainers.explainers.AllenNLPExplainer'>
----------------------------------------------------------------------------------------------------
PERTURBATION add_typo
['follow', 'smart', 'up', ',', 'a']
['follow', 'p.', 'smart', 'u', ',']
TOP-5 score: 0.6
SPEARMAN: 0.6428571428571429
ORIGINAL: A smart , witty follow-up .


None
PERTURBED: A smart , witty follow - u p.


 33%|███▎      | 1/3 [00:00<00:01,  1.76it/s]

None
----------------------------------------------------------------------------------------------------
----------------------------------------------------------------------------------------------------
PERTURBATION add_typo
['but', 'meditation', 'realistic', 'and', 'friendship']
['but', 'meditation', 'realistic', 'friendship', 'affection']
TOP-5 score: 0.8
SPEARMAN: 0.7552447552447553
ORIGINAL: A warm but realistic meditation on friendship , family and affection .


None
PERTURBED: A warm but realistic meditation on friendship , famil yand affection .


 67%|██████▋   | 2/3 [00:01<00:00,  1.81it/s]

None
----------------------------------------------------------------------------------------------------
----------------------------------------------------------------------------------------------------
PERTURBATION change_location
['-', 'spirited', 'india', 'from', 'gorgeous']
['angola', 'from', '-', 'and', 'high']
TOP-5 score: 0.4
SPEARMAN: 0.5347261434217957
ORIGINAL: A gorgeous , high-spirited musical from India that exquisitely blends music , dance , song , and high drama .


None
PERTURBED: A gorgeous , high-spirited musical from Angola that exquisitely blends music , dance , song , and high drama .


None
----------------------------------------------------------------------------------------------------
----------------------------------------------------------------------------------------------------
PERTURBATION add_typo
['-', 'spirited', 'india', 'from', 'gorgeous']
['-', 'texquisitely', 'blends', 'spirited', 'tha']
TOP-5 score: 0.4
SPEARMAN: 0.3890457368718239
ORIGINAL: A gorgeous , high-spirited musical from India that exquisitely blends music , dance , song , and high drama .


None
PERTURBED: A gorgeous , high - spirited musical from India tha texquisitely blends music , dance , song , and high drama .


100%|██████████| 3/3 [00:01<00:00,  1.52it/s]

None
----------------------------------------------------------------------------------------------------


100%|██████████| 3/3 [00:01<00:00,  1.51it/s]
