In [1]:
import math
import csv
import os
import pandas as pd
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader
import torchvision
import torchvision.transforms as T

from dataclasses import dataclass
from typing import Literal

from google.colab import drive

drive.mount("/content/drive")

Mounted at /content/drive


In [2]:
BASE_DIR = "/content/drive/MyDrive/Colab Notebooks/Data_Mining"

CIFAR10_PATH = f"{BASE_DIR}/cifar10"
MNIST_PATH = f"{BASE_DIR}/mnist"

TODAY = pd.Timestamp.today().strftime("%Y-%m-%d_%H-%M-%S")

RESULTS_FILE = f"{BASE_DIR}/Results/results_2_RPE_{TODAY}.csv"

In [3]:
# ================================================================
# 1. Patch embedding + CLS + simple learned absolute positions
# ================================================================

class PatchEmbedding(nn.Module):
    """
    Image -> sequence of patch embeddings via a Conv2d layer.
    Args:
        img_size: size of the input image (assumed square)
        patch_size: size of each patch (assumed square)
        in_channels: number of input channels (e.g., 3 for RGB)
        embed_dim: dimension of the output embeddings
    """

    def __init__(
        self, img_size: int, patch_size: int, in_channels: int, embed_dim: int
    ):
        super().__init__()
        assert img_size % patch_size == 0, "img_size must be divisible by patch_size"
        self.img_size = img_size
        self.patch_size = patch_size
        self.num_patches = (img_size // patch_size) ** 2
        self.proj = nn.Conv2d(
            in_channels, embed_dim, kernel_size=patch_size, stride=patch_size
        )

    def forward(self, x):
        """
        Args:
            x: Tensor of shape (B, C, H, W)
                Input batch of images.
        Returns:
            Tensor of shape (B, N, D)
                Sequence of patch embeddings, where:
                    N = number of patches
                    D = embedding dimension.
        """
        x = self.proj(x)  # (B, D, H/P, W/P)
        x = x.flatten(2)  # (B, D, N)
        x = x.transpose(1, 2)  # (B, N, D)
        return x


class ViTInputLayer(nn.Module):
    """
    PatchEmbedding + optional [CLS] token

    This module converts an input image into a sequence of patch embeddings,
    optionally prepends a trainable [CLS] token, and adds learned positional
    encodings to all tokens.
    """

    def __init__(self, img_size, patch_size, in_channels, embed_dim, cls_token=True):
        super().__init__()

        # Convert image into sequence of patch embeddings
        self.patch_embed = PatchEmbedding(img_size, patch_size, in_channels, embed_dim)

        self.num_patches = self.patch_embed.num_patches

        # Learnable CLS token (added in front of patch tokens)
        self.cls_token = (
            nn.Parameter(torch.zeros(1, 1, embed_dim)) if cls_token else None
        )

    def forward(self, x):
        B = x.size(0)
        x = self.patch_embed(x)  # (B, N, D)

        if self.cls_token is not None:
            cls = self.cls_token.expand(B, -1, -1)  # match batch size
            x = torch.cat([cls, x], dim=1)
        return x

# ================================================================
# 2. Multi-Head Attention variants
#    - Full softmax attention (baseline)
#    - Performer FAVOR+ (softmax approx with positive random features)
#    - Performer-ReLU (ReLU feature map)
# ================================================================

class MultiHeadSelfAttentionFull(nn.Module):
    def __init__(
        self,
        embed_dim: int,
        num_heads: int,
        attn_drop: float = 0.0,
        proj_drop: float = 0.0,
    ):
        super().__init__()
        assert embed_dim % num_heads == 0
        self.num_heads = num_heads
        self.head_dim = embed_dim // num_heads
        self.scale = self.head_dim**-0.5

        self.qkv = nn.Linear(embed_dim, 3 * embed_dim)
        self.out_proj = nn.Linear(embed_dim, embed_dim)
        self.attn_drop = nn.Dropout(attn_drop)
        self.proj_drop = nn.Dropout(proj_drop)

    def forward(self, x):
        B, N, D = x.shape

        qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, self.head_dim)
        qkv = qkv.permute(2, 0, 3, 1, 4)
        q, k, v = qkv[0], qkv[1], qkv[2]  # (B, H, N, d)

        attn_scores = torch.matmul(q, k.transpose(-2, -1)) * self.scale  # (B,H,N,N)
        attn = F.softmax(attn_scores, dim=-1)
        attn = self.attn_drop(attn)

        out = torch.matmul(attn, v)  # (B,H,N,d)
        out = out.transpose(1, 2).reshape(B, N, D)  # (B,N,D)
        out = self.out_proj(out)
        return self.proj_drop(out)

class PerformerAttentionFavorPlus(nn.Module):
    def __init__(
        self,
        embed_dim: int,
        num_heads: int,
        nb_features: int = 64,
        attn_drop: float = 0.0,
        proj_drop: float = 0.0,
        eps: float = 1e-6,
        rpe_type: Literal["none", "rope", "classic", "string"] = "none",
        seq_len: int | None = None,
    ):
        super().__init__()

        assert embed_dim % num_heads == 0
        self.embed_dim = embed_dim
        self.num_heads = num_heads
        self.head_dim = embed_dim // num_heads
        self.nb_features = nb_features
        self.eps = eps

        self.q_proj = nn.Linear(embed_dim, embed_dim)
        self.k_proj = nn.Linear(embed_dim, embed_dim)
        self.v_proj = nn.Linear(embed_dim, embed_dim)
        self.out_proj = nn.Linear(embed_dim, embed_dim)

        self.attn_drop = nn.Dropout(attn_drop)
        self.proj_drop = nn.Dropout(proj_drop)

        # Random features
        W = torch.randn(num_heads, nb_features, self.head_dim)
        self.register_buffer("W", W)

        # RPE
        self.rotary = None
        self.rpe_classic = None
        self.rpe_string = None

        if rpe_type == "rope":
            if seq_len is None:
                raise ValueError("seq_len required for RoPE")
            self.rotary = RotaryEmbedding(self.head_dim, max_seq_len=seq_len)

        elif rpe_type == "classic":
            if seq_len is None:
                raise ValueError("seq_len required for classic RPE")
            self.rpe_classic = ClassicRPEPerformer(seq_len)

        elif rpe_type == "string":
            if seq_len is None:
                raise ValueError("seq_len required for STRING RPE")
            self.rpe_string = CirculantSTRING(seq_len, self.head_dim)

    def _favor_feature_map(self, x):
        B, H, N, d_k = x.shape
        x_proj = torch.einsum("b h n d, h m d -> b h n m", x, self.W)
        sq_norm = (x**2).sum(dim=-1, keepdim=True) / 2.0
        return torch.exp(x_proj - sq_norm) / math.sqrt(self.nb_features)

    def forward(self, x):
        B, N, D = x.shape

        q = self.q_proj(x)
        k = self.k_proj(x)
        v = self.v_proj(x)

        q = q.reshape(B, N, self.num_heads, self.head_dim).permute(0, 2, 1, 3)
        k = k.reshape(B, N, self.num_heads, self.head_dim).permute(0, 2, 1, 3)
        v = v.reshape(B, N, self.num_heads, self.head_dim).permute(0, 2, 1, 3)

        # RoPE
        if self.rotary is not None:
            q, k = self.rotary(q, k)

        q_feat = self._favor_feature_map(q)
        k_feat = self._favor_feature_map(k)

        kv = torch.einsum("b h n m, b h n d -> b h m d", k_feat, v)
        k_sum = k_feat.sum(dim=2)

        denom = torch.einsum("b h n m, b h m -> b h n", q_feat, k_sum)
        denom = denom.unsqueeze(-1) + self.eps

        out = torch.einsum("b h n m, b h m d -> b h n d", q_feat, kv)
        out = out / denom  # (B,H,N,d)

        if self.rpe_classic is not None:
            weight = self.rpe_classic(N, x.device)  # (N,N)
            out = torch.einsum("b h n d, n m -> b h m d", out, weight)

        if self.rpe_string is not None:
            weight = self.rpe_string(N, x.device)  # (N,N)
            out = torch.einsum("b h n d, n m -> b h m d", out, weight)

        out = out.permute(0, 2, 1, 3).reshape(B, N, D)  # (B,N,D)

        out = self.out_proj(out)
        out = self.proj_drop(out)
        return out


class PerformerAttentionReLU(nn.Module):
    def __init__(
        self,
        embed_dim: int,
        num_heads: int,
        attn_drop: float = 0.0,
        proj_drop: float = 0.0,
        eps: float = 1e-6,
        rpe_type: Literal["none", "rope", "classic", "string"] = "none",
        seq_len: int | None = None,
    ):
        super().__init__()

        assert embed_dim % num_heads == 0
        self.embed_dim = embed_dim
        self.num_heads = num_heads
        self.head_dim = embed_dim // num_heads
        self.eps = eps

        self.q_proj = nn.Linear(embed_dim, embed_dim)
        self.k_proj = nn.Linear(embed_dim, embed_dim)
        self.v_proj = nn.Linear(embed_dim, embed_dim)
        self.out_proj = nn.Linear(embed_dim, embed_dim)

        self.attn_drop = nn.Dropout(attn_drop)
        self.proj_drop = nn.Dropout(proj_drop)

        # RPE
        self.rotary = None
        self.rpe_classic = None
        self.rpe_string = None

        if rpe_type == "rope":
            if seq_len is None:
                raise ValueError("seq_len required for RoPE")
            self.rotary = RotaryEmbedding(self.head_dim, max_seq_len=seq_len)

        elif rpe_type == "classic":
            if seq_len is None:
                raise ValueError("seq_len required for classic RPE")
            self.rpe_classic = ClassicRPEPerformer(seq_len)

        elif rpe_type == "string":
            if seq_len is None:
                raise ValueError("seq_len required for STRING RPE")
            self.rpe_string = CirculantSTRING(seq_len, self.head_dim)

    def _relu_feature_map(self, x):
        return F.relu(x)

    def forward(self, x):
        B, N, D = x.shape

        q = self.q_proj(x)
        k = self.k_proj(x)
        v = self.v_proj(x)

        q = q.reshape(B, N, self.num_heads, self.head_dim).permute(0, 2, 1, 3)
        k = k.reshape(B, N, self.num_heads, self.head_dim).permute(0, 2, 1, 3)
        v = v.reshape(B, N, self.num_heads, self.head_dim).permute(0, 2, 1, 3)

        # RoPE
        if self.rotary is not None:
            q, k = self.rotary(q, k)

        q_feat = self._relu_feature_map(q)
        k_feat = self._relu_feature_map(k)

        kv = torch.einsum("b h n d, b h n e -> b h d e", k_feat, v)
        k_sum = k_feat.sum(dim=2)

        denom = torch.einsum("b h n d, b h d -> b h n", q_feat, k_sum)
        denom = denom.unsqueeze(-1) + self.eps

        out = torch.einsum("b h n d, b h d e -> b h n e", q_feat, kv)
        out = out / denom  # (B,H,N,d)

        # Classic RPE multiplicative
        if self.rpe_classic is not None:
            weight = self.rpe_classic(N, x.device)  # (N,N)
            out = torch.einsum("b h n e, n m -> b h m e", out, weight)

        # STRING RPE
        if self.rpe_string is not None:
            weight = self.rpe_string(N, x.device)  # (N,N)
            out = torch.einsum("b h n e, n m -> b h m e", out, weight)

        out = out.permute(0, 2, 1, 3).reshape(B, N, D)

        out = self.out_proj(out)
        out = self.proj_drop(out)
        return out

# ================================================================
# 3. Transformer block + ViT backbone
# ================================================================

class MLPBlock(nn.Module):
    """
    Standard Transformer feedforward block:
    Linear -> GELU -> Linear (with optional dropout).

    Position-wise feedforward network that expands and compresses each token
    to increase the transformer's expressive capacity.
    """

    def __init__(self, embed_dim: int, mlp_ratio: float = 4.0, drop: float = 0.0):
        super().__init__()
        hidden_dim = int(embed_dim * mlp_ratio)

        self.fc1 = nn.Linear(embed_dim, hidden_dim)  # expansion layer
        self.act = nn.GELU()  # non-linearity
        self.fc2 = nn.Linear(hidden_dim, embed_dim)  # projection back to D
        self.drop = nn.Dropout(drop)

    def forward(self, x):
        x = self.fc1(x)  # (B, N, hidden_dim)
        x = self.act(x)  # GELU activation
        x = self.drop(x)  # dropout after activation
        x = self.fc2(x)  # back to (B, N, D)
        x = self.drop(x)  # optional dropout on output
        return x

class TransformerEncoderBlock(nn.Module):
  """
  Pre-LN Transformer block with selectable attention type
  """

  def __init__(
      self,
      embed_dim: int,
      num_heads: int,
      mlp_ratio: float = 4.0,
      attn_type: Literal["full", "favor+", "relu"] = "full",
      nb_features: int = 64,
      drop: float = 0.0,
      attn_drop: float = 0.0,
      rpe_type: Literal["none", "rope", "classic", "string"] = "none",
      seq_len: int | None = None,
  ):
      super().__init__()

      self.norm1 = nn.LayerNorm(embed_dim)
      self.norm2 = nn.LayerNorm(embed_dim)

      if attn_type == "full":
          self.attn = MultiHeadSelfAttentionFull(
              embed_dim,
              num_heads,
              attn_drop,
              drop,
          )

      elif attn_type == "favor+":
          self.attn = PerformerAttentionFavorPlus(
              embed_dim,
              num_heads,
              nb_features,
              attn_drop,
              drop,
              rpe_type=rpe_type,
              seq_len=seq_len,
          )
      elif attn_type == "relu":
          self.attn = PerformerAttentionReLU(
              embed_dim,
              num_heads,
              attn_drop=attn_drop,
              proj_drop=drop,
              rpe_type=rpe_type,
              seq_len=seq_len,
          )

      self.mlp = MLPBlock(embed_dim, mlp_ratio, drop)

  def forward(self, x):
      x = x + self.attn(self.norm1(x))
      x = x + self.mlp(self.norm2(x))
      return x

class ViTClassifier(nn.Module):
    """
    ViT / Performer-ViT for MNIST / CIFAR-10
    """

    def __init__(
        self,
        img_size: int,
        patch_size: int,
        in_channels: int,
        num_classes: int,
        embed_dim: int = 64,
        depth: int = 4,
        num_heads: int = 4,
        mlp_ratio: float = 4.0,
        attn_type: Literal["full", "favor+", "relu"] = "full",
        nb_features: int = 64,
        drop: float = 0.0,
        attn_drop: float = 0.0,
        rpe_type: Literal["none", "rope", "classic", "string"] = "none",
    ):
        super().__init__()

        self.input_layer = ViTInputLayer(
            img_size, patch_size, in_channels, embed_dim, cls_token=True
        )

        seq_len = 1 + (img_size // patch_size) ** 2

        self.blocks = nn.ModuleList(
            [
                TransformerEncoderBlock(
                    embed_dim=embed_dim,
                    num_heads=num_heads,
                    mlp_ratio=mlp_ratio,
                    attn_type=attn_type,
                    nb_features=nb_features,
                    drop=drop,
                    attn_drop=attn_drop,
                    rpe_type=rpe_type,
                    seq_len=seq_len,
                )
                for _ in range(depth)
            ]
        )

        self.norm = nn.LayerNorm(embed_dim)
        self.head = nn.Linear(embed_dim, num_classes)

    def forward(self, x):
        x = self.input_layer(x)
        for blk in self.blocks:
            x = blk(x)
        x = self.norm(x)
        cls_token = x[:, 0]
        return self.head(cls_token)

In [4]:
# ---------------------------------------------------------
# RPE MODULES
# ---------------------------------------------------------

class ClassicRPEPerformer(nn.Module):
    """
    Classic multiplicative RPE for Performer (Luo et al., 2021)
    Applied as a kernel weight: w[i,j]
    """

    def __init__(self, seq_len):
        super().__init__()
        self.seq_len = seq_len
        self.rpe = nn.Parameter(torch.zeros(seq_len))  # 1D kernel

    def forward(self, N, device):
        idx = torch.arange(N, device=device)

        # directional: (j - i) mod N
        rel = (idx[None, :] - idx[:, None]) % N

        # kernel weighting
        weight = torch.exp(self.rpe[rel])  # (N,N)
        return weight


class CirculantSTRING(nn.Module):
    """
    Circulant-STRING RPE : génère une matrice (N,N) de pondération (directionnelle).
    Utilisée uniquement dans les Performers.
    """

    def __init__(self, seq_len: int, head_dim: int):
        super().__init__()
        self.seq_len = seq_len
        self.kernel = nn.Parameter(torch.zeros(seq_len))
        self.scale = 1.0 / math.sqrt(head_dim)

    def forward(self, N: int, device: torch.device):
        if N != self.seq_len:
            raise ValueError(
                f"CirculantSTRING requires fixed seq_len={self.seq_len}, got N={N}"
            )

        kernel_fft = torch.fft.rfft(self.kernel, n=N)  # (N_fft,)

        eye = torch.eye(N, device=device)  # (N,N)
        eye_fft = torch.fft.rfft(eye, n=N, dim=-1)  # (N, N_fft)

        out = torch.fft.irfft(eye_fft * kernel_fft[None, :], n=N)  # (N,N)
        return self.scale * out  # (N,N)

class RotaryEmbedding(nn.Module):
    def __init__(self, head_dim: int, max_seq_len: int = 512):
        super().__init__()
        if head_dim % 2 != 0:
            raise ValueError("RoPE nécessite un head_dim pair")

        self.head_dim = head_dim
        self.max_seq_len = max_seq_len

        inv_freq = 1.0 / (10000 ** (torch.arange(0, head_dim, 2).float() / head_dim))
        self.register_buffer("inv_freq", inv_freq)

    def _build_sin_cos(self, seq_len: int, device: torch.device):
        # positions (N,)
        pos = torch.arange(seq_len, dtype=torch.float32, device=device)  # (N,)

        # freqs = (N, head_dim/2)
        freqs = torch.einsum("n,d->nd", pos, self.inv_freq)

        # sin/cos = (N, head_dim/2)
        sin = torch.sin(freqs)
        cos = torch.cos(freqs)

        sin = torch.stack([sin, sin], dim=-1).reshape(seq_len, self.head_dim)
        cos = torch.stack([cos, cos], dim=-1).reshape(seq_len, self.head_dim)

        # reshape (1,1,N,head_dim)
        return (
            sin.unsqueeze(0).unsqueeze(0),
            cos.unsqueeze(0).unsqueeze(0),
        )

    @staticmethod
    def _rotate_half(x: torch.Tensor) -> torch.Tensor:
        d = x.shape[-1]
        x1 = x[..., : d // 2]
        x2 = x[..., d // 2 :]
        return torch.cat([-x2, x1], dim=-1)

    def forward(self, q, k):
        B, H, N, d = q.shape
        if N > self.max_seq_len:
            raise ValueError(f"seq_len {N} > max_seq_len {self.max_seq_len}")

        sin, cos = self._build_sin_cos(N, q.device)

        q_rot = (q * cos) + (self._rotate_half(q) * sin)
        k_rot = (k * cos) + (self._rotate_half(k) * sin)
        return q_rot, k_rot

# ================================================================
# 4. Helper functions to build model from config
# ================================================================
@dataclass
class ViTConfig:
    img_size: int = 32
    patch_size: int = 4
    in_channels: int = 3
    num_classes: int = 10
    embed_dim: int = 64
    depth: int = 4
    num_heads: int = 4
    mlp_ratio: float = 4.0
    attn_type: str = "favor+"
    nb_features: int = 64
    drop: float = 0.0
    attn_drop: float = 0.0
    rpe_type: Literal["none", "rope", "classic", "string"] = "none"


def build_model(cfg: ViTConfig) -> nn.Module:
    return ViTClassifier(
        img_size=cfg.img_size,
        patch_size=cfg.patch_size,
        in_channels=cfg.in_channels,
        num_classes=cfg.num_classes,
        embed_dim=cfg.embed_dim,
        depth=cfg.depth,
        num_heads=cfg.num_heads,
        mlp_ratio=cfg.mlp_ratio,
        attn_type=cfg.attn_type,
        nb_features=cfg.nb_features,
        drop=cfg.drop,
        attn_drop=cfg.attn_drop,
        rpe_type=cfg.rpe_type,
    )

# ================================================================
# 5. Training loop for MNIST / CIFAR-10 (simple baseline)
# ================================================================
import csv
import time

def get_mnist_loaders(batch_size=128):
    # Basic preprocessing: resize to 32×32 then convert to tensor
    transform = T.Compose(
        [
            T.Resize(32),
            T.ToTensor(),
        ]
    )

    # Load MNIST train/test splits
    train_set = torchvision.datasets.MNIST(
        root=MNIST_PATH, train=True, download=True, transform=transform
    )
    test_set = torchvision.datasets.MNIST(
        root=MNIST_PATH, train=False, download=True, transform=transform
    )

    # Wrap into DataLoader objects
    train_loader = DataLoader(train_set, batch_size=batch_size, shuffle=True)
    test_loader = DataLoader(test_set, batch_size=batch_size, shuffle=False)
    return train_loader, test_loader


def get_cifar10_loaders(batch_size=128):
    # Data augmentation for training
    transform_train = T.Compose(
        [
            T.RandomHorizontalFlip(),
            T.ToTensor(),
        ]
    )

    # Standard preprocessing for test
    transform_test = T.Compose([T.ToTensor()])

    # Load CIFAR-10 train/test splits
    train_set = torchvision.datasets.CIFAR10(
        root=CIFAR10_PATH, train=True, download=True, transform=transform_train
    )
    test_set = torchvision.datasets.CIFAR10(
        root=CIFAR10_PATH, train=False, download=True, transform=transform_test
    )

    # Wrap into DataLoaders
    train_loader = DataLoader(train_set, batch_size=batch_size, shuffle=True)
    test_loader = DataLoader(test_set, batch_size=batch_size, shuffle=False)
    return train_loader, test_loader


def train_one_epoch(model, loader, optimizer, device):
    model.train()
    total_loss = 0.0
    total_correct = 0
    total_samples = 0
    start_time = time.time()

    for x, y in loader:
        x = x.to(device)
        y = y.to(device)

        optimizer.zero_grad()
        logits = model(x)  # forward pass
        loss = F.cross_entropy(logits, y)  # classification loss
        loss.backward()  # backprop
        optimizer.step()  # update weights

        # Accumulate stats
        total_loss += loss.item() * x.size(0)
        preds = logits.argmax(dim=-1)
        total_correct += (preds == y).sum().item()
        total_samples += x.size(0)

    end_time = time.time()
    elapsed_time = end_time - start_time


    # Return average loss, accuracy, and elapsed time
    return total_loss / total_samples, total_correct / total_samples, elapsed_time


def evaluate(model, loader, device):
    model.eval()
    total_loss = 0.0
    total_correct = 0
    total_samples = 0
    start_time = time.time()


    with torch.no_grad():
        for x, y in loader:
            x = x.to(device)
            y = y.to(device)

            logits = model(x)
            loss = F.cross_entropy(logits, y)

            total_loss += loss.item() * x.size(0)
            preds = logits.argmax(dim=-1)
            total_correct += (preds == y).sum().item()
            total_samples += x.size(0)

    end_time = time.time()
    elapsed_time = end_time - start_time

    return total_loss / total_samples, total_correct / total_samples, elapsed_time


# ================================================================
# CSV logging helper (appends without overwriting existing data)
# ================================================================

def log_results(filepath, row_dict):
    file_exists = os.path.isfile(filepath)
    with open(filepath, "a", newline="") as f:
        writer = csv.DictWriter(f, fieldnames=row_dict.keys())
        if not file_exists:
            writer.writeheader()
        writer.writerow(row_dict)

# ================================================================
# RUN EXAMPLE
# ================================================================

def run_training_RPE(
    dataset: Literal["mnist", "cifar10"] = "mnist",
    attn_type: Literal["full", "favor+", "relu"] = "full",
    rpe_type: Literal["none", "rope", "classic", "string"] = "none",
    epochs: int = 5,
    lr: float = 1e-3,
    batch_size: int = 128,
):
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    print(f"Using device: {device}")

    # 1) Dataset selection
    if dataset == "mnist":
        train_loader, test_loader = get_mnist_loaders(batch_size)
        cfg = ViTConfig(
            img_size=32,
            patch_size=4,
            in_channels=1,
            num_classes=10,
            embed_dim=64,
            depth=4,
            num_heads=4,
            mlp_ratio=4.0,
            attn_type=attn_type,
            nb_features=64,
            rpe_type=rpe_type,
        )

    elif dataset == "cifar10":
        train_loader, test_loader = get_cifar10_loaders(batch_size)
        cfg = ViTConfig(
            img_size=32,
            patch_size=4,
            in_channels=3,
            num_classes=10,
            embed_dim=64,
            depth=4,
            num_heads=4,
            mlp_ratio=4.0,
            attn_type=attn_type,
            nb_features=64,
            rpe_type=rpe_type,
        )

    # 2) Build model
    model = build_model(cfg).to(device)
    optimizer = torch.optim.Adam(model.parameters(), lr=lr)

    # 3) Training loop
    for epoch in range(1, epochs + 1):
        train_loss, train_acc, train_time = train_one_epoch(model, train_loader, optimizer, device)
        test_loss, test_acc, eval_time = evaluate(model, test_loader, device)

        print(
            f"[{dataset}][attn={attn_type}][rpe={rpe_type}] "
            f"Epoch {epoch:02d}: "
            f"train loss={train_loss:.4f}, train acc={train_acc:.4f}, "
            f"test loss={test_loss:.4f}, test acc={test_acc:.4f}"
        )

        # Log to CSV
        log_results(
            RESULTS_FILE,
            {
                "dataset": dataset,
                "attn_type": attn_type,
                "rpe_type": rpe_type,
                "epochs_total": epochs,
                "epoch": epoch,
                "lr": lr,
                "batch_size": batch_size,
                "train_loss": train_loss,
                "train_acc": train_acc,
                "train_time_sec": train_time,
                "test_loss": test_loss,
                "test_acc": test_acc,
                "eval_time_sec": eval_time,
            },
        )

# datasets = ["mnist", "cifar10"]
datasets = ["cifar10"]
attn_types = ["favor+", "relu"]
rpe_types_full = ["none"]
rpe_types_performer = ["none", "rope", "classic", "string"]

EPOCHS_MNIST = 5
EPOCHS_CIFAR = 20

for dataset in datasets:
    for attn in attn_types:
        if attn == "full":
            to_run = rpe_types_full
        else:
            to_run = rpe_types_performer

        for rpe in to_run:
            epochs = EPOCHS_MNIST if dataset == "mnist" else EPOCHS_CIFAR

            print("\n==============================================")
            print(
                f" RUN: dataset={dataset} | attn={attn} | rpe={rpe} | epochs={epochs}"
            )
            print("==============================================\n")

            run_training_RPE(
                dataset=dataset,
                attn_type=attn,
                rpe_type=rpe,
                epochs=epochs,
                lr=1e-3,
                batch_size=128,
            )


 RUN: dataset=cifar10 | attn=favor+ | rpe=none | epochs=20

Using device: cuda


KeyboardInterrupt: 

In [None]:
!ls -l "/content"


In [None]:
!ls -l "/content/drive/MyDrive/Colab Notebooks"


# Special case used to maximize accuracy (for explanatory purposes)

In [5]:
from torch.optim.lr_scheduler import CosineAnnealingLR
RESULTS_FILE_MAX = f"{BASE_DIR}/Results/results_2_RPE_MAX_{TODAY}.csv"

# ================================================================
# 1. Patch embedding + CLS + simple learned absolute positions
# ================================================================

class PatchEmbedding(nn.Module):

    def __init__(
        self, img_size: int, patch_size: int, in_channels: int, embed_dim: int
    ):
        super().__init__()
        assert img_size % patch_size == 0,
        self.img_size = img_size
        self.patch_size = patch_size
        self.num_patches = (img_size // patch_size) ** 2
        self.proj = nn.Conv2d(
            in_channels, embed_dim, kernel_size=patch_size, stride=patch_size
        )

    def forward(self, x):
        x = self.proj(x)
        x = x.flatten(2)
        x = x.transpose(1, 2)
        return x


class ViTInputLayer(nn.Module):
    def __init__(self, img_size, patch_size, in_channels, embed_dim, cls_token=True):
        super().__init__()

        self.patch_embed = PatchEmbedding(img_size, patch_size, in_channels, embed_dim)

        self.num_patches = self.patch_embed.num_patches

        self.cls_token = (
            nn.Parameter(torch.zeros(1, 1, embed_dim)) if cls_token else None
        )

    def forward(self, x):
        B = x.size(0)
        x = self.patch_embed(x)

        if self.cls_token is not None:
            cls = self.cls_token.expand(B, -1, -1)
            x = torch.cat([cls, x], dim=1)
        return x

# ================================================================
# 2. Multi-Head Attention variants
#    - Full softmax attention (baseline)
#    - Performer FAVOR+ (softmax approx with positive random features)
#    - Performer-ReLU (ReLU feature map)
# ================================================================

class MultiHeadSelfAttentionFull(nn.Module):
    def __init__(
        self,
        embed_dim: int,
        num_heads: int,
        attn_drop: float = 0.0,
        proj_drop: float = 0.0,
    ):
        super().__init__()
        assert embed_dim % num_heads == 0
        self.num_heads = num_heads
        self.head_dim = embed_dim // num_heads
        self.scale = self.head_dim**-0.5

        self.qkv = nn.Linear(embed_dim, 3 * embed_dim)
        self.out_proj = nn.Linear(embed_dim, embed_dim)
        self.attn_drop = nn.Dropout(attn_drop)
        self.proj_drop = nn.Dropout(proj_drop)

    def forward(self, x):
        B, N, D = x.shape

        qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, self.head_dim)
        qkv = qkv.permute(2, 0, 3, 1, 4)
        q, k, v = qkv[0], qkv[1], qkv[2]

        attn_scores = torch.matmul(q, k.transpose(-2, -1)) * self.scale
        attn = F.softmax(attn_scores, dim=-1)
        attn = self.attn_drop(attn)

        out = torch.matmul(attn, v)
        out = out.transpose(1, 2).reshape(B, N, D)
        out = self.out_proj(out)
        return self.proj_drop(out)

class PerformerAttentionFavorPlus(nn.Module):
    def __init__(
        self,
        embed_dim: int,
        num_heads: int,
        nb_features: int = 64,
        attn_drop: float = 0.0,
        proj_drop: float = 0.0,
        eps: float = 1e-6,
        rpe_type: Literal["none", "rope", "classic", "string"] = "none",
        seq_len: int | None = None,
    ):
        super().__init__()

        assert embed_dim % num_heads == 0
        self.embed_dim = embed_dim
        self.num_heads = num_heads
        self.head_dim = embed_dim // num_heads
        self.nb_features = nb_features
        self.eps = eps

        self.q_proj = nn.Linear(embed_dim, embed_dim)
        self.k_proj = nn.Linear(embed_dim, embed_dim)
        self.v_proj = nn.Linear(embed_dim, embed_dim)
        self.out_proj = nn.Linear(embed_dim, embed_dim)

        self.attn_drop = nn.Dropout(attn_drop)
        self.proj_drop = nn.Dropout(proj_drop)

        W = torch.randn(num_heads, nb_features, self.head_dim)
        self.register_buffer("W", W)

        self.rotary = None
        self.rpe_classic = None
        self.rpe_string = None

        if rpe_type == "rope":
            if seq_len is None:
                raise ValueError("seq_len required for RoPE")
            self.rotary = RotaryEmbedding(self.head_dim, max_seq_len=seq_len)

        elif rpe_type == "classic":
            if seq_len is None:
                raise ValueError("seq_len required for classic RPE")
            self.rpe_classic = ClassicRPEPerformer(seq_len)

        elif rpe_type == "string":
            if seq_len is None:
                raise ValueError("seq_len required for STRING RPE")
            self.rpe_string = CirculantSTRING(seq_len, self.head_dim)

    def _favor_feature_map(self, x):
        B, H, N, d_k = x.shape
        x_proj = torch.einsum("b h n d, h m d -> b h n m", x, self.W)
        sq_norm = (x**2).sum(dim=-1, keepdim=True) / 2.0
        return torch.exp(x_proj - sq_norm) / math.sqrt(self.nb_features)

    def forward(self, x):
        B, N, D = x.shape

        q = self.q_proj(x)
        k = self.k_proj(x)
        v = self.v_proj(x)

        q = q.reshape(B, N, self.num_heads, self.head_dim).permute(0, 2, 1, 3)
        k = k.reshape(B, N, self.num_heads, self.head_dim).permute(0, 2, 1, 3)
        v = v.reshape(B, N, self.num_heads, self.head_dim).permute(0, 2, 1, 3)

        if self.rotary is not None:
            q, k = self.rotary(q, k)

        q_feat = self._favor_feature_map(q)
        k_feat = self._favor_feature_map(k)

        kv = torch.einsum("b h n m, b h n d -> b h m d", k_feat, v)
        k_sum = k_feat.sum(dim=2)

        denom = torch.einsum("b h n m, b h m -> b h n", q_feat, k_sum)
        denom = denom.unsqueeze(-1) + self.eps

        out = torch.einsum("b h n m, b h m d -> b h n d", q_feat, kv)
        out = out / denom

        if self.rpe_classic is not None:
            weight = self.rpe_classic(N, x.device)
            out = torch.einsum("b h n d, n m -> b h m d", out, weight)

        if self.rpe_string is not None:
            weight = self.rpe_string(N, x.device)
            out = torch.einsum("b h n d, n m -> b h m d", out, weight)

        out = out.permute(0, 2, 1, 3).reshape(B, N, D)

        out = self.out_proj(out)
        out = self.proj_drop(out)
        return out


# ================================================================
# 3. Transformer block + ViT backbone
# ================================================================

class MLPBlock(nn.Module):
    def __init__(self, embed_dim: int, mlp_ratio: float = 4.0, drop: float = 0.0):
        super().__init__()
        hidden_dim = int(embed_dim * mlp_ratio)

        self.fc1 = nn.Linear(embed_dim, hidden_dim)
        self.act = nn.GELU()
        self.fc2 = nn.Linear(hidden_dim, embed_dim)
        self.drop = nn.Dropout(drop)

    def forward(self, x):
        x = self.fc1(x)
        x = self.act(x)
        x = self.drop(x)
        x = self.fc2(x)
        x = self.drop(x)
        return x

class TransformerEncoderBlock(nn.Module):
  def __init__(
      self,
      embed_dim: int,
      num_heads: int,
      mlp_ratio: float = 4.0,
      attn_type: Literal["full", "favor+", "relu"] = "full",
      nb_features: int = 64,
      drop: float = 0.0,
      attn_drop: float = 0.0,
      rpe_type: Literal["none", "rope", "classic", "string"] = "none",
      seq_len: int | None = None,
  ):
      super().__init__()

      self.norm1 = nn.LayerNorm(embed_dim)
      self.norm2 = nn.LayerNorm(embed_dim)

      if attn_type == "full":
          self.attn = MultiHeadSelfAttentionFull(
              embed_dim,
              num_heads,
              attn_drop,
              drop,
          )

      elif attn_type == "favor+":
          self.attn = PerformerAttentionFavorPlus(
              embed_dim,
              num_heads,
              nb_features,
              attn_drop,
              drop,
              rpe_type=rpe_type,
              seq_len=seq_len,
          )
      elif attn_type == "relu":
          self.attn = PerformerAttentionReLU(
              embed_dim,
              num_heads,
              attn_drop=attn_drop,
              proj_drop=drop,
              rpe_type=rpe_type,
              seq_len=seq_len,
          )

      self.mlp = MLPBlock(embed_dim, mlp_ratio, drop)

  def forward(self, x):
      x = x + self.attn(self.norm1(x))
      x = x + self.mlp(self.norm2(x))
      return x

class ViTClassifier(nn.Module):
    def __init__(
        self,
        img_size: int,
        patch_size: int,
        in_channels: int,
        num_classes: int,
        embed_dim: int = 64,
        depth: int = 4,
        num_heads: int = 4,
        mlp_ratio: float = 4.0,
        attn_type: Literal["full", "favor+", "relu"] = "full",
        nb_features: int = 64,
        drop: float = 0.0,
        attn_drop: float = 0.0,
        rpe_type: Literal["none", "rope", "classic", "string"] = "none",
    ):
        super().__init__()

        self.input_layer = ViTInputLayer(
            img_size, patch_size, in_channels, embed_dim, cls_token=True
        )

        seq_len = 1 + (img_size // patch_size) ** 2

        self.blocks = nn.ModuleList(
            [
                TransformerEncoderBlock(
                    embed_dim=embed_dim,
                    num_heads=num_heads,
                    mlp_ratio=mlp_ratio,
                    attn_type=attn_type,
                    nb_features=nb_features,
                    drop=drop,
                    attn_drop=attn_drop,
                    rpe_type=rpe_type,
                    seq_len=seq_len,
                )
                for _ in range(depth)
            ]
        )

        self.norm = nn.LayerNorm(embed_dim)
        self.head = nn.Linear(embed_dim, num_classes)

    def forward(self, x):
        x = self.input_layer(x)
        for blk in self.blocks:
            x = blk(x)
        x = self.norm(x)
        cls_token = x[:, 0]
        return self.head(cls_token)

In [None]:
# ---------------------------------------------------------
# RPE MODULES
# ---------------------------------------------------------

class RotaryEmbedding(nn.Module):
    def __init__(self, head_dim: int, max_seq_len: int = 512):
        super().__init__()
        if head_dim % 2 != 0:
            raise ValueError("RoPE nécessite un head_dim pair")

        self.head_dim = head_dim
        self.max_seq_len = max_seq_len

        inv_freq = 1.0 / (10000 ** (torch.arange(0, head_dim, 2).float() / head_dim))
        self.register_buffer("inv_freq", inv_freq)

    def _build_sin_cos(self, seq_len: int, device: torch.device):
        pos = torch.arange(seq_len, dtype=torch.float32, device=device)  # (N,)
        freqs = torch.einsum("n,d->nd", pos, self.inv_freq)

        sin = torch.sin(freqs)
        cos = torch.cos(freqs)

        sin = torch.stack([sin, sin], dim=-1).reshape(seq_len, self.head_dim)
        cos = torch.stack([cos, cos], dim=-1).reshape(seq_len, self.head_dim)

        return (
            sin.unsqueeze(0).unsqueeze(0),
            cos.unsqueeze(0).unsqueeze(0),
        )

    @staticmethod
    def _rotate_half(x: torch.Tensor) -> torch.Tensor:
        d = x.shape[-1]
        x1 = x[..., : d // 2]
        x2 = x[..., d // 2 :]
        return torch.cat([-x2, x1], dim=-1)

    def forward(self, q, k):
        B, H, N, d = q.shape
        if N > self.max_seq_len:
            raise ValueError(f"seq_len {N} > max_seq_len {self.max_seq_len}")

        sin, cos = self._build_sin_cos(N, q.device)

        q_rot = (q * cos) + (self._rotate_half(q) * sin)
        k_rot = (k * cos) + (self._rotate_half(k) * sin)
        return q_rot, k_rot

# ================================================================
# 4. Helper functions to build model from config
# ================================================================
@dataclass
class ViTConfig:
    img_size: int = 32
    patch_size: int = 4
    in_channels: int = 3
    num_classes: int = 10
    embed_dim: int = 64
    depth: int = 4
    num_heads: int = 4
    mlp_ratio: float = 4.0
    attn_type: str = "favor+"
    nb_features: int = 64
    drop: float = 0.0
    attn_drop: float = 0.0
    rpe_type: Literal["rope"] = "rope"
    drop: float = 0.1
    attn_drop: float = 0.1



def build_model(cfg: ViTConfig) -> nn.Module:
    return ViTClassifier(
        img_size=cfg.img_size,
        patch_size=cfg.patch_size,
        in_channels=cfg.in_channels,
        num_classes=cfg.num_classes,
        embed_dim=cfg.embed_dim,
        depth=cfg.depth,
        num_heads=cfg.num_heads,
        mlp_ratio=cfg.mlp_ratio,
        attn_type=cfg.attn_type,
        nb_features=cfg.nb_features,
        drop=cfg.drop,
        attn_drop=cfg.attn_drop,
        rpe_type=cfg.rpe_type,
    )

# ================================================================
# 5. Training loop for MNIST / CIFAR-10 (simple baseline)
# ================================================================
import csv
import time

def get_mnist_loaders(batch_size=128):
    transform = T.Compose(
        [
            T.Resize(32),
            T.ToTensor(),
        ]
    )

    train_set = torchvision.datasets.MNIST(
        root=MNIST_PATH, train=True, download=True, transform=transform
    )
    test_set = torchvision.datasets.MNIST(
        root=MNIST_PATH, train=False, download=True, transform=transform
    )

    train_loader = DataLoader(train_set, batch_size=batch_size, shuffle=True)
    test_loader = DataLoader(test_set, batch_size=batch_size, shuffle=False)
    return train_loader, test_loader


def get_cifar10_loaders(batch_size=128):
    transform_train = T.Compose(
        [
            T.RandomHorizontalFlip(),
            T.ToTensor(),
        ]
    )

    transform_test = T.Compose([T.ToTensor()])

    train_set = torchvision.datasets.CIFAR10(
        root=CIFAR10_PATH, train=True, download=True, transform=transform_train
    )
    test_set = torchvision.datasets.CIFAR10(
        root=CIFAR10_PATH, train=False, download=True, transform=transform_test
    )

    train_loader = DataLoader(train_set, batch_size=batch_size, shuffle=True)
    test_loader = DataLoader(test_set, batch_size=batch_size, shuffle=False)
    return train_loader, test_loader


def train_one_epoch(model, loader, optimizer, device):
    model.train()
    total_loss = 0.0
    total_correct = 0
    total_samples = 0
    start_time = time.time()

    for x, y in loader:
        x = x.to(device)
        y = y.to(device)

        optimizer.zero_grad()
        logits = model(x)
        loss = F.cross_entropy(logits, y)
        loss.backward()
        optimizer.step()

        total_loss += loss.item() * x.size(0)
        preds = logits.argmax(dim=-1)
        total_correct += (preds == y).sum().item()
        total_samples += x.size(0)

    end_time = time.time()
    elapsed_time = end_time - start_time

    return total_loss / total_samples, total_correct / total_samples, elapsed_time


def evaluate(model, loader, device):
    model.eval()
    total_loss = 0.0
    total_correct = 0
    total_samples = 0
    start_time = time.time()


    with torch.no_grad():
        for x, y in loader:
            x = x.to(device)
            y = y.to(device)

            logits = model(x)
            loss = F.cross_entropy(logits, y)

            total_loss += loss.item() * x.size(0)
            preds = logits.argmax(dim=-1)
            total_correct += (preds == y).sum().item()
            total_samples += x.size(0)

    end_time = time.time()
    elapsed_time = end_time - start_time

    return total_loss / total_samples, total_correct / total_samples, elapsed_time


# ================================================================
# CSV logging helper (appends without overwriting existing data)
# ================================================================

def log_results(filepath, row_dict):
    file_exists = os.path.isfile(filepath)
    with open(filepath, "a", newline="") as f:
        writer = csv.DictWriter(f, fieldnames=row_dict.keys())
        if not file_exists:
            writer.writeheader()
        writer.writerow(row_dict)

# ================================================================
# RUN EXAMPLE
# ================================================================

def run_training_RPE_maxperf(dataset="cifar10"):
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    print(f"Using device: {device}")

    epochs = 60 if dataset == "cifar10" else 10

    if dataset == "cifar10":
        train_loader, test_loader = get_cifar10_loaders(128)
        cfg = ViTConfig(
            img_size=32,
            patch_size=4,
            in_channels=3,
            num_classes=10,
            embed_dim=64,
            depth=4,
            num_heads=4,
            mlp_ratio=4.0,
            attn_type="favor+",
            nb_features=64,
            rpe_type="rope",
            drop=0.1,
            attn_drop=0.1,
        )
    else:
        train_loader, test_loader = get_mnist_loaders(128)
        cfg = ViTConfig(
            img_size=32,
            patch_size=4,
            in_channels=1,
            num_classes=10,
            embed_dim=64,
            depth=4,
            num_heads=4,
            mlp_ratio=4.0,
            attn_type="favor+",
            nb_features=64,
            rpe_type="rope",
            drop=0.1,
            attn_drop=0.1,
        )

    model = build_model(cfg).to(device)

    optimizer = torch.optim.AdamW(
        model.parameters(),
        lr=3e-4,
        weight_decay=0.05
    )

    scheduler = CosineAnnealingLR(optimizer, T_max=epochs)
    warmup_epochs = 5

    for epoch in range(1, epochs + 1):

        if epoch <= warmup_epochs:
            lr_scale = epoch / warmup_epochs
            for pg in optimizer.param_groups:
                pg["lr"] = lr_scale * 3e-4

        train_loss, train_acc, _ = train_one_epoch(
            model, train_loader, optimizer, device
        )
        test_loss, test_acc, _ = evaluate(
            model, test_loader, device
        )

        scheduler.step()

        print(
            f"[{dataset}] Epoch {epoch:03d} | "
            f"train acc={train_acc:.4f} | test acc={test_acc:.4f}"
        )


        log_results(
            RESULTS_FILE_MAX,
            {
                "dataset": dataset,
                "attn_type": attn_type,
                "rpe_type": rpe_type,
                "epochs_total": epochs,
                "epoch": epoch,
                "lr": lr,
                "batch_size": batch_size,
                "train_loss": train_loss,
                "train_acc": train_acc,
                "train_time_sec": train_time,
                "test_loss": test_loss,
                "test_acc": test_acc,
                "eval_time_sec": eval_time,
            },
        )

datasets = ["cifar10"]
attn_types = ["favor+"]
rpe_types_full = ["none"]
rpe_types_performer = ["rope"]

EPOCHS_MNIST = 10
EPOCHS_CIFAR = 60

for dataset in datasets:
    for attn in attn_types:
            to_run = rpe_types_performer
            for rpe in to_run:
                epochs = EPOCHS_MNIST if dataset == "mnist" else EPOCHS_CIFAR

                print("\n==============================================")
                print(
                    f" RUN: dataset={dataset} | attn={attn} | rpe={rpe} | epochs={epochs}"
                )
                print("==============================================\n")

                run_training_RPE(
                    dataset=dataset,
                    attn_type=attn,
                    rpe_type=rpe,
                    epochs=epochs,
                    lr=1e-3,
                    batch_size=128,
                )


 RUN: dataset=cifar10 | attn=favor+ | rpe=rope | epochs=60

Using device: cuda
[cifar10][attn=favor+][rpe=rope] Epoch 01: train loss=1.7617, train acc=0.3482, test loss=1.5459, test acc=0.4360
[cifar10][attn=favor+][rpe=rope] Epoch 02: train loss=1.5084, train acc=0.4496, test loss=1.4172, test acc=0.4897
[cifar10][attn=favor+][rpe=rope] Epoch 03: train loss=1.4252, train acc=0.4819, test loss=1.3808, test acc=0.4993
[cifar10][attn=favor+][rpe=rope] Epoch 04: train loss=1.3712, train acc=0.5050, test loss=1.3463, test acc=0.5136
[cifar10][attn=favor+][rpe=rope] Epoch 05: train loss=1.3258, train acc=0.5202, test loss=1.2983, test acc=0.5333
[cifar10][attn=favor+][rpe=rope] Epoch 06: train loss=1.2918, train acc=0.5352, test loss=1.2651, test acc=0.5410
[cifar10][attn=favor+][rpe=rope] Epoch 07: train loss=1.2570, train acc=0.5462, test loss=1.2244, test acc=0.5595
[cifar10][attn=favor+][rpe=rope] Epoch 08: train loss=1.2284, train acc=0.5567, test loss=1.2057, test acc=0.5674
[cifar10