-
Notifications
You must be signed in to change notification settings - Fork 476
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Captum for BERT #150
Comments
@felicitywang, thank you for the question. This is something that has high priority on the list. Yes, it will work in combination with downstream tasks. I have to look closer into this but you we will need to compute the gradients of any output that we choose with respect to those pre-trained embedding vectors. I'll hopefully have a tutorial out for this soon. We have another unmerged totorial on seqtoseq: |
Hi, import torch
import torch.nn as nn
from transformers import BertTokenizer
from transformers import BertForSequenceClassification, BertConfig
from captum.attr import IntegratedGradients
from captum.attr import InterpretableEmbeddingBase, TokenReferenceBase
from captum.attr import visualization
from captum.attr import configure_interpretable_embedding_layer, remove_interpretable_embedding_layer
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
model = BertForSequenceClassification.from_pretrained('bert-base-uncased')
# We need to split forward pass into two part:
# 1) embeddings computation
# 2) classification
def compute_bert_outputs(model_bert, embedding_output, attention_mask=None, head_mask=None):
if attention_mask is None:
attention_mask = torch.ones(embedding_output.shape[0], embedding_output.shape[1]).to(embedding_output)
extended_attention_mask = attention_mask.unsqueeze(1).unsqueeze(2)
extended_attention_mask = extended_attention_mask.to(dtype=next(model_bert.parameters()).dtype) # fp16 compatibility
extended_attention_mask = (1.0 - extended_attention_mask) * -10000.0
if head_mask is not None:
if head_mask.dim() == 1:
head_mask = head_mask.unsqueeze(0).unsqueeze(0).unsqueeze(-1).unsqueeze(-1)
head_mask = head_mask.expand(model_bert.config.num_hidden_layers, -1, -1, -1, -1)
elif head_mask.dim() == 2:
head_mask = head_mask.unsqueeze(1).unsqueeze(-1).unsqueeze(-1) # We can specify head_mask for each layer
head_mask = head_mask.to(dtype=next(model_bert.parameters()).dtype) # switch to fload if need + fp16 compatibility
else:
head_mask = [None] * model_bert.config.num_hidden_layers
encoder_outputs = model_bert.encoder(embedding_output,
extended_attention_mask,
head_mask=head_mask)
sequence_output = encoder_outputs[0]
pooled_output = model_bert.pooler(sequence_output)
outputs = (sequence_output, pooled_output,) + encoder_outputs[1:] # add hidden_states and attentions if they are here
return outputs # sequence_output, pooled_output, (hidden_states), (attentions)
class BertModelWrapper(nn.Module):
def __init__(self, model):
super(BertModelWrapper, self).__init__()
self.model = model
def forward(self, embeddings):
outputs = compute_bert_outputs(self.model.bert, embeddings)
pooled_output = outputs[1]
pooled_output = self.model.dropout(pooled_output)
logits = self.model.classifier(pooled_output)
return torch.softmax(logits, dim=1)[:, 1].unsqueeze(1)
bert_model_wrapper = BertModelWrapper(model)
ig = IntegratedGradients(bert_model_wrapper)
# accumalate couple samples in this array for visualization purposes
vis_data_records_ig = []
def interpret_sentence(model_wrapper, sentence, label=1):
model_wrapper.eval()
model_wrapper.zero_grad()
input_ids = torch.tensor([tokenizer.encode(sentence, add_special_tokens=True)])
input_embedding = model_wrapper.model.bert.embeddings(input_ids)
# predict
pred = model_wrapper(input_embedding).item()
pred_ind = round(pred)
# compute attributions and approximation delta using integrated gradients
attributions_ig, delta = ig.attribute(input_embedding, n_steps=500, return_convergence_delta=True)
print('pred: ', pred_ind, '(', '%.2f' % pred, ')', ', delta: ', abs(delta))
tokens = tokenizer.convert_ids_to_tokens(input_ids[0].numpy().tolist())
add_attributions_to_visualizer(attributions_ig, tokens, pred, pred_ind, label, delta, vis_data_records_ig)
def add_attributions_to_visualizer(attributions, tokens, pred, pred_ind, label, delta, vis_data_records):
attributions = attributions.sum(dim=2).squeeze(0)
attributions = attributions / torch.norm(attributions)
attributions = attributions.detach().numpy()
# storing couple samples in an array for visualization purposes
vis_data_records.append(visualization.VisualizationDataRecord(
attributions,
pred,
pred_ind,
label,
"label",
attributions.sum(),
tokens[:len(attributions)],
delta))
interpret_sentence(bert_model_wrapper, sentence="text to classify", label=0)
visualization.visualize_text(vis_data_records_ig) @NarineK it would be helpful if you could comment out whether the code is correct about association between tokens and attributions. Thanks :) HTH |
Looks great, @vfdev-5 ! Thank you!
|
Thanks @vfdev-5 for the example and @NarineK for the feedback. A quick question for the first point in @NarineK 's post: in the tutorial, the reference is using the padding token. However, what would the reference token be in BERT related models? Do you mind to provide some instructions on how we could construct baseline reference using BERT? Thanks! |
@mralexisw a reference can be probably defined as input_ids = tokenizer.encode(sentence, add_special_tokens=True)
t_input_ids = torch.tensor([input_ids, ]).to(device)
t_ref_input_ids = t_input_ids.clone()
t_ref_input_ids[0, 1:-1] = 0 which should be something like
|
Thank you @vfdev-5 ! Yes, that right. @mralexisw, we can choose reference/baseline tokens, for example |
Nice! Sorry for the delay. I'll have clean tutorials by the latest next week. |
@NarineK Thanks! If there is no reference embedding, would that be equivalent to a vanilla gradient? |
@mralexisw , it will still compute the integral of gradients along the path from 0 to given input but the attribution might be a little off. It is know that the attribution depends on the choice of baseline and the carefully we chose it, the better results we get. In the case of saliency, it is taking the gradient for given input point. It won't be the same but you can easily compare by calling: Also, in the example above are the weights for the |
@NarineK Thanks for pointing out the @vfdev-5 should have a better idea on |
@vfdev-5, @mralexisw, @felicitywang, we've published a tutorial here: https://github.com/pytorch/captum/blob/master/tutorials/Bert_SQUAD_Interpret.ipynb To be more flexible on working with multiple sub-embedding layers and to be able to interpret all of them simultaneously I still pre-compute embedding layers here similar to previous tutorials but we'll ultimately also have a version that doesn't require to do that and allows to attribute to |
@NarineK Thanks for the tutorial! Very easy to follow. I was able to easily replicate the process for a classification task using BERT (BertForSequenceClassification) with a few minor changes. One small issue I ran into is with the forward function of InterpretableEmbeddingBase (currently using install from master). I believe it's because BertModel passes in all arguments as keyword arguments for embeddings (see here) whereas InterpretableEmbeddingBase's forward function expects one positional argument. I'm sure there's a cleaner solution, but for now I had to change the function as below to get it working. def forward(self, *inputs, **kwargs):
"""
The forward function of a wrapper embedding layer that takes and returns
embedding layer. It allows embeddings to be created outside of the model
and passes them seamlessly to the preceding layers of the model.
Args:
input (tensor): Embedding tensor generated using the `self.embedding`
layer using `other_inputs` and `kwargs` of necessary.
*other_inputs (Any, optional): A sequence of additional inputs that the
forward function takes. Since forward functions can take any
type and number of arguments, this will ensure that we can
execute the forward pass using interpretable embedding layer
**kwargs (Any, optional): Similar to `other_inputs` we want to make sure
that our forward pass supports arbitrary number and type of
key-value arguments
Returns:
tensor:
Returns output tensor which is the same as input tensor.
It passes embedding tensors to lower layers without any
modifications.
"""
return kwargs["inputs_embeds"] Separately, an issue I keep running into is GPU memory usage. I'm on a single Nvidia Tesla V100 (16GB) which has no problems finetuning the model (using a maximum sequence length of 128 and batch size of 32) and similarly has no issues with inference. For integrated gradients, once I pass in a larger text sample (e.g., 30 - 40 tokens), I immediately run into memory issues. The same happens if I run it on shorter text multiple times. Do you have any pointers to what's driving this increased memory usage and ideas on how to optimize? I'm hoping to run it on the entire training set, which seems infeasible right now. |
@jchoi92 thank you very much for the feedback! With respect to memory issues:
Let me know if this helps. Thank you! |
Hi, trying to reproduce the SQUAD BERT tutorial I get the following error trying to run the: start_scores, end_scores = predict(input_embeddings, |
Thank you for trying out the tutorial, @armheb ! Yeah, I think that they made some changes in Bert and there were some inconsistencies. I have a new PR (#222) open where I made some updates. Do you mind trying the version in #222 PR ? Also, how did you fine-tune Bert model? Did you use |
Thanks for your quick response, I used bert-large-uncased-whole-word-masking-finetuned-squad which they have finetuned on squad. I'll try the new PR and let you know the results. |
Thank you so much, it's fixed now. |
Thanks, @NarineK ! |
On a separate note, I found that captum is super (GPU) memory-intensive. I was not able to run the code using a 12GB mem GPU (the exact same code works on CPU and 32GB mem GPU). It would be super helpful to specify the minimum GPU memory required for large models like BERT/ResNet/etc. |
Thank you for the feedback @mralexisw ! It depends on what algorithms and what parametrization you use. The tutorial that we have on Bert, runs on CPU under 2 - 3 mins. In general IG can be memory intensive depending on the integral approximation steps. |
@NarineK Another couple of quick questions for the BERT visualization part:
|
e.g. if we predict that something on the image is a dog with a high probability then Does it make sense ? |
That makes sense, @NarineK . A bit further for point 1, why should we use Frobenius Norm for attribution scores? For point 2, a renaming/docstring would be super helpful. One more thing I noticed for the tutorial: I added
It makes sense for the first two. However, do you have any idea why at some point the shape is Thanks again for your help! |
@mralexisw, that's a good point!
|
Do you, guys, still want to keep this open or can we close it ? |
@NarineK It OK for me to close the issue. |
Awesome! Thank you! Closing for now! Feel free to open a new issue if you'll have more questions. |
Can you pleaase provide a full example on how to use Captum with BERTSetenceClassification? I've to run the code provided here. It is running without errors the results don't make sense to me. |
I'm also interested in doing that. Would you mind sharing a bit how your result looks like? |
If I'm not mistaken, the arguments for VisualizationDataRecord are (in order)
Regarding label being included in the colored word importance, what do you mean by this? Are you referring to [CLS] or [SEP]? |
thanks @heytitle for the clarification. but whne i do the visualization i find the word "bullying" between [sep] bullying [sep] I don't understand why? |
I see. I didn't pay attention to Could it be that |
I shared the jupyter notebook on google Colab https://colab.research.google.com/drive/1wC6Z5eCs4SnZo6RFlTGUvYlIcuQa82WK?usp=sharing Thanks |
I don't see any content in the notebook. Is there anything I should do in order to see the content? |
There is something wrong with my Google Drive. can you try this link https://colab.research.google.com/drive/1gzOOKplSCAVTXagUwfv68ivwebYjKJFC?usp=sharing |
@vfdev-5 i tried your example but when i tried the baselines you suggested above attributions_ig, delta = ig.attribute(input_embedding, baselines=t_ref_input_ids, n_steps=500,return_convergence_delta=True) I got this error "RuntimeError: The size of tensor a (768) must match the size of tensor b (6) at non-singleton dimension 2" |
Hi all, with BERT binary classification example wiht BERT. I found it in this forum I have not tried it yet but it looks promising. |
@efatmae did you find a solution? I too am having the error: |
Hi,
Thanks for the great work. The LSTM tutorial looks very nice.
Are any suggestions on how to use Captum for Transformer-based / BERT-like pre-trained contextualized word embeddings? If I want to see the attribution of each token in the word embedding layer, is it that I'd also need the FFN layer for fine-tuning downstream tasks in order to get the gradients? The current code is implemented with torch/text; would really appreciate it if you could some hints how to integrate it with BERT models(e.g. huggingface/transformers).
Thank you.
The text was updated successfully, but these errors were encountered: