<h1> Setup Environment </h1>
Necessary imports and environment variables.

In [None]:
import os
import torch
import random

import matplotlib.pyplot as plt
import torchvision.transforms as T

from models import CNN, ViT
from trainers import SimCLRTrainer, DINOTrainer
from datasets import get_mnist_loaders, get_cifar10_loaders, get_imagenet_loaders, get_imagenette_loaders
from torch.utils.data import random_split, DataLoader
from tqdm.notebook import tqdm


os.environ['CUDA_LAUNCH_BLOCKING'] = '1'
os.environ['TORCH_USE_CUDA_DSA'] = '1'

torch.backends.cudnn.benchmark = False
torch.backends.cudnn.deterministic = True

print(torch.cuda.is_available())  # Should print True
print(torch.cuda.get_device_name(0))  # Should print your GPU name


<h1> Parameters for Model and Training </h1>
Parameters for model and training.

In [None]:
# Model parameters
chw = (3, 224, 224)
n_patches = 14
hidden_d = 512
n_heads = 8
n_blocks = 4
n_classes = 1000

# Training parameters
n_epochs = 1
patience = 100
eval_every = 1
visualize_every = 1
num_workers = 2
batch_size = 16

<h1> Train Model</h1>
Model and trainer initialization, data loading, and training.

In [None]:
model = ViT(
    chw=chw, 
    n_patches=n_patches, 
    n_blocks=n_blocks, 
    hidden_d=hidden_d, 
    n_heads=n_heads, 
    num_classes=n_classes
).to('cuda')

trainer = DINOTrainer(model)

train_loader, test_loader = get_imagenet_loaders(batch_size=batch_size, num_workers=num_workers)


In [None]:
def show_random_images(dataloader, num_images=5, seed=42):
    random.seed(seed)
    random_indices = random.sample(range(len(dataloader.dataset)), num_images)
    
    fig, axes = plt.subplots(1, num_images, figsize=(15, 3))
    
    denorm = T.Normalize(
        mean=[-0.485/0.229, -0.456/0.224, -0.406/0.225],
        std=[1/0.229, 1/0.224, 1/0.225]
    )
    
    for idx, sample_idx in enumerate(random_indices):
        img, label = dataloader.dataset[sample_idx]
        img = denorm(img)
        img = T.ToPILImage()(img)
        
        axes[idx].imshow(img)
        axes[idx].axis('off')
        axes[idx].set_title(f'Label: {label}')
    
    plt.tight_layout()
    plt.show()

# Display sample images
show_random_images(test_loader, seed = 69)

In [None]:
train_history = trainer.finetune(
    train_loader=train_loader, 
    test_loader=test_loader, 
    epochs=n_epochs, 
    patience=patience, 
    evaluate_every = eval_every,
    visualize_every = visualize_every
)

<h1> Visualize Attention Map </h1>

In [None]:
# Load class labels from file
with open('data/imagenet/imagenet1000_clsidx_to_labels.txt', 'r') as f:
    class_labels = eval(f.read())

random_indices = random.sample(range(len(test_loader.dataset)), 2)

for i in random_indices:
    # Get the class index and convert to class name
    class_idx = test_loader.dataset[i][1]
    class_name = class_labels[class_idx]
    
    # Print both index and class name
    print(f"\nImage {i} class: {class_name} (index: {class_idx})")
    
    # Create a figure for each layer
    for layer in range(n_blocks):
        fig, axes = plt.subplots(1, min(n_heads, 4), figsize=(20, 5))
        if n_heads == 1:
            axes = [axes]
            
        for head in range(min(n_heads, 4)):  # Show up to 4 heads
            model.visualize_attention(
                images=test_loader.dataset[i][0].unsqueeze(0).to('cuda'),
                layer_idx=layer,
                head_idx=head,
                alpha=0.7,
                ax=axes[head]
            )
        
        plt.suptitle(f'Layer {layer} Attention Maps\nClass: {class_name}', size=16)
        plt.tight_layout()
        plt.show()