In [25]:
import torch
import torchvision

In [26]:
!nvidia-smi

Thu Jan  1 01:23:45 2026       
+-----------------------------------------------------------------------------------------+
| NVIDIA-SMI 581.80                 Driver Version: 581.80         CUDA Version: 13.0     |
+-----------------------------------------+------------------------+----------------------+
| GPU  Name                  Driver-Model | Bus-Id          Disp.A | Volatile Uncorr. ECC |
| Fan  Temp   Perf          Pwr:Usage/Cap |           Memory-Usage | GPU-Util  Compute M. |
|                                         |                        |               MIG M. |
|   0  NVIDIA GeForce RTX 4050 ...  WDDM  |   00000000:01:00.0  On |                  N/A |
| N/A   52C    P5              6W /   60W |     993MiB /   6141MiB |      0%      Default |
|                                         |                        |                  N/A |
+-----------------------------------------+------------------------+----------------------+

+----------------------------------------------

In [27]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(device)
# gpu configuration failed locally, so using cpu

cpu


In [28]:
# python uses pillow to operate on images
# convert images to tensors before itself
transform = torchvision.transforms.Compose([torchvision.transforms.ToTensor()])

In [29]:
# data already available in torchvision as splits
train_dataset = torchvision.datasets.MNIST(root='./data', train=True, transform=transform, download=True)
val_dataset = torchvision.datasets.MNIST(root='./data', train=False, transform=transform, download=True)

In [30]:
num_classes = 10
batch_size = 64
num_channels = 1 #black and white image so 1
img_size = 28 #28x28 each mnist image
patch_size = 7 #paper it was 16, but as here already image size is small, so 7
num_patches = (img_size // patch_size) ** 2
embedding_dim = 64
attention_heads = 4
transformer_blocks = 4
learning_rate = 0.001
epochs = 5
mlp_hidden_nodes = 128 #2x

In [31]:
print(num_patches)

16


In [32]:
import torch.utils.data as dataloader
train_loader = dataloader.DataLoader(train_dataset, batch_size=batch_size, shuffle=True)  
val_loader = dataloader.DataLoader(val_dataset, batch_size=batch_size, shuffle=False)

In [33]:
import torch.nn as nn

In [34]:
class PatchEmbedding(nn.Module):
    def __init__(self):
        super().__init__()
        # 1,64,4,4
        self.patch_embed = nn.Conv2d(num_channels, embedding_dim, kernel_size=patch_size, stride=patch_size)

    def forward(self, x):
        # patch -> flatten
        x = self.patch_embed(x)
        x = x.flatten(2).transpose(1, 2)
        return x

In [35]:
class TransformerEncoder(nn.Module):
    def __init__(self):
        super().__init__()
        self.layer_norm1 = nn.LayerNorm(embedding_dim)
        self.layer_norm2 = nn.LayerNorm(embedding_dim)
        # batch first is true because we want the input to be in the shape of (batch_size, sequence_length, embedding_dim) else accuracy worsens
        self.multihead_attention = nn.MultiheadAttention(embedding_dim, attention_heads, batch_first=True)
        # embed -> hiddenlayer -> original form (mlp)
        self.mlp = nn.Sequential(
            nn.Linear(embedding_dim, mlp_hidden_nodes),
            nn.GELU(),
            nn.Linear(mlp_hidden_nodes, embedding_dim),
        )

        # residual connections
    def forward(self, x):
        residual1 = x
        x = self.layer_norm1(x)
        # thrice to make x key,value,query
        x = self.multihead_attention(x, x, x)[0]
        x = x + residual1
        residual2 = x
        x = self.layer_norm2(x)
        x = self.mlp(x)
        x = x + residual2
        return x

In [36]:
# MLP HEAD (after transformer encoder in the architecture)
class MLPHead(nn.Module):
    def __init__(self):
        super().__init__()
        self.layer_norm1 = nn.LayerNorm(embedding_dim)
        self.mlp_head = nn.Linear(embedding_dim, num_classes)

    def forward(self, x):
        x = self.layer_norm1(x)
        x = self.mlp_head(x)
        return x

In [37]:
class ViT(nn.Module):
    def __init__(self):
        super().__init__()
        self.patch_embedding = PatchEmbedding()
        # cls token -> pos embedding ->cls token passed to mlp head
        self.cls_token = nn.Parameter(torch.randn(1, 1, embedding_dim))
        self.pos_embedding = nn.Parameter(torch.randn(1, num_patches + 1, embedding_dim))
        self.transformer_blocks = nn.Sequential(*[TransformerEncoder() for _ in range(transformer_blocks)]) #will be 4
        self.mlp_head = MLPHead()

    def forward(self, x):
        x = self.patch_embedding(x)
        #classtoken for every patch, so we expand
        class_token = self.cls_token.expand(x.shape[0], -1, -1)
        x = torch.cat([class_token, x], dim=1)
        x = x + self.pos_embedding
        x = self.transformer_blocks(x)
        # only class token is passed to mlp head
        x = self.mlp_head(x[:, 0])
        return x


In [38]:
model = ViT()
optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)
criterion = nn.CrossEntropyLoss()

In [39]:
# training loop
for epoch in range(epochs):
    model.train()
    total_loss = 0
    correct_epoch = 0
    total_epoch = 0
    print(f'Epoch {epoch+1}/{epochs}')

    for batch_idx, (images, labels) in enumerate(train_loader):
        images = images.to(device)
        labels = labels.to(device)

        optimizer.zero_grad()
        outputs = model(images)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()

        total_loss += loss.item()
        preds = outputs.argmax(dim=1)
        correct = (preds == labels).sum().item()
        accuracy = 100.0 * correct / labels.size(0)
        correct_epoch += correct
        total_epoch += labels.size(0)

        if (batch_idx) % 100 == 0:
            print(f'Batch {batch_idx}/{len(train_loader)} - Loss: {loss.item():.4f} - Accuracy: {accuracy:.2f}%')

    epoch_accuracy = 100.0 * correct_epoch / total_epoch
    print(f'epoch {epoch+1} summary: loss {total_loss:.4f} accuracy {epoch_accuracy:.2f}%')

Epoch 1/5
Batch 0/938 - Loss: 2.5629 - Accuracy: 3.12%
Batch 100/938 - Loss: 0.5765 - Accuracy: 82.81%
Batch 200/938 - Loss: 0.3744 - Accuracy: 87.50%
Batch 300/938 - Loss: 0.1495 - Accuracy: 95.31%
Batch 400/938 - Loss: 0.2234 - Accuracy: 89.06%
Batch 500/938 - Loss: 0.2351 - Accuracy: 93.75%
Batch 600/938 - Loss: 0.4059 - Accuracy: 90.62%
Batch 700/938 - Loss: 0.1922 - Accuracy: 92.19%
Batch 800/938 - Loss: 0.2358 - Accuracy: 93.75%
Batch 900/938 - Loss: 0.1541 - Accuracy: 96.88%
epoch 1 summary: loss 367.8988 accuracy 87.52%
Epoch 2/5
Batch 0/938 - Loss: 0.1123 - Accuracy: 98.44%
Batch 100/938 - Loss: 0.1460 - Accuracy: 95.31%
Batch 200/938 - Loss: 0.1005 - Accuracy: 95.31%
Batch 300/938 - Loss: 0.0397 - Accuracy: 98.44%
Batch 400/938 - Loss: 0.1431 - Accuracy: 95.31%
Batch 500/938 - Loss: 0.0222 - Accuracy: 100.00%
Batch 600/938 - Loss: 0.1010 - Accuracy: 95.31%
Batch 700/938 - Loss: 0.2170 - Accuracy: 95.31%
Batch 800/938 - Loss: 0.0198 - Accuracy: 100.00%
Batch 900/938 - Loss: 0.

In [40]:
# validation loop
model.eval()
correct = 0
total = 0

with torch.no_grad():
    for images, labels in val_loader:
        images, labels = images.to(device), labels.to(device)
        outputs = model(images)
        preds = outputs.argmax(dim=1)
        correct += (preds == labels).sum().item()
        total += labels.size(0)

test_accuracy = 100.0 * correct / total
print(f'Test accuracy: {test_accuracy:.2f}%')

Test accuracy: 97.25%


In [None]:
# dropout not used, also mlp expansion is 2x and not 4x in this implementaion