In [None]:
import numpy as np
from tqdm import tqdm, trange

import torch
import torch.nn as nn
from torch.optim import Adam
from torch.nn import CrossEntropyLoss
from torch.utils.data import DataLoader
from torchvision.transforms import ToTensor
from torchvision.datasets.cifar import CIFAR10
from torch.optim.lr_scheduler import StepLR
import matplotlib.pyplot as plt
from random import random
# Importing all the necessary libraries

In [None]:
def patchify(images, n_patches):
    n, c, h, w = images.shape

    assert h == w, "Patchify method is implemented for square images only"

    patches = torch.zeros(n, n_patches ** 2, h * w * c // n_patches ** 2) # Initialisation of the tensor with zeroes
    patch_size = h // n_patches

    for idx, image in enumerate(images):
        for i in range(n_patches):
            for j in range(n_patches):
                patch = image[:, i * patch_size: (i + 1) * patch_size, j * patch_size: (j + 1) * patch_size]
                patches[idx, i * n_patches + j] = patch.flatten()
    return patches
# The function divides the images into patches and then flattens the formed 3D vectors into flattened 1D vectors

In [None]:
class MyMSA(nn.Module):
    def __init__(self, d, n_heads=8):
        super(MyMSA, self).__init__()
        self.d = d
        self.n_heads = n_heads

        assert d % n_heads == 0, f"Can't divide dimension {d} into {n_heads} heads"

        d_head = int(d / n_heads)
        self.q_mappings = nn.ModuleList(
            [nn.Linear(d_head, d_head) for _ in range(self.n_heads)]
        )
        self.k_mappings = nn.ModuleList(
            [nn.Linear(d_head, d_head) for _ in range(self.n_heads)]
        )
        self.v_mappings = nn.ModuleList(
            [nn.Linear(d_head, d_head) for _ in range(self.n_heads)]
        )
        self.d_head = d_head
        self.softmax = nn.Softmax(dim=-1)

    def forward(self, sequences):
        # Sequences has shape (N, seq_length, token_dim)
        # We go into shape    (N, seq_length, n_heads, token_dim / n_heads)
        # And come back to    (N, seq_length, item_dim)  (through concatenation)
        result = []
        for sequence in sequences:
            seq_result = []
            for head in range(self.n_heads):
                q_mapping = self.q_mappings[head]
                k_mapping = self.k_mappings[head]
                v_mapping = self.v_mappings[head]

                seq = sequence[:, head * self.d_head : (head + 1) * self.d_head]
                q, k, v = q_mapping(seq), k_mapping(seq), v_mapping(seq)

                attention = self.softmax(q @ k.T / (self.d_head**0.5))
                seq_result.append(attention @ v)
            result.append(torch.hstack(seq_result))
        return torch.cat([torch.unsqueeze(r, dim=0) for r in result])

In [None]:
class MyViTBlock(nn.Module):
    def __init__(self, hidden_d, n_heads, mlp_ratio=4):
        super(MyViTBlock, self).__init__()
        self.hidden_d = hidden_d
        self.n_heads = n_heads

        self.norm1 = nn.LayerNorm(hidden_d)
        self.mhsa = MyMSA(hidden_d, n_heads)
        self.norm2 = nn.LayerNorm(hidden_d)
        self.mlp = nn.Sequential(
            nn.Linear(hidden_d, mlp_ratio * hidden_d),
            nn.GELU(),
            nn.Linear(mlp_ratio * hidden_d, hidden_d),
        )

    def forward(self, x):
        out = x + self.mhsa(self.norm1(x))
        out = out + self.mlp(self.norm2(out))
        return out


In [None]:
def get_positional_embeddings(sequence_length, d):
    result = torch.ones(sequence_length, d)
    for i in range(sequence_length):
        for j in range(d):
            result[i][j] = (
                np.sin(i / (10000 ** (j / d)))
                if j % 2 == 0
                else np.cos(i / (10000 ** ((j - 1) / d)))
            )
    return result
# This function just calculates the positional embeddings and uses the trigonometric functions for that.

In [None]:
# Helper function to plot attention maps
def plot_attention_maps(attention_maps, title):
    plt.figure(figsize=(12, 4))
    for i, attention_map in enumerate(attention_maps):
        plt.subplot(1, len(attention_maps), i + 1)
        plt.imshow(attention_map.squeeze().cpu().numpy(), cmap='viridis', interpolation='nearest')
        plt.title(f'Attention Map {i + 1}')
        plt.axis('off')
    plt.suptitle(title)
    plt.show()


In [None]:
class MyViT(nn.Module):
    def __init__(self, chw, n_patches=8, n_blocks=4, hidden_d=8, n_heads=8, out_d=10):
        # Super constructor
        super(MyViT, self).__init__()

        # Attributes
        self.chw = chw  # ( C , H , W )
        self.n_patches = n_patches
        self.n_blocks = n_blocks
        self.n_heads = n_heads
        self.hidden_d = hidden_d

        # Input and patches sizes
        assert (chw[1] % n_patches == 0 ), "Input shape not entirely divisible by number of patches"
        assert (chw[2] % n_patches == 0 ), "Input shape not entirely divisible by number of patches"
        self.patch_size = (chw[1] / n_patches, chw[2] / n_patches)

        # 1) Linear mapper
        self.input_d = int(chw[0] * self.patch_size[0] * self.patch_size[1])
        self.linear_mapper = nn.Linear(self.input_d, self.hidden_d)

        # 2) Learnable classification token
        self.class_token = nn.Parameter(torch.rand(1, self.hidden_d))

        # 3) Positional embedding
        self.register_buffer(
            "positional_embeddings",
            get_positional_embeddings(n_patches**2 + 1, hidden_d),
            persistent=False,
        )

        # 4) Transformer encoder blocks
        self.blocks = nn.ModuleList(
            [MyViTBlock(hidden_d, n_heads) for _ in range(n_blocks)]
        )

        # 5) Classification MLPk
        self.mlp = nn.Sequential(nn.Linear(self.hidden_d, out_d), nn.Softmax(dim=-1))
        self.attention_maps = []
    def forward(self, images):
        # Dividing images into patches
        n, c, h, w = images.shape
        patches = patchify(images, self.n_patches).to(self.positional_embeddings.device)

        # Running linear layer tokenization
        # Map the vector corresponding to each patch to the hidden size dimension
        tokens = self.linear_mapper(patches)

        # Adding classification token to the tokens
        tokens = torch.cat((self.class_token.expand(n, 1, -1), tokens), dim=1)

        # Adding positional embedding
        out = tokens + self.positional_embeddings.repeat(n, 1, 1)


        # Transformer Blocks
        for block in self.blocks:
            out = block(out)


        # Getting the classification token only
        out = out[:, 0]

        return self.mlp(out)  # Map to output dimension, output category distribution


In [None]:
def get_attention_maps(model, images, n_patches=8):
    # Set the model to evaluation mode
    model.eval()

    # Convert images to tensor and move to the appropriate device
    images = torch.tensor(images).unsqueeze(0).to(model.positional_embeddings.device)

    # Dividing images into patches
    patches = patchify(images, n_patches).to(model.positional_embeddings.device)

    # Running linear layer tokenization
    tokens = model.linear_mapper(patches)

    # Adding classification token to the tokens
    tokens = torch.cat((model.class_token.expand(1, 1, -1), tokens), dim=1)

    # Adding positional embedding
    out = tokens + model.positional_embeddings.repeat(1, 1, 1)

    # Transformer Blocks
    attention_maps = []
    for block in model.blocks:
        out = block.mhsa(out)  # Get attention maps from the MyMSA module
        attention_maps.append(out[0, :, 1:, :].detach().cpu().numpy())  # Extract attention maps, excluding the classification token

    return attention_maps

In [None]:
import numpy as np
def main():
    # Loading data
    transform = ToTensor()

    train_set = CIFAR10(
        root="./../datasets", train=True, download=True, transform=transform
    )
    test_set = CIFAR10(
        root="./../datasets", train=False, download=True, transform=transform
    )

    train_loader = DataLoader(train_set, shuffle=True, batch_size=128)
    test_loader = DataLoader(test_set, shuffle=False, batch_size=128)

    # Defining model and training options
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    print(
        "Using device: ",
        device,
        f"({torch.cuda.get_device_name(device)})" if torch.cuda.is_available() else "",
    )
    model = MyViT(
        (3, 32, 32), n_patches=8, n_blocks=4, hidden_d=8, n_heads=8, out_d=10
    ).to(device)
    N_EPOCHS = 2
    LR = 0.01

    # Training loop
    optimizer = Adam(model.parameters(), lr=LR)
    criterion = CrossEntropyLoss()
    for epoch in range(N_EPOCHS):
        train_loss = 0.0
        num_batches = len(train_loader)
        print()
        for batch_idx, batch in enumerate(train_loader):
            x, y = batch
            x, y = x.to(device), y.to(device)
            y_hat = model(x)
            loss = criterion(y_hat, y)

            train_loss += loss.detach().cpu().item() / len(train_loader)

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            completion_percentage = (batch_idx + 1) / num_batches * 100
            print(f"\rEpoch {epoch + 1}/{N_EPOCHS} [{completion_percentage:.2f}%] - Loss: {train_loss:.2f}", end='')

    # Test loop
    with torch.no_grad():
        correct, total = 0, 0
        test_loss = 0.0
        for batch_idx, batch in enumerate(test_loader):
            x, y = batch
            x, y = x.to(device), y.to(device)
            y_hat = model(x)
            loss = criterion(y_hat, y)
            test_loss += loss.detach().cpu().item() / len(test_loader)

            correct += torch.sum(torch.argmax(y_hat, dim=1) == y).detach().cpu().item()
            total += len(x)
        print(f"Test loss: {test_loss:.2f}")
        print(f"Test accuracy: {correct / total * 100:.2f}%")

    # Get attention maps for sample images
    with torch.no_grad():
        sample_images = torch.stack([test_set[i][0] for i in range(4)])
        attention_maps = get_attention_maps(model, sample_images)

        # Plot attention maps
        for i, attention_map in enumerate(attention_maps):
            plt.figure(figsize=(15, 5))
            for j in range(len(attention_map)):
                plt.subplot(1, len(attention_map), j + 1)
                plt.imshow(attention_map[j], cmap='viridis', interpolation='nearest')
                plt.title(f'Block {j + 1}')
                plt.colorbar()
            plt.suptitle(f'Attention Maps - Sample {i + 1}')
            plt.show()


In [None]:
if __name__ == "__main__":
     main()