In [1]:
import time
import torch
import torchvision
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, random_split
from tqdm import tqdm
from timm import create_model

  from .autonotebook import tqdm as notebook_tqdm


In [2]:

class Trainer:
    def __init__(self, model, optimizer, loss_fn, device, scheduler=None):
        self.model = model.to(device)
        self.optimizer = optimizer
        self.loss_fn = loss_fn
        self.device = device
        self.scheduler = scheduler

    def train_epoch(self, train_loader):
        """Train for one epoch."""
        self.model.train()
        total_loss, correct = 0, 0

        with tqdm(train_loader, desc="Training", unit="batch") as t:
            for images, labels in t:
                images, labels = images.to(self.device), labels.to(self.device)
                self.optimizer.zero_grad()
                logits = self.model(images)
                loss = self.loss_fn(logits, labels)
                loss.backward()
                self.optimizer.step()

                total_loss += loss.item() * len(images)
                correct += (logits.argmax(dim=1) == labels).sum().item()

                t.set_postfix(loss=loss.item())

        return total_loss / len(train_loader.dataset), correct / len(train_loader.dataset)

    @torch.no_grad()
    def evaluate(self, test_loader):
        """Evaluate model and return loss, accuracy, and softmax predictions."""
        self.model.eval()
        total_loss, correct = 0, 0
        all_softmax_preds = []

        with tqdm(test_loader, desc="Testing", unit="batch") as t:
            for images, labels in t:
                images, labels = images.to(self.device), labels.to(self.device)
                logits = self.model(images)
                loss = self.loss_fn(logits, labels)

                total_loss += loss.item() * len(images)
                correct += (logits.argmax(dim=1) == labels).sum().item()

                # Compute softmax predictions
                softmax_preds = torch.nn.functional.softmax(logits, dim=1)
                all_softmax_preds.append(softmax_preds.cpu())

        avg_loss = total_loss / len(test_loader.dataset)
        accuracy = correct / len(test_loader.dataset)
        all_softmax_preds = torch.cat(all_softmax_preds, dim=0)

        return avg_loss, accuracy, all_softmax_preds

    def train(self, train_loader, test_loader, epochs):
        """Train model and return softmax predictions on test set."""
        for epoch in range(epochs):
            print(f"Epoch {epoch+1}/{epochs}")
            train_loss, train_acc = self.train_epoch(train_loader)
            val_loss, val_acc, softmax_preds = self.evaluate(test_loader)

            print(f"Train Loss: {train_loss:.4f}, Accuracy: {train_acc:.4f}")
            print(f"Test Loss: {val_loss:.4f}, Accuracy: {val_acc:.4f}")

        return softmax_preds


In [3]:
def load_vit(model_name, num_classes, device, freeze_backbone=True):
    """
    Load a ViT model, modify its classification head, and optionally freeze the backbone.
    """
    model = create_model(model_name, pretrained=True, num_classes=num_classes)
    
    if freeze_backbone:
        for param in model.parameters():
            param.requires_grad = False
        for param in model.head.parameters():
            param.requires_grad = True  # Only train the classification head

    return model.to(device)

In [11]:
def train_model(model_name, num_classes, train_loader, test_loader, epochs=5, lr=2e-4, weight_decay=1e-2):
    """
    Fine-tune a ViT model on a dataset.
    """
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    # Load model
    model = load_vit(model_name, num_classes, device)

    # Define optimizer and loss function
    optimizer = optim.AdamW(model.head.parameters(), lr=lr, weight_decay=weight_decay)
    loss_fn = nn.CrossEntropyLoss()

    # Initialize trainer
    trainer = Trainer(model, optimizer, loss_fn, device)

    # Start training
    print(f"Training {model_name} for {epochs} epochs on {num_classes}-class dataset")
    start_time = time.time()
    softmax_preds = trainer.train(train_loader, test_loader, epochs)
    elapsed_time = time.time() - start_time

    print(f"Training completed in: {elapsed_time:.2f} seconds")
    return model, softmax_preds


In [5]:
import numpy as np

def save_softmax_predictions(predictions, filename="predictions.npz"):
    """
    Save softmax predictions to an .npz file.

    Args:
        predictions (torch.Tensor or np.ndarray): The softmax predictions.
        filename (str): Name of the file to save.
    """
    if isinstance(predictions, torch.Tensor):
        predictions = predictions.cpu().numpy()
    
    np.savez_compressed(filename, predictions=predictions)
    print(f"Predictions saved to {filename}")


In [6]:
dataset_directory = "../../../cifar-10-batches-py-for-pytorch"

transform = torchvision.transforms.Compose([
    torchvision.transforms.ToTensor(),
    torchvision.transforms.Resize((224, 224))
])

cifar10_dataset = torchvision.datasets.CIFAR10(root=dataset_directory, train=True, download=True, transform=transform)


Files already downloaded and verified


In [7]:
from torch.utils.data import DataLoader, random_split

use_CIFAR10 = True

# Get the precomputed mean and std
# Those are needed to normalize the dataset
# NOTE: To calculate the mean and std we have to 
# 1. calculate the sum for each channel
# 2. implement mean and variance formulas
if use_CIFAR10:
    mean = (0.4914, 0.4822, 0.4465)
    std = (0.2023, 0.1994, 0.2010)
else:
    # raise NotImplementedError("Please compute mean, std for CIFAR100")
    mean = (0.5070, 0.4865, 0.4408)
    std = (0.2613, 0.2503, 0.2703)

# dataset directory remains the same for both cases
dataset_directory = "../../../cifar-10-batches-py-for-pytorch"


# the transformation also remains the same
transform = torchvision.transforms.Compose([
    torchvision.transforms.ToTensor(),
    torchvision.transforms.Resize((224, 224)),
    torchvision.transforms.Normalize(mean, std)
])

# download the dataset
if use_CIFAR10:
    cifar_dataset = torchvision.datasets.CIFAR10(root=dataset_directory, train=True, download=True, transform=transform)
    test_dataset = torchvision.datasets.CIFAR10(root=dataset_directory, train=False, download=True, transform=transform)
else:
    cifar_dataset = torchvision.datasets.CIFAR100(root=dataset_directory, train=True, download=True, transform=transform)
    test_dataset = torchvision.datasets.CIFAR100(root=dataset_directory, train=False, download=True, transform=transform)

print(f'Dataset downloaded. Total images: {len(cifar_dataset)}')

# Split the dataset into train / valildation sets
train_size = int(0.9 * len(cifar_dataset))
val_size = len(cifar_dataset) - train_size

train_dataset, val_dataset = random_split(cifar_dataset, [train_size, val_size])

# set the batch size to 64
batch_size = 64

# Create some dataloaders
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=True)

print(f'Train images: {train_size}')
print(f'Validation images: {val_size}')
print(f'Test images: {len(test_dataset)}')

Files already downloaded and verified
Files already downloaded and verified
Dataset downloaded. Total images: 50000
Train images: 45000
Validation images: 5000
Test images: 10000


In [12]:
# Fine-tune ViT-Tiny on CIFAR-10  test implementation
num_classes = 10  # CIFAR-10
vit_tiny_model, vit_tiny_softmax_preds = train_model("vit_tiny_patch16_224", num_classes, train_loader, test_loader)
save_softmax_predictions(vit_tiny_softmax_preds, "vit_tiny_predictions.npz")

Training vit_tiny_patch16_224 for 5 epochs on 10-class dataset
Epoch 1/5


Training: 100%|██████████| 704/704 [02:01<00:00,  5.82batch/s, loss=1.27] 
Testing: 100%|██████████| 157/157 [00:29<00:00,  5.39batch/s]


Train Loss: 1.1576, Accuracy: 0.6170
Test Loss: 0.8435, Accuracy: 0.7193
Epoch 2/5


Training: 100%|██████████| 704/704 [02:16<00:00,  5.16batch/s, loss=0.596]
Testing: 100%|██████████| 157/157 [00:30<00:00,  5.22batch/s]


Train Loss: 0.7927, Accuracy: 0.7300
Test Loss: 0.7535, Accuracy: 0.7432
Epoch 3/5


Training: 100%|██████████| 704/704 [02:06<00:00,  5.58batch/s, loss=0.615]
Testing: 100%|██████████| 157/157 [00:26<00:00,  5.90batch/s]


Train Loss: 0.7365, Accuracy: 0.7461
Test Loss: 0.7224, Accuracy: 0.7522
Epoch 4/5


Training: 100%|██████████| 704/704 [02:16<00:00,  5.16batch/s, loss=0.39] 
Testing: 100%|██████████| 157/157 [00:26<00:00,  6.02batch/s]


Train Loss: 0.7099, Accuracy: 0.7544
Test Loss: 0.7041, Accuracy: 0.7607
Epoch 5/5


Training: 100%|██████████| 704/704 [02:11<00:00,  5.35batch/s, loss=0.748]
Testing: 100%|██████████| 157/157 [00:32<00:00,  4.87batch/s]

Train Loss: 0.6940, Accuracy: 0.7595
Test Loss: 0.6937, Accuracy: 0.7605
Training completed in: 795.96 seconds
Predictions saved to vit_tiny_predictions.npz



