In [None]:
import torch
torch.manual_seed(1746)

B,H,T,S,N,P = 8,16,1000,1000,32,32
Q = torch.randn(B,T,H,N//H).transpose(1,2)
K = torch.randn(B,S,H,N//H).transpose(1,2)
V = torch.randn(B,S,H,P//H).transpose(1,2)
L = torch.tril(torch.ones(T,S)).view(1,1,T,S)

# Standard masked attention
G = Q @ K.transpose(-2,-1)
M = G * L
Y_standard = (M @ V).reshape(B,T,N)

# Single contraction
Y_contraction = torch.einsum('bhtn,bhsn,bhsp,bhts->bhtp', Q, K, V, L).reshape(B,T,N)
print(torch.allclose(Y_standard, Y_contraction, atol=1e-3))

# Step-by-step verification
G_contraction = torch.einsum('bhtn,bhsn->bhts', Q, K)
M_contraction = torch.einsum('bhts,bhts->bhts', G_contraction, L)
Y_step_contraction = torch.einsum('bhts,bhsp->bhtp', M_contraction, V).reshape(B,T,N)
print(torch.allclose(Y_standard, Y_step_contraction, atol=1e-3))

# Derivation for cumsum proof
Z = torch.einsum('bhsp,bhsn->bhspn', V, K)
H = torch.einsum('bhts,bhspn->bhtpn', L, Z)
Y_derivation = torch.einsum('bhtn,bhtpn->bhtp', Q, H).reshape(B,T,N)
print(torch.allclose(Y_standard, Y_derivation, atol=1e-3))

# Cumulative sum formulation
H_cumsum = torch.cumsum(Z, dim=2)
Y_cumsum = torch.einsum('bhtn,bhtpn->bhtp', Q, H_cumsum).reshape(B,T,N)
print(torch.allclose(Y_standard, Y_cumsum, atol=1e-3))

In [None]:
import torch.nn.functional as F

from torch.utils.data import TensorDataset


def generate_copy(
    num_examples: int = 10,
    num_categories: int = 10,
    copy_len: int = 10,
    blank_len: int = 10,
    selective: bool = False,
    one_hot: bool = True,
    seed: int = 1_337,
    dtype: torch.dtype = torch.bfloat16,
) -> TensorDataset:
    """
    Generate a copy task dataset inspired by Arjovsky, Shah, and Bengio (2016).

    Task Description:
    - Input sequence: [copy_sequence][pre_delim_blanks][delimiter][post_delim_blanks]
    - Output sequence: [blank_tokens][copy_sequence]

    The task requires remembering a categorical sequence for a variable number of time steps.

    Args:
        num_examples: Number of examples to generate.
        num_categories: Number of token categories.
            - Categories 0 to num_categories-3: Tokens to be copied.
            - Category num_categories-2: Blank token.
            - Category num_categories-1: Delimiter token.
        copy_len: Length of the sequence to be copied.
        blank_len: Number of blank tokens after the delimiter in the input sequence.
        selective: If True, inserts blank tokens between the tokens in the copied sequence (pre-delimiter).
        seed: Random seed for reproducibility.

    Returns:
        TensorDataset with:
            - inputs: Shape (num_examples, seq_len)
              where seq_len is:
              - copy_len + (num_categories-1) + blank_len + 1
            - targets: Shape (num_examples, num_categories + blank_len + copy_len)
              where the output consists of blank tokens followed by the copied sequence.

    Example:
        >>> dataset = generate_copy(num_examples=10, num_categories=10, copy_len=3, blank_len=10, selective=True)
        >>> inputs, targets = dataset[0]
        >>> print(inputs.shape, targets.shape)
        torch.Size([23]) torch.Size([23])

    Note:
        A memoryless baseline strategy would predict the blank token for the first
        (num_categories + blank_len) steps, then random tokens, yielding a categorical
        cross entropy of (copy_len * log(num_categories-2)) / (num_categories + blank_len + copy_len).
    """
    torch.manual_seed(seed)

    # Assign characters
    blank_char = num_categories - 2  # Reserve penultimate token for blank
    delim_char = num_categories - 1  # Reserve last token for delimiter

    # Construct input sequences
    to_copy = torch.randint(
        0, blank_char, # random sequence in [0, blank_char)
        (num_examples, copy_len) # num_examples many random sequences
    )
    pre_delim_blanks = torch.full(
        (num_examples, num_categories - 1), # num_examples many sequences of pre_delim blanks
        blank_char
    )
    delim = torch.full(
        (num_examples, 1),  # num_examples many delimiters
        delim_char
    )
    post_delim_blanks = torch.full(
        (num_examples, blank_len),  # "Remaining 10 entries are set to a_8"
        blank_char
    )

    if selective:
        # Selective case
        def insert_pre_delim_blanks(row):
            pre_delim_len = copy_len + num_categories - 1  # Length of the copied sequence + pre-delim blanks
            insert_positions = torch.randperm(pre_delim_len)[:num_categories - 1]  # Randomly select positions for blanks
            inserted_row = torch.full((pre_delim_len,), blank_char)  # Fill with blank_char
            mask = torch.ones(pre_delim_len, dtype=torch.bool)  # Mask to identify which positions to replace with to_copy
            mask[insert_positions] = False  # Positions for blanks
            inserted_row[mask] = row  # Insert the copied sequence where the mask allows
            return inserted_row

        inputs = torch.stack([insert_pre_delim_blanks(row) for row in to_copy])
    else:
        # Non-selective case
        inputs = torch.cat((to_copy, pre_delim_blanks), dim=1)

    # Add delimiter and post-delimiter blanks
    inputs = torch.cat((inputs, delim, post_delim_blanks), dim=1)

    # Construct output sequences
    blank_output = torch.full(
        (num_examples, num_categories + blank_len),
        blank_char
    )
    outputs = torch.cat((blank_output, to_copy), dim=1)

    if one_hot:
        inputs = F.one_hot(inputs, num_classes=num_categories).to(dtype)
        outputs = F.one_hot(outputs, num_classes=num_categories).to(dtype)

    return TensorDataset(inputs, outputs)


In [None]:
import torch

from torch.utils.data import TensorDataset


def generate_document_similarity(
    num_examples: int = 10,
    num_documents: int = 10,
    num_elements: int = 10,
    top_k: int = 2,
    seed: int = 1_337,
    dtype: torch.dtype = torch.bfloat16,
) -> TensorDataset:
    """
    Generate a dataset for the cosine similarity task. The goal is to find the
    pair of documents (tensors) with the highest cosine similarity.
    
    Alman and Yu (2024) claimed that no sub-quadratic model can solve this task:
    https://arxiv.org/abs/2410.04271.
    
    
    Args:
        num_examples (int): Number of examples (sets of documents) to generate.
        num_documents (int): Number of documents (tensors) in each example.
        num_elements (int): Number of elements in each document.
        top_k (int): Number of top similar document pairs to identify.
        seed (int): Random seed for reproducibility.
        dtype (torch.dtype): Data type for the tensors.
    
    Returns:
        TensorDataset:
            - Inputs: Shape (num_examples, num_documents, num_elements)
            - Targets: Shape (num_examples, top_k, 2) - the indices of the document pairs 
              with the highest cosine similarity.
    """
    torch.manual_seed(seed)

    # Validate parameters
    if top_k < 1:
        raise ValueError("top_k must be at least 1.")
    if num_documents < 2:
        raise ValueError("num_documents must be at least 2 to form pairs.")
    max_topk = num_documents * (num_documents - 1) // 2
    if top_k > max_topk:
        raise ValueError(f"top_k={top_k} exceeds the maximum number of unique document pairs ({max_topk}).")
    
    # Generate and normalize the inputs
    inputs = torch.randn((num_examples, num_documents, num_elements), dtype=dtype)
    normalized_inputs = F.normalize(inputs, p=2, dim=2)
    
    # Compute the cosine similarity between all pairs of documents
    cosine_similarity = normalized_inputs @ normalized_inputs.transpose(1, 2)
    
    # Get upper triangular indices (excluding diagonal)
    triu_indices = torch.triu_indices(num_documents, num_documents, offset=1)
    
    # Extract upper triangular similarities
    sim_pairs = cosine_similarity[:, triu_indices[0], triu_indices[1]]  # Shape: (num_examples, num_pairs)
    
    # Get top_k indices for each example
    topk_values, topk_indices = torch.topk(sim_pairs, top_k, dim=1, largest=True, sorted=True)
    
    # Map topk_indices to pair indices
    topk_pairs = triu_indices[:, topk_indices]  # Shape: (2, num_examples, top_k)
    targets = topk_pairs.permute(1, 2, 0)  # Shape: (num_examples, top_k, 2)

    return TensorDataset(inputs, targets)


In [None]:
def get_spectral_filters(
    seq_len: int,
    K: int,
    use_hankel_L: bool = False,
    device: torch.device = None,
    dtype: torch.dtype = torch.bfloat16,
) -> torch.Tensor:
    Z = get_hankel(seq_len, use_hankel_L).to(device)
    sigma, phi = torch.linalg.eigh(Z)
    sigma_k, phi_k = sigma[-K:], phi[:, -K:]
    phi_k *= sigma_k ** 0.25
    phi_k = phi_k.to(device=device, dtype=dtype)
    return phi_k

def get_hankel(seq_len: int, use_hankel_L: bool = False) -> torch.Tensor:
    entries = torch.arange(1, seq_len + 1, dtype=torch.float64)
    i_plus_j = entries[:, None] + entries[None, :]

    if use_hankel_L:
        sgn = (-1.0) ** (i_plus_j - 2.0) + 1.0
        denom = (i_plus_j + 3.0) * (i_plus_j - 1.0) * (i_plus_j + 1.0)
        Z = sgn * (8.0 / denom)
    elif not use_hankel_L:
        Z = 2.0 / (i_plus_j**3 - i_plus_j)
    else:
        raise ValueError("use_hankel_L must be a boolean")

    return Z

def compute_dimensions(n: int) -> tuple[int, int, int]:
    if n <= 2:
        raise ValueError("n must be greater than 2")

    T_prime = (math.ceil(math.sqrt(n - 2)))**2 + 2
    sqrt_T_prime = math.ceil(math.sqrt(T_prime - 2))
    k_max = sqrt_T_prime
    return T_prime, sqrt_T_prime, k_max

def get_tensorized_spectral_filters(
    n: int = 8192,
    k: int = 24,
    use_hankel_L: bool = False,
    device: torch.device = None,
    dtype: torch.dtype = torch.bfloat16,
) -> torch.Tensor:
    """
    Compute tensorized spectral filters for given sequence length and filter count.

    Args:
        n: Sequence length
        k: Number of filters
        use_hankel_L: Hankel_main ⊗ Hankel_L? Default is Hankel_main ⊗ Hankel_main.
        device: Computation device
        dtype: Computation dtype
    """
    T_prime, sqrt_T_prime, k_max = compute_dimensions(n)
    k = min(k, k_max)

    Z = get_hankel(sqrt_T_prime)
    sigma, phi = torch.linalg.eigh(Z)
    phi_i = phi[:, -k:] * sigma[-k:] ** 0.25

    if use_hankel_L: # TODO: We may want to use Hankel_L above too if use_hankel_L is true, make another variable for this (mix != use_hankel_L)
        Z_L = get_hankel(sqrt_T_prime, True)
        sigma_L, phi_L = torch.linalg.eigh(Z_L)
        phi_j = phi_L[:, -k:] * sigma_L[-k:] ** 0.25
    else:
        phi_j = phi_i

    filters = torch.kron(phi_i, phi_j)
    return filters.to(device=device, dtype=dtype)


In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F

class Attention(nn.Module):
    def __init__(self, n_embd, n_heads):
        super().__init__()
        assert n_embd % n_heads == 0, f"n_embd ({n_embd}) must be divisible by n_heads ({n_heads})"
        self.n_heads = n_heads
        self.head_dim = n_embd // n_heads
        self.wq = nn.Linear(n_embd, n_embd)
        self.wk = nn.Linear(n_embd, n_embd)
        self.wv = nn.Linear(n_embd, n_embd)
        self.o_proj = nn.Linear(n_embd, n_embd)

    def forward(self, x):
        bsz, seqlen, dim = x.shape
        Q = self.wq(x).view(bsz, seqlen, self.n_heads, self.head_dim).transpose(1, 2)  # (B, N, S, D)
        K = self.wk(x).view(bsz, seqlen, self.n_heads, self.head_dim).transpose(1, 2)  # (B, N, S, D)
        V = self.wv(x).view(bsz, seqlen, self.n_heads, self.head_dim).transpose(1, 2)  # (B, N, S, D)

        # causal mask
        mask = torch.triu(torch.ones(seqlen, seqlen, device=x.device), diagonal=1).bool()  # (S, S)

        # use scaled_dot_product_attention
        attn_output = F.scaled_dot_product_attention(
            Q, K, V,
            attn_mask=mask,
            dropout_p=0.0
        )  # (B, N, S, D)

        attn_output = attn_output.transpose(1, 2).contiguous().view(bsz, seqlen, dim)
        return self.o_proj(attn_output)

class MLP(nn.Module):
    def __init__(self, dim):
        super().__init__()
        self.hidden_dim = 4 * dim
        self.gate_proj = nn.Linear(dim, self.hidden_dim)
        self.up_proj = nn.Linear(dim, self.hidden_dim)
        self.down_proj = nn.Linear(self.hidden_dim, dim)
        self.dropout = nn.Dropout(0.1)

    def forward(self, x):
        gate = self.gate_proj(x)
        modulated_gate = F.silu(gate)
        up = self.up_proj(x)
        fuse = modulated_gate * up
        outputs = self.down_proj(fuse)
        outputs = self.dropout(outputs)
        return outputs

class AttentionLayer(nn.Module):
    def __init__(self, n_embd, n_heads):
        super().__init__()
        self.attn = Attention(n_embd, n_heads)
        self.mlp = MLP(n_embd)
        self.attn_norm = nn.LayerNorm(n_embd)
        self.mlp_norm = nn.LayerNorm(n_embd)

    def forward(self, x):
        x = x + self.attn(self.attn_norm(x))
        x = x + self.mlp(self.mlp_norm(x))
        return x

class Transformer(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.num_layers = config.num_layers
        self.tok_emb = nn.Embedding(config.vocab_size, config.n_embd)
        self.dropout = nn.Dropout(config.dropout)
        self.layers = nn.ModuleList([AttentionLayer(config.n_embd, config.n_heads) for _ in range(config.num_layers)])
        self.norm = nn.LayerNorm(config.n_embd)
        self.output = nn.Linear(config.n_embd, config.vocab_size)

    def forward(self, x):
        x = self.dropout(self.tok_emb(x))
        for layer in self.layers:
            x = layer(x)
        x = self.norm(x)
        return self.output(x)


In [None]:
import torch
import torch.nn as nn
import math

class LearnableSpectralFilters(nn.Module):
    def __init__(self, seq_len, k, use_hankel_L=False, device=None, dtype=torch.float32):
        super().__init__()
        self.seq_len = seq_len
        self.k = k
        self.use_hankel_L = use_hankel_L
        self.device = device

        filters = get_tensorized_spectral_filters(
            n=seq_len,
            k=k,
            use_hankel_L=use_hankel_L,
            device=device,
            dtype=dtype,
        )
        self.filters = nn.Parameter(filters).to(device)

    def forward(self):
        return self.filters

class SpectralAttention(nn.Module):
    """
    Implements the linear form of structured masked attention using spectral filters.
    According to the linear form:

    Z = contract(V,K) -> (B,S,i,j) with i,j=k
    H = contract(L,Z) -> (B,T,i,j)
    Y = contract(Q,H) -> (B,T,j)

    Then project Y back to (B,T,dim).

    We print shapes at each step for verification.
    """
    def __init__(self, seq_len, n_embd, k, use_hankel_L=False, device=None):
        super().__init__()
        self.seq_len = seq_len
        self.k = k

        self.Q_filt = LearnableSpectralFilters(seq_len, k, use_hankel_L, device).filters.transpose(0, 1)
        self.K_filt = LearnableSpectralFilters(seq_len, k, use_hankel_L, device).filters.transpose(0, 1)
        self.V_filt = LearnableSpectralFilters(seq_len, k, use_hankel_L, device).filters.transpose(0, 1)

        # i_proj maps input embedding to the filter dimension
        self.i_proj = nn.Linear(n_embd, self.Q_filt.shape[0])
        # o_proj maps back from filter dimension to embedding dimension
        self.o_proj = nn.Linear(self.Q_filt.shape[0], n_embd)

    def forward(self, x, L):
        # x: (B,T,dim)
        # L: (B,T,S) lower-triangular mask, typically T=S
        bsz, T, dim = x.shape
        S = self.seq_len

        # Project input to filter dimension space
        x_proj = self.i_proj(x) # (B,T,h)
        
        # Compute Q, K, V = (B,T,k)
        Q = torch.einsum("bth,hk->btk", x_proj, self.Q_filt)
        K = torch.einsum("bth,hk->btk", x_proj, self.K_filt)
        V = torch.einsum("bth,hk->btk", x_proj, self.V_filt)

        # print("Q shape:", Q.shape)  # (B,T,k)
        # print("K shape:", K.shape)  # (B,T,k)
        # print("V shape:", V.shape)  # (B,T,k)
        # print("L shape:", L.shape)  # (B,T,S)

        # Step 1: Z = (b,s,i,j)
        # Here s = T = S, i and j both = k
        Z = torch.einsum("bsi, bsj -> bsij", V, K)
        # print("Z shape:", Z.shape)  # (B,T,k,k)

        # Step 2: H = (b,t,i,j) by contracting L (B,T,S) with Z (B,S,i,j)
        H = torch.einsum("bts, bsij-> btij", L, Z)
        # print("H shape:", H.shape)  # (B,T,k,k)

        # Step 3: Y = (b,t,j) by contracting Q (b,t,i) with H (b,t,i,j)
        Y = torch.einsum("bti, btij -> btj", Q, H)
        # print("Y shape before projection:", Y.shape)  # (B,T,k)

        # Project back to (B,T,dim)
        Y = self.o_proj(Y)
        # print("Y shape after projection:", Y.shape)
        return Y

class SpectralAttentionLayer(nn.Module):
    def __init__(self, seq_len, n_embd, k, dropout=0.1, use_hankel_L=False, device=None):
        super().__init__()
        print('here', device)
        self.attn_norm = nn.LayerNorm(n_embd)
        self.mlp_norm = nn.LayerNorm(n_embd)
        self.attn = SpectralAttention(seq_len, n_embd, k, use_hankel_L, device)
        self.mlp = MLP(n_embd)

    def forward(self, x, L):
        x = x + self.attn(self.attn_norm(x), L)
        x = x + self.mlp(self.mlp_norm(x))
        return x

class SpectralTransformer(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.num_layers = config.num_layers
        self.tok_emb = nn.Embedding(config.vocab_size, config.n_embd)
        self.dropout = nn.Dropout(config.dropout)
        self.layers = nn.ModuleList([
            SpectralAttentionLayer(config.seq_len, config.n_embd, config.k, config.dropout, device=config.device)
            for _ in range(config.num_layers)
        ])
        self.norm = nn.LayerNorm(config.n_embd)
        self.output = nn.Linear(config.n_embd, config.vocab_size)

    def forward(self, x):
        bsz, seq_len = x.size()

        # Construct lower-triangular mask L: (B,T,S)
        L = torch.tril(torch.ones(seq_len, seq_len, device=x.device))
        L = L.unsqueeze(0).expand(bsz, -1, -1)  # (B,T,S)

        x = self.dropout(self.tok_emb(x))  # (B,T,embed)
        for layer in self.layers:
            x = layer(x, L)
        x = self.norm(x)
        return self.output(x)

class Config:
    def __init__(self, vocab_size, seq_len, n_embd, num_layers, k, dropout, device):
        self.vocab_size = vocab_size
        self.seq_len = seq_len
        self.n_embd = n_embd
        self.num_layers = num_layers
        self.k = k
        self.dropout = dropout
        self.device = device

# Example usage:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
config = Config(vocab_size=12, seq_len=128, n_embd=32, num_layers=2, k=24, dropout=0.0, device=device)
model = SpectralTransformer(config).to(device)
input_ids = torch.randint(0, config.vocab_size, (4, config.seq_len)).to(device)
output = model(input_ids)
print(output.shape)  # (4,64,12)


In [None]:
# -----------------------------
# Required Imports
# -----------------------------
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import TensorDataset, DataLoader, random_split
import math
import time
from tqdm import tqdm
import random
import numpy as np

# -----------------------------
# Utility Functions
# -----------------------------

def format_num(num: int) -> str:
    if num >= 1_000_000_000:
        return f"{num / 1_000_000_000:.2f}B"
    elif num >= 1_000_000:
        return f"{num / 1_000_000:.2f}M"
    elif num >= 1000:
        return f"{num / 1000:.2f}K"
    else:
        return str(num)

def count_non_embedding_params(model):
    total_params = 0
    non_embedding_params = 0
    for name, param in model.named_parameters():
        if param.requires_grad:
            num_params = param.numel()
            total_params += num_params
            if 'emb' not in name:
                non_embedding_params += num_params

    print(f"Total parameters: {format_num(total_params)}")
    print(f"Parameters (excluding embeddings): {format_num(non_embedding_params)}")

def set_seed(seed: int = 1746):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed_all(seed)

# Set the seed for reproducibility
set_seed(1746)

# -----------------------------
# Configuration
# -----------------------------

class Config:
    def __init__(self):
        # Dataset parameters
        self.num_examples = 30000
        self.num_categories = 12
        self.copy_len = 10
        self.blank_len = 10
        self.selective = False

        # Model parameters
        self.num_layers = 2
        self.n_heads = 8
        self.n_embd = 512
        self.vocab_size = self.num_categories
        self.seq_len = self.copy_len + (self.num_categories - 1) + 1 + self.blank_len
        self.k = 24
        self.use_hankel_L = False

        # Training parameters
        self.batch_size = 1
        self.lr = 1e-2
        self.num_epochs = 1

        # Others: set device
        self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        self.dropout = 0.0

config = Config()

print(f"Using device: {config.device}")

# -----------------------------
# Dataset Preparation
# -----------------------------

dataset = generate_copy(
    num_examples=config.num_examples,
    num_categories=config.num_categories,
    copy_len=config.copy_len,
    blank_len=config.blank_len,
    selective=config.selective,
    one_hot=False,
    seed=1746,
    dtype=torch.long,
)

train_size = int(0.8 * len(dataset))
test_size = len(dataset) - train_size
train_dataset, test_dataset = random_split(dataset, [train_size, test_size])

# Use pinned memory if on GPU for faster host-to-device transfers
pin_memory = True if config.device.type == 'cuda' else False

train_loader = DataLoader(train_dataset, batch_size=config.batch_size, shuffle=True, pin_memory=pin_memory)
test_loader = DataLoader(test_dataset, batch_size=config.batch_size, pin_memory=pin_memory)

# -----------------------------
# Model Initialization
# -----------------------------

transformer = Transformer(config).to(config.device)
spectral_transformer = SpectralTransformer(config).to(config.device)
count_non_embedding_params(transformer)
count_non_embedding_params(spectral_transformer)

# -----------------------------
# Loss and Optimizer
# -----------------------------

criterion = nn.CrossEntropyLoss()
optimizer_transformer = torch.optim.AdamW(transformer.parameters(), lr=config.lr)
optimizer_spectral = torch.optim.AdamW(spectral_transformer.parameters(), lr=config.lr)

# Optional: gradient clipping to improve stability
grad_clip = None

# -----------------------------
# Training and Evaluation Loops
# -----------------------------

def train(model, optimizer, loader, device, desc="Training"):
    model.train()
    total_loss = 0
    progress_bar = tqdm(loader, desc=desc, leave=False)
    for batch_idx, (inputs, targets) in enumerate(progress_bar):
        inputs, targets = inputs.to(device, non_blocking=True), targets.to(device, non_blocking=True)

        optimizer.zero_grad()
        outputs = model(inputs)
        # Reshape for loss
        outputs = outputs.view(-1, config.vocab_size)
        targets = targets.view(-1)

        loss = criterion(outputs, targets)
        loss.backward()

        if grad_clip is not None:
            torch.nn.utils.clip_grad_norm_(model.parameters(), grad_clip)

        optimizer.step()

        total_loss += loss.item()
        progress_bar.set_postfix({'Loss': loss.item(), 'LR': optimizer.param_groups[0]['lr']})

    avg_loss = total_loss / len(loader)
    return avg_loss

@torch.no_grad()
def evaluate(model, loader, device, desc="Evaluating"):
    model.eval()
    total_loss = 0
    correct = 0
    total = 0
    progress_bar = tqdm(loader, desc=desc, leave=False)
    for inputs, targets in progress_bar:
        inputs, targets = inputs.to(device, non_blocking=True), targets.to(device, non_blocking=True)

        outputs = model(inputs)
        outputs = outputs.view(-1, config.vocab_size)
        targets = targets.view(-1)

        loss = criterion(outputs, targets)
        total_loss += loss.item()

        _, predicted = torch.max(outputs, dim=1)
        correct += (predicted == targets).sum().item()
        total += targets.size(0)

    avg_loss = total_loss / len(loader)
    accuracy = correct / total
    return avg_loss, accuracy

# -----------------------------
# Training Execution
# -----------------------------

best_test_acc_t = 0.0
best_test_acc_s = 0.0

for epoch in range(1, config.num_epochs + 1):
    epoch_start_time = time.time()

    print(f"\nEpoch {epoch}/{config.num_epochs}")
    print("-" * 30)
    loss_train_t = train(transformer, optimizer_transformer, train_loader, config.device, desc="Training Transformer")
    loss_test_t, acc_test_t = evaluate(transformer, test_loader, config.device, desc="Evaluating Transformer")

    loss_train_s = train(spectral_transformer, optimizer_spectral, train_loader, config.device, desc="Training SpectralTransformer")
    loss_test_s, acc_test_s = evaluate(spectral_transformer, test_loader, config.device, desc="Evaluating SpectralTransformer")

    epoch_elapsed = time.time() - epoch_start_time

    # Track best accuracy
    if acc_test_t > best_test_acc_t:
        best_test_acc_t = acc_test_t
    if acc_test_s > best_test_acc_s:
        best_test_acc_s = acc_test_s

    print(f"Epoch {epoch} completed in {epoch_elapsed:.2f}s")
    print(f"  Transformer    | Train Loss: {loss_train_t:.4f} | Test Loss: {loss_test_t:.4f} | Test Acc: {acc_test_t:.4f} (Best: {best_test_acc_t:.4f})")
    print(f"  SpectralTransf | Train Loss: {loss_train_s:.4f} | Test Loss: {loss_test_s:.4f} | Test Acc: {acc_test_s:.4f} (Best: {best_test_acc_s:.4f})")

print("\nFinal Evaluation on Test Set:")
loss_test_t, acc_test_t = evaluate(transformer, test_loader, config.device, desc="Final Evaluation Transformer")
loss_test_s, acc_test_s = evaluate(spectral_transformer, test_loader, config.device, desc="Final Evaluation SpectralTransformer")

print(f"Transformer    | Test Loss: {loss_test_t:.4f} | Test Accuracy: {acc_test_t:.4f}")
print(f"SpectralTransf | Test Loss: {loss_test_s:.4f} | Test Accuracy: {acc_test_s:.4f}")
