In [10]:
%matplotlib inline
%load_ext autoreload
%autoreload 2

import torch
from torchvision import transforms
import numpy as np
import cv2
import matplotlib.pyplot as plt
from PIL import Image

from CLIP import CLIP, build_transform
from clip_tokenizer import SimpleTokenizer
from visualize_attention import get_attention_maps

plt.rcParams['figure.figsize'] = (15, 15)

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [13]:
MODEL_PATH = 'clip.pth'

model = CLIP(attention_probs_dropout_prob=0, hidden_dropout_prob=0)
model.load_state_dict(state_dict = torch.load(MODEL_PATH))

is_fp16 = False

# device = "cuda" if torch.cuda.is_available() else "cpu"
device = 'cpu'
if is_fp16:
    model.to(device=device).eval().half()
else:
    model.to(device=device).eval().float()

In [11]:
VOCAB_PATH = 'bpe_simple_vocab_16e6.txt.gz'
tokenizer = SimpleTokenizer(
        bpe_path=VOCAB_PATH,
        context_length=model.context_length.item())
transform = build_transform(model.input_resolution.item())
view_transform = transforms.Compose([
    transforms.Resize(224, interpolation=Image.BICUBIC),
    transforms.CenterCrop(224),
    lambda image: image.convert('RGB')])

In [94]:
with torch.no_grad():
    query = ["a baloon", "a human", "a labrador retriever", "eyes", "a dog", "a human and a tiger"]
    text = tokenizer.encode(query).to(device)
    text_features = model.encode_text(text)  # N_queries x 512

    image_path = "images/dog.jpg"
    image_vis = np.asarray(view_transform(Image.open(image_path)))
    image = transform(Image.open(image_path)).unsqueeze(0).to(device)
    image_features = model.encode_image(image) # 1 x 512

    text_attention = get_attention_maps(model, visual=False)
    visual_attention = get_attention_maps(model, visual=True).squeeze(0)

    vis = visual_attention[0, 0, 1:].reshape(7,7).detach().numpy()
    vis -= vis.min()
    vis /= vis.max()
    vis = cv2.resize(vis, (224, 224))[...,np.newaxis]
    result = (vis * image_vis).astype(np.uint8)
    plt.imshow(result)
    logits_per_image, logits_per_text = model(image, text, return_loss=False)
    probs = logits_per_image.softmax(dim=-1).cpu().numpy()[0]

final = sorted(zip(query, list(probs)), key=lambda obj: obj[1], reverse=True)
for key, value in final:
    print(f'{key}: {value}')

TypeError: super(type, obj): obj must be an instance or subtype of type

In [44]:
print(text_attention.shape)

torch.Size([6, 12, 77, 77])


In [45]:
print(text_attention[0,-1])

tensor([[1.0000e+00, 0.0000e+00, 0.0000e+00,  ..., 0.0000e+00, 0.0000e+00,
         0.0000e+00],
        [9.9949e-01, 5.1279e-04, 0.0000e+00,  ..., 0.0000e+00, 0.0000e+00,
         0.0000e+00],
        [9.9913e-01, 3.8664e-04, 4.8334e-04,  ..., 0.0000e+00, 0.0000e+00,
         0.0000e+00],
        ...,
        [9.9729e-01, 2.4863e-04, 2.3441e-04,  ..., 3.3182e-04, 0.0000e+00,
         0.0000e+00],
        [9.9722e-01, 2.5147e-04, 2.3545e-04,  ..., 9.1951e-05, 3.3444e-04,
         0.0000e+00],
        [9.9758e-01, 2.2153e-04, 2.1913e-04,  ..., 3.6957e-05, 3.1341e-05,
         2.8237e-04]])


In [73]:
for TEXT_LAYER in range(0, 12):
    print(TEXT_LAYER)
    print(text_attention[0,TEXT_LAYER,1]) # 3
    print(text_attention[3,TEXT_LAYER,1]) # 2

0
tensor([0.1623, 0.8377, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
        0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
        0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
        0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
        0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
        0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
        0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
        0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
        0.0000, 0.0000, 0.0000, 0.0000, 0.0000])
tensor([0.1176, 0.8824, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
        0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
        0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
        0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.000