In [None]:
pip install -q transformers torch torchvision pillow


In [None]:
import torch
from PIL import Image
import requests
from transformers import AutoImageProcessor, AutoModelForImageClassification

MODEL_NAME = "google/vit-base-patch16-224"

processor = AutoImageProcessor.from_pretrained(MODEL_NAME) # resizing and normalization
model = AutoModelForImageClassification.from_pretrained(MODEL_NAME) # pulls the model

model.eval() # because we are using pre-trained models and doing only inference


In [None]:
# getting a sample image
url = "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/cats.png"
image = Image.open(requests.get(url, stream=True).raw).convert("RGB")

# processing it for Vit with our processor
inputs = processor(images=image, return_tensors="pt")

with torch.no_grad():
    outputs = model(**inputs)
    logits = outputs.logits

pred_id = logits.argmax(dim=-1).item() # takes largest output logit (predicted class ID) and converts to python int from tensor
label = model.config.id2label[pred_id] # gets the label for the predicted class ID from internal look-up dictionary stored in HF model.

print("Prediction:", label)


In [None]:
import matplotlib.pyplot as plt
plt.imshow(image)

In [None]:
# visualizing attention maps
model.set_attn_implementation("eager") # to ensure attention tensor is stored in memory for extraction

outputs = model(**inputs, output_attentions=True, output_hidden_states=True)

attentions = outputs.attentions      # attention maps. tuple: [layers, B, heads, tokens, tokens]
hidden_states = outputs.hidden_states # feature maps. tuple: [layers, B, tokens, dim]


In [None]:
attn = attentions[-1]        # [B, heads, N, N]
attn = attn.mean(dim=1)     # [B, N, N]
cls_attn = attn[:, 0, 1:]   # [B, num_patches]
cls_attn = cls_attn / cls_attn.max()

heatmap = cls_attn.reshape(14, 14)
# Detach the tensor and convert to numpy before plotting
plt.imshow(heatmap.detach().numpy(), cmap="jet")
plt.colorbar()

In [None]:
# getting attention rollout

"""
How it works conceptually :

For each layer:

- Average attention heads
- Add identity matrix (residual connection)
- Normalize rows
- Multiply attention matrices layer by layer

Mathematically:

Rollout = A₁ · A₂ · A₃ · ... · Aₙ

Where each Aᵢ is a layer’s attention matrix.
"""
def attention_rollout(attentions):
    rollout = torch.eye(attentions[0].size(-1)).to(attentions[0].device)

    for attn in attentions:
        attn = attn.mean(dim=1)           # avg heads
        attn = attn + torch.eye(attn.size(-1)).to(attn.device)
        attn = attn / attn.sum(dim=-1, keepdim=True)
        rollout = attn @ rollout

    return rollout

rollout = attention_rollout(attentions)
cls_rollout = rollout[0, 0, 1:]   # CLS → patches
cls_rollout = cls_rollout / cls_rollout.max()

heatmap = cls_rollout.reshape(14, 14)
plt.imshow(heatmap.detach().numpy(), cmap="jet")
plt.colorbar()

In [None]:
def overlay_heatmap_on_image(image, heatmap, alpha=0.5, cmap="jet"):
    """
    image  : PIL.Image | numpy array | torch tensor  (H,W,3)
    heatmap: 2D tensor/array (H,W) or (h,w) → will be resized
    """

    import numpy as np
    import torch
    import cv2
    import matplotlib.pyplot as plt

    # ---- Convert image to numpy ----
    if isinstance(image, torch.Tensor):
        image = image.detach().cpu().numpy()
        if image.shape[0] == 3:  # CHW → HWC
            image = image.transpose(1, 2, 0)

    if hasattr(image, "convert"):  # PIL
        image = np.array(image)

    h, w = image.shape[:2]

    # ---- Convert heatmap to numpy ----
    if isinstance(heatmap, torch.Tensor):
        heatmap = heatmap.detach().cpu().numpy()

    # ---- Resize heatmap if needed ----
    if heatmap.shape != (h, w):
        heatmap = cv2.resize(heatmap, (w, h))

    # ---- Normalize heatmap ----
    heatmap = heatmap - heatmap.min()
    heatmap = heatmap / (heatmap.max() + 1e-8)

    # ---- Visualize ----
    plt.imshow(image)
    plt.imshow(heatmap, cmap=cmap, alpha=alpha)
    plt.axis("off")
    plt.show()


overlay_heatmap_on_image(image, heatmap)


In [None]:
# for fine-tuning. Ex: setting model optimizer, new learning rate

model.train()
optimizer = torch.optim.AdamW(model.parameters(), lr=5e-5)

# we take output logits (same as in block 4 after a forward pass)
logits = outputs.logits

# we specify our G.T in tensor format
target_labels = torch.tensor([0], device=logits.device) # Ensure device matches logits

# Calculate a loss using CrossEntropyLoss
loss_fn = torch.nn.CrossEntropyLoss()
loss = loss_fn(logits, target_labels)

# backpropagate
loss.backward()
optimizer.step()
optimizer.zero_grad()