
# Example usage of VI with a PyTorch model.
This is an example of VI usage on the DINO ViT model:
`https://github.com/facebookresearch/dino`
>This notebook comes as part of the `https://github.com/stamatiad/vit_inspect_tensorboard_plugin` repo and it assumes that you
 > run it in a linux/unix environment from within the `vit_inspect` folder.
This notebook supports both running on the Colab cloud and locally in ipython. 
* In case you run locally, it takes care to clone any repositories, in the parent folder of `vit_inspect` repo. 
* In case it runs on Colab cloud, it should clone both `dino` and the `vit_inspect` inside the runtimes storage.

To avoid restarting the Colab cloud runtime, and messing with cell order
execution, we based our code on the preinstalled package versions of Colab.
So no need to install any requirements.

In [None]:
import os, sys
from pathlib import Path

In [None]:
# Determine if you run in Colab cloud.
if 'google.colab' in str(get_ipython()):
    on_colab = True
else:
  on_colab = False

In [None]:
# Get out from the vit_inspect folder, into the parent directory of your git repositories:
if not on_colab:
  %cd ../

Now clone and install the VI plugin, if not done already. It is important to use setuptools for the TB to register the plugin.

In [None]:
if not Path(os.getcwd()).joinpath('vit_inspect_tensorboard_plugin').exists():
  !git clone https://stamatiad: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@github.com/stamatiad/vit_inspect_tensorboard_plugin.git
%cd vit_inspect_tensorboard_plugin
# Make sure repo is up to date:
!git pull
# Install using setup tools:
!pip install .
%cd ../

Now clone the repo of the PyTorch example model and checkout our customized branch to see
the changes required to run it along with the VI.

In [None]:
if not Path(os.getcwd()).joinpath('dino').exists():
  !git clone https://github.com/stamatiad/dino.git
%cd dino
# Make sure repo is up to date:
!git pull
# Checkout our custom branch, that integrates VI:
!git checkout stamatiad

As you can see, we have created a wrapper function (save_attn_weights) that
wraps the forward method of Attention. Now each time the forward method is
called and we have VI recording enabled (with vi.enable_vi() context manager)
, we will save the TB summaries in the directory ./vi_logs.

In [None]:
!git diff main..stamatiad -- vision_transformer.py

Now lets run our example

In [None]:
# Work on the original DINO with PyTorch:

# VI imports:
import tensorflow as tf
import json
from vit_inspect import vit_inspector as vi
from vit_inspect.summary_v2 import vi_summary

In [None]:
import numpy as np
import torch
import torch.nn as nn
from torchvision import transforms as pth_transforms
from vision_transformer import VisionTransformer
from PIL import Image
import requests
from io import BytesIO
from pathlib import Path

Load the pre-trained model and transfer its parameters to a new instance of
our modified model, that VI listens to.

In [None]:
# Load the pre-trained model:
model_cached = torch.hub.load('facebookresearch/dino:main', 'dino_vits16')
# Create a version of the model that holds our attention VI modifications:
# Match model params, before load:
num_features = model_cached.embed_dim
model = VisionTransformer(embed_dim=num_features)
model.load_state_dict(model_cached.state_dict(), strict=False)

# Enable evaluation mode:
device = torch.device("cpu")
for p in model.parameters():
    p.requires_grad = False
model.eval()
model.to(device)

Initialize the VI and inform it about our model parameters:

In [None]:
# Get some model params, required for VI:
vi.params["num_layers"] = len(model.blocks)
vi.params["num_heads"] = model.blocks[0].attn.num_heads
# The number of tokens when the attention dot product happens.
# Here tokens are the patches. Any other feature (e.g. class) is removed.
patch_size = model.patch_embed.patch_size
crop_size = 480
img_size_in_patches = crop_size // patch_size
vi.params["len_in_patches"] = img_size_in_patches
# Total patches in the image:
vi.params["num_tokens"] = img_size_in_patches ** 2


Load a sample image to calculate attention maps uppon.

In [None]:
# Load sample images:
response = requests.get("https://dl.fbaipublicfiles.com/dino/img.png")
img = Image.open(BytesIO(response.content))
img = img.convert('RGB')

# Perform the original transformations that the authors did.
transform = pth_transforms.Compose([
    pth_transforms.Resize(img.size),
    pth_transforms.ToTensor(),
    pth_transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)),
])
img = transform(img)
# make the image divisible by the patch size
w, h = img.shape[1] - img.shape[1] % patch_size, \
       img.shape[2] - img.shape[2] % patch_size
img = img[:, :w, :h].unsqueeze(0)

w_featmap = img.shape[-2] // patch_size
h_featmap = img.shape[-1] // patch_size


Save a copy of the input image for the VI to display it as preview, making it
 easier to visualize the attention maps.

In [None]:
# Save the input image into the summary:
flat_arr_rgb = tf.convert_to_tensor(
    # Make sure image's channels is the last dim:
    np.moveaxis(np.asarray(img), 1, -1)
)
with vi.writer.as_default():
    step = 0
    batch_id = 0
    vi.params["step"] = 0
    vi.params["batch_id"] = batch_id
    vi_summary(
        f"b{batch_id}",
        flat_arr_rgb,
        step=step,
        description=json.dumps(vi.params)
    )
    vi.writer.flush()


Finally, perform inference with VI enabled:

In [None]:
# Use the VI context manager to get attention maps of each layer and head:
with vi.enable_vi():
    attentions = model.get_last_selfattention(img.to(device))


In [None]:
%load_ext tensorboard
%tensorboard --logdir vi_logs