In [1]:
import ipywidgets
import matplotlib.pyplot as plt
import timm
import torch
from torchvision.datasets import ImageFolder
import torchvision.transforms as transforms
from torchvision.utils import make_grid
import torch.nn.functional as F

# Helpers

In [2]:
def get_last_attention(backbone, x):
    """Get the attention weights of CLS from the last self-attention layer.

    Very hacky!

    Parameters
    ----------
    backbone : timm.models.vision_transformer.VisionTransformer
        Instantiated Vision Transformer. Note that we will in-place
        take the `head` attribute and replace it with `nn.Identity`.

    x : torch.Tensor
        Batch of images of shape `(n_samples, 3, size, size)`.

    Returns
    -------
    torch.Tensor
        Attention weights `(n_samples, n_heads, n_patches)`.
    """
    attn_module = backbone.blocks[-1].attn
    n_heads = attn_module.num_heads

    # define hook
    inp = None
    def fprehook(self, inputs):
        nonlocal inp
        inp = inputs[0]

    # Register a hook
    handle = attn_module.register_forward_pre_hook(fprehook)

    # Run forward pass
    _ = backbone(x)
    handle.remove()

    B, N, C = inp.shape
    qkv = attn_module.qkv(inp).reshape(B, N, 3, n_heads, C // n_heads).permute(2, 0, 3, 1, 4)
    q, k, v = qkv[0], qkv[1], qkv[2]

    attn = (q @ k.transpose(-2, -1)) * attn_module.scale
    attn = attn.softmax(dim=-1)

    return attn[:, :, 0, 1:]

In [3]:
def threshold(attn, k=30):
    n_heads = len(attn)
    indices = attn.argsort(dim=1, descending=True)[:, k:]

    for head in range(n_heads):
        attn[head, indices[head]] = 0

    attn /= attn.sum(dim=1, keepdim=True)

    return attn

In [4]:
def visualize_attention(img, backbone, k=30):
    """Create attention image.

    Parameteres
    -----------
    img : PIL.Image
        RGB image.

    backbone : timm.models.vision_transformer.VisionTransformer
        The vision transformer.

    Returns
    -------
    new_img : torch.Tensor
        Image of shape (n_heads, 1, height, width).
    """
    # imply parameters

    patch_size = backbone.patch_embed.proj.kernel_size[0]

    transform = transforms.Compose([

        transforms.Resize((224, 224)),
        transforms.ToTensor(),
        transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)),
        ]
    )

    device = next(backbone.parameters()).device
    x = transform(img)[None, ...].to(device)
    attn = get_last_attention(backbone, x)[0]  # (n_heads, n_patches)
    attn = attn / attn.sum(dim=1, keepdim=True)  # (n_heads, n_patches)
    attn = threshold(attn, k)
    attn = attn.reshape(-1, 14, 14)  # (n_heads, 14, 14)
    attn = F.interpolate(attn.unsqueeze(0),
        scale_factor=patch_size,
        mode="nearest"
        )[0]

    return attn

# Preparation

In [12]:
models = {
    "supervised": timm.create_model("vit_deit_small_patch16_224", pretrained=True),
    "selfsupervised": torch.load("/home/masud/Desktop/Thesis/mildlyoverfitted/dino/logs/best_model.pth", map_location="cpu").backbone,
}
# /home/masud/Desktop/Thesis/mildlyoverfitted/dino/logs/best_model.pth
# /home/masud/Desktop/Thesis/Results/22-06-03_jitter-blur-gray-non/checkpoint.pth
dataset = ImageFolder("/home/masud/Desktop/Thesis/dataset/ImageNet-20-Val")

colors = ["yellow", "red", "green", "blue"]

In [13]:
@ipywidgets.interact
def _(
    i=ipywidgets.IntSlider(min=0, max=len(dataset) - 1, continuous_update=False),
    k=ipywidgets.IntSlider(min=0, max=195, value=10, continuous_update=False),
    model=ipywidgets.Dropdown(options=["supervised", "selfsupervised"]),
):
    img = dataset[i][0]
    attns = visualize_attention(img, models[model], k=k).detach()[:].permute(1, 2, 0).numpy()

    tform = transforms.Compose([

        transforms.Resize((224, 224)),
    ])
    # original image
    plt.imshow(tform(img))
    plt.axis("off")
    plt.show()

    kwargs = {"vmin": 0, "vmax": 0.24}
    # Attentions
    n_heads = 6

    fig, axs = plt.subplots(2, 3, figsize=(10, 7))
    
    for i in range(n_heads):
        ax = axs[i // 3, i % 3]
        ax.imshow(attns[..., i], **kwargs)
        ax.axis("off")
        
    plt.tight_layout()
        
    plt.show()

interactive(children=(IntSlider(value=0, continuous_update=False, description='i', max=999), IntSlider(value=1…

In [None]:
# 3244, 1942, 3482, 688, 1509, 3709