In [None]:
import sys
sys.path.append('..')

In [None]:
import cv2
import timm
import torch
import random
import numpy as np
import torch.nn as nn
from PIL import Image
from tqdm import tqdm
from typing import List, Dict
import matplotlib.pyplot as plt
from src.io.io import load_config
from src.dataset import SSLDataset
from src.transform import Transform
from torch.utils.data import DataLoader
from src.model.vit import compute_attentions
from src.model.utils import create_model, load_state_dict_ssl

### **Checkpoints + Config**

In [None]:
ckpt_path = "model/model.ckpt"
config_path = "model/config.yaml"

### **Loading Model**

In [None]:
config = load_config(path=config_path)
config["model"]

In [None]:
BACKBONE = config["model"]["backbone"]
IMG_SIZE = config["transform"]["img_size"]


In [None]:
model = create_model(
    backbone=BACKBONE,
    pretrained=False,
    img_size=IMG_SIZE
)

In [None]:
model = load_state_dict_ssl(
    model=model,
    ssl_state_dict=torch.load(ckpt_path, map_location="cpu")["state_dict"]
)

### **Setting up SSLTransform + ODIN Dataset**

In [None]:
transform = Transform(
    framework="dino",
    train=False,
    img_size=IMG_SIZE
)
dataset = SSLDataset(
    root_dir="images",
    split="val",
    with_folders=True,
    transform=transform
)

In [None]:
i = random.randint(a=0, b=len(dataset)-1)
img_path = dataset.img_paths[i]
    
img = Image.open(img_path)
img = img.resize((IMG_SIZE, IMG_SIZE))

# Augmentation
x, views = transform(img=img)

# Input Tensor
x = torch.from_numpy(x).unsqueeze(dim=0)
print(f"dataset index {i} - path: {dataset.img_paths[i]}")

In [None]:
plt.imshow(img)

### **Attentions**

In [None]:
attentions = compute_attentions(
    model=model,
    x=x, 
    patch_size=16
)

Preparing images to show

In [None]:
np_img = np.array(img)
mask = np.sum(attentions, axis=0)
mask = cv2.blur(mask,(10,10))
mask = np.stack([mask, mask, mask], axis=-1)
mask = mask / mask.max()
result = (mask * img).astype("uint8")

In [None]:
fig, ax = plt.subplots(figsize=(15, 5), nrows=1, ncols=3)

ax[0].imshow(img)
ax[0].set_title(f"Original")
ax[0].axis("off")

ax[1].imshow(mask)
ax[1].set_title("Attention mask")
ax[1].axis("off")

ax[2].imshow(result)
ax[2].set_title(f"{BACKBONE} - Attention on image")
ax[2].axis("off")