You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
I am currently applying Saliency to interpret the results of my custom Bert model, which outputs a single non-normalised value per example (no softmax or sigmoid applied) for a binary classification task.
While visualising the attributions, I observed that the model's predictions (with Saliency) do not match the original ones (without Saliency), i.e., the logits obtained on the respective forward passes are different.
I found out that this behaviour is related to configure_interpretable_embedding_layer which I use in order to create the input for computing the attributions.
More specifically, this is observed when I use input_ = self.interpretable_embedding.indices_to_embeddings(input_, att_mask, segment_ids) inside interpret_sentence function.
Is this behaviour expected? I assume the predictions should have been the same regardless of the explanation method used.
Any help would be appreciated! Thanks!
Below is my code for reference (currently on transformers==3.5.1):
from pathlib import Path
from transformers import BertTokenizer
from BertForAnalogy import BertForAnalogy, WiCTSVDatasetEncodingOptions, WiCTSVDataset, WiCTSVDataLoader, read_wic_tsv
import argparse
import random
import numpy as np
import torch
from captum.attr import Saliency,
configure_interpretable_embedding_layer, remove_interpretable_embedding_layer, visualization
from tqdm import tqdm
def remove_padding(token_list):
return [token for token in token_list if token != '[PAD]']
class BertWrapper(torch.nn.Module):
def init(self, transformer):
super(BertWrapper, self).init()
self.transformer = transformer
input_ = model.transformer.bert.embeddings(input_) by model.transformer.bert.embeddings.word_embeddings(input_)
Essentially, the first one will run the whole BertEmbeddings module, which applies a series of transformations to the input_ (adding positions and token_type embeddings, LayerNorm and dropout) whereas the second only creates an nn.Embedding without altering it further.
Hi,
I am currently applying Saliency to interpret the results of my custom Bert model, which outputs a single non-normalised value per example (no softmax or sigmoid applied) for a binary classification task.
While visualising the attributions, I observed that the model's predictions (with Saliency) do not match the original ones (without Saliency), i.e., the logits obtained on the respective forward passes are different.
I found out that this behaviour is related to configure_interpretable_embedding_layer which I use in order to create the input for computing the attributions.
More specifically, this is observed when I use input_ = self.interpretable_embedding.indices_to_embeddings(input_, att_mask, segment_ids) inside interpret_sentence function.
Is this behaviour expected? I assume the predictions should have been the same regardless of the explanation method used.
Any help would be appreciated! Thanks!
Below is my code for reference (currently on transformers==3.5.1):
from pathlib import Path
from transformers import BertTokenizer
from BertForAnalogy import BertForAnalogy, WiCTSVDatasetEncodingOptions, WiCTSVDataset, WiCTSVDataLoader, read_wic_tsv
import argparse
import random
import numpy as np
import torch
from captum.attr import Saliency,
configure_interpretable_embedding_layer, remove_interpretable_embedding_layer, visualization
from tqdm import tqdm
def remove_padding(token_list):
return [token for token in token_list if token != '[PAD]']
class BertWrapper(torch.nn.Module):
def init(self, transformer):
super(BertWrapper, self).init()
self.transformer = transformer
def get_model(model_args, device):
transformer_model = BertForAnalogy.from_pretrained(
model_args.model_path, permute=model_args.permute).to(device)
class GradientBasedVisualizer:
"""
Source: https://captum.ai/tutorials/IMDB_TorchText_Interpret
"""
if name == 'main':
The text was updated successfully, but these errors were encountered: