## Import Necessary Libraries

In [1]:
import time
import torch
import torchvision
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, random_split
from tqdm import tqdm
from timm import create_model

  from .autonotebook import tqdm as notebook_tqdm


## Trainer Class 

The `Trainer` class manages the training and evaluation process. It encapsulates the following functionalities:

1. **Initialization**: 
   - Accepts the model, optimizer, loss function, and device (CPU/GPU), ensuring the model is moved to the correct device.
   - Optionally supports a learning rate scheduler.

2. **Training**: The `train_epoch` method:
   - Iterates through training batches using a `tqdm` progress bar.
   - Performs forward passes, computes loss, backpropagates, and updates model parameters.
   - Returns the average training loss and accuracy for the epoch.

3. **Evaluation**: The `evaluate` method:
   - Runs the model on a validation or test set in evaluation mode (`model.eval()`).
   - Computes accuracy and average loss, displaying progress with a `tqdm` bar.
   - Computes and stores softmax probabilities for each sample in the test set.

4. **Full Training Loop**: The `train` method:
   - Combines `train_epoch` and `evaluate` for multiple epochs.
   - Logs training loss, validation loss, and accuracy after each epoch.
   - Returns the softmax predictions for the test set after training.


In [8]:
class Trainer:
    def __init__(self, model, optimizer, loss_fn, device, scheduler=None):
        self.model = model.to(device)
        self.optimizer = optimizer
        self.loss_fn = loss_fn
        self.device = device
        self.scheduler = scheduler

    def train_epoch(self, train_loader):
        """Train for one epoch."""
        self.model.train()
        total_loss, correct = 0, 0

        with tqdm(train_loader, desc="Training", unit="batch") as t:
            for images, labels in t:
                images, labels = images.to(self.device), labels.to(self.device)
                self.optimizer.zero_grad()
                logits = self.model(images)
                loss = self.loss_fn(logits, labels)
                loss.backward()
                self.optimizer.step()

                total_loss += loss.item() * len(images)
                correct += (logits.argmax(dim=1) == labels).sum().item()

                t.set_postfix(loss=loss.item())

        return total_loss / len(train_loader.dataset), correct / len(train_loader.dataset)

    @torch.no_grad()
    def evaluate(self, test_loader):
        """Evaluate model and return loss, accuracy, and softmax predictions."""
        self.model.eval()
        total_loss, correct = 0, 0
        all_softmax_preds = []

        with tqdm(test_loader, desc="Testing", unit="batch") as t:
            for images, labels in t:
                images, labels = images.to(self.device), labels.to(self.device)
                logits = self.model(images)
                loss = self.loss_fn(logits, labels)

                total_loss += loss.item() * len(images)
                correct += (logits.argmax(dim=1) == labels).sum().item()

                # Compute softmax predictions
                softmax_preds = torch.nn.functional.softmax(logits, dim=1)
                all_softmax_preds.append(softmax_preds.cpu())

        avg_loss = total_loss / len(test_loader.dataset)
        accuracy = correct / len(test_loader.dataset)
        all_softmax_preds = torch.cat(all_softmax_preds, dim=0)

        return avg_loss, accuracy, all_softmax_preds

    def train(self, train_loader, test_loader, epochs):
        """Train model and return softmax predictions and final test accuracy."""
        final_test_acc = 0  
        for epoch in range(epochs):
            print(f"Epoch {epoch+1}/{epochs}")
            train_loss, train_acc = self.train_epoch(train_loader)
            val_loss, val_acc, softmax_preds = self.evaluate(test_loader)

            print(f"Train Loss: {train_loss:.4f}, Accuracy: {train_acc:.4f}")
            print(f"Test Loss: {val_loss:.4f}, Accuracy: {val_acc:.4f}")

            final_test_acc = val_acc 
        return softmax_preds, final_test_acc



## Load Vision Transformer (ViT) Model

The `load_vit` function is responsible for loading a Vision Transformer (ViT) model, modifying its classification head, and optionally freezing the backbone layers.

1. **Model Creation**:
   - Loads a pretrained ViT model using the specified `model_name`.
   - Adjusts the model's classification head to match the required number of classes (`num_classes`).

2. **Freezing the Backbone**:
   - If `freeze_backbone=True`, all model parameters except for the classification head are frozen.

3. **Device Assignment**:
   - Moves the model to the specified `device` (CPU/GPU) for optimized computation.


In [3]:
def load_vit(model_name, num_classes, device, freeze_backbone=True):
    """
    Load a ViT model, modify its classification head, and optionally freeze the backbone.
    """
    model = create_model(model_name, pretrained=True, num_classes=num_classes)
    
    if freeze_backbone:
        for param in model.parameters():
            param.requires_grad = False
        for param in model.head.parameters():
            param.requires_grad = True  # Only train the classification head

    return model.to(device)

## ViT Model Training Function

The `train_model` function is responsible for fine-tuning a Vision Transformer (ViT) model on a given dataset. It provides the following functionalities:

1. **Model Initialization**:
   - Loads a ViT model using `load_vit`, ensuring it is correctly configured with the specified `model_name`, `num_classes`, and `device`.

2. **Optimizer and Loss Function**:
   - Uses AdamW as the optimizer.
   - Defines cross-entropy loss for multi-class classification tasks.

3. **Training Process**:
   - Initializes a `Trainer` instance to manage training and evaluation.
   - Iterates through multiple epochs, logging performance metrics after each epoch.
   - Computes and returns softmax predictions from the trained model.

4. **Performance Logging**

In [4]:
def train_model(model_name, num_classes, train_loader, test_loader, epochs=10, lr=1e-3, weight_decay=1e-2):
    """
    Fine-tune a ViT model on a dataset.
    """
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    # Load model
    model = load_vit(model_name, num_classes, device)

    # Define optimizer and loss function
    optimizer = optim.AdamW(model.head.parameters(), lr=lr, weight_decay=weight_decay)
    loss_fn = nn.CrossEntropyLoss()

    # Initialize trainer
    trainer = Trainer(model, optimizer, loss_fn, device)

    # Start training
    print(f"Training {model_name} for {epochs} epochs on {num_classes}-class dataset")
    start_time = time.time()
    softmax_preds = trainer.train(train_loader, test_loader, epochs)
    elapsed_time = time.time() - start_time

    print(f"Training completed in: {elapsed_time:.2f} seconds")
    return model, softmax_preds

## Softmax Predictions Saving Function

The `save_softmax_predictions` function is used to save the softmax predictions generated by a trained model into a compressed `.npz` file for later analysis. It provides the following functionalities:

1. **Data Handling**:
   - Accepts softmax predictions as a `torch.Tensor` or `numpy.ndarray`.
   - Converts `torch.Tensor` predictions to a NumPy array if necessary.

2. **File Saving**:
   - Saves the predictions in a compressed `.npz` format using `numpy.savez_compressed`.
   - Allows for specifying a custom filename.

3. **Logging**:
   - Prints a confirmation message indicating where the predictions were saved.

This function ensures that model outputs can be efficiently stored and retrieved for further evaluation, such as model comparison or uncertainty estimation

In [5]:
import numpy as np

def save_softmax_predictions(predictions, filename="predictions.npz"):
    """
    Save softmax predictions to an .npz file.

    Args:
        predictions (torch.Tensor or np.ndarray): The softmax predictions.
        filename (str): Name of the file to save.
    """
    if isinstance(predictions, torch.Tensor):
        predictions = predictions.cpu().numpy()
    
    np.savez_compressed(filename, predictions=predictions)
    print(f"Predictions saved to {filename}")

## DataLoader Loading Function

The dataset loading process is managed through pickled `DataLoader` objects.
   - The `load_dataloader` function reads and unpickles `DataLoader` objects stored in `.pkl` files.
   - `train_loader`, `val_loader`, and `test_loader` are loaded from specified paths.

In [6]:
import pickle
import torchvision.transforms
from torch.utils.data import DataLoader

# Paths to the pickle files
train_loader_path = r"D:\Master\ComputerVision\rag-enhanced-image-classification\src\development\train_loader.pkl"
val_loader_path = r"D:\Master\ComputerVision\rag-enhanced-image-classification\src\development\val_loader.pkl"
test_loader_path = r"D:\Master\ComputerVision\rag-enhanced-image-classification\src\development\test_loader.pkl"

# Function to load pickled DataLoader objects
def load_dataloader(file_path):
    with open(file_path, "rb") as f:
        return pickle.load(f)

# Load the DataLoaders
train_loader = load_dataloader(train_loader_path)
val_loader = load_dataloader(val_loader_path)
test_loader = load_dataloader(test_loader_path)

# Print dataset sizes
print(f'Train images: {len(train_loader.dataset)}')
print(f'Validation images: {len(val_loader.dataset)}')
print(f'Test images: {len(test_loader.dataset)}')

Train images: 45000
Validation images: 5000
Test images: 10000


## Fine-tune ViT-Tiny Classifier head on the CIFAR-10 dataset

In [9]:
# Fine-tune ViT-Tiny on CIFAR-10  test implementation
num_classes = 10  # CIFAR-10
vit_tiny_model, vit_tiny_softmax_preds = train_model("vit_tiny_patch16_224", num_classes, train_loader, test_loader)
save_softmax_predictions(vit_tiny_softmax_preds, "vit_tiny_softmax_predictions_CIFAR10.npz")

Training vit_tiny_patch16_224 for 10 epochs on 10-class dataset
Epoch 1/10


Training: 100%|██████████| 704/704 [02:15<00:00,  5.19batch/s, loss=0.555]
Testing: 100%|██████████| 157/157 [00:33<00:00,  4.74batch/s]


Train Loss: 1.1197, Accuracy: 0.6310
Test Loss: 0.8388, Accuracy: 0.7194
Epoch 2/10


Training: 100%|██████████| 704/704 [01:58<00:00,  5.92batch/s, loss=1.31] 
Testing: 100%|██████████| 157/157 [00:32<00:00,  4.79batch/s]


Train Loss: 0.7890, Accuracy: 0.7313
Test Loss: 0.7555, Accuracy: 0.7437
Epoch 3/10


Training: 100%|██████████| 704/704 [02:31<00:00,  4.64batch/s, loss=1.73] 
Testing: 100%|██████████| 157/157 [00:28<00:00,  5.43batch/s]


Train Loss: 0.7350, Accuracy: 0.7493
Test Loss: 0.7205, Accuracy: 0.7560
Epoch 4/10


Training: 100%|██████████| 704/704 [01:55<00:00,  6.11batch/s, loss=0.62] 
Testing: 100%|██████████| 157/157 [00:24<00:00,  6.45batch/s]


Train Loss: 0.7098, Accuracy: 0.7550
Test Loss: 0.7026, Accuracy: 0.7563
Epoch 5/10


Training: 100%|██████████| 704/704 [02:03<00:00,  5.69batch/s, loss=0.16] 
Testing: 100%|██████████| 157/157 [00:30<00:00,  5.18batch/s]


Train Loss: 0.6936, Accuracy: 0.7606
Test Loss: 0.6929, Accuracy: 0.7618
Epoch 6/10


Training: 100%|██████████| 704/704 [02:02<00:00,  5.73batch/s, loss=0.777]
Testing: 100%|██████████| 157/157 [00:27<00:00,  5.70batch/s]


Train Loss: 0.6831, Accuracy: 0.7638
Test Loss: 0.6866, Accuracy: 0.7629
Epoch 7/10


Training: 100%|██████████| 704/704 [02:36<00:00,  4.51batch/s, loss=0.435]
Testing: 100%|██████████| 157/157 [00:33<00:00,  4.69batch/s]


Train Loss: 0.6750, Accuracy: 0.7658
Test Loss: 0.6815, Accuracy: 0.7650
Epoch 8/10


Training: 100%|██████████| 704/704 [01:56<00:00,  6.04batch/s, loss=0.731]
Testing: 100%|██████████| 157/157 [00:26<00:00,  5.97batch/s]


Train Loss: 0.6689, Accuracy: 0.7689
Test Loss: 0.6822, Accuracy: 0.7675
Epoch 9/10


Training: 100%|██████████| 704/704 [02:09<00:00,  5.44batch/s, loss=0.504]
Testing: 100%|██████████| 157/157 [00:27<00:00,  5.63batch/s]


Train Loss: 0.6646, Accuracy: 0.7705
Test Loss: 0.6741, Accuracy: 0.7690
Epoch 10/10


Training: 100%|██████████| 704/704 [02:08<00:00,  5.48batch/s, loss=0.725]
Testing: 100%|██████████| 157/157 [00:26<00:00,  5.93batch/s]

Train Loss: 0.6603, Accuracy: 0.7721
Test Loss: 0.6772, Accuracy: 0.7657
Training completed in: 1590.16 seconds
Predictions saved to vit_tiny_softmax_predictions_CIFAR10.npz





## Softmax Predictions Loading and Inspection

The softmax predictions saved in an `.npz` file can be loaded and inspected using the following process:
   - Defines the path to the `.npz` file containing softmax predictions.
   - Uses `numpy.load()` to open the compressed file.
   - Verifies that the key "predictions" exists and prints the shape of the stored softmax predictions.

In [20]:
import numpy as np

file_path = r"D:\Master\ComputerVision\rag-enhanced-image-classification\src\development\vit_tiny_softmax_predictions_CIFAR10.npz"

# Load the .npz file
data = np.load(file_path)

# Check the keys in the file
print("Keys in the file:", data.files)

# Assuming the softmax predictions are stored under the key "predictions"
if "predictions" in data:
    print("Shape of softmax predictions:", data["predictions"].shape)
else:
    print("Key 'predictions' not found. Available keys:", data.files)

Keys in the file: ['predictions']
Shape of softmax predictions: (10000, 10)
