In [25]:
"""
1. Imports
"""

import os
import math
import time
from dataclasses import dataclass
from typing import Optional, List

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader
from torchvision import datasets, transforms

torch.backends.cudnn.benchmark = True

In [26]:
"""
2. Configs
"""

@dataclass
class Config:
    dataset: str = "CIFAR10"  # "MNIST" or "CIFAR10"
    data_root: str = "./data"
    batch_size: int = 128
    img_size: int = 32
    patch_size: int = 4
    num_classes: int = 10

    dim: int = 256          # model dimension
    num_heads: int = 4
    depth: int = 6          # number of Transformer layers
    dropout: float = 0.1

    max_epochs: int = 50
    lr: float = 3e-4
    weight_decay: float = 1e-4

    # checkpoint
    ckpt_dir: str = "./checkpoints"
    save_every: int = 1     # save every N epochs

    # positional encoding
    max_seq_len: int = 256
    device: str = "cuda" if torch.cuda.is_available() else "cpu"

cfg = Config()

os.makedirs(cfg.ckpt_dir, exist_ok=True)

from google.colab import drive
drive.mount('/content/drive')
cfg.ckpt_dir = "/content/drive/MyDrive/E6617_final_project/models"

Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).


In [27]:
"""
3. Data: MNIST / CIFAR-10
"""

def get_dataloaders(cfg: Config):
    if cfg.dataset.upper() == "MNIST":
        transform = transforms.Compose([
            transforms.Resize(cfg.img_size),
            transforms.ToTensor(),
            # MNIST is 1-ch; expand to 3-ch
            transforms.Lambda(lambda x: x.expand(3, -1, -1)),
            transforms.Normalize((0.5, 0.5, 0.5),
                                 (0.5, 0.5, 0.5)),
        ])
        train_dataset = datasets.MNIST(
            root=cfg.data_root, train=True, download=True, transform=transform
        )
        test_dataset = datasets.MNIST(
            root=cfg.data_root, train=False, download=True, transform=transform
        )
    elif cfg.dataset.upper() == "CIFAR10":
        train_transform = transforms.Compose([
            transforms.RandomResizedCrop(cfg.img_size, scale=(0.8, 1.0)),
            transforms.RandomHorizontalFlip(),
            transforms.AutoAugment(transforms.AutoAugmentPolicy.CIFAR10),
            transforms.ToTensor(),
            transforms.Normalize((0.5, 0.5, 0.5),
                                 (0.5, 0.5, 0.5)),
        ])

        test_transform = transforms.Compose([
            transforms.Resize(cfg.img_size),
            transforms.CenterCrop(cfg.img_size),
            transforms.ToTensor(),
            transforms.Normalize((0.5, 0.5, 0.5),
                                 (0.5, 0.5, 0.5)),
        ])

        train_dataset = datasets.CIFAR10(
            root=cfg.data_root, train=True, download=True,
            transform=train_transform
        )
        test_dataset = datasets.CIFAR10(
            root=cfg.data_root, train=False, download=True,
            transform=test_transform
        )
    else:
        raise ValueError("Unsupported dataset")

    train_loader = DataLoader(train_dataset, batch_size=cfg.batch_size,
                              shuffle=True, num_workers=2, pin_memory=True,
                              persistent_workers=True)
    test_loader = DataLoader(test_dataset, batch_size=cfg.batch_size,
                             shuffle=False, num_workers=2, pin_memory=True,
                             persistent_workers=True)
    return train_loader, test_loader

In [28]:
"""
4. Positional Encoding Utilities
>> 4.1 RoPE (Rotary)
>> 4.2 Cayley-STRING / MG-STRING
"""

# 4.1 RoPE

def build_rope_freqs(head_dim: int, max_seq_len: int, device):
    half_dim = head_dim // 2
    theta = torch.arange(half_dim, device=device, dtype=torch.float32)
    theta = 1.0 / (10000 ** (theta / half_dim))
    positions = torch.arange(max_seq_len, device=device, dtype=torch.float32)
    freqs = torch.einsum('p,f->pf', positions, theta)  # [max_seq_len, half_dim]
    cos = freqs.cos()
    sin = freqs.sin()
    return cos, sin  # [max_seq_len, half_dim]


def apply_rope(q, k, cos, sin, seq_len):
    # q,k: [B, H, N, D], D even
    # cos,sin: [max_seq_len, D/2];
    B, H, N, D = q.shape
    assert D % 2 == 0
    cos = cos[:seq_len, :].unsqueeze(0).unsqueeze(0)  # [1,1,N,D/2]
    sin = sin[:seq_len, :].unsqueeze(0).unsqueeze(0)  # [1,1,N,D/2]

    def rotate(x):
        x1, x2 = x[..., :D//2], x[..., D//2:]
        # rotary: (x1, x2) -> (x1*cos - x2*sin, x2*cos + x1*sin)
        x1_new = x1 * cos - x2 * sin
        x2_new = x2 * cos + x1 * sin
        return torch.cat([x1_new, x2_new], dim=-1)

    return rotate(q), rotate(k)

In [29]:
# 4.2 Cayley-based PE (single & multi-generator)

class CayleyPE(nn.Module):
    """
    Implements A(p) = sum_l rho_l(p) S_l  with S_l skew-symmetric,
    then R(p) = Cayley(A(p)) = (I - A)(I + A)^{-1}, applied per position.
    - L = 1: Cayley-STRING style
    - L > 1: multi-generator (MG-STRING)
    """
    def __init__(
        self,
        head_dim: int,
        max_seq_len: int,
        num_generators: int = 1,
        sparse: bool = False,
        sparsity_band: int | None = None
    ):
        super().__init__()
        self.head_dim = head_dim
        self.max_seq_len = max_seq_len
        self.num_generators = num_generators

        self.register_buffer(
            "I",
            torch.eye(head_dim, dtype=torch.float32),
            persistent=False,
        )

        # unconstrained parameters that become skew-symmetric
        self.L_list = nn.ParameterList([
            nn.Parameter(torch.randn(head_dim, head_dim) * 1e-4)
            for _ in range(num_generators)
        ])

        # simple learnable scalar per generator
        self.rho_scale = nn.Parameter(torch.ones(num_generators))

        self.sparse = sparse
        D = head_dim

        if self.sparse:
            idx = torch.arange(D)

            if sparsity_band is None:
                band = max(1, D // 4)
            else:
                band = max(1, sparsity_band)

            mask = (idx[None, :] - idx[:, None]).abs() <= band
            mask = mask.to(torch.float32)

            mask.fill_diagonal_(0.0)

            self.register_buffer(
                "sparsity_mask",
                mask,
                persistent=False,
            )
        else:
            # dense
            self.register_buffer(
                "sparsity_mask",
                torch.ones(D, D),
                persistent=False,
            )

    def forward(self, q, k, pos_ids):
        """
        q, k: [B, H, N, D]
        pos_ids: [N] or [B,N] integer positions in [0, max_seq_len)
        """
        device = q.device
        B, H, N, D = q.shape
        assert D == self.head_dim

        # normalize pos to [0,1]
        if pos_ids.dim() == 2:
            # assume [B,N], but we only support same positions per batch
            pos_ids = pos_ids[0]
        pos_float = pos_ids.to(q.dtype) / float(self.max_seq_len)        # [N]

        S_list: list[torch.Tensor] = []
        for Lmat in self.L_list:
            S_dense = Lmat - Lmat.T                 # [D, D], skew-symmetric
            S_sparse = S_dense * self.sparsity_mask # apply fixed sparsity pattern
            # Normalize S_l
            norm = S_sparse.norm(p='fro') + 1e-6
            S_sparse = S_sparse / norm
            S_list.append(S_sparse)

        # S: [L, D, D]
        S = torch.stack(S_list, dim=0)

        # rho_l(p) = w_l * p_norm
        rho_l = torch.clamp(self.rho_scale, -2.0, 2.0)
        rho = rho_l[None, :, None, None] * pos_float[:, None, None, None]  # [N,L,1,1]
        A = (rho * S[None, ...]).sum(dim=1)                              # [N,D,D]

        I = self.I.to(dtype=q.dtype, device=device)
        I_plus  = I + A                                                  # [N,D,D]
        I_minus = I - A                                                  # [N,D,D]

        epsI = 1e-3 * I
        R_stack = torch.linalg.solve(I_plus + epsI, I_minus)

        # Apply R(p) to each position's q,k: new_q[p] = R_p @ q[p]
        q_rot = torch.einsum('pij,bhpj->bhpi', R_stack, q)
        k_rot = torch.einsum('pij,bhpj->bhpi', R_stack, k)
        return q_rot, k_rot


In [30]:
"""
5. Model: Patch Embedding + Transformer + Head
"""

class PatchEmbed(nn.Module):
    def __init__(self, img_size, patch_size, in_chans=3, embed_dim=256):
        super().__init__()
        self.img_size = img_size
        self.patch_size = patch_size
        self.grid_size = img_size // patch_size
        self.num_patches = self.grid_size * self.grid_size

        self.proj = nn.Conv2d(in_chans, embed_dim,
                              kernel_size=patch_size,
                              stride=patch_size)

    def forward(self, x):
        # x: [B,3,H,W] -> [B, N, D]
        x = self.proj(x)  # [B,D,H',W']
        x = x.flatten(2).transpose(1, 2)  # [B, N, D]
        return x


class MultiHeadSelfAttention(nn.Module):
    def __init__(self, dim, num_heads,
                 pe_type: str = "rope",
                 max_seq_len: int = 256,
                 num_generators: int = 1,
                 sparsity_band: int | None = None):
        super().__init__()
        assert dim % num_heads == 0
        self.dim = dim
        self.num_heads = num_heads
        self.head_dim = dim // num_heads
        self.scale = self.head_dim ** -0.5

        self.qkv = nn.Linear(dim, dim * 3)
        self.proj = nn.Linear(dim, dim)
        self.pe_type = pe_type
        self.max_seq_len = max_seq_len
        self.sparsity_band = sparsity_band

        self.register_buffer(
            "pos_ids_buffer",
            torch.arange(max_seq_len, dtype=torch.long),
            persistent=False,
        )

        # Use for "rope" and as baseline for others
        if pe_type in ["rope", "cayley", "cayley_sparse", "mg", "mg_sparse"]:
            cos, sin = build_rope_freqs(self.head_dim, max_seq_len,
                                        device=cfg.device)
            self.register_buffer("rope_cos", cos, persistent=False)
            self.register_buffer("rope_sin", sin, persistent=False)
        else:
            self.register_buffer("rope_cos", None, persistent=False)
            self.register_buffer("rope_sin", None, persistent=False)

        # Cayley or MG-STRING
        if pe_type in ["cayley", "mg", "cayley_sparse", "mg_sparse"]:
            gens = num_generators if pe_type in ["mg", "mg_sparse"] else 1
            use_sparse = pe_type in ["cayley_sparse", "mg_sparse"]
            self.cayley_pe = CayleyPE(
                self.head_dim,
                max_seq_len,
                gens,
                sparse=use_sparse,
                sparsity_band=self.sparsity_band,
            )
        else:
            self.cayley_pe = None

        if pe_type in ["mg", "mg_sparse"]:
            # small positive init so it starts close to RoPE
            init_alpha = 0.23
            init_logit = torch.log(torch.tensor(init_alpha / (1 - init_alpha)))
            self.mg_alpha = nn.Parameter(init_logit)
        else:
            self.mg_alpha = None

    def forward(self, x):
        # x: [B, N, D]
        B, N, D = x.shape

        qkv = self.qkv(x)  # [B, N, 3*D]
        qkv = qkv.reshape(B, N, 3, self.num_heads, self.head_dim)
        qkv = qkv.permute(2, 0, 3, 1, 4)  # [3,B,H,N,D]
        q, k, v = qkv[0], qkv[1], qkv[2]  # [B,H,N,D]

        if self.pe_type in ["rope", "cayley", "cayley_sparse", "mg", "mg_sparse"]:
            q, k = apply_rope(q, k, self.rope_cos, self.rope_sin, N)

        # After q,k have RoPE applied
        if self.pe_type in ["cayley", "mg", "cayley_sparse", "mg_sparse"]:
            # 1) Cayley-STRING base with position-independent P
            pos_ids_const = self.pos_ids_buffer.new_full(
                (N,),
                self.max_seq_len - 1,
                dtype=self.pos_ids_buffer.dtype,
            )
            q_base, k_base = self.cayley_pe(q, k, pos_ids_const)  # Cayley-STRING

            if self.pe_type in ["cayley", "cayley_sparse"]:
                # used as is
                q, k = q_base, k_base
            else:
                # 2) MG part: position-dependent multi-generator Cayley
                pos_ids_full = self.pos_ids_buffer[:N]
                q_mg, k_mg = self.cayley_pe(q, k, pos_ids_full)

                # 3) MG-STRING = Cayley-STRING + residual towards position-dependent version
                alpha = torch.sigmoid(self.mg_alpha)  # (0,1)

                q = q_base + alpha * (q_mg - q_base)
                k = k_base + alpha * (k_mg - k_base)

        attn_scores = (q @ k.transpose(-2, -1)) * self.scale  # [B,H,N,N]
        attn = attn_scores.softmax(dim=-1)
        out = attn @ v  # [B,H,N,D]

        out = out.transpose(1, 2).reshape(B, N, D)
        out = self.proj(out)
        return out


class TransformerBlock(nn.Module):
    def __init__(self, dim, num_heads, mlp_ratio=4.0, dropout=0.1,
                 pe_type="rope", max_seq_len=256, num_generators=1):
        super().__init__()
        self.norm1 = nn.LayerNorm(dim)
        self.attn = MultiHeadSelfAttention(
            dim, num_heads, pe_type=pe_type,
            max_seq_len=max_seq_len,
            num_generators=num_generators,
        )
        self.norm2 = nn.LayerNorm(dim)
        hidden_dim = int(dim * mlp_ratio)
        self.mlp = nn.Sequential(
            nn.Linear(dim, hidden_dim),
            nn.GELU(),
            nn.Dropout(dropout),
            nn.Linear(hidden_dim, dim),
            nn.Dropout(dropout),
        )

    def forward(self, x):
        x = x + self.attn(self.norm1(x))
        x = x + self.mlp(self.norm2(x))
        return x


class ViTWithPE(nn.Module):
    def __init__(self, cfg: Config,
                 pe_type: str = "rope",
                 num_generators: int = 1):
        super().__init__()
        self.cfg = cfg
        self.pe_type = pe_type
        self.num_generators = num_generators

        self.patch_embed = PatchEmbed(
            img_size=cfg.img_size,
            patch_size=cfg.patch_size,
            in_chans=3,
            embed_dim=cfg.dim,
        )
        self.num_patches = self.patch_embed.num_patches

        self.cls_token = nn.Parameter(torch.zeros(1, 1, cfg.dim))
        self.pos_drop = nn.Dropout(cfg.dropout)

        self.blocks = nn.ModuleList([
            TransformerBlock(
                dim=cfg.dim,
                num_heads=cfg.num_heads,
                mlp_ratio=4.0,
                dropout=cfg.dropout,
                pe_type=pe_type,
                max_seq_len=self.num_patches + 1,
                num_generators=num_generators,
            )
            for _ in range(cfg.depth)
        ])

        self.norm = nn.LayerNorm(cfg.dim)
        self.head = nn.Linear(cfg.dim, cfg.num_classes)

    def forward(self, x):
        # x: [B,3,H,W]
        B = x.size(0)
        x = self.patch_embed(x)  # [B, N, D]
        N = x.size(1)

        cls_tokens = self.cls_token.expand(B, -1, -1)  # [B,1,D]
        x = torch.cat([cls_tokens, x], dim=1)  # [B, N+1, D]

        x = self.pos_drop(x)
        for blk in self.blocks:
            x = blk(x)

        x = self.norm(x)
        cls_out = x[:, 0]  # [B,D]
        logits = self.head(cls_out)
        return logits



In [31]:
"""
6. Training / Evaluation / Checkpoint Utils
"""

def save_checkpoint(path, model, optimizer, epoch, best_acc):
    state = {
        "model_state": model.state_dict(),
        "optim_state": optimizer.state_dict(),
        "epoch": epoch,
        "best_acc": best_acc,
    }
    torch.save(state, path)
    print(f"[Checkpoint] Saved at epoch {epoch} to {path}")


def load_checkpoint(path, model, optimizer, device, strict_model=False, load_optim=False):
    ckpt = torch.load(path, map_location=device)

    # 1) Load model
    try:
        msg = model.load_state_dict(ckpt["model_state"], strict=strict_model)
        if not strict_model:
            print("[Checkpoint] Loaded with strict=False")
            print("    Missing / unexpected keys:", msg)
    except RuntimeError as e:
        print("[Checkpoint] Model state_dict mismatch:")
        print("   ", e)
        print("[Checkpoint] Starting from scratch for model.")
        return 0, 0.0

    start_epoch = ckpt.get("epoch", -1) + 1
    best_acc = ckpt.get("best_acc", 0.0)

    # 2) Load optimizer
    if load_optim:
        try:
            optimizer.load_state_dict(ckpt["optim_state"])
        except Exception as e:
            print("[Checkpoint] Optimizer state mismatch, reinitializing optimizer.")
            print("   ", e)
            start_epoch = 0
            best_acc = 0.0

    return start_epoch, best_acc


def evaluate(model, loader, device):
    model.eval()
    correct = 0
    total = 0
    loss_sum = 0.0
    criterion = nn.CrossEntropyLoss()
    with torch.no_grad():
        for images, labels in loader:
            with torch.amp.autocast("cuda", enabled=(device == "cuda")):
                images = images.to(device)
                labels = labels.to(device)
                logits = model(images)
                loss = criterion(logits, labels)
                loss_sum += loss.item() * labels.size(0)
                preds = logits.argmax(dim=-1)
                correct += (preds == labels).sum().item()
                total += labels.size(0)
    return loss_sum / total, correct / total


# Relative-Position Metrics
def compute_relative_position_metrics(cayley_pe, max_positions=64, n_attn_samples=500):
    """
    Computes:
      - relative rotation error (mean, max)
      - commutator norms (mean, max) if L>1
      - orthogonality error (mean, max)
      - smoothness (mean)
      - attention relative error (mean)
    """
    device = cayley_pe.rho_scale.device
    D = cayley_pe.head_dim
    N = min(max_positions, cayley_pe.max_seq_len)

    S_list = []
    for Lmat in cayley_pe.L_list:
        S = (Lmat - Lmat.T).detach()
        if hasattr(cayley_pe, "sparsity_mask") and cayley_pe.sparsity_mask is not None:
            S = S * cayley_pe.sparsity_mask.detach()
        S_list.append(S)

    pos_ids = torch.arange(N, device=device, dtype=torch.float32)
    pos_norm = pos_ids / float(cayley_pe.max_seq_len)
    I = torch.eye(D, device=device)
    R_list = []
    for pi in pos_norm:
        if cayley_pe.num_generators == 1:
            A_p = cayley_pe.rho_scale[0] * pi * S_list[0]
        else:
            A_p = torch.zeros_like(S_list[0])
            for l, S in enumerate(S_list):
                A_p = A_p + cayley_pe.rho_scale[l] * pi * S
        R_p = torch.linalg.solve(I + A_p, I - A_p)
        R_list.append(R_p.detach())
    R_list = torch.stack(R_list, dim=0)  # [N,D,D]

    # Metric 1: Relative rotation error
    rel_errors = []
    for i in range(N):
        for j in range(i + 1, N):
            delta = j - i
            if delta >= N:
                continue
            R_rel_ideal = R_list[delta]           # R(p_j - p_i)
            R_rel_actual = R_list[i].T @ R_list[j]
            rel = torch.norm(R_rel_actual - R_rel_ideal, p='fro').item()
            rel_errors.append(rel)
    rel_mean = float(sum(rel_errors) / len(rel_errors)) if rel_errors else None
    rel_max = float(max(rel_errors)) if rel_errors else None

    # Metric 2: Commutator norms (Abelianity)
    comm_vals = []
    Lg = len(S_list)
    if Lg > 1:
        for i in range(Lg):
            for j in range(i + 1, Lg):
                comm = S_list[i] @ S_list[j] - S_list[j] @ S_list[i]
                comm_vals.append(torch.norm(comm, p='fro').item())
        comm_mean = float(sum(comm_vals) / len(comm_vals))
        comm_max = float(max(comm_vals))
    else:
        comm_mean = None
        comm_max = None

    # Metric 3: Orthogonality error
    ortho_vals = []
    for R in R_list:
        ortho_vals.append(torch.norm(R.T @ R - I, p='fro').item())
    ortho_mean = float(sum(ortho_vals) / len(ortho_vals))
    ortho_max = float(max(ortho_vals))

    # Metric 4: Smoothness (adjacent rotation distance)
    smooth_vals = []
    for i in range(N - 1):
        smooth_vals.append(torch.norm(R_list[i + 1] - R_list[i], p='fro').item())
    smooth_mean = float(sum(smooth_vals) / len(smooth_vals)) if smooth_vals else None

    # Metric 5: Attention-relative error
    attn_errors = []
    for _ in range(n_attn_samples):
        q = torch.randn(D, device=device)
        k = torch.randn(D, device=device)
        i = torch.randint(0, N - 1, (1,)).item()
        j = torch.randint(i + 1, N, (1,)).item()
        actual = (R_list[i] @ q).dot(R_list[j] @ k)
        ideal = (R_list[j - i] @ q).dot(k)
        attn_errors.append(float(abs(actual - ideal).item()))
    attn_mean = float(sum(attn_errors) / len(attn_errors))

    metrics = {
        "rel_rot_error_mean": rel_mean,
        "rel_rot_error_max": rel_max,
        "comm_mean": comm_mean,
        "comm_max": comm_max,
        "ortho_error_mean": ortho_mean,
        "ortho_error_max": ortho_max,
        "smoothness_mean": smooth_mean,
        "attn_rel_error_mean": attn_mean,
    }
    return metrics


def train_one_experiment(cfg: Config,
                         pe_type: str,
                         num_generators: int,
                         exp_name: str,
                         strict_model=True,
                         load_optim=False):
    device = cfg.device
    train_loader, test_loader = get_dataloaders(cfg)

    model = ViTWithPE(cfg, pe_type=pe_type, num_generators=num_generators).to(device)

    if "mg" in pe_type:
        base_params = []
        mg_params = []
        for name, p in model.named_parameters():
            if ("cayley_pe" in name) or ("mg_alpha" in name):
                mg_params.append(p)
            else:
                base_params.append(p)

        optimizer = torch.optim.AdamW(
            [
                {
                    "params": base_params,
                    "lr": cfg.lr,
                    "weight_decay": cfg.weight_decay,
                    "betas": (0.9, 0.95),
                },
                {
                    "params": mg_params,
                    "lr": cfg.lr * 1.5,
                    "weight_decay": 0.0,
                    "betas": (0.9, 0.95),
                },
            ]
        )
    else:
        optimizer = torch.optim.AdamW(
            model.parameters(),
            lr=cfg.lr,
            weight_decay=cfg.weight_decay,
            betas=(0.9, 0.95),
        )
    criterion = nn.CrossEntropyLoss(label_smoothing=0.1)

    scaler = torch.amp.GradScaler("cuda", enabled=(device == "cuda"))

    ckpt_path = os.path.join(cfg.ckpt_dir, f"{exp_name}.pt")
    start_epoch = 0
    best_acc = 0.0

    if os.path.exists(ckpt_path):
        start_epoch, best_acc = load_checkpoint(ckpt_path, model, optimizer, device, strict_model, load_optim)

    # lr scheduling
    steps_per_epoch = len(train_loader)
    total_steps = steps_per_epoch * cfg.max_epochs
    warmup_steps = int(0.05 * total_steps)

    def lr_lambda(step: int):
        if step < warmup_steps:
            return float(step + 1) / float(max(1, warmup_steps))
        # cosine decay from 1 -> 0
        progress = (step - warmup_steps) / float(max(1, total_steps - warmup_steps))
        return 0.5 * (1.0 + math.cos(math.pi * progress))

    scheduler = torch.optim.lr_scheduler.LambdaLR(
        optimizer, lr_lambda=lr_lambda
    )

    train_start = time.perf_counter()

    for epoch in range(start_epoch, cfg.max_epochs):
        model.train()
        running_loss = 0.0
        total = 0
        correct = 0

        for images, labels in train_loader:
            images = images.to(device, non_blocking=True)
            labels = labels.to(device, non_blocking=True)

            optimizer.zero_grad(set_to_none=True)

            with torch.amp.autocast("cuda", enabled=(device == "cuda")):
                logits = model(images)
                loss = criterion(logits, labels)

            scaler.scale(loss).backward()
            scaler.step(optimizer)
            scaler.update()
            scheduler.step()

            batch_size = labels.size(0)
            running_loss += loss.item() * batch_size
            total += batch_size
            preds = logits.argmax(dim=-1)
            correct += (preds == labels).sum().item()

        train_loss = running_loss / total
        train_acc = correct / total
        val_loss, val_acc = evaluate(model, test_loader, device)

        print(f"[{exp_name}] Epoch {epoch+1}/{cfg.max_epochs} "
              f"Train Loss: {train_loss:.4f} Acc: {train_acc:.4f} | "
              f"Val Loss: {val_loss:.4f} Acc: {val_acc:.4f}")

        if getattr(model.blocks[0].attn, "mg_alpha", None) is not None:
            alpha = torch.sigmoid(model.blocks[0].attn.mg_alpha).item()
            print(f"[MG alpha] {alpha:.4f}")

        if val_acc > best_acc:
            best_acc = val_acc
            save_checkpoint(ckpt_path, model, optimizer, epoch, best_acc)

        elif (epoch + 1) % cfg.save_every == 0:
            save_checkpoint(ckpt_path, model, optimizer, epoch, best_acc)

    train_end = time.perf_counter()
    total_train_time = train_end - train_start

    print(f"[{exp_name}] Best Val Acc: {best_acc:.4f}")

    model.eval()
    infer_start = time.perf_counter()
    with torch.no_grad():
        for images, labels in test_loader:
            images = images.to(device)
            _ = model(images)
    infer_end = time.perf_counter()
    total_infer_time = infer_end - infer_start

    print(f"[{exp_name}] Training time (s): {total_train_time:.2f}")
    print(f"[{exp_name}] Inference time (s, 1x test loader): {total_infer_time:.2f}")

    if pe_type in ("cayley", "mg", "cayley_sparse", "mg_sparse"):
        with torch.no_grad():
            cayley_pe = None
            for blk in model.blocks:
                if hasattr(blk.attn, "cayley_pe") and blk.attn.cayley_pe is not None:
                    cayley_pe = blk.attn.cayley_pe
                    break
            if cayley_pe is not None:
                metrics = compute_relative_position_metrics(
                    cayley_pe,
                    max_positions=min(64, model.num_patches + 1),
                    n_attn_samples=500,
                )
                print(f"[{exp_name}] Relative-position metrics:")
                for k, v in metrics.items():
                    if v is not None:
                        print(f"   {k}: {v:.6e}")
                    else:
                        print(f"   {k}: None (not applicable)")


In [32]:
"""
7. Main: run experiments
"""

EXPERIMENTS = [
    {"pe_type": "rope",   "num_generators": 0, "name": f"{cfg.dataset}_rope"},
    {"pe_type": "cayley", "num_generators": 1, "name": f"{cfg.dataset}_cayley1"},
    {"pe_type": "mg",     "num_generators": 1, "name": f"{cfg.dataset}_mg1"},
    {"pe_type": "mg",     "num_generators": 4, "name": f"{cfg.dataset}_mg4"},
    {"pe_type": "cayley_sparse", "num_generators": 1, "name": f"{cfg.dataset}_cayley1_sparse"},
    {"pe_type": "mg_sparse",     "num_generators": 4, "name": f"{cfg.dataset}_mg4_sparse"}
]

strict_model = True
load_optim = False

for exp in EXPERIMENTS:
    print("-" * 60)
    print(f"Running experiment: {exp['name']} "
          f"(pe_type={exp['pe_type']}, L={exp['num_generators']})")
    print("-" * 60)

    train_one_experiment(
        cfg,
        pe_type=exp["pe_type"],
        num_generators=exp["num_generators"],
        exp_name=exp["name"],
        strict_model=strict_model,
        load_optim=load_optim
    )

------------------------------------------------------------
Running experiment: CIFAR10_rope (pe_type=rope, L=0)
------------------------------------------------------------
[CIFAR10_rope] Best Val Acc: 0.8204
[CIFAR10_rope] Training time (s): 0.00
[CIFAR10_rope] Inference time (s, 1x test loader): 1.51
------------------------------------------------------------
Running experiment: CIFAR10_cayley1 (pe_type=cayley, L=1)
------------------------------------------------------------
[CIFAR10_cayley1] Best Val Acc: 0.8602
[CIFAR10_cayley1] Training time (s): 0.00
[CIFAR10_cayley1] Inference time (s, 1x test loader): 1.75
[CIFAR10_cayley1] Relative-position metrics:
   rel_rot_error_mean: 7.760215e-02
   rel_rot_error_max: 2.704786e-01
   comm_mean: None (not applicable)
   comm_max: None (not applicable)
   ortho_error_mean: 3.048992e-06
   ortho_error_max: 3.848920e-06
   smoothness_mean: 7.500699e-02
   attn_rel_error_mean: 2.091826e+00
--------------------------------------------------