In [None]:
import sys
sys.path.append("..")

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.utils.data
from torch.utils.data import Dataset
import torchvision.transforms as transforms
from transformers import GPT2Tokenizer, AutoConfig
from transformers import AdamW, get_linear_schedule_with_warmup
import json
from self_training.cococaption.pycocotools.coco import COCO
from self_training.cococaption.pycocoevalcap.eval import COCOEvalCap
from PIL import Image
from accelerate import Accelerator
from self_training import data_utils
from self_training.models.gpt import GPT2LMHeadModel
from self_training.eval_utils import top_filtering
import self_training.models.clip_x.clip as clip

### Data Path Definition

In [None]:
caption_save_path = '../self_training/cococaption/results/' 
annFileExp = '../self_training/cococaption/annotations/vqaX_test_annot_exp.json'
annFileFull = '../self_training/cococaption/annotations/vqaX_test_annot_full.json'
nle_data_test_path = '../self_training/nle_data/VQA-X/vqaX_test.json'
nle_data_val_path = '../self_training/nle_data/VQA-X/vqaX_val.json'
nle_data_train_path = '../self_training/nle_data/VQA-X/vqaX_train.json'

### Load Model

In [None]:
def load_checkpoint(ckpt_path, epoch):
    
    model_name = 'nle_model_{}'.format(str(epoch))
    tokenizer_name = 'nle_gpt2_tokenizer_0'
    filename = 'ckpt_stats_' + str(epoch) + '.tar'
    
    tokenizer = GPT2Tokenizer.from_pretrained(ckpt_path + tokenizer_name)        # load tokenizer
    model = GPT2LMHeadModel.from_pretrained(ckpt_path + model_name).to(device)   # load model with config

    return tokenizer, model

def change_requires_grad(model, req_grad):
    for p in model.parameters():
        p.requires_grad = req_grad

accelerator = Accelerator()
device = accelerator.device

clip._MODELS = {
    "ViT-B/32": "https://openaipublic.azureedge.net/clip/models/40d365715913c9da98579312b702a82c18be219cc2a73407c4526f58eba950af/ViT-B-32.pt",
    "ViT-B/16": "https://openaipublic.azureedge.net/clip/models/5806e77cd80f8b59890b7e101eabd078d9fb84e6937f9e85e4ecb61988df416f/ViT-B-16.pt",
    "ViT-L/14": "https://openaipublic.azureedge.net/clip/models/b8cca3fd41ae0c99ba7e8951adf17d267cdb84cd88be6f7c2e0eca1737a03836/ViT-L-14.pt"
}

clip_model, preprocess = clip.load("ViT-B/16", device=device, jit=False)
print("Image resolution:", clip_model.visual.input_resolution)
image_encoder = clip_model
# change_requires_grad(image_encoder, False)
img_transform = preprocess

### Data Processing

In [None]:
def get_elements(question_id):
    sample = data[question_id]
    img_name = sample['image_name']
    text_a = data_utils.proc_ques(sample['question'])    # question

    # tokenization process
    q_segment_id, a_segment_id, e_segment_id = tokenizer.convert_tokens_to_ids(['<question>', '<answer>', '<explanation>'])
    tokens = tokenizer.tokenize(text_a)
    segment_ids = [q_segment_id] * len(tokens)

    answer = [tokenizer.bos_token] + tokenizer.tokenize(" the answer is")
    answer_len = len(answer)
    tokens += answer 

    segment_ids += [a_segment_id] * answer_len

    input_ids = tokenizer.convert_tokens_to_ids(tokens)
    input_ids = torch.tensor(input_ids, dtype=torch.long)
    segment_ids = torch.tensor(segment_ids, dtype=torch.long)

    folder = '../self_training/images/train2014/' if 'train' in img_name else '../self_training/images/val2014/'   # test and val are both in val2014
    img_path = folder + img_name
    img = Image.open(img_path)
    #.convert('RGB')
    img = img_transform(img)
    qid = torch.LongTensor([int(question_id)])

    return (img, qid, input_ids, segment_ids, img_path)

# data list
data = json.load(open(nle_data_test_path, "r"))
ids_list = list(data.keys())
len(ids_list)

### Transformer Interability

In [None]:
from PIL import Image
import numpy as np
import cv2, math
import matplotlib.pyplot as plt

from self_training.models.gpt import NLX_GPT


load_from_epoch = 11
ckpt_path = '../self_training/ckpts/VQAX_p/'

tokenizer, model = load_checkpoint(ckpt_path, load_from_epoch)

SPECIAL_TOKENS = ['<|endoftext|>', '<pad>', '<question>', '<answer>', '<explanation>']
special_tokens_ids = tokenizer.convert_tokens_to_ids(SPECIAL_TOKENS)
because_token_id = tokenizer.convert_tokens_to_ids('Ġbecause')


nlx_gpt = NLX_GPT(visual_backbone=image_encoder, lm_backbone=model)

In [None]:
def show_image_relevance(image_relevance, image, orig_image):
    # create heatmap from mask on image
    def show_cam_on_image(img, mask):
        heatmap = cv2.applyColorMap(np.uint8(255 * mask), cv2.COLORMAP_JET)
        heatmap = np.float32(heatmap) / 255
        cam = heatmap + np.float32(img)
        cam = cam / np.max(cam)
        return cam

    # plt.axis('off')
    # f, axarr = plt.subplots(1,2)
    # axarr[0].imshow(orig_image)

    fig, axs = plt.subplots(1, 2)
    axs[0].imshow(orig_image);
    axs[0].axis('off');
    
    feat_hw = int(math.sqrt(image_relevance.shape[-1]))
    image_relevance = image_relevance.reshape(1, 1, feat_hw, feat_hw)
    image_relevance = torch.nn.functional.interpolate(image_relevance, size=224, mode='bilinear')
    image_relevance = image_relevance.reshape(224, 224).cuda().data.cpu().numpy()
    image_relevance = (image_relevance - image_relevance.min()) / (image_relevance.max() - image_relevance.min())
    image = image[0].permute(1, 2, 0).data.cpu().numpy()
    image = (image - image.min()) / (image.max() - image.min())
    vis = show_cam_on_image(image, image_relevance)
    vis = np.uint8(255 * vis)
    vis = cv2.cvtColor(np.array(vis), cv2.COLOR_RGB2BGR)
    # axar[1].imshow(vis)
    axs[1].imshow(vis);
    axs[1].axis('off');
    # plt.imshow(vis)

In [None]:
from transformers import top_k_top_p_filtering

img_size = 224
max_seq_len = 40
do_sample = False
top_k =  0
top_p =  0.9
temperature = 1

start_layer = 11

q_id = ids_list[0]
batch = get_elements(q_id)
img_path = batch[-1]
batch = tuple(input_tensor.unsqueeze(0).to(device) for input_tensor in batch[:-1])
img, img_id, input_ids, segment_ids = batch
batch_size = img.shape[0]

image_attn_blocks = list(dict(nlx_gpt.visual_encoder.visual.transformer.resblocks.named_children()).values())
num_tokens = image_attn_blocks[0].attn_probs.shape[-1]

img_relevance_maps = []
current_output = []
current_logits = []
always_exp = False
for step in range(max_seq_len+1):
    logits = nlx_gpt(image=img, input_ids=input_ids, segment_ids=segment_ids)
    
    logits = logits / temperature if temperature > 0 else 1.0
    filtered_logits = top_k_top_p_filtering(logits, top_k=top_k, top_p=top_p)
    probs = F.softmax(filtered_logits, dim=-1)
    prev = torch.multinomial(probs, dim=-1) if do_sample else torch.argmax(filtered_logits, dim=-1).unsqueeze(-1)
    
    # for explanation
    one_hot = F.one_hot(prev, num_classes=logits.shape[-1]).type(torch.float32)
    one_hot = one_hot.requires_grad_(True).cuda()
    one_hot = torch.sum(one_hot.cuda()*logits, dim=-1)
    
    nlx_gpt.zero_grad()
    
    # image relevance
    R = torch.eye(num_tokens, num_tokens, dtype=image_attn_blocks[0].attn_probs.dtype).to(device)
    R = R.unsqueeze(0).expand(batch_size, num_tokens, num_tokens)
    for i, blk in enumerate(image_attn_blocks):
        if i < start_layer:
            continue
        grad = torch.autograd.grad(one_hot, [blk.attn_probs], retain_graph=True)[0].detach()
        cam = blk.attn_probs.detach()
        cam = cam.reshape(-1, cam.shape[-1], cam.shape[-1])
        grad = grad.reshape(-1, grad.shape[-1], grad.shape[-1])
        cam = grad * cam
        cam = cam.reshape(batch_size, -1, cam.shape[-1], cam.shape[-1])
        cam = cam.clamp(min=0).mean(dim=1)
        R = R + torch.bmm(cam, R)
    image_relevance = R[:,1,1:]
    img_relevance_maps.append(image_relevance)
    
    if prev.item() in special_tokens_ids:
        break
    
    # take care of when to start the <explanation> token
    if not always_exp:
        if prev.item()!=because_token_id:
            new_segment = special_tokens_ids[-2] # answer segment
        else:
            new_segment = special_tokens_ids[-1] # explanation segment
            always_exp = True
    else:
        new_segment = special_tokens_ids[-1] # explanation segment
    
    new_segment = torch.LongTensor([new_segment]).to(device)
    current_output.append(prev)
    current_logits.append(logits.unsqueeze(1))
    input_ids = torch.cat((input_ids, prev), dim=1)
    segment_ids = torch.cat((segment_ids, new_segment.unsqueeze(0).expand(segment_ids.shape[0],-1)), dim=1)

current_output = torch.cat(current_output, dim=1)
current_logits = torch.cat(current_logits, dim=1)
current_output = current_output.detach().cpu().numpy()
current_output = tokenizer.decode(current_output[0], clean_up_tokenization_space=True)
print(current_output)
pil_image = Image.open(img_path)
show_image_relevance(img_relevance_maps[0], img, pil_image)

### Attention Map

In [None]:
# R = R / torch.max(R)
R = R.squeeze(0)
plt.imshow(R.detach().cpu().numpy())
plt.colorbar()