In [None]:
import math
from dataclasses import dataclass
from typing import Literal

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader

import torchvision
import torchvision.transforms as T

from google.colab import drive

drive.mount("/content/drive")
import pandas as pd
import os

Mounted at /content/drive


In [None]:
BASE_DIR = "/content/drive/MyDrive/Colab Notebooks/Data_Mining"

CIFAR10_PATH = f"{BASE_DIR}/cifar10"
MNIST_PATH = f"{BASE_DIR}/mnist"

TODAY = pd.Timestamp.today().strftime("%Y-%m-%d_%H-%M-%S")

RESULTS_FILE = f"{BASE_DIR}/Results/results_2_noRPE_{TODAY}.csv"

In [None]:
# ================================================================
# 1. Patch embedding + CLS + simple learned absolute positions
# ================================================================


class PatchEmbedding(nn.Module):
    """
    Image -> sequence of patch embeddings via a Conv2d layer.
    Args:
        img_size: size of the input image (assumed square)
        patch_size: size of each patch (assumed square)
        in_channels: number of input channels (e.g., 3 for RGB)
        embed_dim: dimension of the output embeddings
    """

    def __init__(
        self, img_size: int, patch_size: int, in_channels: int, embed_dim: int
    ):
        super().__init__()
        assert img_size % patch_size == 0, "img_size must be divisible by patch_size"
        self.img_size = img_size
        self.patch_size = patch_size
        self.num_patches = (img_size // patch_size) ** 2
        self.proj = nn.Conv2d(
            in_channels, embed_dim, kernel_size=patch_size, stride=patch_size
        )

    def forward(self, x):
        """
        Args:
            x: Tensor of shape (B, C, H, W)
                Input batch of images.
        Returns:
            Tensor of shape (B, N, D)
                Sequence of patch embeddings, where:
                    N = number of patches
                    D = embedding dimension.
        """
        x = self.proj(x)
        x = x.flatten(2)
        x = x.transpose(1, 2)
        return x


class ViTInputLayer(nn.Module):
    """
    PatchEmbedding + optional [CLS] token

    This module converts an input image into a sequence of patch embeddings,
    optionally prepends a trainable [CLS] token, and adds learned positional
    encodings to all tokens.
    """

    def __init__(self, img_size, patch_size, in_channels, embed_dim, cls_token=True):
        super().__init__()

        self.patch_embed = PatchEmbedding(img_size, patch_size, in_channels, embed_dim)

        self.num_patches = self.patch_embed.num_patches

        self.cls_token = (
            nn.Parameter(torch.zeros(1, 1, embed_dim)) if cls_token else None
        )

    def forward(self, x):
        B = x.size(0)
        x = self.patch_embed(x)

        if self.cls_token is not None:
            cls = self.cls_token.expand(B, -1, -1)
            x = torch.cat([cls, x], dim=1)
        return x

In [None]:
# ================================================================
# 2. Multi-Head Attention variants
#    - Full softmax attention (baseline)
#    - Performer FAVOR+ (softmax approx with positive random features)
#    - Performer-ReLU (ReLU feature map)
# ================================================================


class MultiHeadSelfAttentionFull(nn.Module):
    """
    Standard multi-head self-attention (softmax attention).
    Baseline attention used in the original ViT.
    No kernel approximation.
    """

    def __init__(
        self,
        embed_dim: int,
        num_heads: int,
        attn_drop: float = 0.0,
        proj_drop: float = 0.0,
    ):
        super().__init__()
        assert embed_dim % num_heads == 0

        self.embed_dim = embed_dim
        self.num_heads = num_heads
        self.head_dim = embed_dim // num_heads

        self.qkv = nn.Linear(embed_dim, 3 * embed_dim)

        self.out_proj = nn.Linear(embed_dim, embed_dim)

        self.attn_drop = nn.Dropout(attn_drop)
        self.proj_drop = nn.Dropout(proj_drop)

    def forward(self, x):
        """
        x: (B, N, D)
        Returns: (B, N, D)
        """
        B, N, D = x.shape

        qkv = self.qkv(x)

        qkv = qkv.reshape(B, N, 3, self.num_heads, self.head_dim)

        qkv = qkv.permute(2, 0, 3, 1, 4)

        q, k, v = qkv[0], qkv[1], qkv[2]

        scale = 1.0 / math.sqrt(self.head_dim)
        attn_scores = torch.matmul(q, k.transpose(-2, -1)) * scale

        attn = F.softmax(attn_scores, dim=-1)
        attn = self.attn_drop(attn)

        out = torch.matmul(attn, v)

        out = out.transpose(1, 2).reshape(B, N, D)

        out = self.out_proj(out)
        out = self.proj_drop(out)
        return out


class PerformerAttentionFavorPlus(nn.Module):
    """
    Performer FAVOR+ : approximates exp(q k^T) using positive random features.
    Complexity O(N * m) instead of O(N^2).
    """

    def __init__(
        self,
        embed_dim: int,
        num_heads: int,
        nb_features: int = 64,
        attn_drop: float = 0.0,
        proj_drop: float = 0.0,
        eps: float = 1e-6,
    ):
        super().__init__()
        assert embed_dim % num_heads == 0
        self.embed_dim = embed_dim
        self.num_heads = num_heads
        self.head_dim = embed_dim // num_heads
        self.nb_features = nb_features
        self.eps = eps

        self.q_proj = nn.Linear(embed_dim, embed_dim)
        self.k_proj = nn.Linear(embed_dim, embed_dim)
        self.v_proj = nn.Linear(embed_dim, embed_dim)

        self.out_proj = nn.Linear(embed_dim, embed_dim)
        self.attn_drop = nn.Dropout(attn_drop)
        self.proj_drop = nn.Dropout(proj_drop)

        self.register_buffer("W", torch.randn(num_heads, nb_features, self.head_dim))

    def _favor_feature_map(self, x):
        """
        x: (B, H, N, d_k)
        -> phi(x): (B, H, N, m)
        phi(x) = exp(Wx - ||x||^2 / 2) / sqrt(m)
        """
        B, H, N, d_k = x.shape

        x_proj = torch.einsum("b h n d, h m d -> b h n m", x, self.W)

        sq_norm = (x**2).sum(dim=-1, keepdim=True) / 2.0

        x_feat = torch.exp(x_proj - sq_norm) / math.sqrt(self.nb_features)
        return x_feat

    def forward(self, x):
        """
        x: (B, N, D)
        """
        B, N, D = x.shape

        q = self.q_proj(x)
        k = self.k_proj(x)
        v = self.v_proj(x)

        q = q.reshape(B, N, self.num_heads, self.head_dim).permute(0, 2, 1, 3)
        k = k.reshape(B, N, self.num_heads, self.head_dim).permute(0, 2, 1, 3)
        v = v.reshape(B, N, self.num_heads, self.head_dim).permute(0, 2, 1, 3)

        q_feat = self._favor_feature_map(q)
        k_feat = self._favor_feature_map(k)

        kv = torch.einsum("b h n m, b h n d -> b h m d", k_feat, v)

        k_sum = k_feat.sum(dim=2)

        denom = torch.einsum("b h n m, b h m -> b h n", q_feat, k_sum)
        denom = denom.unsqueeze(-1) + self.eps

        out = torch.einsum("b h n m, b h m d -> b h n d", q_feat, kv)
        out = out / denom
        out = out.permute(0, 2, 1, 3).reshape(B, N, D)

        out = self.out_proj(out)
        out = self.proj_drop(out)
        return out


class PerformerAttentionReLU(nn.Module):
    """
    Performer-ReLU : kernel phi(x) = ReLU(x).
    Same linear-attention scheme as FAVOR+, but no random features.
    """

    def __init__(
        self,
        embed_dim: int,
        num_heads: int,
        attn_drop: float = 0.0,
        proj_drop: float = 0.0,
        eps: float = 1e-6,
    ):
        super().__init__()
        assert embed_dim % num_heads == 0
        self.embed_dim = embed_dim
        self.num_heads = num_heads
        self.head_dim = embed_dim // num_heads
        self.eps = eps

        self.q_proj = nn.Linear(embed_dim, embed_dim)
        self.k_proj = nn.Linear(embed_dim, embed_dim)
        self.v_proj = nn.Linear(embed_dim, embed_dim)

        self.out_proj = nn.Linear(embed_dim, embed_dim)
        self.attn_drop = nn.Dropout(attn_drop)
        self.proj_drop = nn.Dropout(proj_drop)

    def _relu_feature_map(self, x):
        """
        x: (B, H, N, d_k)
        Feature map phi(x) = ReLU(x) ensures non-negativity.
        """
        return F.relu(x)

    def forward(self, x):
        """
        x: (B, N, D)
        """
        B, N, D = x.shape

        q = self.q_proj(x)
        k = self.k_proj(x)
        v = self.v_proj(x)

        q = q.reshape(B, N, self.num_heads, self.head_dim).permute(0, 2, 1, 3)
        k = k.reshape(B, N, self.num_heads, self.head_dim).permute(0, 2, 1, 3)
        v = v.reshape(B, N, self.num_heads, self.head_dim).permute(0, 2, 1, 3)

        q_feat = self._relu_feature_map(q)
        k_feat = self._relu_feature_map(k)

        kv = torch.einsum("b h n d, b h n e -> b h d e", k_feat, v)

        k_sum = k_feat.sum(dim=2)

        denom = torch.einsum("b h n d, b h d -> b h n", q_feat, k_sum)
        denom = denom.unsqueeze(-1) + self.eps

        out = torch.einsum("b h n d, b h d e -> b h n e", q_feat, kv)

        out = out / denom

        out = out.permute(0, 2, 1, 3).reshape(B, N, D)

        out = self.out_proj(out)
        out = self.proj_drop(out)
        return out


In [None]:
# ================================================================
# 3. Transformer block + ViT backbone
# ================================================================


class MLPBlock(nn.Module):
    """
    Standard Transformer feedforward block:
    Linear -> GELU -> Linear (with optional dropout).

    Position-wise feedforward network that expands and compresses each token
    to increase the transformer's expressive capacity.
    """

    def __init__(self, embed_dim: int, mlp_ratio: float = 4.0, drop: float = 0.0):
        super().__init__()
        hidden_dim = int(embed_dim * mlp_ratio)

        self.fc1 = nn.Linear(embed_dim, hidden_dim)
        self.act = nn.GELU()
        self.fc2 = nn.Linear(hidden_dim, embed_dim)
        self.drop = nn.Dropout(drop)

    def forward(self, x):
        x = self.fc1(x)
        x = self.act(x)
        x = self.drop(x)
        x = self.fc2(x)
        x = self.drop(x)
        return x


class TransformerEncoderBlock(nn.Module):
    """
    Pre-LN Transformer block with choice of attention mechanism.
    """

    def __init__(
        self,
        embed_dim: int,
        num_heads: int,
        mlp_ratio: float = 4.0,
        attn_type: Literal["full", "favor+", "relu"] = "full",
        nb_features: int = 64,
        drop: float = 0.0,
        attn_drop: float = 0.0,
    ):
        super().__init__()

        self.norm1 = nn.LayerNorm(embed_dim)
        self.norm2 = nn.LayerNorm(embed_dim)

        if attn_type == "full":
            self.attn = MultiHeadSelfAttentionFull(
                embed_dim, num_heads, attn_drop, drop
            )
        elif attn_type == "favor+":
            self.attn = PerformerAttentionFavorPlus(
                embed_dim, num_heads, nb_features, attn_drop, drop
            )
        elif attn_type == "relu":
            self.attn = PerformerAttentionReLU(embed_dim, num_heads, attn_drop, drop)
        else:
            raise ValueError(f"Unknown attn_type: {attn_type}")

        self.mlp = MLPBlock(embed_dim, mlp_ratio, drop)

    def forward(self, x):
        x = x + self.attn(self.norm1(x))

        x = x + self.mlp(self.norm2(x))
        return x


class ViTClassifier(nn.Module):
    """
    Vision Transformer / Performer-ViT classifier for MNIST or CIFAR-10.
    """

    def __init__(
        self,
        img_size: int,
        patch_size: int,
        in_channels: int,
        num_classes: int,
        embed_dim: int = 64,
        depth: int = 4,
        num_heads: int = 4,
        mlp_ratio: float = 4.0,
        attn_type: Literal["full", "favor+", "relu"] = "full",
        nb_features: int = 64,
        drop: float = 0.0,
        attn_drop: float = 0.0,
    ):
        super().__init__()

        self.input_layer = ViTInputLayer(
            img_size, patch_size, in_channels, embed_dim, cls_token=True
        )

        self.blocks = nn.ModuleList(
            [
                TransformerEncoderBlock(
                    embed_dim=embed_dim,
                    num_heads=num_heads,
                    mlp_ratio=mlp_ratio,
                    attn_type=attn_type,
                    nb_features=nb_features,
                    drop=drop,
                    attn_drop=attn_drop,
                )
                for _ in range(depth)
            ]
        )

        self.norm = nn.LayerNorm(embed_dim)

        self.head = nn.Linear(embed_dim, num_classes)

    def forward(self, x):
        """
        x: (B, C, H, W)
        """
        x = self.input_layer(x)

        for blk in self.blocks:
            x = blk(x)

        x = self.norm(x)

        cls_token = x[:, 0]

        logits = self.head(cls_token)
        return logits

In [None]:
# ================================================================
# 4. Helper functions to build model from config
# ================================================================


@dataclass
class ViTConfig:
    """
    Configuration container for building a ViT/Performer-ViT model.
    """

    img_size: int = 32
    patch_size: int = 4
    in_channels: int = 3
    num_classes: int = 10
    embed_dim: int = 64
    depth: int = 4
    num_heads: int = 4
    mlp_ratio: float = 4.0
    attn_type: str = "full"
    nb_features: int = 64
    drop: float = 0.0
    attn_drop: float = 0.0


def build_model(cfg: ViTConfig) -> nn.Module:
    """
    Build a ViT/Performer-ViT model from a configuration object.
    """
    return ViTClassifier(
        img_size=cfg.img_size,
        patch_size=cfg.patch_size,
        in_channels=cfg.in_channels,
        num_classes=cfg.num_classes,
        embed_dim=cfg.embed_dim,
        depth=cfg.depth,
        num_heads=cfg.num_heads,
        mlp_ratio=cfg.mlp_ratio,
        attn_type=cfg.attn_type,
        nb_features=cfg.nb_features,
        drop=cfg.drop,
        attn_drop=cfg.attn_drop,
    )

In [None]:
# ================================================================
# 5. Training loop for MNIST / CIFAR-10 (simple baseline)
# ================================================================
import csv
import time

def get_mnist_loaders(batch_size=128):
    transform = T.Compose(
        [
            T.Resize(32),
            T.ToTensor(),
        ]
    )

    train_set = torchvision.datasets.MNIST(
        root=MNIST_PATH, train=True, download=False, transform=transform
    )
    test_set = torchvision.datasets.MNIST(
        root=MNIST_PATH, train=False, download=False, transform=transform
    )

    train_loader = DataLoader(train_set, batch_size=batch_size, shuffle=True)
    test_loader = DataLoader(test_set, batch_size=batch_size, shuffle=False)
    return train_loader, test_loader


def get_cifar10_loaders(batch_size=128):
    transform_train = T.Compose(
        [
            T.RandomHorizontalFlip(),
            T.ToTensor(),
        ]
    )

    transform_test = T.Compose([T.ToTensor()])

    train_set = torchvision.datasets.CIFAR10(
        root=CIFAR10_PATH, train=True, download=False, transform=transform_train
    )
    test_set = torchvision.datasets.CIFAR10(
        root=CIFAR10_PATH, train=False, download=False, transform=transform_test
    )

    train_loader = DataLoader(train_set, batch_size=batch_size, shuffle=True)
    test_loader = DataLoader(test_set, batch_size=batch_size, shuffle=False)
    return train_loader, test_loader


def train_one_epoch(model, loader, optimizer, device):
    model.train()
    total_loss = 0.0
    total_correct = 0
    total_samples = 0
    start_time = time.time()

    for x, y in loader:
        x = x.to(device)
        y = y.to(device)

        optimizer.zero_grad()
        logits = model(x)
        loss = F.cross_entropy(logits, y)
        loss.backward()
        optimizer.step()

        total_loss += loss.item() * x.size(0)
        preds = logits.argmax(dim=-1)
        total_correct += (preds == y).sum().item()
        total_samples += x.size(0)

    end_time = time.time()
    elapsed_time = end_time - start_time

    return total_loss / total_samples, total_correct / total_samples, elapsed_time


def evaluate(model, loader, device):
    model.eval()
    total_loss = 0.0
    total_correct = 0
    total_samples = 0
    start_time = time.time()

    with torch.no_grad():
        for x, y in loader:
            x = x.to(device)
            y = y.to(device)

            logits = model(x)
            loss = F.cross_entropy(logits, y)

            total_loss += loss.item() * x.size(0)
            preds = logits.argmax(dim=-1)
            total_correct += (preds == y).sum().item()
            total_samples += x.size(0)

    end_time = time.time()
    elapsed_time = end_time - start_time

    return total_loss / total_samples, total_correct / total_samples, elapsed_time


# ================================================================
# CSV logging helper (appends without overwriting existing data)
# ================================================================


def log_results(filepath, row_dict):
    file_exists = os.path.isfile(filepath)
    with open(filepath, "a", newline="") as f:
        writer = csv.DictWriter(f, fieldnames=row_dict.keys())
        if not file_exists:
            writer.writeheader()
        writer.writerow(row_dict)


# ================================================================
# RUN EXAMPLE
# ================================================================
def run_training_noRPE(
    dataset: Literal["mnist", "cifar10"] = "mnist",
    attn_type: Literal["full", "favor+", "relu"] = "favor+",
    epochs: int = 5,
    lr: float = 1e-3,
    batch_size: int = 128,
):
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    print(f"Using device: {device}")

    if dataset == "mnist":
        train_loader, test_loader = get_mnist_loaders(batch_size)
        cfg = ViTConfig(
            img_size=32,
            patch_size=4,
            in_channels=1,
            num_classes=10,
            embed_dim=64,
            depth=4,
            num_heads=4,
            mlp_ratio=4.0,
            attn_type=attn_type,
            nb_features=64,
        )

    elif dataset == "cifar10":
        train_loader, test_loader = get_cifar10_loaders(batch_size)
        cfg = ViTConfig(
            img_size=32,
            patch_size=4,
            in_channels=3,
            num_classes=10,
            embed_dim=64,
            depth=4,
            num_heads=4,
            mlp_ratio=4.0,
            attn_type=attn_type,
            nb_features=64,
        )

    else:
        raise ValueError("dataset must be 'mnist' or 'cifar10'")

    model = build_model(cfg).to(device)
    optimizer = torch.optim.Adam(model.parameters(), lr=lr)

    for epoch in range(1, epochs + 1):
        train_loss, train_acc, train_time = train_one_epoch(model, train_loader, optimizer, device)
        test_loss, test_acc, eval_time = evaluate(model, test_loader, device)

        print(
            f"[{dataset}][{attn_type}] "
            f"Epoch {epoch:02d}: "
            f"train loss={train_loss:.4f}, train acc={train_acc:.4f}, train time={train_time:.2f}s, "
            f"test loss={test_loss:.4f}, test acc={test_acc:.4f}, eval time={eval_time:.2f}s"
        )

        log_results(
            RESULTS_FILE,
            {
                "dataset": dataset,
                "attn_type": attn_type,
                "epoch": epoch,
                "epochs_total": epochs,
                "lr": lr,
                "batch_size": batch_size,
                "train_loss": train_loss,
                "train_acc": train_acc,
                "train_time_sec": train_time,
                "test_loss": test_loss,
                "test_acc": test_acc,
                "eval_time_sec": eval_time,
            },
        )

In [None]:
!nvidia-smi

Mon Dec 15 17:33:07 2025       
+-----------------------------------------------------------------------------------------+
| NVIDIA-SMI 550.54.15              Driver Version: 550.54.15      CUDA Version: 12.4     |
|-----------------------------------------+------------------------+----------------------+
| GPU  Name                 Persistence-M | Bus-Id          Disp.A | Volatile Uncorr. ECC |
| Fan  Temp   Perf          Pwr:Usage/Cap |           Memory-Usage | GPU-Util  Compute M. |
|                                         |                        |               MIG M. |
|   0  Tesla T4                       Off |   00000000:00:04.0 Off |                    0 |
| N/A   35C    P8              9W /   70W |       0MiB /  15360MiB |      0%      Default |
|                                         |                        |                  N/A |
+-----------------------------------------+------------------------+----------------------+
                                                

In [None]:
EPOCHS_MNIST = 5
EPOCHS_CIFAR = 20


if __name__ == "__main__":
    for dataset in ["mnist","cifar10"]:
        for attn in ["full","favor+", "relu"]:
            epochs = EPOCHS_MNIST if dataset == "mnist" else EPOCHS_CIFAR

            print("\n==============================================")
            print(
                f" RUN (no RPE): dataset={dataset} | attn={attn} | epochs={epochs}"
            )
            print("==============================================\n")

            run_training_noRPE(
                dataset=dataset,
                attn_type=attn,
                epochs=epochs,
            )


 RUN (no RPE): dataset=mnist | attn=full | epochs=5

Using device: cuda
[mnist][full] Epoch 01: train loss=1.5229, train acc=0.4350, train time=17.84s, test loss=1.0315, test acc=0.6315, eval time=2.05s
[mnist][full] Epoch 02: train loss=0.8520, train acc=0.7037, train time=16.65s, test loss=0.6763, test acc=0.7653, eval time=1.81s
[mnist][full] Epoch 03: train loss=0.5927, train acc=0.7986, train time=15.85s, test loss=0.5264, test acc=0.8216, eval time=1.78s
[mnist][full] Epoch 04: train loss=0.4753, train acc=0.8399, train time=16.05s, test loss=0.4444, test acc=0.8485, eval time=2.60s
[mnist][full] Epoch 05: train loss=0.4072, train acc=0.8618, train time=15.73s, test loss=0.3715, test acc=0.8756, eval time=1.78s

 RUN (no RPE): dataset=mnist | attn=favor+ | epochs=5

Using device: cuda
[mnist][favor+] Epoch 01: train loss=1.3274, train acc=0.5193, train time=19.55s, test loss=0.9296, test acc=0.6704, eval time=2.61s
[mnist][favor+] Epoch 02: train loss=0.7456, train acc=0.7440, t