# Captum Insights for Visual Question Answering with Added Evaluation of Models

This notebook provides a simple example for the [Captum Insights API](https://captum.ai/docs/captum_insights), which is an easy to use API built on top of Captum that provides a visualization widget.


As with the referenced tutorial, you will need the following installed on your machine:

- Python Packages: torchvision, PIL, and matplotlib
- pytorch-vqa: https://github.com/Cyanogenoid/pytorch-vqa
- pytorch-resnet: https://github.com/Cyanogenoid/pytorch-resnet
- A pretrained pytorch-vqa model, which can be obtained from: https://github.com/Cyanogenoid/pytorch-vqa/releases/download/v1.0/2017-08-04_00.55.19.pth
- Create a CUDA environment with environment.yml do all dependencies and versions are correct and working

Please modify the below section for your specific installation paths:

In [1]:
import sys, os

# Replace the placeholder strings with the associated 
# path for the root of pytorch-vqa and pytorch-resnet respectively
PYTORCH_VQA_DIR = os.path.realpath("C:\\Users\\saroa\\OneDrive\\Documentos\\XAI\\pytorch-vqa")
PYTORCH_RESNET_DIR = os.path.realpath("C:\\Users\\saroa\\OneDrive\\Documentos\\XAI\\pytorch-resnet")

# Please modify this path to where it is located on your machine
# you can download this model from: 
# https://github.com/Cyanogenoid/pytorch-vqa/releases/download/v1.0/2017-08-04_00.55.19.pth
VQA_MODEL_PATH = "models/2017-08-04_00.55.19.pth"

assert(os.path.exists(PYTORCH_VQA_DIR))
assert(os.path.exists(PYTORCH_RESNET_DIR))
assert(os.path.exists(VQA_MODEL_PATH))

sys.path.append(PYTORCH_VQA_DIR)
sys.path.append(PYTORCH_RESNET_DIR)

Now, we will import the necessary modules to run the code. Please make sure you have the [prerequisites to run captum](https://captum.ai/docs/getting_started), along with the pre-requisites to run this tutorial (as described in the first section).

In [2]:
import numpy as np
import matplotlib.pyplot as plt
from matplotlib.colors import LinearSegmentedColormap
from PIL import Image

import torch
import torchvision
import torchvision.transforms as transforms
import torch.nn.functional as F

try:
    import resnet  # from pytorch-resnet
except:
    print("please provide a valid path to pytorch-resnet")

try:
    from model import Net, apply_attention, tile_2d_over_nd  # from pytorch-vqa
    from utils import get_transform  # from pytorch-vqa
except:
    print("please provide a valid path to pytorch-vqa")
    
from captum.insights import AttributionVisualizer, Batch
from captum.insights.attr_vis.features import ImageFeature, TextFeature
from captum.attr import TokenReferenceBase, configure_interpretable_embedding_layer, remove_interpretable_embedding_layer

In [3]:
run_on='cuda'  # change to 'cuda' if a GPU is available
if run_on == 'cuda':
    # Let's set the device we will use for model inference
    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
else:
    device = torch.device("cpu")

# VQA Model Setup

Let's load the VQA model (again, please refer to the [model interpretation tutorial on VQA](https://captum.ai/tutorials/Multimodal_VQA_Interpret) if you want details)

In [4]:
saved_state = torch.load(VQA_MODEL_PATH, map_location=device)

# reading vocabulary from saved model
vocab = saved_state["vocab"]

# reading word tokens from saved model
token_to_index = vocab["question"]

# reading answers from saved model
answer_to_index = vocab["answer"]

num_tokens = len(token_to_index) + 1

# reading answer classes from the vocabulary
answer_words = ["unk"] * len(answer_to_index)
for w, idx in answer_to_index.items():
    answer_words[idx] = w
    
if run_on == 'cuda':
    vqa_net = torch.nn.DataParallel(Net(num_tokens), device_ids=[0])
    vqa_net.load_state_dict(saved_state["weights"])
    vqa_net = vqa_net.to(device)
else:
    vqa_net = Net(num_tokens)
    state_dict = saved_state["weights"]
    from collections import OrderedDict
    new_state_dict = OrderedDict()
    for k, v in state_dict.items():
        name = k[7:] if k.startswith("module.") else k  # remove `module.` if it exists
        new_state_dict[name] = v
    vqa_net.load_state_dict(new_state_dict)
    vqa_net = vqa_net.to(device)


In [5]:
 # for visualization to convert indices to tokens for questions
question_words = ["unk"] * num_tokens
for w, idx in token_to_index.items():
    question_words[idx] = w

Let's modify the VQA model to use pytorch-resnet. Our model will be called `vqa_resnet`.

In [6]:
class ResNetLayer4(torch.nn.Module):
    def __init__(self):
        super().__init__()
        self.r_model = resnet.resnet152(pretrained=True)
        self.r_model.eval()
        self.r_model.to(device)

    def forward(self, x):
        x = self.r_model.conv1(x)
        x = self.r_model.bn1(x)
        x = self.r_model.relu(x)
        x = self.r_model.maxpool(x)
        x = self.r_model.layer1(x)
        x = self.r_model.layer2(x)
        x = self.r_model.layer3(x)
        return self.r_model.layer4(x)

class VQA_Resnet_Model(Net):
    def __init__(self, embedding_tokens):
        super().__init__(embedding_tokens)
        self.resnet_layer4 = ResNetLayer4()

    def forward(self, v, q, q_len):
        q = self.text(q, list(q_len.data))
        v = self.resnet_layer4(v)

        v = v / (v.norm(p=2, dim=1, keepdim=True).expand_as(v) + 1e-8)

        a = self.attention(v, q)
        v = apply_attention(v, a)

        combined = torch.cat([v, q], dim=1)
        answer = self.classifier(combined)
        return answer
    
if run_on == 'cuda':
    vqa_resnet = VQA_Resnet_Model(vqa_net.module.text.embedding.num_embeddings)
    # `device_ids` contains a list of GPU ids which are used for parallelization supported by `DataParallel`
    vqa_resnet = torch.nn.DataParallel(vqa_resnet, device_ids=[0])
else:
    vqa_resnet = VQA_Resnet_Model(vqa_net.text.embedding.num_embeddings)



# saved vqa model's parameters
partial_dict = vqa_net.state_dict()

state = vqa_resnet.state_dict()
state.update(partial_dict)
vqa_resnet.load_state_dict(state)

vqa_resnet.to(device)
vqa_resnet.eval()

# This is original VQA model without resnet. Removing it, since we do not need it
del vqa_net

# this is necessary for the backpropagation of RNNs models in eval mode
torch.backends.cudnn.enabled = False

# Input Utilities

Now we will need some utility functions for the inputs of our model. 

Let's start off with our image input transform function. We will separate out the normalization step from the transform in order to view the original image.

In [7]:
image_size = 448  # scale image to given size and center
central_fraction = 1.0

transform = get_transform(image_size, central_fraction=central_fraction)
transform_normalize = transform.transforms.pop()

Now for the input question, we will need an encoding function (to go from words -> indices):

In [8]:
def encode_question(question):
    """ Turn a question into a vector of indices and a question length """
    question_arr = question.lower().split()
    vec = torch.zeros(len(question_arr), device=device).long()
    for i, token in enumerate(question_arr):
        index = token_to_index.get(token, 0)
        vec[i] = index
    return vec, torch.tensor(len(question_arr), device=device)

# Baseline Inputs 

The insights API utilises captum's attribution API under the hood, hence we will need a baseline for our inputs. A baseline is (typically) a neutral output to reference in order for our attribution algorithm(s) to understand which features are important in making a prediction (this is very simplified explanation, 'Remark 1' in the [Integrated Gradients paper](https://arxiv.org/pdf/1703.01365.pdf) has an excellent explanation on why they must be utilised).

For images and for the purpose of this tutorial, we will let this baseline be the zero vector (a black image).

In [9]:
def baseline_image(x):
    return x * 0

For sentences, as done in the multi-modal VQA tutorial, we will use a sentence composed of padded symbols.

We will also require to pass our model through the [`configure_interpretable_embedding_layer`](https://captum.ai/api/utilities.html?highlight=configure_interpretable_embedding_layer#captum.attr._models.base.configure_interpretable_embedding_layer) function, which separates the embedding layer and precomputes word embeddings. To put it simply, this function allows us to precompute and give the embedding vectors directly to our model, which will allow us to reference the words associated to particular embeddings (for visualization purposes).

In [10]:
if run_on == 'cuda':
    interpretable_embedding = configure_interpretable_embedding_layer(
        vqa_resnet, "module.text.embedding")
else:
    interpretable_embedding = configure_interpretable_embedding_layer(
        vqa_resnet, "text.embedding")


PAD_IND = token_to_index["pad"]
token_reference = TokenReferenceBase(reference_token_idx=PAD_IND)

def baseline_text(x):
    seq_len = x.size(0)
    ref_indices = token_reference.generate_reference(seq_len, device=device).unsqueeze(
        0
    )
    return interpretable_embedding.indices_to_embeddings(ref_indices).squeeze(0)

def input_text_transform(x):
    return interpretable_embedding.indices_to_embeddings(x)

  "In order to make embedding layers more interpretable they will "


# Using the Insights API

Finally we have reached the relevant part of the tutorial.

First let's create a utility function to allow us to pass data into the insights API. This function will essentially produce `Batch` objects, which tell the insights API what your inputs, labels and any additional arguments are.

In [11]:
def vqa_dataset(image, questions, targets):
    img = Image.open(image).convert("RGB")
    img = transform(img).unsqueeze(0).to(device)

    for question, target in zip(questions, targets):
        q, q_len = encode_question(question)

        q = q.unsqueeze(0)
        q_len = q_len.unsqueeze(0)

        target_idx = answer_to_index[target]

        yield Batch(
            inputs=(img, q), labels=(target_idx,), additional_args=q_len
        )
    

Let's create our `AttributionVisualizer`, to do this we need the following:

- A score function, which tells us how to interpret the model's output vector
- Description of the input features given to the model
- The data to visualize (as described above)
- Description of the output (the class names), in our case this is our answer words

In our case, we want to produce a single answer output via softmax

In [12]:
def score_func(o):
    return F.softmax(o, dim=1)

The following function will convert a sequence of question indices to the associated question words for visualization purposes. This will be provided to the `TextFeature` object to describe text features.

In [13]:
def itos(input):
    return [question_words[int(i)] for i in input.squeeze(0)]

Let's define some dummy data to visualize using the function we declared earlier.

In [14]:
dataset = vqa_dataset("./img/vqa/siamese.jpg", 
    ["what is on the picture",
    "what color is the cat",
    "where color are the cat eyes" ],
    ["cat", "white and brown", "blue"]
)    

Now let's describe our features. Each feature requires an input transformation function and a set of baselines. As described earlier, we will use the black image for the image baseline and a padded sequence for the text baseline.

The input image will be transformed via our normalization transform (`transform_normalize`).
Our input text will need to be transformed into embeddings, as it is a sequence of indices. Our model only accepts embeddings as input, as we modified the model with `configure_interpretable_embedding_layer` earlier.

We also need to provide how the input text should be transformed in order to be visualized, which will be accomplished through the `itos` function, as described earlier.

In [15]:
features = [
    ImageFeature(
        "Picture",
        input_transforms=[transform_normalize],
        baseline_transforms=[baseline_image],
    ),
    TextFeature(
        "Question",
        input_transforms=[input_text_transform],
        baseline_transforms=[baseline_text],
        visualization_transform=itos,
    ),
]

An addition to the code was made to create perturbations for both text and images, enabling the evaluation of different explanation methods on sensitivity.

In [16]:
import torch
import torchvision.transforms as transforms
from PIL import Image
import numpy as np
import random
from nltk.corpus import wordnet
from itertools import chain
import nltk
nltk.download('wordnet')
nltk.download('omw-1.4')


def perturb_image(image, noise_level=0.1, device='cpu'):
    if isinstance(image, Image.Image):
        image = transforms.ToTensor()(image).to(device)
    
    noise = torch.randn(image.size(), device=device) * noise_level
    perturbed_image = image + noise
    perturbed_image = torch.clamp(perturbed_image, 0, 1)  # Ensure pixel values are within [0, 1]
    
    return perturbed_image

def get_synonyms(word):
    synonyms = wordnet.synsets(word)
    return set(chain.from_iterable([word.lemma_names() for word in synonyms]))

def perturb_text(text, perturbation_rate=0.1):

    words = text.split()
    num_perturb = int(len(words) * perturbation_rate)
    indices = random.sample(range(len(words)), num_perturb)
    
    for i in indices:
        synonyms = get_synonyms(words[i])
        if synonyms:
            words[i] = random.choice(list(synonyms))
    
    perturbed_text = ' '.join(words)
    return perturbed_text


[nltk_data] Downloading package wordnet to
[nltk_data]     C:\Users\saroa\AppData\Roaming\nltk_data...
[nltk_data]   Package wordnet is already up-to-date!
[nltk_data] Downloading package omw-1.4 to
[nltk_data]     C:\Users\saroa\AppData\Roaming\nltk_data...
[nltk_data]   Package omw-1.4 is already up-to-date!


Here, all the different functions for evaluating models on the multimodal dataset is coded.

In [17]:
from scipy.stats import spearmanr
from textstat import flesch_kincaid_grade
import torch
from captum.attr import IntegratedGradients


# Evaluation functions
def feature_importance_consistency(attributions, model_weights):
    # Summarize attributions to match the dimensionality of model weights
    summarized_attributions = torch.mean(torch.tensor(attributions), dim=0).detach().cpu().numpy()
    weights = model_weights.detach().cpu().numpy()
    return spearmanr(summarized_attributions, weights[:len(summarized_attributions)]).correlation

def perturbation_test(attributions, inputs, model, additional_args, perturbation_factor=0.1):
    perturbed_inputs = []
    for input in inputs:
        perturbed_input = []
        for element in input:
            if isinstance(element, torch.Tensor):
                perturbed_input.append(element.clone())
            else:
                perturbed_input.append(element)
        perturbed_inputs.append(perturbed_input)

    # Get the device of the inputs
    device = inputs[0][0].device

    # Convert attributions to tensors and move them to the same device as inputs
    image_attributions = torch.tensor([attr[0] for attr in attributions]).to(device)
    text_attributions = torch.tensor([attr[1] for attr in attributions]).to(device)

    # Debug prints
    print(f"Image attributions shape: {image_attributions.shape}")
    print(f"Text attributions shape: {text_attributions.shape}")

    # Ensure image_attributions matches the shape of the image input
    for i in range(len(perturbed_inputs)):
        batch_size, channels, height, width = perturbed_inputs[i][0].shape
        print(f"Perturbed input shape: {perturbed_inputs[i][0].shape}")
        image_perturbation = image_attributions.view(batch_size, channels, height, width) * perturbation_factor
        perturbed_inputs[i][0] += image_perturbation

        # Ensure text_attributions matches the shape of the question input
        sequence_length = perturbed_inputs[i][1].shape[1]
        text_perturbation = text_attributions.view(batch_size, sequence_length) * perturbation_factor
        perturbed_inputs[i][1] += text_perturbation

    original_output = model(*inputs, additional_args)
    perturbed_output = model(*[tuple(input) for input in perturbed_inputs], additional_args)
    return torch.norm(original_output - perturbed_output).item()


def readability_score(text):
    return flesch_kincaid_grade(text)

def stability_test(attributions, inputs, model, additional_args, noise_level=0.1):
    noisy_inputs = [input + noise_level * torch.randn_like(input) for input in inputs]
    noisy_attributions = visualizer.attribution_calculation.calculate_attribution(
        baselines=None,
        data=noisy_inputs,
        additional_forward_args=additional_args,
        label=None,
        attribution_method_name="IntegratedGradients",
        attribution_arguments={'n_steps': 25},
        model=model
    )
    return torch.norm(attributions - noisy_attributions).item()

Let's define our AttributionVisualizer object with the above parameters and our `vqa_resnet` model. 

In [18]:
visualizer = AttributionVisualizer(
    models=[vqa_resnet],
    score_func=score_func,
    features=features,
    dataset=dataset,
    classes=answer_words,
)

And now we can visualize the outputs produced by the model.

Insights allows [different attribution methods](https://captum.ai/docs/algorithms) to be chosen. By default, [integrated gradients](https://captum.ai/api/integrated_gradients) is selected.

In [19]:
visualizer.serve(debug=True)

 * Debugger is active!



Fetch data and view Captum Insights at http://localhost:53377/



 * Debugger PIN: 210-765-223


53377

 * Running on http://127.0.0.1:53377/ (Press CTRL+C to quit)


127.0.0.1 - - [25/Jul/2024 13:11:56] "GET / HTTP/1.1" 200 -
127.0.0.1 - - [25/Jul/2024 13:11:56] "GET /static/css/main.fac91593.chunk.css HTTP/1.1" 200 -
127.0.0.1 - - [25/Jul/2024 13:11:56] "GET /static/js/2.c6c4604e.chunk.js HTTP/1.1" 200 -
127.0.0.1 - - [25/Jul/2024 13:11:56] "GET /static/js/main.835ab072.chunk.js HTTP/1.1" 200 -
127.0.0.1 - - [25/Jul/2024 13:11:56] "GET /init HTTP/1.1" 200 -
127.0.0.1 - - [25/Jul/2024 13:11:56] "GET /init HTTP/1.1" 200 -
127.0.0.1 - - [25/Jul/2024 13:11:56] "GET /favicon.ico HTTP/1.1" 404 -
            computed in a `torch.no_grad` block or perhaps the inputs have no
            requires_grad.
  requires_grad."""
  "required_grads has been set automatically." % index
            computed in a `torch.no_grad` block or perhaps the inputs have no
            requires_grad.
  requires_grad."""
  "required_grads has been set automatically." % index
            computed in a `torch.no_grad` block or perhaps the inputs have no
            requires_grad.
 

Now that the visualizer is being displayed on localhost, when a explanation is run we have the requiered information to create evaluations for this models.

In [20]:
from captum.attr import (Deconvolution,DeepLift,FeatureAblation,GuidedBackprop,InputXGradient,IntegratedGradients,Occlusion,Saliency,)
import pandas as pd
from tabulate import tabulate

SUPPORTED_ATTRIBUTION_METHODS = [Deconvolution,DeepLift,GuidedBackprop,InputXGradient,IntegratedGradients,Saliency,FeatureAblation,Occlusion,]
ATTRIBUTION_NAMES_TO_METHODS = {
    cls.get_name(): cls  # type: ignore
    for cls in SUPPORTED_ATTRIBUTION_METHODS
}

def calculate_attributions(model, inputs, target, baselines, additional_args, xai_model, selected_arguments):
    xai = ATTRIBUTION_NAMES_TO_METHODS[xai_model](model)
    if xai_model in ['IntegratedGradients', 'FeatureAblation', 'Occlusion']:
        attributions = xai.attribute.__wrapped__(xai, inputs=inputs, additional_forward_args=additional_args, target=target, **selected_arguments)
    else:
        attributions = xai.attribute.__wrapped__(xai, inputs=inputs, additional_forward_args=additional_args, target=target)
    return attributions

modality_attributions = visualizer.get_attributions()
 
selected_arguments= visualizer.get_insights_config()['selected_arguments']    
xai_model = visualizer.get_insights_config()['selected_method']
print(xai_model) 

    
dataset = vqa_dataset("./img/vqa/siamese.jpg", 
    ["what is on the picture",
    "what color is the cat",
    "where color are the cat eyes" ],
    ["cat", "white and brown", "blue"]
)    

results = []

for batch in dataset:
    original_inputs = batch.inputs
    original_additional_args = batch.additional_args
    target = batch.labels
    
    noise_level = 0.1
    perturbation_rate = 0.1
    
    # Calculate original attributions
    (original_predicted_scores, original_baselines, orginal_transformed_inputs,) = visualizer.attribution_calculation.calculate_predicted_scores(original_inputs, original_additional_args, vqa_resnet)
    original_attributions = calculate_attributions(vqa_resnet, orginal_transformed_inputs, target, None, original_additional_args, xai_model, selected_arguments)
    original_net_contrib = visualizer.attribution_calculation.calculate_net_contrib(original_attributions)

    # Clear unused variables and cache to free GPU memory
    del original_baselines, orginal_transformed_inputs, original_attributions
    torch.cuda.empty_cache()
    
    original_label = original_predicted_scores[0].label
    
    while True:
        # Perturb image
        perturbed_image = perturb_image(original_inputs[0][0], noise_level=noise_level, device=run_on)
        perturbed_image = perturbed_image.unsqueeze(0).to(device)

        # Perturb text
        original_question = ' '.join(itos(original_inputs[1][0]))
        perturbed_question = perturb_text(original_question, perturbation_rate=perturbation_rate)
        perturbed_question_vec, perturbed_question_len = encode_question(perturbed_question)
        perturbed_question_vec = perturbed_question_vec.unsqueeze(0)
        perturbed_question_len = perturbed_question_len.unsqueeze(0)

        perturbed_inputs = (perturbed_image, perturbed_question_vec)
        perturbed_additional_args = perturbed_question_len


        # Calculate new attributions with perturbed inputs
        (perturbed_rpredicted_scores, perturbed_baselines, perturbed_transformed_inputs,) = visualizer.attribution_calculation.calculate_predicted_scores(perturbed_inputs, perturbed_additional_args, vqa_resnet)
        perturbed_attributions = calculate_attributions(vqa_resnet, perturbed_transformed_inputs, target, None, perturbed_additional_args, xai_model, selected_arguments)
        perturbed_net_contrib = visualizer.attribution_calculation.calculate_net_contrib(perturbed_attributions)

        # Clear unused variables and cache to free GPU memory
        del perturbed_baselines, perturbed_transformed_inputs, perturbed_attributions
        torch.cuda.empty_cache()

        # Compare original and perturbed attributions
        perturbed_label = perturbed_rpredicted_scores[0].label
        prediction_consistency = original_label == perturbed_label
        comparison = np.array(original_net_contrib) - np.array(perturbed_net_contrib)

        # Store the results
        results.append({
            'noise_level': noise_level,
            'perturbation_rate': perturbation_rate,
            'original_label': original_label,
            'perturbed_label': perturbed_label,
            'prediction_consistency': prediction_consistency,
            'original_net_contrib': str(original_net_contrib),
            'perturbed_net_contrib': str(perturbed_net_contrib),
            'comparison': str(comparison)
        })

        # Check for consistency
        if not prediction_consistency:
            break
        else:
            noise_level += 0.2
            perturbation_rate += 0.1

# Create a DataFrame from the results
results_df = pd.DataFrame(results)

print(tabulate(results_df, headers='keys', tablefmt='psql'))

Saliency


  "required_grads has been set automatically." % index


+----+---------------+---------------------+------------------+-------------------+--------------------------+-------------------------------------------+-------------------------------------------+---------------------------+
|    |   noise_level |   perturbation_rate | original_label   | perturbed_label   | prediction_consistency   | original_net_contrib                      | perturbed_net_contrib                     | comparison                |
|----+---------------+---------------------+------------------+-------------------+--------------------------+-------------------------------------------+-------------------------------------------+---------------------------|
|  0 |           0.1 |                 0.1 | cat              | cat               | True                     | [0.8570834398269653, 0.14291661977767944] | [0.8150659799575806, 0.18493403494358063] | [ 0.04201746 -0.04201742] |
|  1 |           0.3 |                 0.2 | cat              | cat               | True    

In [None]:
# Evaluate explanations
model_weights = vqa_resnet.module.classifier.lin2.weight # Example for accessing model weights
faithfulness_score = feature_importance_consistency(modality_attributions, model_weights)
#perturbation_score = perturbation_test(attributions, inputs, vqa_resnet, additional_args)
readability = readability_score("Generated Explanation Text")  # Replace with actual text
#sensitivity_score = stability_test(attributions, inputs, vqa_resnet)

# Print or log evaluation results
print(f"Faithfulness Score: {faithfulness_score}")
#print(f"Perturbation Score: {perturbation_score}")
print(f"Readability Score: {readability}")
#print(f"Sensitivity Score: {sensitivity_score}")

In [None]:
# show a screenshot if using notebook non-interactively
import IPython.display
IPython.display.Image(filename='img/captum_insights_vqa.png')

Finally, since we are done with visualization, we will revert the change to the model we made with `configure_interpretable_embedding_layer`. To do this, we will invoke the `remove_interpretable_embedding_layer` function. Uncomment the line below to execute the cell.

In [None]:
# remove_interpretable_embedding_layer(vqa_resnet, interpretable_embedding)