<h1> Imports </h1>

In [1]:
import os
import torch
from models import CNN, ViT
from trainers import SimCLRTrainer, DINOTrainer
from datasets import get_mnist_loaders, get_cifar10_loaders, get_imagenet_loaders
from torch.utils.data import random_split, DataLoader
os.environ['CUDA_LAUNCH_BLOCKING'] = '1'

<h1> Initialize Model </h1>

In [2]:
batch_size = 32
n_patches = 8
n_blocks = 4
n_heads = 6
hidden_d = 768

model = ViT(
        chw = (3, 128, 128),
        n_patches = n_patches,
        n_blocks = n_blocks,
        hidden_d = hidden_d,
        n_heads = n_heads,
        num_classes = 1000,
        dropout = 0.2
    ).to('cuda')

<h1> Load Data and Train Model</h1>

In [None]:
train_loader, test_loader = get_imagenet_loaders(batch_size=batch_size)
dino_trainer = DINOTrainer(model)

dino_history = dino_trainer.finetune(
    train_loader=finetune_loader,
    test_loader=test_loader,
    epochs=20,
    lr=0.0001,
    patience=10,
    evaluate_every=2
)

<h1> Visualize Attention Map </h1>

In [None]:
classes = ['airplane', 'automobile', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck']
for i in range(5):
    print(f"Image {i} class: {classes[test_loader.dataset[i][1]]}")
    for layer in range(n_blocks):
        for head in range(n_heads):
            model.visualize_attention(
                images=test_loader.dataset[i][0].unsqueeze(0).to('cuda'),
                layer_idx=layer,  # Transformer block
                head_idx=head,    # Attention head
                # save_path='attention_map.png'
            )
    