<a href="https://colab.research.google.com/github/Warra07/ABSADatasets/blob/master/Copie_de_DeepLift_BERT.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# DeepLift BERT

*Issue: DeepLift delta magnitude becomes large when softmax explicitly initialised in huggingface/transformers BERT model. Raised as https://github.com/pytorch/captum/issues/519*

Using captum's DeepLift implementation to explain predictions from a huggingface/transformers BERT model.

As per discussion on captum issue 347 (https://github.com/pytorch/captum/issues/347#issuecomment-616864035) softmax non-linearity needs to be explicitly initialised in model's `__init__`. 

I have forked the transformers repository and made this change in https://github.com/lannelin/transformers/commit/1fd1e4a59628a731b24eb3514ef586dc0b075b5f

This change results in a very large delta. Increasing in magnitude, on the same text, from 1.9306 to 12386754.

The attributions from DeepLift (contrasted to IntegratedGradients) are also misleading.


----

edit: 23rd nov
update transformers fork to explicity instantiate `torch.nn.GELU` (https://github.com/lannelin/transformers/commit/c731d9b621fc349513b447d564ef2972cf683242)

use captum fork with `torch.nn.GELU` added to DeepLift as supported non-linearity (https://github.com/lannelin/captum/commit/4dd7eae56e507af26e956d43bf7989e176d6dbe9)

note: colab installation seems to error but package seems to be installed? verified change is present to deeplift at end of notebook

In [None]:
# install transformers from fork. deeplift_bert branch contains change to
# explicitly initialise softmax as per advice
# https://github.com/pytorch/captum/issues/347#issuecomment-616864035
! pip install git+https://github.com/lannelin/transformers.git@deeplift_bert

# install captum from fork. deeplft_gelu branch contains change to
# add nn.GELU  as supporrted nonlinearity
! pip install git+https://github.com/lannelin/captum.git@deeplift_gelu

Collecting git+https://github.com/lannelin/transformers.git@deeplift_bert
  Cloning https://github.com/lannelin/transformers.git (to revision deeplift_bert) to /tmp/pip-req-build-z9h23k5y
  Running command git clone -q https://github.com/lannelin/transformers.git /tmp/pip-req-build-z9h23k5y
  Running command git checkout -b deeplift_bert --track origin/deeplift_bert
  Switched to a new branch 'deeplift_bert'
  Branch 'deeplift_bert' set up to track remote branch 'deeplift_bert' from 'origin'.
  Installing build dependencies ... [?25l[?25hdone
  Getting requirements to build wheel ... [?25l[?25hdone
    Preparing wheel metadata ... [?25l[?25hdone
Collecting sentencepiece==0.1.91
[?25l  Downloading https://files.pythonhosted.org/packages/d4/a4/d0a884c4300004a78cca907a6ff9a5e9fe4f090f5d95ab341c53d28cbc58/sentencepiece-0.1.91-cp36-cp36m-manylinux1_x86_64.whl (1.1MB)
[K     |████████████████████████████████| 1.1MB 6.4MB/s 
Collecting sacremoses
[?25l  Downloading https://files.pyth

In [None]:
from contextlib import contextmanager

from captum.attr import configure_interpretable_embedding_layer, remove_interpretable_embedding_layer
from captum.attr import IntegratedGradients, DeepLift
import numpy as np
import torch
from transformers import BertTokenizer, BertForSequenceClassification

In [None]:
  # use publically available pretrained sequence classification model
PRETRAINED = "textattack/bert-base-uncased-imdb"

DEVICE = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

# obviously positive text to classify and explain predictions on
TEXT = "I went to see this movie last night. I've got to say... it's great! This is an amazing movie!"

TARGET = 1 # positive

## Helpers

In [None]:
# taken from https://captum.ai/tutorials/Bert_SQUAD_Interpret
# modified for sequence classification

def construct_input_ref_pair(text, ref_token_id, sep_token_id, cls_token_id):
    text_ids = tokenizer.encode(text, add_special_tokens=False)

    # construct input token ids
    input_ids = [cls_token_id] + text_ids + [sep_token_id]

    # construct reference token ids
    ref_input_ids = [cls_token_id] + [ref_token_id] * len(text_ids) + [sep_token_id]

    return torch.tensor([input_ids], device=DEVICE), torch.tensor([ref_input_ids], device=DEVICE)


def construct_input_ref_token_type_pair(input_ids):
    token_type_ids = torch.zeros_like(input_ids, device=DEVICE)
    ref_token_type_ids = token_type_ids.clone()
    return token_type_ids, ref_token_type_ids


def construct_input_ref_pos_id_pair(input_ids):
    seq_length = input_ids.size(1)
    position_ids = torch.arange(seq_length, dtype=torch.long, device=DEVICE)
    # we could potentially also use random permutation with `torch.randperm(seq_length, device=device)`
    ref_position_ids = torch.zeros(seq_length, dtype=torch.long, device=DEVICE)

    position_ids = position_ids.unsqueeze(0).expand_as(input_ids)
    ref_position_ids = ref_position_ids.unsqueeze(0).expand_as(input_ids)
    return position_ids, ref_position_ids


def construct_attention_mask(input_ids):
    return torch.ones_like(input_ids)


def construct_whole_bert_embeddings(input_ids, ref_input_ids, token_type_ids=None, ref_token_type_ids=None,
                                    position_ids=None, ref_position_ids=None):
    input_embeddings = interpretable_embedding.indices_to_embeddings(input_ids, token_type_ids=token_type_ids,
                                                                     position_ids=position_ids)
    ref_input_embeddings = interpretable_embedding.indices_to_embeddings(ref_input_ids, token_type_ids=token_type_ids,
                                                                         position_ids=position_ids)

    return input_embeddings, ref_input_embeddings


def summarize_attributions(attributions):
    attributions = attributions.sum(dim=-1).squeeze(0)
    attributions = attributions / torch.norm(attributions)
    return attributions

In [None]:
# need to take the first output of the tuple and return a tensor so we're compatible with captum here
# deeplift needs the model itself (not just forward func so can't do same as SQUAD demo)
class ModelWithUnpackedForward(BertForSequenceClassification):
    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)

    def forward(self, *args, **kwargs):
        out = super().forward(*args, **kwargs)
        out = out[0]  # no longer a tuple
        return out

    def orig_forward(self, *args, **kwargs):
        return super().forward(*args, **kwargs)


In [None]:
# context manager to ensure removal of embedding layer (useful if errors)
@contextmanager
def managed_interpretable_embedding_layer(model: torch.nn.Module, embedding_layer_name: str = "embedding"):
    interpretable_embedding = configure_interpretable_embedding_layer(model=model,
                                                                      embedding_layer_name=embedding_layer_name)
    try:
        yield interpretable_embedding
    finally:
        remove_interpretable_embedding_layer(model=model, interpretable_emb=interpretable_embedding)

## Load Model and Inputs

In [None]:
# load model (wrapped for compatibility with captum)
model_wrapper = ModelWithUnpackedForward.from_pretrained(PRETRAINED)

# load tokenizer
tokenizer = BertTokenizer.from_pretrained(PRETRAINED)


# construct inputs
input_ids, ref_input_ids = construct_input_ref_pair(TEXT,
                                                    ref_token_id=tokenizer.pad_token_id,
                                                    sep_token_id=tokenizer.sep_token_id,
                                                    cls_token_id=tokenizer.cls_token_id)
token_type_ids, ref_token_type_ids = construct_input_ref_token_type_pair(input_ids)
position_ids, ref_position_ids = construct_input_ref_pos_id_pair(input_ids)
attention_mask = construct_attention_mask(input_ids)

# get embeddings
input_embeddings = model_wrapper.base_model.embeddings(input_ids, token_type_ids=token_type_ids,
                                                       position_ids=position_ids)
ref_input_embeddings = model_wrapper.base_model.embeddings(ref_input_ids,
                                                           token_type_ids=ref_token_type_ids,
                                                           position_ids=ref_position_ids)


# for display
indices = input_ids[0].detach().tolist()
all_tokens = tokenizer.convert_ids_to_tokens(indices)

HBox(children=(FloatProgress(value=0.0, description='Downloading', max=511.0, style=ProgressStyle(description_…




HBox(children=(FloatProgress(value=0.0, description='Downloading', max=437985387.0, style=ProgressStyle(descri…




Using explicity instantiated GELU
Using explicity instantiated GELU
Using explicity instantiated GELU
Using explicity instantiated GELU
Using explicity instantiated GELU
Using explicity instantiated GELU
Using explicity instantiated GELU
Using explicity instantiated GELU
Using explicity instantiated GELU
Using explicity instantiated GELU
Using explicity instantiated GELU
Using explicity instantiated GELU


HBox(children=(FloatProgress(value=0.0, description='Downloading', max=231508.0, style=ProgressStyle(descripti…




HBox(children=(FloatProgress(value=0.0, description='Downloading', max=112.0, style=ProgressStyle(description_…




HBox(children=(FloatProgress(value=0.0, description='Downloading', max=48.0, style=ProgressStyle(description_w…




## Sanity Checks

- check model output for text
- check model output for reference
- check attention output for text
- Integrated Gradients


model output for text should be much more positive than output for ref

expect attention and IG to skew towards `["great", "amazing"]`

In [None]:
results, attentions = model_wrapper.orig_forward(input_ids=input_ids,
                                                 attention_mask=attention_mask,
                                                 token_type_ids=token_type_ids,
                                                 position_ids=position_ids,
                                                 output_attentions=True)

results = torch.squeeze(results[0])
print("Model output for text:", results)

Model output for text: tensor([-1.4382,  2.4062], grad_fn=<SqueezeBackward0>)


In [None]:
ref_results = model_wrapper.forward(input_ids=ref_input_ids,
                                    attention_mask=attention_mask,
                                    token_type_ids=ref_token_type_ids,
                                    position_ids=ref_position_ids,
                                    output_attentions=False)
ref_results = torch.squeeze(ref_results)
print("Model output for ref:", ref_results)

Model output for ref: tensor([-0.9225,  0.7754], grad_fn=<SqueezeBackward0>)


This is more positive than expected but still much lower than the original text scored

In [None]:
def summarize_attention(attention, layers):
    assert attention.ndim == 5

    if layers is not None:
        # filter to layers
        attention = attention[layers]

    aggd_attention = attention.mean(dim=(0, 1, 2, 3))
    return aggd_attention / torch.norm(aggd_attention)


attn = summarize_attention(torch.stack(attentions), layers=range(8, 12)).detach().numpy()

# ordered words with attention values

ordered_attn = [(all_tokens[i], attn[i]) for i in np.argsort(attn)][::-1]
for item in ordered_attn:
  print(item)

('[SEP]', 0.8570724)
("'", 0.45161295)
('great', 0.122145966)
('.', 0.08723522)
('this', 0.06891007)
('[CLS]', 0.063985914)
('this', 0.05368692)
('amazing', 0.05098502)
('.', 0.047463313)
('!', 0.04521247)
('say', 0.04496376)
('went', 0.04355322)
('last', 0.041117698)
('i', 0.037362393)
('!', 0.036605727)
('it', 0.036225673)
('see', 0.03572213)
('night', 0.035285752)
('s', 0.033716016)
('an', 0.031193623)
('is', 0.030344974)
('movie', 0.03017318)
("'", 0.028420145)
('movie', 0.02681826)
('.', 0.024656259)
('.', 0.023904601)
('got', 0.02022178)
('ve', 0.019560454)
('to', 0.01792848)
('i', 0.014291423)
('to', 0.013351872)


"Amazing" not particularly high scoring but overall this output seems sane

Use Integrated Gradients to generate attributions

In [None]:
ig = IntegratedGradients(model_wrapper)

with managed_interpretable_embedding_layer(model_wrapper.base_model, 'embeddings') as interpretable_embedding:

  ig_attributions, ig_delta = ig.attribute(inputs=input_embeddings,
                                          baselines=ref_input_embeddings,
                                          additional_forward_args=(
                                              attention_mask, token_type_ids, position_ids),
                                          return_convergence_delta=True,
                                          target=TARGET,
                                          n_steps=100
                                          )

ig_summarized_attributions = summarize_attributions(attributions=ig_attributions).detach().numpy()
print("ig delta:", ig_delta)
print()

ordered_ig = [(all_tokens[i], ig_summarized_attributions[i]) for i in np.argsort(ig_summarized_attributions)][::-1]
for item in ordered_ig:
  print(item)

  "In order to make embedding layers more interpretable they will "


ig delta: tensor([-0.1639], dtype=torch.float64)

('amazing', 0.5622139748741481)
('great', 0.30483500764809546)
('it', 0.25907722666045846)
('last', 0.2585275642131233)
('see', 0.2230709011144637)
('night', 0.21436294495263788)
('this', 0.1574227807921535)
('s', 0.14397478843900857)
('.', 0.11823730135804614)
("'", 0.10387450558353814)
('.', 0.10232595039317278)
('got', 0.10120857263529436)
('i', 0.07585758998569704)
('this', 0.06509549674756436)
('is', 0.050789809806672984)
('!', 0.0186688158604917)
('.', 0.014072702069863587)
('to', 0.012939010866128678)
('[CLS]', 0.0)
('ve', -0.007128044029508743)
("'", -0.00940010498423295)
('an', -0.041612562201641834)
('.', -0.05129002584882091)
('to', -0.06675189200845233)
('!', -0.06876116121220933)
('[SEP]', -0.07518066348057816)
('say', -0.1313145861355041)
('i', -0.16918752714294683)
('movie', -0.22656130530256544)
('movie', -0.23318458257571606)
('went', -0.2934810849502168)


IG working well!

## DeepLift



In [None]:
dl = DeepLift(model_wrapper)

with managed_interpretable_embedding_layer(model_wrapper.base_model, 'embeddings') as interpretable_embedding:

  dl_attributions, dl_delta = dl.attribute(inputs=input_embeddings,
                                          baselines=ref_input_embeddings,
                                          additional_forward_args=(
                                              attention_mask, token_type_ids, position_ids),
                                          return_convergence_delta=True,
                                          target=TARGET)


dl_summarized_attributions = summarize_attributions(attributions=dl_attributions).detach().numpy()

print("dl delta:", dl_delta)
print()
# no modification: dl delta: tensor([-1.9306])

ordered_dl = [(all_tokens[i], dl_summarized_attributions[i]) for i in np.argsort(dl_summarized_attributions)][::-1]

for item in ordered_dl:
  print(item)

  "In order to make embedding layers more interpretable they will "
  if input.grad is not None:
               activations. The hooks and attributes will be removed
            after the attribution is finished
  after the attribution is finished"""
  if input.grad is not None:


dl delta: tensor([-753467.2500])

('say', 0.43542343)
('.', 0.29138362)
('last', 0.2016859)
('to', 0.11776617)
('amazing', 0.10944588)
('.', 0.045121446)
('movie', 0.017274547)
('.', 0.0061797984)
('!', 0.0045581986)
('s', 0.0013995223)
('[CLS]', 0.0)
('.', -0.0022967476)
('got', -0.0023316701)
('movie', -0.004586594)
('ve', -0.016162675)
("'", -0.025916887)
('to', -0.026769772)
('i', -0.03517566)
('is', -0.041742746)
("'", -0.041766968)
('great', -0.052930668)
('this', -0.055721063)
('see', -0.08397544)
('night', -0.11885013)
('this', -0.12899446)
('an', -0.14705154)
('[SEP]', -0.18361235)
('went', -0.21900855)
('!', -0.22808671)
('i', -0.23096457)
('it', -0.63070023)


Very high delta!

Delta of -1.9306 can be achieved without softmax change (can be verified by installing transformers repo at master branch).

We also see some unexpected, strongly negative attributions

<hr/>
quick check to ensure that the change to captum was correctly installed after all. We should see `nn.GELU` in the `SUPPORTED_NON_LINEAR` dict keys.

In [None]:
import captum
captum.__path__

['/usr/local/lib/python3.6/dist-packages/captum']

In [None]:
! tail -n 20 /usr/local/lib/python3.6/dist-packages/captum/attr/_core/deep_lift.py

    output, output_ref = outputs.chunk(2)
    delta_in = input - input_ref
    delta_out = output - output_ref

    return torch.cat(2 * [delta_in]), torch.cat(2 * [delta_out])


SUPPORTED_NON_LINEAR = {
    nn.ReLU: nonlinear,
    nn.ELU: nonlinear,
    nn.LeakyReLU: nonlinear,
    nn.GELU: nonlinear,
    nn.Sigmoid: nonlinear,
    nn.Tanh: nonlinear,
    nn.Softplus: nonlinear,
    nn.MaxPool1d: maxpool1d,
    nn.MaxPool2d: maxpool2d,
    nn.MaxPool3d: maxpool3d,
    nn.Softmax: softmax,
}
