# **Generic Attention-model Explainability for Interpreting Bi-Modal and Encoder-Decoder Transformers (Transformer Multi Modal Explainability)**

In [None]:
# !rm -r ../Transformer-MM-Explainability

In [1]:
import os

In [2]:
# os.chdir('../')

In [3]:
# !git clone https://github.com/shiv2110/Transformer-MM-Explainability

In [4]:
# os.chdir(f'./Transformer-MM-Explainability')

In [5]:
!pip install -r requirements.txt



In [6]:
# import cv2

In [7]:
!ls

'ls' is not recognized as an internal or external command,
operable program or batch file.


# **LXMERT**

**Examples from paper**

In [8]:
from lxmert.lxmert.src.modeling_frcnn import GeneralizedRCNN
import lxmert.lxmert.src.vqa_utils as utils
from lxmert.lxmert.src.processing_image import Preprocess
from transformers import LxmertTokenizer

## problem lies here
from lxmert.lxmert.src.huggingface_lxmert import LxmertForQuestionAnswering
from lxmert.lxmert.src.lxmert_lrp import LxmertForQuestionAnswering as LxmertForQuestionAnsweringLRP

from lxmert.lxmert.src.lxmert_lrp import LxmertAttention


from tqdm import tqdm
from lxmert.lxmert.src.ExplanationGenerator import GeneratorOurs, GeneratorBaselines, GeneratorOursAblationNoAggregation
import random
import numpy as np
import cv2
import torch
import matplotlib.pyplot as plt
from PIL import Image
import torchvision.transforms as transforms
from captum.attr import visualization
import requests

In [9]:
OBJ_URL = "https://raw.githubusercontent.com/airsplay/py-bottom-up-attention/master/demo/data/genome/1600-400-20/objects_vocab.txt"
ATTR_URL = "https://raw.githubusercontent.com/airsplay/py-bottom-up-attention/master/demo/data/genome/1600-400-20/attributes_vocab.txt"
VQA_URL = "https://raw.githubusercontent.com/airsplay/lxmert/master/data/vqa/trainval_label2ans.json"

In [10]:
class ModelUsage:
    def __init__(self, use_lrp=False):
        self.vqa_answers = utils.get_data(VQA_URL)

        # load models and model components
        self.frcnn_cfg = utils.Config.from_pretrained("unc-nlp/frcnn-vg-finetuned")
        self.frcnn_cfg.MODEL.DEVICE = "cuda"

        self.frcnn = GeneralizedRCNN.from_pretrained("unc-nlp/frcnn-vg-finetuned", config = self.frcnn_cfg)

        self.image_preprocess = Preprocess(self.frcnn_cfg)

        self.lxmert_tokenizer = LxmertTokenizer.from_pretrained("unc-nlp/lxmert-base-uncased")

        if use_lrp:
            self.lxmert_vqa = LxmertForQuestionAnsweringLRP.from_pretrained("unc-nlp/lxmert-vqa-uncased").to("cuda")
        else:
            self.lxmert_vqa = LxmertForQuestionAnswering.from_pretrained("unc-nlp/lxmert-vqa-uncased").to("cuda")

        self.lxmert_vqa.eval()
        self.model = self.lxmert_vqa

        # self.vqa_dataset = vqa_data.VQADataset(splits="valid")

    def forward(self, item):
        URL, question = item

        self.image_file_path = URL

        # run frcnn
        images, sizes, scales_yx = self.image_preprocess(URL)
        output_dict = self.frcnn(
            images,
            sizes,
            scales_yx=scales_yx,
            padding="max_detections",
            max_detections= self.frcnn_cfg.max_detections,
            return_tensors="pt"
        )
        inputs = self.lxmert_tokenizer(
            question,
            truncation=True,
            return_token_type_ids=True,
            return_attention_mask=True,
            add_special_tokens=True,
            return_tensors="pt"
        )
        self.question_tokens = self.lxmert_tokenizer.convert_ids_to_tokens(inputs.input_ids.flatten())
        self.text_len = len(self.question_tokens)
        # Very important that the boxes are normalized
        normalized_boxes = output_dict.get("normalized_boxes")
        features = output_dict.get("roi_features")
        self.image_boxes_len = features.shape[1]
        self.bboxes = output_dict.get("boxes")
        self.output = self.lxmert_vqa(
            input_ids=inputs.input_ids.to("cuda"),
            attention_mask=inputs.attention_mask.to("cuda"),
            visual_feats=features.to("cuda"),
            visual_pos=normalized_boxes.to("cuda"),
            token_type_ids=inputs.token_type_ids.to("cuda"),
            return_dict=True,
            output_attentions=False,
        )
        return self.output

In [11]:
def save_image_vis(image_file_path, bbox_scores):
    bbox_scores = image_scores
    _, top_bboxes_indices = bbox_scores.topk(k=1, dim=-1)
    img = cv2.imread(image_file_path)
    mask = torch.zeros(img.shape[0], img.shape[1])
    for index in range(len(bbox_scores)):
        [x, y, w, h] = model_lrp.bboxes[0][index]
        curr_score_tensor = mask[int(y):int(h), int(x):int(w)]
        new_score_tensor = torch.ones_like(curr_score_tensor)*bbox_scores[index].item()
        mask[int(y):int(h), int(x):int(w)] = torch.max(new_score_tensor,mask[int(y):int(h), int(x):int(w)])
    mask = (mask - mask.min()) / (mask.max() - mask.min())
    mask = mask.unsqueeze_(-1)
    mask = mask.expand(img.shape)
    img = img * mask.cpu().data.numpy()
    cv2.imwrite(
            'lxmert/lxmert/experiments/paper/new.jpg', img)

In [12]:
# def save_image_vis_dummy(image_file_path, bbox_scores):
#     # bbox_scores = image_scores
#     # _, top_bboxes_indices = bbox_scores.topk(k=1, dim=-1)
#     img = cv2.imread(image_file_path)
#     mask = torch.zeros(img.shape[0], img.shape[1])
#     for index in bbox_scores:
#         [x, y, w, h] = model_lrp.bboxes[0][index]
#         curr_score_tensor = mask[int(y):int(h), int(x):int(w)]
#         new_score_tensor = torch.ones_like(curr_score_tensor)*bbox_scores[index].item()
#         mask[int(y):int(h), int(x):int(w)] = torch.max(new_score_tensor,mask[int(y):int(h), int(x):int(w)])
#     mask = (mask - mask.min()) / (mask.max() - mask.min())
#     mask = mask.unsqueeze_(-1)
#     mask = mask.expand(img.shape)
#     img = img * mask.cpu().data.numpy()
#     cv2.imwrite(
#             'lxmert/lxmert/experiments/paper/new.jpg', img)

In [14]:
model_lrp = ModelUsage(use_lrp=True)
lrp = GeneratorOurs(model_lrp)
baselines = GeneratorBaselines(model_lrp)
vqa_answers = utils.get_data(VQA_URL)

# baselines.generate_transformer_attr(None)
# baselines.generate_attn_gradcam(None)
# baselines.generate_partial_lrp(None)
# baselines.generate_raw_attn(None)
# baselines.generate_rollout(None)

image_ids = [
    # giraffe
    'COCO_val2014_000000185590',
    # baseball
    'COCO_val2014_000000127510',
    # bath
    'COCO_val2014_000000324266',
    # frisbee
    'COCO_val2014_000000200717'
]

test_questions_for_images = [
    ################## paper samples
    # giraffe
    "is the animal eating?",
    # baseball
    "did this man catch the ball?",
    # bath
    "is the tub white ?",
    # frisbee
    "did the man just catch the frisbee?"
    ################## paper samples
]

loading configuration file cache
loading weights file https://cdn.huggingface.co/unc-nlp/frcnn-vg-finetuned/pytorch_model.bin from cache at C:\Users\shiva/.cache\torch\transformers\57f6df6abe353be2773f2700159c65615babf39ab5b48114d2b49267672ae10f.77b59256a4cf8343ae0f923246a81489fc8d82f98d082edc2d2037c977c0d9d0
All model checkpoint weights were used when initializing GeneralizedRCNN.

All the weights of GeneralizedRCNN were initialized from the model checkpoint at unc-nlp/frcnn-vg-finetuned.
If your task is similar to the task the model of the checkpoint was trained on, you can already use GeneralizedRCNN for predictions without further training.




ConnectionError: HTTPSConnectionPool(host='cdn-lfs.huggingface.co', port=443): Read timed out.

In [None]:
################# Test block ###################3

# model_lrp.lxmert_vqa.lxmert.encoder.r_layers[0-4].attention.output

In [None]:
# top_indices = 0
def test_save_image_vis(image_file_path, bbox_scores):
    # print(bbox_scores)
    bbox_scores = image_scores
    _, top_bboxes_indices = bbox_scores.topk(k=5, dim=-1)

    img = cv2.imread(image_file_path)
    mask = torch.zeros(img.shape[0], img.shape[1])
    for index in top_bboxes_indices:
        img = cv2.imread(image_file_path)
        [x, y, w, h] = model_lrp.bboxes[0][index]
        cv2.rectangle(img, (int(x), int(y)), (int(w), int(h)), (0, 0, 255), 2)
        cv2.imwrite('{}.jpg'.format(index), img)

    count = 1
    plt.figure(figsize = (15, 10))

    for idx in top_bboxes_indices:
      idx = idx.item()
      plt.subplot(1, len(top_bboxes_indices), count)
      plt.title(str(idx))
      plt.axis('off')
      plt.imshow(cv2.imread('{}.jpg'.format(idx)))
      count += 1

In [None]:
# top_indices = 0
def test_save_image_vis_dummy(image_file_path, top_bboxes_indices):
    # print(bbox_scores)
    # bbox_scores = image_scores
    # _, top_bboxes_indices = bbox_scores.topk(k=5, dim=-1)

    img = cv2.imread(image_file_path)
    mask = torch.zeros(img.shape[0], img.shape[1])
    for index in top_bboxes_indices:
        img = cv2.imread(image_file_path)
        [x, y, w, h] = model_lrp.bboxes[0][index]
        cv2.rectangle(img, (int(x), int(y)), (int(w), int(h)), (0, 0, 255), 2)
        cv2.imwrite('{}.jpg'.format(index), img)

    count = 1
    plt.figure(figsize = (15, 10))

    for idx in top_bboxes_indices:
      idx = idx.item()
      plt.subplot(1, len(top_bboxes_indices), count)
      plt.title(str(idx))
      plt.axis('off')
      plt.imshow(cv2.imread('{}.jpg'.format(idx)))
      count += 1

In [None]:
# def get_largest_cc_box(mask: np.array):
#     from skimage.measure import label as measure_label
#     labels = measure_label(mask)  # get connected components
#     largest_cc_index = np.argmax(np.bincount(labels.flat)[1:]) + 1
#     mask = np.where(labels == largest_cc_index)
#     ymin, ymax = min(mask[0]), max(mask[0]) + 1
#     xmin, xmax = min(mask[1]), max(mask[1]) + 1
#     return [xmin, ymin, xmax, ymax]

In [None]:
# def get_bbox_from_patch_mask(patch_mask, init_image_size):

#     # Sizing
#     H, W = init_image_size[:-1]
#     T = patch_mask.numel()
#     if (H // 8) * (W // 8) == T:
#         P, H_lr, W_lr = (8, H // 8, W // 8)
#     elif (H // 16) * (W // 16) == T:
#         P, H_lr, W_lr = (16, H // 16, W // 16)
#     elif 4 * (H // 16) * (W // 16) == T:
#         P, H_lr, W_lr = (8, 2 * (H // 16), 2 * (W // 16))
#     elif 16 * (H // 32) * (W // 32) == T:
#         P, H_lr, W_lr = (8, 4 * (H // 32), 4 * (W // 32))
#     else:
#         raise ValueError(f'{init_image_size=}, {patch_mask.shape=}')

#     # Create patch mask
#     patch_mask = patch_mask.reshape(H_lr, W_lr).cpu().numpy()

#     # Possibly reverse mask
#     # print(np.mean(patch_mask).item())
#     if 0.5 < np.mean(patch_mask).item() < 1.0:
#         patch_mask = (1 - patch_mask).astype(np.uint8)
#     elif np.sum(patch_mask).item() == 0:  # nothing detected at all, so cover the entire image
#         patch_mask = (1 - patch_mask).astype(np.uint8)

#     # Get the box corresponding to the largest connected component of the first eigenvector
#     xmin, ymin, xmax, ymax = get_largest_cc_box(patch_mask)
#     # pred = [xmin, ymin, xmax, ymax]

#     # Rescale to image size
#     r_xmin, r_xmax = P * xmin, P * xmax
#     r_ymin, r_ymax = P * ymin, P * ymax

#     # Prediction bounding box
#     pred = [r_xmin, r_ymin, r_xmax, r_ymax]

#     # Check not out of image size (used when padding)
#     pred[2] = min(pred[2], W)
#     pred[3] = min(pred[3], H)

#     return np.asarray(pred)

In [None]:
# reimg = cv2.resize( cv2.imread('lxmert/lxmert/experiments/paper/{0}/{0}.jpg'.format(image_ids[0])),
#                     (96, 96), interpolation = cv2.INTER_AREA)

# cv2.imwrite('giraffe.jpg', reimg)

In [None]:
URL = 'lxmert/lxmert/experiments/paper/{0}/{0}.jpg'.format(image_ids[0])
# URL = 'giraffe.jpg'

R_t_t, R_t_i = lrp.generate_ours((URL, test_questions_for_images[0]), use_lrp=True, normalize_self_attention=True, method_name="ours")
image_scores = R_t_i[0]
text_scores = R_t_t[0]
print(image_scores)

test_save_image_vis(URL, image_scores)


save_image_vis(URL, image_scores)
orig_image = Image.open(model_lrp.image_file_path)

fig, axs = plt.subplots(ncols=2, figsize=(20, 5))
axs[0].imshow(orig_image);
axs[0].axis('off');
axs[0].set_title('original');

masked_image = Image.open('lxmert/lxmert/experiments/paper/new.jpg')
axs[1].imshow(masked_image);
axs[1].axis('off');
axs[1].set_title('masked');

text_scores = (text_scores - text_scores.min()) / (text_scores.max() - text_scores.min())
vis_data_records = [visualization.VisualizationDataRecord(text_scores,0,0,0,0,0,model_lrp.question_tokens,1)]
visualization.visualize_text(vis_data_records)
print("ANSWER:", vqa_answers[model_lrp.output.question_answering_score.argmax()])

In [None]:
# init_image_size = cv2.imread(URL).shape

In [None]:
plt.imshow(lrp.self_attn_image_agg[-1].cpu())

In [None]:
# URL = 'lxmert/lxmert/experiments/paper/{0}/{0}.jpg'.format(image_ids[0])
# R_t_t, R_t_i = lrp.generate_ours((URL, test_questions_for_images[0]), use_lrp=False, normalize_self_attention=True, method_name="ours")
# image_scores = lrp.attn_t_i[-1][0]
# text_scores = lrp.self_attn_lang_agg[-1][0]


# test_save_image_vis(URL, image_scores)


# save_image_vis(URL, image_scores)
# orig_image = Image.open(model_lrp.image_file_path)

# fig, axs = plt.subplots(ncols=2, figsize=(20, 5))
# axs[0].imshow(orig_image);
# axs[0].axis('off');
# axs[0].set_title('original');

# masked_image = Image.open('lxmert/lxmert/experiments/paper/new.jpg')
# axs[1].imshow(masked_image);
# axs[1].axis('off');
# axs[1].set_title('masked');

# text_scores = (text_scores - text_scores.min()) / (text_scores.max() - text_scores.min())
# vis_data_records = [visualization.VisualizationDataRecord(text_scores,0,0,0,0,0,model_lrp.question_tokens,1)]
# visualization.visualize_text(vis_data_records)
# print("ANSWER:", vqa_answers[model_lrp.output.question_answering_score.argmax()])

In [None]:
plt.imshow(R_t_i.cpu(), cmap = 'magma')
plt.colorbar(shrink = 0.5)

In [None]:
plt.imshow(lrp.attn_t_i[-1].cpu(), cmap = 'magma')
plt.colorbar(shrink = 0.5)

In [None]:
### padding t_i
final_attn_map = lrp.attn_t_i[-1].cpu()
W = torch.cat( (final_attn_map, torch.zeros( final_attn_map.shape[1] - final_attn_map.shape[0], final_attn_map.shape[1])), dim = 0 )
W = torch.where( W > 5e-5, 1, 0 )
### co_attn_image_agg
# W = lrp.co_attn_image_agg[0].cpu()


### i_t * t_i
# W = torch.matmul( lrp.attn_i_t[-1].cpu(), lrp.attn_t_i[-1].cpu() )


In [None]:
D = torch.zeros(W.shape[0], W.shape[1])
for i in range(D.shape[0]):
  D[i, i] = torch.sum(D[i])

In [None]:
L = D - W

In [None]:
eig_vals, eig_vecs = torch.linalg.eig(L)
eig_vals = eig_vals.real
eig_vecs = eig_vecs.real

In [None]:
result, indices = torch.sort(eig_vals)

In [None]:
eig_vecs[0]

In [None]:
indices

In [None]:
# bboxes = get_bbox_from_patch_mask(eig_vecs[0], init_image_size)

In [None]:
URL = 'lxmert/lxmert/experiments/paper/{0}/{0}.jpg'.format(image_ids[0])
# R_t_t, R_t_i = lrp.generate_ours((URL, test_questions_for_images[0]), use_lrp=False, normalize_self_attention=True, method_name="ours")
# image_scores = torch.where(eig_vecs[0] < 0, 1 - eig_vecs[0], eig_vecs[0])
image_scores = eig_vecs[0]
# text_scores = lrp.self_attn_lang_agg[-1][0]


test_save_image_vis(URL, image_scores)


save_image_vis(URL, image_scores)
orig_image = Image.open(model_lrp.image_file_path)

fig, axs = plt.subplots(ncols=2, figsize=(20, 5))
axs[0].imshow(orig_image);
axs[0].axis('off')
axs[0].set_title('original');

masked_image = Image.open('lxmert/lxmert/experiments/paper/new.jpg')
axs[1].imshow(masked_image);
axs[1].axis('off');
axs[1].set_title('masked');

text_scores = (text_scores - text_scores.min()) / (text_scores.max() - text_scores.min())
vis_data_records = [visualization.VisualizationDataRecord(text_scores,0,0,0,0,0,model_lrp.question_tokens,1)]
visualization.visualize_text(vis_data_records)
print("ANSWER:", vqa_answers[model_lrp.output.question_answering_score.argmax()])

In [None]:
URL = 'lxmert/lxmert/experiments/paper/{0}/{0}.jpg'.format(image_ids[0])
# R_t_t, R_t_i = lrp.generate_ours((URL, test_questions_for_images[0]), use_lrp=True, normalize_self_attention=True, method_name="ours")
image_scores = indices[:5]
# text_scores = R_t_t[0]


test_save_image_vis_dummy(URL, image_scores)


# save_image_vis(URL, image_scores)
# orig_image = Image.open(model_lrp.image_file_path)

# fig, axs = plt.subplots(ncols=2, figsize=(20, 5))
# axs[0].imshow(orig_image);
# axs[0].axis('off');
# axs[0].set_title('original');

# masked_image = Image.open('lxmert/lxmert/experiments/paper/new.jpg')
# axs[1].imshow(masked_image);
# axs[1].axis('off');
# axs[1].set_title('masked');

# text_scores = (text_scores - text_scores.min()) / (text_scores.max() - text_scores.min())
# vis_data_records = [visualization.VisualizationDataRecord(text_scores,0,0,0,0,0,model_lrp.question_tokens,1)]
# visualization.visualize_text(vis_data_records)
# print("ANSWER:", vqa_answers[model_lrp.output.question_answering_score.argmax()])

In [None]:
URL = 'lxmert/lxmert/experiments/paper/{0}/{0}.jpg'.format(image_ids[0])
text_scores = R_t_t[0]

# R_t_t, R_t_i = lrp.generate_ours((URL, test_questions_for_images[0]), use_lrp=False, normalize_self_attention=True, method_name="ours")
for idx in range(len(crit[1:])):
  image_scores = crit[idx][0]


  test_save_image_vis(URL, image_scores)


  save_image_vis(URL, image_scores)
  orig_image = Image.open(model_lrp.image_file_path)

  fig, axs = plt.subplots(ncols=2, figsize=(20, 5))
  axs[0].imshow(orig_image);
  axs[0].axis('off');
  axs[0].set_title('original');

  masked_image = Image.open('lxmert/lxmert/experiments/paper/new.jpg')
  axs[1].imshow(masked_image);
  axs[1].axis('off');
  axs[1].set_title('masked');

  # text_scores = (text_scores - text_scores.min()) / (text_scores.max() - text_scores.min())
  # vis_data_records = [visualization.VisualizationDataRecord(text_scores,0,0,0,0,0,model_lrp.question_tokens,1)]
  # visualization.visualize_text(vis_data_records)
  # print("ANSWER:", vqa_answers[model_lrp.output.question_answering_score.argmax()])

# ----------------- BREAK ---------------------

In [None]:
cv2.imread(URL).shape

In [None]:
URL = '../nii_depressed.jpg'
R_t_t, R_t_i = lrp.generate_ours((URL, 'is the boy holding a phone?'), use_lrp=False, normalize_self_attention=True, method_name="ours")
image_scores = R_t_i[0]
text_scores = R_t_t[0]


test_save_image_vis(URL, image_scores)


save_image_vis(URL, image_scores)
orig_image = Image.open(model_lrp.image_file_path)

fig, axs = plt.subplots(ncols=2, figsize=(20, 5))
axs[0].imshow(orig_image);
axs[0].axis('off');
axs[0].set_title('original');

masked_image = Image.open('lxmert/lxmert/experiments/paper/new.jpg')
axs[1].imshow(masked_image);
axs[1].axis('off');
axs[1].set_title('masked');

text_scores = (text_scores - text_scores.min()) / (text_scores.max() - text_scores.min())
vis_data_records = [visualization.VisualizationDataRecord(text_scores,0,0,0,0,0,model_lrp.question_tokens,1)]
visualization.visualize_text(vis_data_records)
print("ANSWER:", vqa_answers[model_lrp.output.question_answering_score.argmax()])

In [None]:
plt.imshow(R_t_i.cpu(), cmap = 'magma')
plt.colorbar()

In [None]:
plt.imshow(R_t_t.cpu(), cmap = 'magma')

In [None]:
URL = '../nii_depressed.jpg'
R_t_t, R_t_i = lrp.generate_ours((URL, 'does he have earphones plugged in?'), use_lrp=False, normalize_self_attention=True, method_name="ours")
image_scores = R_t_i[0]
text_scores = R_t_t[0]


test_save_image_vis(URL, image_scores)


save_image_vis(URL, image_scores)
orig_image = Image.open(model_lrp.image_file_path)

fig, axs = plt.subplots(ncols=2, figsize=(20, 5))
axs[0].imshow(orig_image);
axs[0].axis('off');
axs[0].set_title('original');

masked_image = Image.open('lxmert/lxmert/experiments/paper/new.jpg')
axs[1].imshow(masked_image);
axs[1].axis('off');
axs[1].set_title('masked');

text_scores = (text_scores - text_scores.min()) / (text_scores.max() - text_scores.min())
vis_data_records = [visualization.VisualizationDataRecord(text_scores,0,0,0,0,0,model_lrp.question_tokens,1)]
visualization.visualize_text(vis_data_records)
print("ANSWER:", vqa_answers[model_lrp.output.question_answering_score.argmax()])

In [None]:
plt.imshow(R_t_i.cpu(), cmap = 'magma')
plt.colorbar()

In [None]:
URL = 'lxmert/lxmert/experiments/paper/{0}/{0}.jpg'.format(image_ids[1])
R_t_t, R_t_i = lrp.generate_ours((URL, test_questions_for_images[1]), use_lrp=True, normalize_self_attention=True, method_name="ours")
image_scores = R_t_i[0]
text_scores = R_t_t[0]

test_save_image_vis(URL, image_scores)


save_image_vis(URL, image_scores)
orig_image = Image.open(model_lrp.image_file_path)

fig, axs = plt.subplots(ncols=2, nrows=1, figsize=(20, 5))
axs[0].imshow(orig_image);
axs[0].axis('off');
axs[0].set_title('original');

masked_image = Image.open('lxmert/lxmert/experiments/paper/new.jpg')
axs[1].imshow(masked_image);
axs[1].axis('off');
axs[1].set_title('masked');

text_scores = (text_scores - text_scores.min()) / (text_scores.max() - text_scores.min())
vis_data_records = [visualization.VisualizationDataRecord(text_scores,0,0,0,0,0,model_lrp.question_tokens,1)]
visualization.visualize_text(vis_data_records)
print("ANSWER:", vqa_answers[model_lrp.output.question_answering_score.argmax()])

In [None]:
URL = 'lxmert/lxmert/experiments/paper/{0}/{0}.jpg'.format(image_ids[1])
R_t_t, R_t_i = lrp.generate_ours((URL, test_questions_for_images[1]), use_lrp=True, normalize_self_attention=True, method_name="ours")
image_scores = lrp.attn_t_i[-1][0]
text_scores = lrp.self_attn_lang_agg[-1][0]


test_save_image_vis(URL, image_scores)


save_image_vis(URL, image_scores)
orig_image = Image.open(model_lrp.image_file_path)

fig, axs = plt.subplots(ncols=2, figsize=(20, 5))
axs[0].imshow(orig_image);
axs[0].axis('off');
axs[0].set_title('original');

masked_image = Image.open('lxmert/lxmert/experiments/paper/new.jpg')
axs[1].imshow(masked_image);
axs[1].axis('off');
axs[1].set_title('masked');

text_scores = (text_scores - text_scores.min()) / (text_scores.max() - text_scores.min())
vis_data_records = [visualization.VisualizationDataRecord(text_scores,0,0,0,0,0,model_lrp.question_tokens,1)]
visualization.visualize_text(vis_data_records)
print("ANSWER:", vqa_answers[model_lrp.output.question_answering_score.argmax()])

In [None]:
lrp.co_attn_lang_agg[0].shape

In [None]:
### padding t_i
final_attn_map = lrp.attn_t_i[-1].cpu()
W = torch.cat( (final_attn_map, torch.zeros( final_attn_map.shape[1] - final_attn_map.shape[0], final_attn_map.shape[1])), dim = 0 )
W = torch.where( W > 5e-5, 1, 0 )
### co_attn_image_agg
# W = lrp.co_attn_image_agg[-1].cpu()

# W = torch.matmul( lrp.attn_i_t[-1].cpu(), lrp.attn_t_i[-1].cpu() )

D = torch.zeros(W.shape[0], W.shape[1])
for i in range(D.shape[0]):
  D[i, i] = torch.sum(D[i])

L = D - W
# L = W

eig_vals, eig_vecs = torch.linalg.eig(L)
eig_vals = eig_vals.real
eig_vecs = eig_vecs.real

result, indices = torch.sort(eig_vals, descending=False)

print(result, indices)

In [None]:
URL = 'lxmert/lxmert/experiments/paper/{0}/{0}.jpg'.format(image_ids[1])
# R_t_t, R_t_i = lrp.generate_ours((URL, test_questions_for_images[0]), use_lrp=False, normalize_self_attention=True, method_name="ours")
# image_scores = eig_vecs[0]
image_scores = torch.where(eig_vecs[0] < 0, 1 - eig_vecs[0], eig_vecs[0])
# text_scores = lrp.self_attn_lang_agg[-1][0]


test_save_image_vis(URL, image_scores)


save_image_vis(URL, image_scores)
orig_image = Image.open(model_lrp.image_file_path)

fig, axs = plt.subplots(ncols=2, figsize=(20, 5))
axs[0].imshow(orig_image);
axs[0].axis('off')
axs[0].set_title('original');

masked_image = Image.open('lxmert/lxmert/experiments/paper/new.jpg')
axs[1].imshow(masked_image);
axs[1].axis('off');
axs[1].set_title('masked');

text_scores = (text_scores - text_scores.min()) / (text_scores.max() - text_scores.min())
vis_data_records = [visualization.VisualizationDataRecord(text_scores,0,0,0,0,0,model_lrp.question_tokens,1)]
visualization.visualize_text(vis_data_records)
print("ANSWER:", vqa_answers[model_lrp.output.question_answering_score.argmax()])

In [None]:
URL = 'lxmert/lxmert/experiments/paper/{0}/{0}.jpg'.format(image_ids[1])
# R_t_t, R_t_i = lrp.generate_ours((URL, test_questions_for_images[0]), use_lrp=True, normalize_self_attention=True, method_name="ours")
image_scores = indices[:5]
# text_scores = R_t_t[0]


test_save_image_vis_dummy(URL, image_scores)

In [None]:
URL = 'lxmert/lxmert/experiments/paper/{0}/{0}.jpg'.format(image_ids[2])
R_t_t, R_t_i = lrp.generate_ours((URL, test_questions_for_images[2]), use_lrp=False, normalize_self_attention=True, method_name="ours")
image_scores = R_t_i[0]
text_scores = R_t_t[0]


test_save_image_vis(URL, image_scores)


save_image_vis(URL, image_scores)
orig_image = Image.open(model_lrp.image_file_path)

fig, axs = plt.subplots(ncols=2, nrows=1, figsize=(20, 5))
axs[0].imshow(orig_image);
axs[0].axis('off');
axs[0].set_title('original');

masked_image = Image.open('lxmert/lxmert/experiments/paper/new.jpg')
axs[1].imshow(masked_image);
axs[1].axis('off');
axs[1].set_title('masked');

text_scores = (text_scores - text_scores.min()) / (text_scores.max() - text_scores.min())
vis_data_records = [visualization.VisualizationDataRecord(text_scores,0,0,0,0,0,model_lrp.question_tokens,1)]
visualization.visualize_text(vis_data_records)
print("ANSWER:", vqa_answers[model_lrp.output.question_answering_score.argmax()])

In [None]:
URL = 'lxmert/lxmert/experiments/paper/{0}/{0}.jpg'.format(image_ids[3])
R_t_t, R_t_i = lrp.generate_ours((URL, test_questions_for_images[3]), use_lrp=False, normalize_self_attention=True, method_name="ours")
image_scores = R_t_i[0]
text_scores = R_t_t[0]


test_save_image_vis(URL, image_scores)


save_image_vis(URL, image_scores)
orig_image = Image.open(model_lrp.image_file_path)

fig, axs = plt.subplots(ncols=2, nrows=1, figsize=(20, 5))
axs[0].imshow(orig_image);
axs[0].axis('off');
axs[0].set_title('original');

masked_image = Image.open('lxmert/lxmert/experiments/paper/new.jpg')
axs[1].imshow(masked_image);
axs[1].axis('off');
axs[1].set_title('masked');

text_scores = (text_scores - text_scores.min()) / (text_scores.max() - text_scores.min())
vis_data_records = [visualization.VisualizationDataRecord(text_scores,0,0,0,0,0,model_lrp.question_tokens,1)]
visualization.visualize_text(vis_data_records)
print("ANSWER:", vqa_answers[model_lrp.output.question_answering_score.argmax()])

In [None]:
### padding t_i
final_attn_map = lrp.attn_t_i[-1].cpu()
W = torch.cat( (final_attn_map, torch.zeros( final_attn_map.shape[1] - final_attn_map.shape[0], final_attn_map.shape[1])), dim = 0 )
# W = torch.where( W > 5e-5, 1, 0 )
### co_attn_image_agg
# W = lrp.co_attn_image_agg[-1].cpu()

# W = torch.matmul( lrp.attn_i_t[-1].cpu(), lrp.attn_t_i[-1].cpu() )

D = torch.zeros(W.shape[0], W.shape[1])
for i in range(D.shape[0]):
  D[i, i] = torch.sum(D[i])

L = D - W
# L = W

eig_vals, eig_vecs = torch.linalg.eig(L)
eig_vals = eig_vals.real
eig_vecs = eig_vecs.real

result, indices = torch.sort(eig_vals, descending=False)

print(result, indices)

In [None]:
URL = 'lxmert/lxmert/experiments/paper/{0}/{0}.jpg'.format(image_ids[3])
# R_t_t, R_t_i = lrp.generate_ours((URL, test_questions_for_images[0]), use_lrp=False, normalize_self_attention=True, method_name="ours")
# image_scores = eig_vecs[0]
image_scores = torch.where(eig_vecs[0] < 0, 1 - eig_vecs[0], eig_vecs[0])
# text_scores = lrp.self_attn_lang_agg[-1][0]


test_save_image_vis(URL, image_scores)


save_image_vis(URL, image_scores)
orig_image = Image.open(model_lrp.image_file_path)

fig, axs = plt.subplots(ncols=2, figsize=(20, 5))
axs[0].imshow(orig_image);
axs[0].axis('off')
axs[0].set_title('original');

masked_image = Image.open('lxmert/lxmert/experiments/paper/new.jpg')
axs[1].imshow(masked_image);
axs[1].axis('off');
axs[1].set_title('masked');

text_scores = (text_scores - text_scores.min()) / (text_scores.max() - text_scores.min())
vis_data_records = [visualization.VisualizationDataRecord(text_scores,0,0,0,0,0,model_lrp.question_tokens,1)]
visualization.visualize_text(vis_data_records)
print("ANSWER:", vqa_answers[model_lrp.output.question_answering_score.argmax()])

# **Online Examples**

To upload your own example, simply modify the URL to your image url, and the question to your question.

In [None]:
URL = "https://vqa.cloudcv.org/media/val2014/COCO_val2014_000000549112.jpg"

question =  'where is the knife?'

# save image to experiments folder
im = Image.open(requests.get(URL, stream=True).raw)
im.save('lxmert/lxmert/experiments/paper/online_image.jpg', 'JPEG')
URL = 'lxmert/lxmert/experiments/paper/online_image.jpg'

R_t_t, R_t_i = lrp.generate_ours((URL, question), use_lrp=False, normalize_self_attention=True, method_name="ours")
image_scores = R_t_i[0]
text_scores = R_t_t[0]
save_image_vis(URL, image_scores)
orig_image = Image.open(model_lrp.image_file_path)

fig, axs = plt.subplots(ncols=2, nrows=1, figsize=(20, 5))
axs[0].imshow(orig_image);
axs[0].axis('off');
axs[0].set_title('original');

masked_image = Image.open('lxmert/lxmert/experiments/paper/new.jpg')
axs[1].imshow(masked_image);
axs[1].axis('off');
axs[1].set_title('masked');

text_scores = (text_scores - text_scores.min()) / (text_scores.max() - text_scores.min())
vis_data_records = [visualization.VisualizationDataRecord(text_scores,0,0,0,0,0,model_lrp.question_tokens,1)]
visualization.visualize_text(vis_data_records)
print("ANSWER:", vqa_answers[model_lrp.output.question_answering_score.argmax()])

In [None]:
URL = "https://vqa.cloudcv.org/media/val2014/COCO_val2014_000000549112.jpg"

question =  'how many sandwiches?'

# save image to experiments folder
im = Image.open(requests.get(URL, stream=True).raw)
im.save('lxmert/lxmert/experiments/paper/online_image.jpg', 'JPEG')
URL = 'lxmert/lxmert/experiments/paper/online_image.jpg'

R_t_t, R_t_i = lrp.generate_ours((URL, question), use_lrp=False, normalize_self_attention=True, method_name="ours")
image_scores = R_t_i[0]
text_scores = R_t_t[0]
save_image_vis(URL, image_scores)
orig_image = Image.open(model_lrp.image_file_path)

fig, axs = plt.subplots(ncols=2, nrows=1, figsize=(20, 5))
axs[0].imshow(orig_image);
axs[0].axis('off');
axs[0].set_title('original');

masked_image = Image.open('lxmert/lxmert/experiments/paper/new.jpg')
axs[1].imshow(masked_image);
axs[1].axis('off');
axs[1].set_title('masked');

text_scores = (text_scores - text_scores.min()) / (text_scores.max() - text_scores.min())
vis_data_records = [visualization.VisualizationDataRecord(text_scores,0,0,0,0,0,model_lrp.question_tokens,1)]
visualization.visualize_text(vis_data_records)
print("ANSWER:", vqa_answers[model_lrp.output.question_answering_score.argmax()])

In [None]:
URL = "https://vqa.cloudcv.org/media/val2014/COCO_val2014_000000549112.jpg"

question =  'what is the surface made of?'

# save image to experiments folder
im = Image.open(requests.get(URL, stream=True).raw)
im.save('lxmert/lxmert/experiments/paper/online_image.jpg', 'JPEG')
URL = 'lxmert/lxmert/experiments/paper/online_image.jpg'

R_t_t, R_t_i = lrp.generate_ours((URL, question), use_lrp=False, normalize_self_attention=True, method_name="ours")
image_scores = R_t_i[0]
text_scores = R_t_t[0]
save_image_vis(URL, image_scores)
orig_image = Image.open(model_lrp.image_file_path)

fig, axs = plt.subplots(ncols=2, nrows=1, figsize=(20, 5))
axs[0].imshow(orig_image);
axs[0].axis('off');
axs[0].set_title('original');

masked_image = Image.open('lxmert/lxmert/experiments/paper/new.jpg')
axs[1].imshow(masked_image);
axs[1].axis('off');
axs[1].set_title('masked');

text_scores = (text_scores - text_scores.min()) / (text_scores.max() - text_scores.min())
vis_data_records = [visualization.VisualizationDataRecord(text_scores,0,0,0,0,0,model_lrp.question_tokens,1)]
visualization.visualize_text(vis_data_records)
print("ANSWER:", vqa_answers[model_lrp.output.question_answering_score.argmax()])

In [None]:
URL = "https://vqa.cloudcv.org/media/val2014/COCO_val2014_000000549112.jpg"

question =  'are there any cups?'

# save image to experiments folder
im = Image.open(requests.get(URL, stream=True).raw)
im.save('lxmert/lxmert/experiments/paper/online_image.jpg', 'JPEG')
URL = 'lxmert/lxmert/experiments/paper/online_image.jpg'

R_t_t, R_t_i = lrp.generate_ours((URL, question), use_lrp=False, normalize_self_attention=True, method_name="ours")
image_scores = R_t_i[0]
text_scores = R_t_t[0]
save_image_vis(URL, image_scores)
orig_image = Image.open(model_lrp.image_file_path)

fig, axs = plt.subplots(ncols=2, nrows=1, figsize=(20, 5))
axs[0].imshow(orig_image);
axs[0].axis('off');
axs[0].set_title('original');

masked_image = Image.open('lxmert/lxmert/experiments/paper/new.jpg')
axs[1].imshow(masked_image);
axs[1].axis('off');
axs[1].set_title('masked');

text_scores = (text_scores - text_scores.min()) / (text_scores.max() - text_scores.min())
vis_data_records = [visualization.VisualizationDataRecord(text_scores,0,0,0,0,0,model_lrp.question_tokens,1)]
visualization.visualize_text(vis_data_records)
print("ANSWER:", vqa_answers[model_lrp.output.question_answering_score.argmax()])

In [None]:
URL = "https://vqa.cloudcv.org/media/val2014/COCO_val2014_000000253263.jpg"

question =  'how many computers?'

# save image to experiments folder
im = Image.open(requests.get(URL, stream=True).raw)
im.save('lxmert/lxmert/experiments/paper/online_image.jpg', 'JPEG')
URL = 'lxmert/lxmert/experiments/paper/online_image.jpg'

R_t_t, R_t_i = lrp.generate_ours((URL, question), use_lrp=False, normalize_self_attention=True, method_name="ours")
image_scores = R_t_i[0]
text_scores = R_t_t[0]
save_image_vis(URL, image_scores)
orig_image = Image.open(model_lrp.image_file_path)

fig, axs = plt.subplots(ncols=2, nrows=1, figsize=(20, 5))
axs[0].imshow(orig_image);
axs[0].axis('off');
axs[0].set_title('original');

masked_image = Image.open('lxmert/lxmert/experiments/paper/new.jpg')
axs[1].imshow(masked_image);
axs[1].axis('off');
axs[1].set_title('masked');

text_scores = (text_scores - text_scores.min()) / (text_scores.max() - text_scores.min())
vis_data_records = [visualization.VisualizationDataRecord(text_scores,0,0,0,0,0,model_lrp.question_tokens,1)]
visualization.visualize_text(vis_data_records)
print("ANSWER:", vqa_answers[model_lrp.output.question_answering_score.argmax()])

In [None]:
URL = "https://vqa.cloudcv.org/media/val2014/COCO_val2014_000000253263.jpg"

question =  'what is the computer on?'

# save image to experiments folder
im = Image.open(requests.get(URL, stream=True).raw)
im.save('lxmert/lxmert/experiments/paper/online_image.jpg', 'JPEG')
URL = 'lxmert/lxmert/experiments/paper/online_image.jpg'

R_t_t, R_t_i = lrp.generate_ours((URL, question), use_lrp=False, normalize_self_attention=True, method_name="ours")
image_scores = R_t_i[0]
text_scores = R_t_t[0]
save_image_vis(URL, image_scores)
orig_image = Image.open(model_lrp.image_file_path)

fig, axs = plt.subplots(ncols=2, nrows=1, figsize=(20, 5))
axs[0].imshow(orig_image);
axs[0].axis('off');
axs[0].set_title('original');

masked_image = Image.open('lxmert/lxmert/experiments/paper/new.jpg')
axs[1].imshow(masked_image);
axs[1].axis('off');
axs[1].set_title('masked');

text_scores = (text_scores - text_scores.min()) / (text_scores.max() - text_scores.min())
vis_data_records = [visualization.VisualizationDataRecord(text_scores,0,0,0,0,0,model_lrp.question_tokens,1)]
visualization.visualize_text(vis_data_records)
print("ANSWER:", vqa_answers[model_lrp.output.question_answering_score.argmax()])

In [None]:
URL = "https://vqa.cloudcv.org/media/val2014/COCO_val2014_000000253263.jpg"

question =  'where is the calander?'

# save image to experiments folder
im = Image.open(requests.get(URL, stream=True).raw)
im.save('lxmert/lxmert/experiments/paper/online_image.jpg', 'JPEG')
URL = 'lxmert/lxmert/experiments/paper/online_image.jpg'

R_t_t, R_t_i = lrp.generate_ours((URL, question), use_lrp=False, normalize_self_attention=True, method_name="ours")
image_scores = R_t_i[0]
text_scores = R_t_t[0]
save_image_vis(URL, image_scores)
orig_image = Image.open(model_lrp.image_file_path)

fig, axs = plt.subplots(ncols=2, nrows=1, figsize=(20, 5))
axs[0].imshow(orig_image);
axs[0].axis('off');
axs[0].set_title('original');

masked_image = Image.open('lxmert/lxmert/experiments/paper/new.jpg')
axs[1].imshow(masked_image);
axs[1].axis('off');
axs[1].set_title('masked');

text_scores = (text_scores - text_scores.min()) / (text_scores.max() - text_scores.min())
vis_data_records = [visualization.VisualizationDataRecord(text_scores,0,0,0,0,0,model_lrp.question_tokens,1)]
visualization.visualize_text(vis_data_records)
print("ANSWER:", vqa_answers[model_lrp.output.question_answering_score.argmax()])

In [None]:
URL = "https://vqa.cloudcv.org/media/val2014/COCO_val2014_000000253263.jpg"

question =  'what is on the desk?'

# save image to experiments folder
im = Image.open(requests.get(URL, stream=True).raw)
im.save('lxmert/lxmert/experiments/paper/online_image.jpg', 'JPEG')
URL = 'lxmert/lxmert/experiments/paper/online_image.jpg'

R_t_t, R_t_i = lrp.generate_ours((URL, question), use_lrp=False, normalize_self_attention=True, method_name="ours")
image_scores = R_t_i[0]
text_scores = R_t_t[0]
save_image_vis(URL, image_scores)
orig_image = Image.open(model_lrp.image_file_path)

fig, axs = plt.subplots(ncols=2, nrows=1, figsize=(20, 5))
axs[0].imshow(orig_image);
axs[0].axis('off');
axs[0].set_title('original');

masked_image = Image.open('lxmert/lxmert/experiments/paper/new.jpg')
axs[1].imshow(masked_image);
axs[1].axis('off');
axs[1].set_title('masked');

text_scores = (text_scores - text_scores.min()) / (text_scores.max() - text_scores.min())
vis_data_records = [visualization.VisualizationDataRecord(text_scores,0,0,0,0,0,model_lrp.question_tokens,1)]
visualization.visualize_text(vis_data_records)
print("ANSWER:", vqa_answers[model_lrp.output.question_answering_score.argmax()])

# **DETR**

In [None]:
from PIL import Image
import requests
import matplotlib.pyplot as plt
import numpy as np

# from DETR.datasets.coco import *
import torch
import torchvision.transforms as T
import os
import random
import cv2
import DETR.util.misc as utils
from DETR.models import build_model
from DETR.modules.ExplanationGenerator import Generator
import argparse

Auxilary functions

In [None]:
class Namespace:
    def __init__(self, **kwargs):
        self.__dict__.update(kwargs)

In [None]:
# COCO classes
CLASSES = [
    'N/A', 'person', 'bicycle', 'car', 'motorcycle', 'airplane', 'bus',
    'train', 'truck', 'boat', 'traffic light', 'fire hydrant', 'N/A',
    'stop sign', 'parking meter', 'bench', 'bird', 'cat', 'dog', 'horse',
    'sheep', 'cow', 'elephant', 'bear', 'zebra', 'giraffe', 'N/A', 'backpack',
    'umbrella', 'N/A', 'N/A', 'handbag', 'tie', 'suitcase', 'frisbee', 'skis',
    'snowboard', 'sports ball', 'kite', 'baseball bat', 'baseball glove',
    'skateboard', 'surfboard', 'tennis racket', 'bottle', 'N/A', 'wine glass',
    'cup', 'fork', 'knife', 'spoon', 'bowl', 'banana', 'apple', 'sandwich',
    'orange', 'broccoli', 'carrot', 'hot dog', 'pizza', 'donut', 'cake',
    'chair', 'couch', 'potted plant', 'bed', 'N/A', 'dining table', 'N/A',
    'N/A', 'toilet', 'N/A', 'tv', 'laptop', 'mouse', 'remote', 'keyboard',
    'cell phone', 'microwave', 'oven', 'toaster', 'sink', 'refrigerator', 'N/A',
    'book', 'clock', 'vase', 'scissors', 'teddy bear', 'hair drier',
    'toothbrush'
]

# colors for visualization
COLORS = [[0.000, 0.447, 0.741], [0.850, 0.325, 0.098], [0.929, 0.694, 0.125],
          [0.494, 0.184, 0.556], [0.466, 0.674, 0.188], [0.301, 0.745, 0.933]]


# standard PyTorch mean-std input image normalization
transform = T.Compose([
    T.Resize(800),
    T.ToTensor(),
    T.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
])

# for output bounding box post-processing
def box_cxcywh_to_xyxy(x):
    x_c, y_c, w, h = x.unbind(1)
    b = [(x_c - 0.5 * w), (y_c - 0.5 * h),
         (x_c + 0.5 * w), (y_c + 0.5 * h)]
    return torch.stack(b, dim=1)

def rescale_bboxes(out_bbox, size):
    img_w, img_h = size
    b = box_cxcywh_to_xyxy(out_bbox)
    b = b * torch.tensor([img_w, img_h, img_w, img_h], dtype=torch.float32)
    return b

def plot_results(pil_img, prob, boxes):
    plt.figure(figsize=(16,10))
    plt.imshow(pil_img)
    ax = plt.gca()
    colors = COLORS * 100
    for p, (xmin, ymin, xmax, ymax), c in zip(prob, boxes.tolist(), colors):
        ax.add_patch(plt.Rectangle((xmin, ymin), xmax - xmin, ymax - ymin,
                                   fill=False, color=c, linewidth=3))
        cl = p.argmax()
        text = f'{CLASSES[cl]}: {p[cl]:0.2f}'
        ax.text(xmin, ymin, text, fontsize=15,
                bbox=dict(facecolor='yellow', alpha=0.5))
    plt.axis('off')
    plt.show()

In [None]:
device = 'cuda'
args = Namespace(aux_loss=True, backbone='resnet50', batch_size=2, bbox_loss_coef=5, clip_max_norm=0.1, coco_panoptic_path=None, coco_path=None, dataset_file='coco', dec_layers=6, device='cuda', dice_loss_coef=1, dilation=False, dim_feedforward=2048, dist_url='env://', distributed=False, dropout=0.1, enc_layers=6, eos_coef=0.1, epochs=300, eval=False, frozen_weights=None, giou_loss_coef=2, hidden_dim=256, lr=0.0001, lr_backbone=1e-05, lr_drop=200, mask_loss_coef=1, masks=False, nheads=8, num_queries=100, num_workers=2, output_dir='', position_embedding='sine', pre_norm=False, remove_difficult=False, resume='https://dl.fbaipublicfiles.com/detr/detr-r50-e632da11.pth', seed=42, set_cost_bbox=5, set_cost_class=1, set_cost_giou=2, start_epoch=0, weight_decay=0.0001, world_size=1)
model, criterion, postprocessors = build_model(args)
model.to(device)
checkpoint = torch.hub.load_state_dict_from_url(
            args.resume, map_location='cpu', check_hash=True)
model.load_state_dict(checkpoint['model'], strict=False)

In [None]:
def evaluate(model, gen, im, device, image_id = None):
    # mean-std normalize the input image (batch-size: 1)
    img = transform(im).unsqueeze(0).to(device)

    # propagate through the model
    outputs = model(img)

    # keep only predictions with 0.7+ confidence
    probas = outputs['pred_logits'].softmax(-1)[0, :, :-1]
    keep = probas.max(-1).values > 0.9

    if keep.nonzero().shape[0] <= 1:
        return

    outputs['pred_boxes'] = outputs['pred_boxes'].cpu()

    # convert boxes from [0; 1] to image scales
    bboxes_scaled = rescale_bboxes(outputs['pred_boxes'][0, keep], im.size)

    # use lists to store the outputs via up-values
    conv_features, enc_attn_weights, dec_attn_weights = [], [], []

    hooks = [
        model.backbone[-2].register_forward_hook(
            lambda self, input, output: conv_features.append(output)
        ),
        # model.transformer.encoder.layers[-1].self_attn.register_forward_hook(
        #     lambda self, input, output: enc_attn_weights.append(output[1])
        # ),
        model.transformer.decoder.layers[-1].multihead_attn.register_forward_hook(
            lambda self, input, output: dec_attn_weights.append(output[1])
        ),
    ]

    for layer in model.transformer.encoder.layers:
        hook = layer.self_attn.register_forward_hook(
            lambda self, input, output: enc_attn_weights.append(output[1])
        )
        hooks.append(hook)

    model(img)

    for hook in hooks:
        hook.remove()

    # don't need the list anymore
    conv_features = conv_features[0]
    enc_attn_weights = enc_attn_weights[-1]
    dec_attn_weights = dec_attn_weights[0]

    # get the feature map shape
    h, w = conv_features['0'].tensors.shape[-2:]
    img_np = np.array(im).astype(np.float)


    fig, axs = plt.subplots(ncols=len(bboxes_scaled), nrows=2, figsize=(22, 7))
    for idx, ax_i, (xmin, ymin, xmax, ymax) in zip(keep.nonzero(), axs.T, bboxes_scaled):
        ax = ax_i[0]
        cam = gen.generate_ours(img, idx, use_lrp=False)
        cam = (cam - cam.min()) / (cam.max() - cam.min())
        cmap = plt.cm.get_cmap('Blues').reversed()
        ax.imshow(cam.view(h, w).data.cpu().numpy(), cmap=cmap)
        ax.axis('off')
        ax.set_title(f'query id: {idx.item()}')
        ax = ax_i[1]
        ax.imshow(im)
        ax.add_patch(plt.Rectangle((xmin.detach(), ymin.detach()), xmax.detach() - xmin.detach(), ymax.detach() - ymin.detach(),
                                   fill=False, color='blue', linewidth=3))
        ax.axis('off')
        ax.set_title(CLASSES[probas[idx].argmax()])
    id_str = '' if image_id == None else image_id
    fig.tight_layout()
    plt.show()

**Paper examples**

In [None]:
gen = Generator(model)

In [None]:
url = 'http://images.cocodataset.org/val2017/000000039769.jpg'
im = Image.open(requests.get(url, stream=True).raw)

evaluate(model, gen, im, 'cuda')

In [None]:
url = 'http://images.cocodataset.org/val2017/000000216516.jpg'
im = Image.open(requests.get(url, stream=True).raw)

evaluate(model, gen, im, 'cuda')

In [None]:

url = 'http://images.cocodataset.org/val2017/000000359937.jpg'
im = Image.open(requests.get(url, stream=True).raw)

evaluate(model, gen, im, 'cuda')

In [None]:
url = 'http://images.cocodataset.org/val2017/000000192191.jpg'
im = Image.open(requests.get(url, stream=True).raw)

evaluate(model, gen, im, 'cuda')