In [1]:
import os
import torch
from models import CNN, ViT
from trainers import SimCLRTrainer, DINOTrainer
from datasets import get_mnist_loaders, get_cifar10_loaders, get_imagenet_loaders
from torch.utils.data import random_split, DataLoader
os.environ['CUDA_LAUNCH_BLOCKING'] = '1'

In [2]:
batch_size = 64

model = ViT(
        chw = (3, 128, 128),
        n_patches = 16,
        n_blocks = 4,
        hidden_d = 512,
        n_heads = 8,
        num_classes = 1000,
        dropout = 0.1
    ).to('cuda')

In [None]:
# Get data loaders
print("Loading data...")
train_loader, test_loader = get_imagenet_loaders(batch_size=batch_size, num_workers=4)

# Split the dataset into two parts
print("Splitting dataset...")
M = 0.000001   # Proportion of data to use for pretraining
split = [M, 1 - M] if M < 1 else [len(train_loader.dataset) - M, M]
pretrain_dataset, finetune_dataset = random_split(train_loader.dataset, split)

In [4]:
pretrain_loader = DataLoader(
    pretrain_dataset, 
    batch_size=batch_size, 
    shuffle=True, 
    num_workers=4
)
finetune_loader = DataLoader(
    finetune_dataset,
    batch_size=batch_size,
    shuffle=True,
    num_workers=4
)

# Initialize DINO trainer
dino_trainer = DINOTrainer(
    model,
    lr=0.00005,
    momentum_teacher=0.996,
    momentum_center=0.9,
    temp_student=0.1,
    temp_teacher=0.04,
    n_local_views=6
)

In [None]:
# Finetune DINO model
print("\nFinetuning DINO model...")
dino_history = dino_trainer.finetune(
    train_loader=finetune_loader,
    test_loader=test_loader,
    epochs=20,
    lr=0.0001,
    patience=10,
    evaluate_every=2
)

In [None]:
classes = ['airplane', 'automobile', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck']
for i in range(3):
    for l in range(4):
        for h in range(4):
            model.visualize_attention(
                images=test_loader.dataset[i][0].unsqueeze(0).to('cuda'),
                layer_idx=l,  # Visualize third transformer block
                head_idx=h,   # Visualize first attention head
                # save_path='attention_map.png'
            )
            print(classes[test_loader.dataset[i][1]])