
# Example usage of VI with a PyTorch model.
This is an example of VI usage on the DINO ViT model:
`https://github.com/facebookresearch/dino`
>This notebook comes as part of the `https://github.com/stamatiad/vit_inspect_tensorboard_plugin` repo and it assumes that you
 > run it in a linux/unix environment from within the `vit_inspect` folder.
This notebook supports both running on the Colab cloud and locally in ipython. 
* In case you run locally, it takes care to clone any repositories, in the parent folder of `vit_inspect` repo. 
* In case it runs on Colab cloud, it should clone both `dino` and the `vit_inspect` inside the runtimes storage.

To avoid restarting the Colab cloud runtime, and messing with cell order
execution, we based our code on the preinstalled package versions of Colab.
So no need to install any requirements.

In [3]:
import os, sys
from pathlib import Path

In [4]:
# Determine if you run in Colab cloud.
if 'google.colab' in str(get_ipython()):
    on_colab = True
else:
  on_colab = False

In [6]:
# Get out from the vit_inspect folder, into the parent directory of your git repositories:
if not on_colab:
  %cd ../../

C:\Users\stama\Documents\GitHub


Now clone and install the VI plugin, if not done already. It is important to use setuptools for the TB to register the plugin.

In [21]:
if not Path(os.getcwd()).joinpath('vit_inspect_tensorboard_plugin').exists():
  !git clone https://stamatiad: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@github.com/stamatiad/vit_inspect_tensorboard_plugin.git
%cd vit_inspect_tensorboard_plugin
# Make sure repo is up to date:
!git pull
# Install using setup tools:
!pip install --use-feature=in-tree-build .
%cd ../

C:\Users\stama\Documents\GitHub\vit_inspect_tensorboard_plugin
Already up to date.
Processing c:\users\stama\documents\github\vit_inspect_tensorboard_plugin
Building wheels for collected packages: vit-inspect
  Building wheel for vit-inspect (setup.py): started
  Building wheel for vit-inspect (setup.py): finished with status 'done'
  Created wheel for vit-inspect: filename=vit_inspect-0.1.0-py3-none-any.whl size=19195 sha256=646067297fdf66f5e5b96e89892a6fefc71c83c6071ba4e501e0c684a5d4356a
  Stored in directory: C:\Users\stama\AppData\Local\Temp\pip-ephem-wheel-cache-xtvypzwq\wheels\57\1b\42\af0f5c7f58855909e5c471759a70a9aa418413499741d9dcb3
Successfully built vit-inspect
Installing collected packages: vit-inspect
  Attempting uninstall: vit-inspect
    Found existing installation: vit-inspect 0.1.0
    Uninstalling vit-inspect-0.1.0:
      Successfully uninstalled vit-inspect-0.1.0
Successfully installed vit-inspect-0.1.0
C:\Users\stama\Documents\GitHub


Now clone the repo of the PyTorch example model and checkout our customized branch to see
the changes required to run it along with the VI.

In [22]:
if not Path(os.getcwd()).joinpath('dino').exists():
  !git clone https://github.com/stamatiad/dino.git
%cd dino
# Make sure repo is up to date:
!git pull
# Checkout our custom branch, that integrates VI:
!git checkout stamatiad

C:\Users\stama\Documents\GitHub\dino
Already up to date.
D	requirements_colab.txt
D	vi_example.ipynb
M	vi_example.py
D	vi_logs/events.out.tfevents.1672074953.stamatiad-laptop.9563.0.v2
D	vi_logs/events.out.tfevents.1674386022.stamatiad-laptop.5141.0.v2
M	vision_transformer.py
Your branch is up to date with 'origin/stamatiad'.


Already on 'stamatiad'


As you can see, we have created a wrapper function (save_attn_weights) that
wraps the forward method of Attention. Now each time the forward method is
called and we have VI recording enabled (with vi.enable_vi() context manager)
, we will save the TB summaries in the directory ./vi_logs.

In [23]:
!git diff main..stamatiad -- vision_transformer.py

diff --git a/vision_transformer.py b/vision_transformer.py
index f69a7ad..102e06d 100644
--- a/vision_transformer.py
+++ b/vision_transformer.py
@@ -23,6 +23,70 @@ import torch.nn as nn
 
 from utils import trunc_normal_
 
+# VI imports:
+from functools import wraps
+import numpy as np
+import json
+import tensorflow as tf
+from vit_inspect import vit_inspector as vi
+from vit_inspect.summary_v2 import vi_summary
+
+def save_attn_weights():
+    # Zero indexed layer counter:
+    layer_counter = 0
+    def decorator(func):
+        @wraps(func)
+        def wrapper(*args, **kw):
+            # Get access to the variable:
+            nonlocal layer_counter
+            try:
+                # Evaluate the attention function first:
+                x, w = func(*args, **kw)
+                # If we are recording with VI:
+                if vi._summary_is_active:
+                    # Number of tokens:
+                    nt = vi.params["num_tokens"]
+                    # Patch size
+

Now lets run our example

In [24]:
# Work on the original DINO with PyTorch:

# VI imports:
import tensorflow as tf
import json
from vit_inspect import vit_inspector as vi
from vit_inspect.summary_v2 import vi_summary

In [25]:
import numpy as np
import torch
import torch.nn as nn
from torchvision import transforms as pth_transforms
from vision_transformer import VisionTransformer
from PIL import Image
import requests
from io import BytesIO
from pathlib import Path

Load the pre-trained model and transfer its parameters to a new instance of
our modified model, that VI listens to.

In [26]:
# Load the pre-trained model:
model_cached = torch.hub.load('facebookresearch/dino:main', 'dino_vits16')
# Create a version of the model that holds our attention VI modifications:
# Match model params, before load:
num_features = model_cached.embed_dim
model = VisionTransformer(embed_dim=num_features)
model.load_state_dict(model_cached.state_dict(), strict=False)

# Enable evaluation mode:
device = torch.device("cpu")
for p in model.parameters():
    p.requires_grad = False
model.eval()
model.to(device)

Using cache found in C:\Users\stama/.cache\torch\hub\facebookresearch_dino_main


VisionTransformer(
  (patch_embed): PatchEmbed(
    (proj): Conv2d(3, 384, kernel_size=(16, 16), stride=(16, 16))
  )
  (pos_drop): Dropout(p=0.0, inplace=False)
  (blocks): ModuleList(
    (0): Block(
      (norm1): LayerNorm((384,), eps=1e-05, elementwise_affine=True)
      (attn): Attention(
        (qkv): Linear(in_features=384, out_features=1152, bias=False)
        (attn_drop): Dropout(p=0.0, inplace=False)
        (proj): Linear(in_features=384, out_features=384, bias=True)
        (proj_drop): Dropout(p=0.0, inplace=False)
      )
      (drop_path): Identity()
      (norm2): LayerNorm((384,), eps=1e-05, elementwise_affine=True)
      (mlp): Mlp(
        (fc1): Linear(in_features=384, out_features=1536, bias=True)
        (act): GELU(approximate='none')
        (fc2): Linear(in_features=1536, out_features=384, bias=True)
        (drop): Dropout(p=0.0, inplace=False)
      )
    )
    (1): Block(
      (norm1): LayerNorm((384,), eps=1e-05, elementwise_affine=True)
      (attn): A

Initialize the VI and inform it about our model parameters:

In [27]:
# Get some model params, required for VI:
vi.params["num_layers"] = len(model.blocks)
vi.params["num_heads"] = model.blocks[0].attn.num_heads
# The number of tokens when the attention dot product happens.
# Here tokens are the patches. Any other feature (e.g. class) is removed.
patch_size = model.patch_embed.patch_size
crop_size = 480
img_size_in_patches = crop_size // patch_size
vi.params["len_in_patches"] = img_size_in_patches
# Total patches in the image:
vi.params["num_tokens"] = img_size_in_patches ** 2


Load a sample image to calculate attention maps uppon.

In [28]:
# Load sample images:
response = requests.get("https://dl.fbaipublicfiles.com/dino/img.png")
img = Image.open(BytesIO(response.content))
img = img.convert('RGB')

# Perform the original transformations that the authors did.
transform = pth_transforms.Compose([
    pth_transforms.Resize(img.size),
    pth_transforms.ToTensor(),
    pth_transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)),
])
img = transform(img)
# make the image divisible by the patch size
w, h = img.shape[1] - img.shape[1] % patch_size, \
       img.shape[2] - img.shape[2] % patch_size
img = img[:, :w, :h].unsqueeze(0)

w_featmap = img.shape[-2] // patch_size
h_featmap = img.shape[-1] // patch_size


Save a copy of the input image for the VI to display it as preview, making it
 easier to visualize the attention maps.

In [29]:
# Save the input image into the summary:
flat_arr_rgb = tf.convert_to_tensor(
    # Make sure image's channels is the last dim:
    np.moveaxis(np.asarray(img), 1, -1)
)
with vi.writer.as_default():
    step = 0
    batch_id = 0
    vi.params["step"] = 0
    vi.params["batch_id"] = batch_id
    vi_summary(
        f"b{batch_id}",
        flat_arr_rgb,
        step=step,
        description=json.dumps(vi.params)
    )
    vi.writer.flush()


Finally, perform inference with VI enabled:

In [30]:
# Use the VI context manager to get attention maps of each layer and head:
with vi.enable_vi():
    attentions = model.get_last_selfattention(img.to(device))


In [8]:
%load_ext tensorboard


In [None]:
%tensorboard --logdir vi_logs