# MNIST Classification with Vision Transformer 
In this exercise we will explore the transformers by implementing multihead attention. This hands-on experiment follows naturally from your previous theoretical exercise
We will use the **Mint dataset** — a small image dataset.

### **Stages**

1. **Import Necessary Libraries** – Load required Python and deep learning packages.  
2. **Download and Load the MNIST Dataset** – Explore its structure and format.  
3. **Preprocess and Prepare DataLoaders** – Normalize, resize, and split the dataset into training, validation, and test sets. 
5. **Define and Train Models** – Implement the standard single layer transformer encoder network. 

Complete the code blocks marked with:


In [None]:
# Do not delete this cell

**Deliverables:** <br>

Submit the completed notebook (Exercise1.ipynb) and your trained model (best_model.pth) to moodle. 
Do not change the name of the notebook file. It may result in 0 points for the exercise.

## 1. Import Necessary Libraries

In this section, we import the essential Python libraries required for building, training, and evaluating our convolutional residual networks.

We will use:
- **PyTorch** for model definition, training, and evaluation.  
- **Torchvision** for loading and transforming the MNIST dataset.  
- **NumPy** for numerical operations.  
- **Random** for setting the random seed. 
- **Sklearn** for calculating f1 score.
- **tqdm** for tracking training progress.

Make sure all required packages are installed before proceeding.

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torchvision import datasets, transforms
from torch.utils.data import DataLoader
import numpy as np
import random
import tqdm
from sklearn.metrics import f1_score
# Set seeds
torch.cuda.manual_seed_all(42)
random.seed(42)
np.random.seed(42)
torch.manual_seed(42)
device = 'cuda' if torch.cuda.is_available() else 'cpu'
device

## 2. Download and Load the MNIST Dataset

In this section, we will download the **MNIST** dataset, which contains **60,000 training** and **10,000 test** grayscale images of handwritten digits (0–9), each of size **28×28 pixels**.

We will:
- Use `torchvision.datasets.MNIST` to download and load the data.  
- Apply image transformations such as **tensor conversion** and **normalization** to ensure consistent model training.  
- Create **DataLoaders** for both training and testing sets, enabling efficient mini-batch processing.

### Normalization Details
The images are normalized using the mean and standard deviation of the MNIST dataset:
- Mean = 0.1307  
- Std  = 0.3081

In [None]:
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.1307,), (0.3081,))
])

# Download and load training and test sets
train_dataset = datasets.MNIST(root='./data', train=True, download=True, transform=transform)
test_dataset = datasets.MNIST(root='./data', train=False, download=True, transform=transform)

# Define DataLoaders for batch processing
train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=1000, shuffle=False)

# Print dataset statistics
print(f"Training samples: {len(train_dataset)}")
print(f"Test samples: {len(test_dataset)}")

## 3. Preprocess and Prepare DataLoaders

Before training, let’s verify that our MNIST data is correctly loaded and preprocessed.

In this stage, we will:
- Check the **shape** and **value range** of sample tensors.  
- Optionally create a **validation set** (useful for model tuning).  
- Confirm the number of samples and batches in our DataLoaders.

This ensures that the data pipeline is ready before defining the neural network architectures.

In [None]:
# Inspect one sample
example_data, example_label = train_dataset[0]
print(f"Sample image shape: {example_data.shape}")
print(f"Label: {example_label}")
print(f"Tensor value range: {example_data.min():.4f} to {example_data.max():.4f}")

# Optional: create a validation split (e.g., 10% of training set)
val_size = int(0.1 * len(train_dataset))
train_size = len(train_dataset) - val_size

train_dataset, val_dataset = torch.utils.data.random_split(train_dataset, [train_size, val_size])

# Updated DataLoaders
train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=1000, shuffle=False)

print(f"Training samples: {len(train_dataset)}")
print(f"Validation samples: {len(val_dataset)}")
print(f"Test samples: {len(test_dataset)}")


# 4. Model Definition and Training
In this section we will define and test the transformer block

In [None]:
def reset_all_parameters(model):
    torch.manual_seed(42)
    torch.cuda.manual_seed_all(42)
    random.seed(42)
    np.random.seed(42)
    
    for module in model.modules():
        
        # Linear layers
        if isinstance(module, nn.Linear):
            nn.init.xavier_uniform_(module.weight)
            if module.bias is not None:
                nn.init.zeros_(module.bias)

        # LayerNorm
        elif isinstance(module, nn.LayerNorm):
            nn.init.ones_(module.weight)
            nn.init.zeros_(module.bias)

        # MultiheadAttention (Q/K/V and output weights)
        elif isinstance(module, nn.MultiheadAttention):
            for name, param in module.named_parameters():
                if param.dim() > 1:
                    nn.init.xavier_uniform_(param)
                else:
                    nn.init.zeros_(param)

    # CLS token & positional embeddings
    if hasattr(model, "cls_token"):
        model.cls_token.data = torch.randn_like(model.cls_token)

    if hasattr(model, "pos_embed"):
        model.pos_embed.data = torch.randn_like(model.pos_embed)

Follow the image below to implement the multihead attention.

The formula Scaled Dot-Product Attention can be found: https://www.geeksforgeeks.org/nlp/transformer-attention-mechanism-in-nlp/

Prior to calculating attention the shapes for q,k,v are (batch, heads, tokens, dim_per_head). 

After calculating the attention transform the attention vector back to the shape: (batch,tokens,heads*dim_per_head)

![mha](mha.png)

In [None]:
class MultiHeadSelfAttention(nn.Module):
    def __init__(self, feats:int, head:int=8, dropout:float=0.):
        super(MultiHeadSelfAttention, self).__init__()
        self.head = head
        self.feats = feats
        self.sqrt_d = self.feats**0.5

        self.q = nn.Linear(feats, feats)
        self.k = nn.Linear(feats, feats)
        self.v = nn.Linear(feats, feats)

        self.o = nn.Linear(feats, feats)
        self.dropout = nn.Dropout(dropout)

    def forward(self, x):
        b, n, f = x.size()
        
        q = self.q(x).view(b, n, self.head, self.feats//self.head).transpose(1,2)
        k = self.k(x).view(b, n, self.head, self.feats//self.head).transpose(1,2)
        v = self.v(x).view(b, n, self.head, self.feats//self.head).transpose(1,2)
        # q, k, v shapes:
        # q, k, v: (batch, heads, tokens, dim_per_head)
        ''' 
        To Do: IMPLEMENT MULTI-HEAD SELF ATTENTION HERE 
        The formula Scaled Dot-Product Attention can be found: https://www.geeksforgeeks.org/nlp/transformer-attention-mechanism-in-nlp/
        Prior to calculating attention the shapes for q,k,v are (batch, heads, tokens, dim_per_head). 
        After calculating the attention transform the attention vector back to the shape: (batch,tokens,heads*dim_per_head)
        You can use torch.matmul for matrix multiplication and F.softmax for softmax operation.
        The answer should be assigned to the variable 'attn' below.
        '''
        # YOUR CODE HERE
        raise NotImplementedError()
        o = self.dropout(self.o(attn))
        return o

In the following section we will define and implement the trasnformer encoder given in the image below.

![encoder](encoder.png)

Follw the instructions in To Do sections to initialize the layers and define the forward function of the vision transformer.

In [None]:
# Vision Transformer definition
class MiniViT(nn.Module):
    def __init__(self, patch_size=7, embed_dim=64, num_heads=4, num_classes=10, dropout=0.0):
        super().__init__()
        self.patch_size = patch_size
        self.embed_dim = embed_dim

        # image 28x28 -> (28/7)^2 = 16 patches
        self.num_patches = (28 // patch_size) ** 2

        self.patch_embed = nn.Linear(patch_size * patch_size, embed_dim)
        self.cls_token = nn.Parameter(torch.randn(1, 1, embed_dim))
        self.pos_embed = nn.Parameter(torch.randn(1, self.num_patches + 1, embed_dim))

        self.mha=None   
        self.ln1=None
        self.ln2=None
        self.ff=None
        self.fc=None
        '''
        To Do: Initialize the following layers:
        1. MultiHeadSelfAttention layer (self.mha) with MultiHeadSelfAttention class and embed_dim and num_heads
        2. LayerNorm layer (self.ln1) with nn.LayerNorm and embed_dim
        3. LayerNorm layer (self.ln2) with nn.LayerNorm and embed_dim
        4. Feed-Forward network (self.ff) as a sequential model with:
            - Linear layer from embed_dim to embed_dim*2
            - GELU activation
            - Linear layer from embed_dim*2 to embed_dim
        5. Final classification layer (self.fc) with nn.Linear from embed_dim to num_classes
        
        '''

        # YOUR CODE HERE
        raise NotImplementedError()

    def forward(self, x):
        B = x.shape[0]
        

        # divide into patches
        patches = x.unfold(2, self.patch_size, self.patch_size).unfold(3, self.patch_size, self.patch_size)
        patches = patches.contiguous().view(B, 1, -1, self.patch_size*self.patch_size).squeeze(1)

        # linear embedding
        x = self.patch_embed(patches)

        # prepend cls token
        cls = self.cls_token.expand(B, -1, -1)
        x = torch.cat([cls, x], dim=1) # (B, num_patches+1, embed_dim)
        
        # add positional embedding
        x = x + self.pos_embed
        '''
         To Do: Implement the Transformer encoder block here:
         1. Apply multi-head self-attention (self.mha) to x
         2. Apply layer normalization (self.ln) then residual connection from before attention
         3. Apply feed-forward network (self.ff)
         4. Apply another  layer normalization (self.ln) then residual connection from before feed-forward
         
         Finally, classify using the cls token with the final linear layer (self.fc)
         The result should be assigned to the variable 'out' below.
        ''' 
        # YOUR CODE HERE
        raise NotImplementedError()
        return out

## 5.1 Training Setup

The model will train and test for few epochs

We will use:
- **Loss:** CrossEntropyLoss  
- **Optimizer:** Adam (learning rate = 1e-3)  
- **Metric:** Accuracy

In [None]:
# Training loop
model = MiniViT().to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)
criterion = nn.CrossEntropyLoss()
print(model)


In [None]:
reset_all_parameters(model)
for epoch in range(1):
    model.train()
    for images, labels in train_loader:
        images, labels = images.to(device), labels.to(device)

        optimizer.zero_grad()
        outputs = model(images)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()

    print(f'Epoch {epoch}, Loss: {loss.item():.4f}')

In [None]:
# Evaluate
model.eval()
correct = 0
total = 0
with torch.no_grad():
    for images, labels in test_loader:
        images, labels = images.to(device), labels.to(device)
        outputs = model(images)
        _, predicted = outputs.max(1)
        total += labels.size(0)
        correct += (predicted == labels).sum().item()
print(predicted.dtype, labels.dtype)
f1_score(predicted.cpu().numpy(), labels.cpu().numpy(), average='macro')
print('F1 Score:', f1_score(predicted.cpu().numpy(), labels.cpu().numpy(), average='macro'))
print('Accuracy:', correct/total)

In [None]:
# This cell contains hidden test cases that will be evaluated after submission

In [None]:
# This cell contains hidden test cases that will be evaluated after submission