Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

configure_interpretable_embedding_layer changes the final predictions of my Bert model #770

Closed
gonconist opened this issue Sep 20, 2021 · 1 comment

Comments

@gonconist
Copy link

gonconist commented Sep 20, 2021

Hi,

I am currently applying Saliency to interpret the results of my custom Bert model, which outputs a single non-normalised value per example (no softmax or sigmoid applied) for a binary classification task.

While visualising the attributions, I observed that the model's predictions (with Saliency) do not match the original ones (without Saliency), i.e., the logits obtained on the respective forward passes are different.

I found out that this behaviour is related to configure_interpretable_embedding_layer which I use in order to create the input for computing the attributions.

More specifically, this is observed when I use input_ = self.interpretable_embedding.indices_to_embeddings(input_, att_mask, segment_ids) inside interpret_sentence function.

Is this behaviour expected? I assume the predictions should have been the same regardless of the explanation method used.

Any help would be appreciated! Thanks!

Below is my code for reference (currently on transformers==3.5.1):

from pathlib import Path
from transformers import BertTokenizer
from BertForAnalogy import BertForAnalogy, WiCTSVDatasetEncodingOptions, WiCTSVDataset, WiCTSVDataLoader, read_wic_tsv

import argparse
import random
import numpy as np
import torch
from captum.attr import Saliency,
configure_interpretable_embedding_layer, remove_interpretable_embedding_layer, visualization
from tqdm import tqdm

def remove_padding(token_list):
return [token for token in token_list if token != '[PAD]']

class BertWrapper(torch.nn.Module):
def init(self, transformer):
super(BertWrapper, self).init()
self.transformer = transformer

def forward(self, input,
            target_start_len=None, descr_start_len=None, def_start_len=None, hyps_start_len=None,
            analogy=None, attention_mask=None, token_type_ids=None):
            out = self.transformer(input,
                           target_start_len=target_start_len,
                           descr_start_len=descr_start_len,
                           def_start_len=def_start_len,
                           hyps_start_len=hyps_start_len,
                           analogy=analogy,
                           attention_mask=attention_mask,
                           token_type_ids=token_type_ids)[0]  # ['logits']
            return out

def get_model(model_args, device):
transformer_model = BertForAnalogy.from_pretrained(
model_args.model_path, permute=model_args.permute).to(device)

return BertWrapper(transformer_model)

class GradientBasedVisualizer:
"""
Source: https://captum.ai/tutorials/IMDB_TorchText_Interpret
"""

def __init__(self, tokenizer, interpretable_embedding, ablator):
    self.vis_data_records_ig = []
    self.label_idx = {0: 'negative', 1: 'positive'}
    self.tokenizer = tokenizer
    self.interpretable_embedding = interpretable_embedding
    self.ablator = ablator

def add_attributions_to_visualizer(self, attributions, text, pred, pred_ind,
                                   label, delta, target):
    attributions = attributions.sum(dim=2).squeeze(0)
    attributions = attributions / torch.norm(attributions)
    attributions = attributions.cpu().detach().numpy()

    # storing couple samples in an array for visualization purposes
    self.vis_data_records_ig.append(visualization.VisualizationDataRecord(
        attributions,
        pred,
        self.label_idx[pred_ind],
        self.label_idx[label],
        self.label_idx[target],
        attributions.sum(),
        text,
        delta))

def interpret_sentence(self, model, sentence, target=None):
    model.zero_grad()
    target_start_len = sentence[1].unsqueeze(0)
    descr_start_len = sentence[2].unsqueeze(0)
    def_start_len = sentence[3].unsqueeze(0)
    hyps_start_len = sentence[4].unsqueeze(0)
    analogy = sentence[5]
    att_mask = sentence[7].unsqueeze(0)
    segment_ids = sentence[8].unsqueeze(0)

    label = sentence[6].detach().cpu().numpy().tolist()
    text = remove_padding(self.tokenizer.convert_ids_to_tokens(
        sentence[0].squeeze().detach().cpu().numpy().tolist()))

    input_ = sentence[0].unsqueeze(0)
    input_ = self.interpretable_embedding.indices_to_embeddings(input_, att_mask, segment_ids)

    # predict
    pred = model(input_,
                 target_start_len, descr_start_len, def_start_len, hyps_start_len, analogy,
                att_mask, segment_ids).detach().cpu().numpy().tolist()[0][0]
    pred_ind = 1 if pred > 0 else 0

    # compute attributions and approximation delta using layer integrated
    # gradients
    attributions_ig = self.ablator.attribute(input_,
                                                 additional_forward_args=(
                                                     target_start_len,
                                                     descr_start_len,
                                                     def_start_len,
                                                     hyps_start_len,
                                                     analogy,
                                                     att_mask,
                                                     segment_ids),
                                                 target=target, abs=False)

    self.add_attributions_to_visualizer(attributions_ig,
                                        text,
                                        pred,
                                        pred_ind,
                                        label,
                                        None,
                                        target=1)

def visualize(self):
    img = visualization.visualize_text(self.vis_data_records_ig)
    self.vis_data_records_ig = []
    return img

if name == 'main':

args = ALL_ARGUMENTS
model_args = argparse.Namespace(**args)

# Set the seed value all over the place to make this reproducible
seed_val = model_args.seed
random.seed(seed_val)
np.random.seed(seed_val)
torch.manual_seed(seed_val)
torch.cuda.manual_seed_all(seed_val)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False

tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')

# Dataset
contexts, target_ses, hypernyms, definitions, labels = read_wic_tsv(Path('data/Development'))
dev_ds = WiCTSVDataset(contexts, target_ses, hypernyms, definitions,
                       tokenizer=tokenizer,
                       focus_token='$',
                       labels=labels,
                       encoding_type=model_args.encoding)
dev_dataloader = WiCTSVDataLoader(dev_ds, 'Development', batch_size=8)

# Device configuration
if torch.cuda.is_available():
    device = torch.device("cuda")
    # model.cuda()
    print('There are %d GPU(s) available.' % torch.cuda.device_count())
    print('We will use the GPU:', torch.cuda.get_device_name(0))
else:
    print('No GPU available, using CPU instead.')
    device = torch.device("cpu")

model = get_model(model_args, device)
model.eval()
model.zero_grad()

embedding_name = 'transformer.bert.embeddings'
interpretable_embedding = configure_interpretable_embedding_layer(model, embedding_name)

for batch_idx, features in enumerate(tqdm(dev_dataloader)):
        input_ids = features["input_ids"].to(device)
        input_mask = features["attention_mask"].to(device)
        segment_ids = features["token_type_ids"].to(device)
        label_ids = features["labels"].to(device)
        target_start_len = features["target_start_len"].to(device)
        descr_start_len = features["descr_start_len"].to(device)
        def_start_len = features["def_start_len"].to(device)
        hyps_start_len = features["hyps_start_len"].to(device)

        ablator = Saliency(model)
        visualizer = GradientBasedVisualizer(tokenizer, interpretable_embedding, ablator)

        for i in range(len(label_ids)):
            visualizer.interpret_sentence(model, (input_ids[i],
                                                 target_start_len[i],
                                                 descr_start_len[i],
                                                 def_start_len[i],
                                                 hyps_start_len[i],
                                                 model_args.analogy,
                                                 label_ids[i],
                                                 input_mask[i],
                                                 segment_ids[i]),)

        img = visualizer.visualize()
@gonconist
Copy link
Author

I managed to fix my problem by replacing

input_ = model.transformer.bert.embeddings(input_) by model.transformer.bert.embeddings.word_embeddings(input_)

Essentially, the first one will run the whole BertEmbeddings module, which applies a series of transformations to the input_ (adding positions and token_type embeddings, LayerNorm and dropout) whereas the second only creates an nn.Embedding without altering it further.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

1 participant