In [6]:
import os
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import numpy as np
import matplotlib.pyplot as plt
import pandas as pd
import json
import time
import math
import seaborn as sns
from datetime import datetime
from sklearn.metrics import accuracy_score
from tqdm import tqdm
import random

# Set device configuration
device = torch.device('cpu')
print(f"Using device: {device}")

# --- Utility Functions ---
def make_dir(path):
    """Create directory if it doesn't exist"""
    if not os.path.exists(path):
        os.makedirs(path)
        print(f"Created directory: {path}")

def set_seed(seed=42):
    """Set random seed for reproducibility"""
    torch.manual_seed(seed)
    np.random.seed(seed)
    random.seed(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

# --- Positional Embeddings and Attention Variants ---
class SinusoidalPosEmb(nn.Module):
    """Sinusoidal positional embeddings as in the original Transformer paper"""
    def __init__(self, dim):
        super().__init__()
        self.dim = dim

    def forward(self, seq_len, device):
        pos = torch.arange(seq_len, dtype=torch.float32, device=device).unsqueeze(1)
        i = torch.arange(self.dim, dtype=torch.float32, device=device)
        freqs = 1 / (10000 ** (2 * (i // 2) / self.dim))
        enc = pos * freqs
        enc[:, 0::2] = torch.sin(enc[:, 0::2])
        enc[:, 1::2] = torch.cos(enc[:, 1::2])
        return enc.unsqueeze(0)

class RotaryPositionalEmbedding(nn.Module):
    """Rotary Positional Embedding (RoPE) - https://arxiv.org/abs/2104.09864"""
    def __init__(self, dim, base=10000):
        super().__init__()
        self.dim = dim
        self.base = base
        inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2).float() / dim))
        self.register_buffer("inv_freq", inv_freq)

    def forward(self, seq_len, device):
        t = torch.arange(seq_len, device=device).type_as(self.inv_freq)
        freqs = torch.einsum("i,j->ij", t, self.inv_freq)
        emb = torch.cat((freqs, freqs), dim=-1)
        return emb.unsqueeze(0).unsqueeze(0)  # [1, 1, seq_len, dim]

def apply_rotary_pos_emb(x, freqs):
    """Apply rotary positional embedding to input tensor"""
    # x shape: [batch_size, num_heads, seq_len, head_dim]
    # freqs shape: [1, 1, seq_len, dim]
    
    # Extract the sequence length from x
    seq_len = x.shape[-2]
    
    # Ensure freqs matches the sequence length of x
    freqs = freqs[:, :, :seq_len, :]
    
    head_dim = x.shape[-1]
    if freqs.shape[-1] != head_dim:
        # If dimension doesn't match, we need to interpolate or truncate
        # For simplicity, we'll take the first head_dim elements
        freqs = freqs[..., :head_dim]
    
    # Reshape for rotary application - handle both even and odd dimensions
    x_rot = x.float()
    x1 = x_rot[..., 0::2]  # Even indices
    x2 = x_rot[..., 1::2]  # Odd indices
    
    # Extract cosine and sine components for the relevant dimensions
    cos = torch.cos(freqs)[..., 0::2]  # Take even indices for cos
    sin = torch.sin(freqs)[..., 0::2]  # Take even indices for sin
    
    # Make sure dimensions match
    if cos.shape[-1] != x1.shape[-1]:
        # Truncate if necessary
        min_dim = min(cos.shape[-1], x1.shape[-1])
        cos = cos[..., :min_dim]
        sin = sin[..., :min_dim]
        x1 = x1[..., :min_dim]
        x2 = x2[..., :min_dim]
    
    # Apply rotary transformation
    rotated_x1 = x1 * cos - x2 * sin
    rotated_x2 = x2 * cos + x1 * sin
    
    rotated_x = torch.zeros_like(x_rot)
    rotated_x[..., 0::2] = rotated_x1
    rotated_x[..., 1::2] = rotated_x2
    
    return rotated_x.type_as(x)

class ALiBiBias(nn.Module):
    """ALiBi (Attention with Linear Biases) - https://arxiv.org/abs/2108.12409"""
    def __init__(self, heads):
        super().__init__()
        self.heads = heads
        slopes = torch.tensor(self._get_slopes(heads))
        self.register_buffer("slopes", slopes.view(1, heads, 1, 1))

    def _get_slopes(self, n):
        def get_slopes_power_of_2(n):
            start = (2**(-2**-(math.log2(n)-3)))
            return [start*(start**i) for i in range(n)]
        
        if math.log2(n).is_integer():
            return get_slopes_power_of_2(n)
        else:
            closest_power_of_2 = 2 ** math.floor(math.log2(n))
            return (get_slopes_power_of_2(closest_power_of_2) + 
                    self._get_slopes(2 * closest_power_of_2)[0::2][:n - closest_power_of_2])

    def forward(self, seq_len, device):
        # Create ALiBi bias matrix with proper dimensions [1, heads, seq_len, seq_len]
        bias = torch.arange(seq_len, device=device).float()
        bias = bias - torch.arange(seq_len, device=device).unsqueeze(1)
        bias = torch.abs(bias)  # Distance matrix
        bias = bias.unsqueeze(0).unsqueeze(0)  # [1, 1, seq_len, seq_len]
        bias = bias * self.slopes  # Apply slopes per head
        return -bias  # Negative because we want to bias against distant tokens

# Multihead with support for ALiBi, RoPE, recurrence, and sparse (Longformer/Performer)
class MultiHeadAttention(nn.Module):
    """Multi-head attention with various positional encoding and attention variants"""
    def __init__(self, dim, heads=4, pos_mode='learned', attn_type='full', memory_len=0, window_size=32):
        super().__init__()
        self.dim = dim
        self.heads = heads
        self.pos_mode = pos_mode
        self.attn_type = attn_type
        self.memory_len = memory_len
        self.window_size = window_size
        self.head_dim = dim // heads

        self.to_qkv = nn.Linear(dim, dim * 3, bias=False)
        self.to_out = nn.Linear(dim, dim)

        # Initialize positional embeddings based on mode
        if pos_mode == 'learned':
            self.pos_emb = nn.Embedding(2048, dim)
        elif pos_mode == 'sinusoidal':
            self.pos_emb = SinusoidalPosEmb(dim)
        elif pos_mode == 'rope':
            self.rope = RotaryPositionalEmbedding(self.head_dim)
        
        # Initialize attention bias if using ALiBi
        if attn_type == 'alibi':
            self.alibi = ALiBiBias(heads)
        else:
            self.alibi = None

        # Store window size for Longformer
        self.window_size = window_size

    def forward(self, x, mask=None, memory=None):
        B, T, C = x.shape
        H = self.heads
        
        # Apply query, key, value projections
        qkv = self.to_qkv(x).reshape(B, T, 3, H, C // H).permute(2, 0, 3, 1, 4)
        q, k, v = qkv[0], qkv[1], qkv[2]  # [B, H, T, D]

        # Apply positional encodings
        if self.pos_mode == 'rope':
            freqs = self.rope(T, x.device)
            q = apply_rotary_pos_emb(q, freqs)
            k = apply_rotary_pos_emb(k, freqs)
        elif self.pos_mode == 'learned':
            # For learned positional embeddings, we need to create position indices
            pos_indices = torch.arange(T, device=x.device).unsqueeze(0)  # [1, T]
            pos_emb = self.pos_emb(pos_indices).reshape(1, T, H, C // H).permute(0, 2, 1, 3)
            q = q + pos_emb
            k = k + pos_emb
        elif self.pos_mode == 'sinusoidal':
            # For sinusoidal, use the SinusoidalPosEmb class
            pos_emb = self.pos_emb(T, x.device).reshape(1, T, H, C // H).permute(0, 2, 1, 3)
            q = q + pos_emb
            k = k + pos_emb

        # Handle memory (for recurrent transformers)
        if memory is not None:
            k_mem, v_mem = memory
            k = torch.cat([k_mem, k], dim=2)
            v = torch.cat([v_mem, v], dim=2)
            mem_len = k_mem.shape[2]
        else:
            mem_len = 0

        # Compute attention scores
        attn_scores = torch.matmul(q, k.transpose(-2, -1)) / (self.head_dim ** 0.5)

        if self.attn_type == 'longformer':
            seq_len = attn_scores.size(-1)
            
            # Create band mask efficiently
            band_mask = torch.ones(seq_len, seq_len, device=x.device, dtype=torch.bool)
            
            # Create band mask using broadcasting
            i = torch.arange(seq_len, device=x.device)
            j = torch.arange(seq_len, device=x.device)
            
            # Create mask for positions within window
            distance = torch.abs(i[:, None] - j[None, :])
            band_mask = distance <= (self.window_size // 2)
            
            # Convert to attention mask (0 for allowed, -inf for masked)
            band_mask = band_mask.float()
            band_mask = (1.0 - band_mask) * -1e9  # 0 for allowed, -inf for masked
            
            attn_scores = attn_scores + band_mask.unsqueeze(0).unsqueeze(0)

        # Apply ALiBi bias
        elif self.attn_type == 'alibi' and self.alibi is not None:
            # ALiBi bias should be applied to the attention scores
            # The bias matrix should have shape [1, heads, T, T + mem_len]
            alibi_bias = self.alibi(T + mem_len, x.device)
            attn_scores = attn_scores + alibi_bias

        
        # Apply input mask if provided
        if mask is not None:
            attn_scores = attn_scores.masked_fill(mask == 0, float('-inf'))

        # Compute attention weights
        attn_weights = F.softmax(attn_scores, dim=-1)
        
        # Apply attention dropout during training
        if self.training:
            attn_weights = F.dropout(attn_weights, p=0.1)

        # Compute output
        out = torch.matmul(attn_weights, v)
        out = out.transpose(1, 2).reshape(B, T, C)
        
        # Return output and memory if needed
        if memory is not None:
            new_memory = (k[:, :, -self.memory_len:], v[:, :, -self.memory_len:])
            return self.to_out(out), new_memory, attn_weights
        
        return self.to_out(out), attn_weights

# Transformer Block
class TransformerBlock(nn.Module):
    """Standard transformer block with pre-normalization"""
    def __init__(self, dim, heads=4, ff_mult=4, pos_mode='learned', 
                 attn_type='full', memory_len=0, dropout=0.1):
        super().__init__()
        self.attn = MultiHeadAttention(dim, heads, pos_mode, attn_type, memory_len)
        self.ff = nn.Sequential(
            nn.Linear(dim, dim * ff_mult),
            nn.GELU(),
            nn.Dropout(dropout),
            nn.Linear(dim * ff_mult, dim),
            nn.Dropout(dropout)
        )
        self.norm1 = nn.LayerNorm(dim)
        self.norm2 = nn.LayerNorm(dim)
        self.dropout = nn.Dropout(dropout)

    def forward(self, x, mask=None, memory=None):
        # Self-attention with residual connection
        if memory is not None:
            attn_out, new_memory, attn_weights = self.attn(self.norm1(x), mask, memory)
            x = x + self.dropout(attn_out)
        else:
            attn_out, attn_weights = self.attn(self.norm1(x), mask)
            x = x + self.dropout(attn_out)
            new_memory = None
        
        # Feed-forward with residual connection
        x = x + self.ff(self.norm2(x))
        
        return x, new_memory, attn_weights

# Transformer Model with optional recurrence/memory
class TransformerModel(nn.Module):
    """Transformer model with configurable architecture and attention mechanisms"""
    def __init__(self, dim=64, depth=4, heads=4, pos_mode='learned', 
                 attn_type='full', memory_len=0, vocab_size=10, ff_mult=4, dropout=0.1):
        super().__init__()
        self.dim = dim
        self.depth = depth
        self.memory_len = memory_len
        self.vocab_size = vocab_size
        
        self.token_emb = nn.Embedding(vocab_size, dim)
        self.dropout = nn.Dropout(dropout)
        
        self.blocks = nn.ModuleList([
            TransformerBlock(dim, heads, ff_mult, pos_mode, attn_type, memory_len, dropout)
            for _ in range(depth)
        ])
        
        self.norm = nn.LayerNorm(dim)
        self.to_logits = nn.Linear(dim, vocab_size)
        
        # Initialize weights
        self.apply(self._init_weights)

    def _init_weights(self, module):
        if isinstance(module, nn.Linear):
            torch.nn.init.xavier_uniform_(module.weight)
            if module.bias is not None:
                torch.nn.init.zeros_(module.bias)
        elif isinstance(module, nn.Embedding):
            torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)

    def forward(self, x, mask=None, memory=None):
        # Token embeddings
        x = self.token_emb(x)
        x = self.dropout(x)
        
        # Initialize memory if not provided
        if memory is None and self.memory_len > 0:
            B, T = x.shape[0], x.shape[1]
            memory = [
                torch.zeros(B, self.blocks[0].attn.heads, self.memory_len, self.dim // self.blocks[0].attn.heads, 
                           device=x.device) for _ in range(self.depth)
            ]
        
        # Process through transformer blocks
        new_memory = [] if self.memory_len > 0 else None
        attention_weights = []
        for i, blk in enumerate(self.blocks):
            if memory is not None:
                x, mem, attn = blk(x, mask, memory[i] if memory is not None else None)
                if new_memory is not None:
                    new_memory.append(mem)
                attention_weights.append(attn)
            else:
                x, _, attn = blk(x, mask)
                attention_weights.append(attn)
        
        # Final normalization and output
        x = self.norm(x)
        logits = self.to_logits(x)
        
        return logits, new_memory, attention_weights

# --- Algorithmic Task Datasets ---
class AlgorithmicTaskDataset:
    def __init__(self, task='sorting', seq_length=10, vocab_size=10, num_samples=10000):
        self.task = task
        self.seq_length = seq_length
        self.vocab_size = vocab_size
        self.num_samples = num_samples
        
    def generate_sorting_data(self):
        data = []
        targets = []
        
        for _ in range(self.num_samples):
            # Generate random sequence
            seq = np.random.randint(1, self.vocab_size, self.seq_length)
            
            # Target is the sorted sequence
            target = np.sort(seq)
            
            data.append(seq)
            targets.append(target)
            
        return np.array(data), np.array(targets)
    
    def generate_addition_data(self):
        data = []
        targets = []
        
        for _ in range(self.num_samples):
            # Generate two numbers to add
            num1 = np.random.randint(0, self.vocab_size // 2, self.seq_length // 2)
            num2 = np.random.randint(0, self.vocab_size // 2, self.seq_length // 2)
            
            # Create input sequence: [num1, +, num2]
            seq = np.concatenate([num1, [self.vocab_size - 1], num2])
            
            # Calculate target (sum)
            total = np.sum(num1) + np.sum(num2)
            # Represent as a sequence of digits
            target_str = str(total)
            target = np.array([int(d) for d in target_str] + [0] * (self.seq_length - len(target_str)))
            
            data.append(seq)
            targets.append(target)
            
        return np.array(data), np.array(targets)
    
    def generate_data(self):
        if self.task == 'sorting':
            return self.generate_sorting_data()
        elif self.task == 'addition':
            return self.generate_addition_data()
        else:
            raise ValueError(f"Unknown task: {self.task}")
    
    def get_dataset(self):
        data, targets = self.generate_data()
        return torch.tensor(data, dtype=torch.long), torch.tensor(targets, dtype=torch.long)

# --- Training and Evaluation Functions ---
def train_model(model, train_data, train_targets, val_data, val_targets, 
                epochs=10, batch_size=32, lr=0.001, device='cpu', warmup_epochs=3):
    """Train the transformer model on algorithmic tasks"""
    criterion = nn.CrossEntropyLoss()
    optimizer = optim.Adam(model.parameters(), lr=lr)
    scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, patience=2, factor=0.5)
    
    train_dataset = torch.utils.data.TensorDataset(train_data, train_targets)
    train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
    
    val_dataset = torch.utils.data.TensorDataset(val_data, val_targets)
    val_loader = torch.utils.data.DataLoader(val_dataset, batch_size=batch_size, shuffle=False)
    
    train_losses = []
    val_losses = []
    val_accuracies = []
    grad_norms = []
    
    start_time = time.time()
    
    # Warm-up scheduler
    warmup_scheduler = optim.lr_scheduler.LambdaLR(
        optimizer, 
        lr_lambda=lambda epoch: min(1.0, (epoch + 1) / warmup_epochs)
    )
    
    for epoch in range(epochs):
        model.train()
        total_loss = 0
        epoch_grad_norms = []
        
        for batch_idx, (data, target) in enumerate(tqdm(train_loader, desc=f"Epoch {epoch+1}/{epochs}")):
            data, target = data.to(device), target.to(device)
            optimizer.zero_grad()
            output, _, _ = model(data)
            
            # Handle sequence length mismatch - truncate to the minimum length
            seq_len = min(output.size(1), target.size(1))
            output = output[:, :seq_len, :]  # [batch_size, seq_len, vocab_size]
            target = target[:, :seq_len]     # [batch_size, seq_len]
            
            # Reshape for loss calculation
            output = output.reshape(-1, output.size(-1))
            target = target.reshape(-1)
            
            loss = criterion(output, target)
            loss.backward()
            
            # Track gradient norms
            total_norm = 0
            for p in model.parameters():
                if p.grad is not None:
                    param_norm = p.grad.data.norm(2)
                    total_norm += param_norm.item() ** 2
            total_norm = total_norm ** 0.5
            epoch_grad_norms.append(total_norm)
            
            # Gradient clipping for stability
            torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
            
            optimizer.step()
            
            total_loss += loss.item()
        
        avg_train_loss = total_loss / len(train_loader)
        avg_grad_norm = np.mean(epoch_grad_norms)
        train_losses.append(avg_train_loss)
        grad_norms.append(avg_grad_norm)
        
        # Update learning rate (warmup for first few epochs)
        if epoch < warmup_epochs:
            warmup_scheduler.step()
        
        # Validation
        model.eval()
        val_loss = 0
        all_predictions = []
        all_targets = []
        
        with torch.no_grad():
            for data, target in val_loader:
                data, target = data.to(device), target.to(device)
                output, _, _ = model(data)
                
                # Handle sequence length mismatch
                seq_len = min(output.size(1), target.size(1))
                output = output[:, :seq_len, :]
                target = target[:, :seq_len]
                
                # Reshape for loss calculation
                output_flat = output.reshape(-1, output.size(-1))
                target_flat = target.reshape(-1)
                
                loss = criterion(output_flat, target_flat)
                val_loss += loss.item()
                
                # Calculate accuracy
                predictions = torch.argmax(output, dim=-1)
                all_predictions.extend(predictions.cpu().numpy().flatten())
                all_targets.extend(target.cpu().numpy().flatten())
        
        avg_val_loss = val_loss / len(val_loader)
        val_losses.append(avg_val_loss)
        
        accuracy = accuracy_score(all_targets, all_predictions)
        val_accuracies.append(accuracy)
        
        # Update learning rate based on validation loss
        if epoch >= warmup_epochs:
            scheduler.step(avg_val_loss)
        
        print(f"Epoch {epoch+1}/{epochs} | Train Loss: {avg_train_loss:.4f} | Val Loss: {avg_val_loss:.4f} | Val Accuracy: {accuracy:.4f} | Grad Norm: {avg_grad_norm:.4f} | LR: {optimizer.param_groups[0]['lr']:.6f}")
    
    training_time = time.time() - start_time
    
    return train_losses, val_losses, val_accuracies, grad_norms, training_time

def test_length_generalization(model, original_length, new_lengths, task='sorting', 
                              vocab_size=10, device='cpu'):
    """Test model generalization to longer sequences"""
    results = {}
    
    for length in new_lengths:
        # Generate test data with new length
        if task == 'sorting':
            test_data = np.random.randint(1, vocab_size, (100, length))
            test_targets = np.sort(test_data, axis=1)
        elif task == 'addition':
            test_data = []
            test_targets = []
            for _ in range(100):
                num1 = np.random.randint(0, vocab_size // 2, length // 2)
                num2 = np.random.randint(0, vocab_size // 2, length // 2)
                seq = np.concatenate([num1, [vocab_size - 1], num2])
                total = np.sum(num1) + np.sum(num2)
                target_str = str(total)
                # Make target the same length as input for consistency
                target = np.array([int(d) for d in target_str] + [0] * (length - len(target_str)))
                test_data.append(seq)
                test_targets.append(target)
            test_data = np.array(test_data)
            test_targets = np.array(test_targets)
        
        test_data = torch.tensor(test_data, dtype=torch.long).to(device)
        test_targets = torch.tensor(test_targets, dtype=torch.long).to(device)
        
        # Test the model
        model.eval()
        all_predictions = []
        all_targets = []
        
        with torch.no_grad():
            # Process in batches to avoid memory issues
            batch_size = 32
            for i in range(0, len(test_data), batch_size):
                batch_data = test_data[i:i+batch_size]
                batch_targets = test_targets[i:i+batch_size]
                
                output, _, _ = model(batch_data)
                
                # Handle sequence length mismatch
                seq_len = min(output.size(1), batch_targets.size(1))
                output = output[:, :seq_len, :]
                batch_targets = batch_targets[:, :seq_len]
                
                predictions = torch.argmax(output, dim=-1)
                
                all_predictions.extend(predictions.cpu().numpy().flatten())
                all_targets.extend(batch_targets.cpu().numpy().flatten())
        
        accuracy = accuracy_score(all_targets, all_predictions)
        results[length] = accuracy
        print(f"Length: {length}, Accuracy: {accuracy:.4f}")
    
    return results

def visualize_attention(model, data_sample, task='sorting', out_dir='./', device='cpu'):
    """Visualize attention patterns for a sample"""
    model.eval()
    with torch.no_grad():
        output, _, attention_weights = model(data_sample.unsqueeze(0).to(device))
    
    # Plot attention weights for each layer and head
    num_layers = len(attention_weights)
    num_heads = attention_weights[0].size(1)  # [batch_size, num_heads, seq_len, seq_len]
    
    fig, axes = plt.subplots(num_layers, num_heads, figsize=(4*num_heads, 4*num_layers))
    
    if num_layers == 1:
        axes = [axes]
    if num_heads == 1:
        for i in range(num_layers):
            axes[i] = [axes[i]]
    
    for layer in range(num_layers):
        for head in range(num_heads):
            attn_map = attention_weights[layer][0, head].cpu().numpy()
            ax = axes[layer][head]
            sns.heatmap(attn_map, ax=ax, cmap='viridis')
            ax.set_title(f'Layer {layer+1}, Head {head+1}')
    
    plt.tight_layout()
    plt.savefig(os.path.join(out_dir, f'{task}_attention.png'), dpi=300, bbox_inches='tight')
    plt.close()
    
    return attention_weights

# --- Experiment Setup and Reporting ---
def run_algorithmic_experiment(cfg, out_dir):
    """Run a complete algorithmic task experiment with given configuration"""
    print(f"Starting experiment: {cfg['name']}")
    print(f"Output directory: {out_dir}")
    make_dir(out_dir)
    
    # Set seed for reproducibility
    set_seed(cfg.get('seed', 42))
    
    # Save configuration
    with open(os.path.join(out_dir, 'config.json'), 'w') as f:
        json.dump(cfg, f, indent=2)
    
    # Generate dataset
    print(f"Generating {cfg['task']} dataset with sequence length {cfg['seq_length']}...")
    dataset = AlgorithmicTaskDataset(
        task=cfg['task'], 
        seq_length=cfg['seq_length'], 
        vocab_size=cfg['vocab_size'],
        num_samples=10000
    )
    data, targets = dataset.get_dataset()
    
    # Split into train and validation
    split_idx = int(0.8 * len(data))
    train_data, val_data = data[:split_idx], data[split_idx:]
    train_targets, val_targets = targets[:split_idx], targets[split_idx:]
    
    # Initialize model
    model = TransformerModel(
        dim=cfg['dim'],
        depth=cfg['depth'], 
        heads=cfg['heads'],
        pos_mode=cfg['pos_mode'], 
        attn_type=cfg['attn_type'],
        memory_len=cfg['memory_len'], 
        vocab_size=cfg['vocab_size'],
        ff_mult=cfg.get('ff_mult', 4),
        dropout=cfg.get('dropout', 0.1)
    ).to(device)
    
    print(f"Model has {sum(p.numel() for p in model.parameters()):,} parameters")
    
    # Train model
    print("Training model...")
    train_losses, val_losses, val_accuracies, grad_norms, training_time = train_model(
        model, train_data, train_targets, val_data, val_targets,
        epochs=cfg['epochs'], batch_size=cfg['batch_size'], lr=cfg['lr'], 
        device=device, warmup_epochs=cfg.get('warmup_epochs', 3)
    )
    
    # Plot training history
    plt.figure(figsize=(15, 5))
    
    plt.subplot(1, 3, 1)
    plt.plot(train_losses, label='Train Loss')
    plt.plot(val_losses, label='Validation Loss')
    plt.xlabel('Epoch')
    plt.ylabel('Loss')
    plt.legend()
    plt.title('Training and Validation Loss')
    
    plt.subplot(1, 3, 2)
    plt.plot(val_accuracies, label='Validation Accuracy')
    plt.xlabel('Epoch')
    plt.ylabel('Accuracy')
    plt.legend()
    plt.title('Validation Accuracy')
    
    plt.subplot(1, 3, 3)
    plt.plot(grad_norms, label='Gradient Norm')
    plt.xlabel('Epoch')
    plt.ylabel('Norm')
    plt.legend()
    plt.title('Gradient Norms')
    
    plt.tight_layout()
    plt.savefig(os.path.join(out_dir, f"{cfg['task']}_training_history.png"), dpi=300, bbox_inches='tight')
    plt.close()
    
    # Visualize attention for a sample
    print("Visualizing attention patterns...")
    sample_idx = 0 
    attention_weights = visualize_attention(
        model, val_data[sample_idx], task=cfg['task'], out_dir=out_dir, device=device
    )
    
    # Test length generalization
    print("Testing length generalization...")
    new_lengths = [cfg['seq_length'], cfg['seq_length'] + 2, cfg['seq_length'] + 5, cfg['seq_length'] + 10]
    generalization_results = test_length_generalization(
        model, cfg['seq_length'], new_lengths, task=cfg['task'], 
        vocab_size=cfg['vocab_size'], device=device
    )
    
    # Plot generalization results
    plt.figure(figsize=(8, 6))
    lengths = list(generalization_results.keys())
    accuracies = list(generalization_results.values())
    
    plt.bar(range(len(lengths)), accuracies, tick_label=lengths)
    plt.xlabel('Sequence Length')
    plt.ylabel('Accuracy')
    plt.title(f"Length Generalization on {cfg['task']} task")
    plt.ylim(0, 1)
    
    for i, acc in enumerate(accuracies):
        plt.text(i, acc + 0.01, f'{acc:.2f}', ha='center')
    
    plt.tight_layout()
    plt.savefig(os.path.join(out_dir, f"{cfg['task']}_generalization.png"), dpi=300, bbox_inches='tight')
    plt.close()
    
    # Generate performance report
    report = {
        "training_summary": {
            "final_train_loss": train_losses[-1],
            "final_val_loss": val_losses[-1],
            "final_val_accuracy": val_accuracies[-1],
            "val_accuracies_history": val_accuracies, 
            "training_time": training_time,
            "best_val_accuracy": max(val_accuracies),
            "convergence_epoch": val_accuracies.index(max(val_accuracies)) + 1,
        },
        "generalization_results": generalization_results,
        "model_architecture": {
            "dimension": cfg['dim'],
            "depth": cfg['depth'],
            "heads": cfg['heads'],
            "position_encoding": cfg['pos_mode'],
            "attention_type": cfg['attn_type'],
            "memory_length": cfg['memory_len'],
            "vocabulary_size": cfg['vocab_size'],
            "parameters_total": sum(p.numel() for p in model.parameters()),
        },
        "training_config": cfg
    }
    
    # Save report as JSON
    with open(os.path.join(out_dir, 'performance_report.json'), 'w') as f:
        json.dump(report, f, indent=2)
    
    # Save as text summary
    with open(os.path.join(out_dir, 'summary.txt'), 'w') as f:
        f.write("ALGORITHMIC TASK EXPERIMENT REPORT\n")
        f.write("==================================\n\n")
        f.write(f"Experiment: {cfg['name']}\n")
        f.write(f"Task: {cfg['task']}\n")
        f.write(f"Training sequence length: {cfg['seq_length']}\n")
        f.write(f"Final training loss: {train_losses[-1]:.4f}\n")
        f.write(f"Final validation loss: {val_losses[-1]:.4f}\n")
        f.write(f"Final validation accuracy: {val_accuracies[-1]:.4f}\n")
        f.write(f"Best validation accuracy: {max(val_accuracies):.4f}\n")
        f.write(f"Training time: {training_time:.2f} seconds\n")
        f.write(f"Model parameters: {report['model_architecture']['parameters_total']:,}\n")
        f.write(f"Attention type: {cfg['attn_type']}\n")
        f.write(f"Position encoding: {cfg['pos_mode']}\n\n")
        
        f.write("Length Generalization Results:\n")
        for length, accuracy in generalization_results.items():
            f.write(f"  Length {length}: {accuracy:.4f}\n")
    
    # Save model
    torch.save({
        'model_state_dict': model.state_dict(),
        'config': cfg,
        'history': {
            'train_losses': train_losses,
            'val_losses': val_losses,
            'val_accuracies': val_accuracies,
            'grad_norms': grad_norms,
        }
    }, os.path.join(out_dir, 'model.pth'))
    
    print(f"Experiment completed! Results saved to {out_dir}")
    return report

# --- Main Execution ---
if __name__ == "__main__":
    # Create output directory
    timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
    base_out_dir = f"./experiments/{timestamp}"
    make_dir(base_out_dir)


    NUM_EPOCHS = 10
    TASK = 'sorting' # 'sorting' or 'addition'
    SEQ_LENGTH = 20
    # Define experiment configurations
    experiments = [
        {
            'name': 'baseline',
            'task': TASK,
            'seq_length': SEQ_LENGTH,
            'vocab_size': 10,
            'dim': 64,
            'depth': 4,
            'heads': 4,
            'pos_mode': 'learned',
            'attn_type': 'full',
            'memory_len': 0,
            'epochs': NUM_EPOCHS,
            'batch_size': 32,
            'lr': 0.001,
            'seed': 42
        },
        {
            'name': 'rope',
            'task': TASK,
            'seq_length': SEQ_LENGTH,
            'vocab_size': 10,
            'dim': 64,
            'depth': 4,
            'heads': 4,
            'pos_mode': 'rope',
            'attn_type': 'full',
            'memory_len': 0,
            'epochs': NUM_EPOCHS,
            'batch_size': 32,
            'lr': 0.001,
            'seed': 42
        },
        {
            'name': 'alibi',
            'task': TASK,
            'seq_length': SEQ_LENGTH,
            'vocab_size': 10,
            'dim': 64,
            'depth': 4,
            'heads': 4,
            'pos_mode': 'sinusoidal',
            'attn_type': 'alibi',
            'memory_len': 0,
            'epochs': NUM_EPOCHS,
            'batch_size': 32,
            'lr': 0.001,
            'seed': 42
        },
        {
            'name': 'longformer',
            'task': TASK,
            'seq_length': SEQ_LENGTH,
            'vocab_size': 10,
            'dim': 64,
            'depth': 4,
            'heads': 4,
            'pos_mode': 'learned',
            'attn_type': 'longformer',
            'memory_len': 0,
            'window_size': 5,
            'epochs': NUM_EPOCHS,
            'batch_size': 32,
            'lr': 0.001,
            'seed': 42
        }
    ]
    
    all_results = {}
    
    # Run each experiment
    for i, cfg in enumerate(experiments):
        print(f"\n{'='*50}")
        print(f"Running experiment {i+1}/{len(experiments)}: {cfg['name']}")
        print(f"{'='*50}")
        
        out_dir = os.path.join(base_out_dir, cfg['name'])
        result = run_algorithmic_experiment(cfg, out_dir)
        all_results[cfg['name']] = result
        
        # Clear GPU memory between experiments
        if torch.cuda.is_available():
            torch.cuda.empty_cache()
    
    # Generate comparative analysis
    print("\nGenerating comparative analysis...")
    
    # Compare generalization across experiments
    plt.figure(figsize=(12, 8))
    
    for exp_name, result in all_results.items():
        gen_results = result['generalization_results']
        lengths = list(gen_results.keys())
        accuracies = list(gen_results.values())
        
        plt.plot(lengths, accuracies, marker='o', label=exp_name, linewidth=2)
    
    plt.xlabel('Sequence Length')
    plt.ylabel('Accuracy')
    plt.title('Length Generalization Comparison Across Models')
    plt.legend()
    plt.grid(True, alpha=0.3)
    plt.ylim(0, 1)
    
    plt.tight_layout()
    plt.savefig(os.path.join(base_out_dir, 'comparative_generalization.png'), dpi=300, bbox_inches='tight')
    plt.close()
    
    # Create summary table
    summary_data = []
    for exp_name, result in all_results.items():
        summary_data.append({
            'Experiment': exp_name,
            'Task': result['training_config']['task'],
            'Attention Type': result['training_config']['attn_type'],
            'Position Encoding': result['training_config']['pos_mode'],
            'Best Val Accuracy': result['training_summary']['best_val_accuracy'],
            'Training Time (s)': result['training_summary']['training_time'],
            'Parameters': result['model_architecture']['parameters_total'],
            'Trained Length': result['training_config']['seq_length'],
            **{f'Length {l} Acc': acc for l, acc in result['generalization_results'].items()}
        })
    
    df_summary = pd.DataFrame(summary_data)
    df_summary.to_csv(os.path.join(base_out_dir, 'experiment_summary.csv'), index=False)
    
    # Print summary
    print("\nEXPERIMENT SUMMARY")
    print("=" * 80)
    print(df_summary.to_string(index=False))
    
    print(f"\nAll experiments completed! Results saved to {base_out_dir}")

Using device: cpu
Created directory: ./experiments/20250903_191527

Running experiment 1/4: baseline
Starting experiment: baseline
Output directory: ./experiments/20250903_191527/baseline
Created directory: ./experiments/20250903_191527/baseline
Generating sorting dataset with sequence length 20...
Model has 724,874 parameters
Training model...


Epoch 1/10: 100%|█████████████████████████████████████████████████████████████████████| 250/250 [00:14<00:00, 17.78it/s]


Epoch 1/10 | Train Loss: 1.6405 | Val Loss: 0.4717 | Val Accuracy: 0.8445 | Grad Norm: 2.9867 | LR: 0.000667


Epoch 2/10: 100%|█████████████████████████████████████████████████████████████████████| 250/250 [00:14<00:00, 17.27it/s]


Epoch 2/10 | Train Loss: 0.4062 | Val Loss: 0.1454 | Val Accuracy: 0.9653 | Grad Norm: 5.5576 | LR: 0.001000


Epoch 3/10: 100%|█████████████████████████████████████████████████████████████████████| 250/250 [00:14<00:00, 16.81it/s]


Epoch 3/10 | Train Loss: 0.2570 | Val Loss: 0.1335 | Val Accuracy: 0.9512 | Grad Norm: 5.9799 | LR: 0.001000


Epoch 4/10: 100%|█████████████████████████████████████████████████████████████████████| 250/250 [00:14<00:00, 16.81it/s]


Epoch 4/10 | Train Loss: 0.1945 | Val Loss: 0.0631 | Val Accuracy: 0.9886 | Grad Norm: 5.4222 | LR: 0.001000


Epoch 5/10: 100%|█████████████████████████████████████████████████████████████████████| 250/250 [00:15<00:00, 15.94it/s]


Epoch 5/10 | Train Loss: 0.1585 | Val Loss: 0.0488 | Val Accuracy: 0.9900 | Grad Norm: 4.9270 | LR: 0.001000


Epoch 6/10: 100%|█████████████████████████████████████████████████████████████████████| 250/250 [00:19<00:00, 12.88it/s]


Epoch 6/10 | Train Loss: 0.1378 | Val Loss: 0.0303 | Val Accuracy: 0.9967 | Grad Norm: 4.3415 | LR: 0.001000


Epoch 7/10: 100%|█████████████████████████████████████████████████████████████████████| 250/250 [00:20<00:00, 12.44it/s]


Epoch 7/10 | Train Loss: 0.1297 | Val Loss: 0.0278 | Val Accuracy: 0.9966 | Grad Norm: 4.5416 | LR: 0.001000


Epoch 8/10: 100%|█████████████████████████████████████████████████████████████████████| 250/250 [00:21<00:00, 11.49it/s]


Epoch 8/10 | Train Loss: 0.1165 | Val Loss: 0.0272 | Val Accuracy: 0.9953 | Grad Norm: 4.2401 | LR: 0.001000


Epoch 9/10: 100%|█████████████████████████████████████████████████████████████████████| 250/250 [00:19<00:00, 12.70it/s]


Epoch 9/10 | Train Loss: 0.1037 | Val Loss: 0.0215 | Val Accuracy: 0.9970 | Grad Norm: 3.8751 | LR: 0.001000


Epoch 10/10: 100%|████████████████████████████████████████████████████████████████████| 250/250 [00:19<00:00, 13.03it/s]


Epoch 10/10 | Train Loss: 0.0933 | Val Loss: 0.0231 | Val Accuracy: 0.9949 | Grad Norm: 3.6912 | LR: 0.001000
Visualizing attention patterns...
Testing length generalization...
Length: 20, Accuracy: 0.9955
Length: 22, Accuracy: 0.6464
Length: 25, Accuracy: 0.3484
Length: 30, Accuracy: 0.2230
Experiment completed! Results saved to ./experiments/20250903_191527/baseline

Running experiment 2/4: rope
Starting experiment: rope
Output directory: ./experiments/20250903_191527/rope
Created directory: ./experiments/20250903_191527/rope
Generating sorting dataset with sequence length 20...
Model has 200,586 parameters
Training model...


Epoch 1/10: 100%|█████████████████████████████████████████████████████████████████████| 250/250 [00:20<00:00, 12.36it/s]


Epoch 1/10 | Train Loss: 0.9896 | Val Loss: 0.2901 | Val Accuracy: 0.9033 | Grad Norm: 4.4458 | LR: 0.000667


Epoch 2/10: 100%|█████████████████████████████████████████████████████████████████████| 250/250 [00:20<00:00, 12.24it/s]


Epoch 2/10 | Train Loss: 0.3590 | Val Loss: 0.1576 | Val Accuracy: 0.9506 | Grad Norm: 7.4462 | LR: 0.001000


Epoch 3/10: 100%|█████████████████████████████████████████████████████████████████████| 250/250 [00:20<00:00, 12.15it/s]


Epoch 3/10 | Train Loss: 0.2732 | Val Loss: 0.1348 | Val Accuracy: 0.9577 | Grad Norm: 8.6026 | LR: 0.001000


Epoch 4/10: 100%|█████████████████████████████████████████████████████████████████████| 250/250 [00:22<00:00, 11.03it/s]


Epoch 4/10 | Train Loss: 0.2187 | Val Loss: 0.0703 | Val Accuracy: 0.9871 | Grad Norm: 7.2765 | LR: 0.001000


Epoch 5/10: 100%|█████████████████████████████████████████████████████████████████████| 250/250 [00:24<00:00, 10.03it/s]


Epoch 5/10 | Train Loss: 0.1850 | Val Loss: 0.0477 | Val Accuracy: 0.9936 | Grad Norm: 6.4653 | LR: 0.001000


Epoch 6/10: 100%|█████████████████████████████████████████████████████████████████████| 250/250 [00:24<00:00, 10.02it/s]


Epoch 6/10 | Train Loss: 0.1650 | Val Loss: 0.0576 | Val Accuracy: 0.9875 | Grad Norm: 6.5915 | LR: 0.001000


Epoch 7/10: 100%|█████████████████████████████████████████████████████████████████████| 250/250 [00:24<00:00, 10.10it/s]


Epoch 7/10 | Train Loss: 0.1503 | Val Loss: 0.0450 | Val Accuracy: 0.9894 | Grad Norm: 6.0703 | LR: 0.001000


Epoch 8/10: 100%|█████████████████████████████████████████████████████████████████████| 250/250 [00:21<00:00, 11.56it/s]


Epoch 8/10 | Train Loss: 0.1320 | Val Loss: 0.0501 | Val Accuracy: 0.9845 | Grad Norm: 5.6749 | LR: 0.001000


Epoch 9/10: 100%|█████████████████████████████████████████████████████████████████████| 250/250 [00:24<00:00, 10.41it/s]


Epoch 9/10 | Train Loss: 0.1264 | Val Loss: 0.1724 | Val Accuracy: 0.9188 | Grad Norm: 6.0963 | LR: 0.001000


Epoch 10/10: 100%|████████████████████████████████████████████████████████████████████| 250/250 [00:20<00:00, 12.11it/s]


Epoch 10/10 | Train Loss: 0.1165 | Val Loss: 0.0265 | Val Accuracy: 0.9935 | Grad Norm: 5.9424 | LR: 0.001000
Visualizing attention patterns...
Testing length generalization...
Length: 20, Accuracy: 0.9945
Length: 22, Accuracy: 0.9100
Length: 25, Accuracy: 0.7800
Length: 30, Accuracy: 0.6197
Experiment completed! Results saved to ./experiments/20250903_191527/rope

Running experiment 3/4: alibi
Starting experiment: alibi
Output directory: ./experiments/20250903_191527/alibi
Created directory: ./experiments/20250903_191527/alibi
Generating sorting dataset with sequence length 20...
Model has 200,586 parameters
Training model...


Epoch 1/10: 100%|█████████████████████████████████████████████████████████████████████| 250/250 [00:20<00:00, 12.30it/s]


Epoch 1/10 | Train Loss: 0.8489 | Val Loss: 0.2741 | Val Accuracy: 0.9356 | Grad Norm: 4.6834 | LR: 0.000667


Epoch 2/10: 100%|█████████████████████████████████████████████████████████████████████| 250/250 [00:20<00:00, 12.30it/s]


Epoch 2/10 | Train Loss: 0.3785 | Val Loss: 0.1691 | Val Accuracy: 0.9502 | Grad Norm: 7.2009 | LR: 0.001000


Epoch 3/10: 100%|█████████████████████████████████████████████████████████████████████| 250/250 [00:18<00:00, 13.81it/s]


Epoch 3/10 | Train Loss: 0.2975 | Val Loss: 0.1268 | Val Accuracy: 0.9699 | Grad Norm: 7.7801 | LR: 0.001000


Epoch 4/10: 100%|█████████████████████████████████████████████████████████████████████| 250/250 [00:16<00:00, 15.39it/s]


Epoch 4/10 | Train Loss: 0.2329 | Val Loss: 0.1046 | Val Accuracy: 0.9724 | Grad Norm: 6.2757 | LR: 0.001000


Epoch 5/10: 100%|█████████████████████████████████████████████████████████████████████| 250/250 [00:16<00:00, 15.61it/s]


Epoch 5/10 | Train Loss: 0.1904 | Val Loss: 0.0736 | Val Accuracy: 0.9858 | Grad Norm: 5.9633 | LR: 0.001000


Epoch 6/10: 100%|█████████████████████████████████████████████████████████████████████| 250/250 [00:24<00:00, 10.28it/s]


Epoch 6/10 | Train Loss: 0.1669 | Val Loss: 0.0554 | Val Accuracy: 0.9892 | Grad Norm: 5.7036 | LR: 0.001000


Epoch 7/10: 100%|█████████████████████████████████████████████████████████████████████| 250/250 [00:17<00:00, 14.50it/s]


Epoch 7/10 | Train Loss: 0.1563 | Val Loss: 0.0761 | Val Accuracy: 0.9752 | Grad Norm: 5.7439 | LR: 0.001000


Epoch 8/10: 100%|█████████████████████████████████████████████████████████████████████| 250/250 [00:16<00:00, 15.49it/s]


Epoch 8/10 | Train Loss: 0.1380 | Val Loss: 0.0525 | Val Accuracy: 0.9850 | Grad Norm: 5.0279 | LR: 0.001000


Epoch 9/10: 100%|█████████████████████████████████████████████████████████████████████| 250/250 [00:18<00:00, 13.20it/s]


Epoch 9/10 | Train Loss: 0.1288 | Val Loss: 0.0602 | Val Accuracy: 0.9813 | Grad Norm: 4.6817 | LR: 0.001000


Epoch 10/10: 100%|████████████████████████████████████████████████████████████████████| 250/250 [00:21<00:00, 11.80it/s]


Epoch 10/10 | Train Loss: 0.1214 | Val Loss: 0.0274 | Val Accuracy: 0.9965 | Grad Norm: 5.2100 | LR: 0.001000
Visualizing attention patterns...
Testing length generalization...
Length: 20, Accuracy: 0.9990
Length: 22, Accuracy: 0.6977
Length: 25, Accuracy: 0.4184
Length: 30, Accuracy: 0.2450
Experiment completed! Results saved to ./experiments/20250903_191527/alibi

Running experiment 4/4: longformer
Starting experiment: longformer
Output directory: ./experiments/20250903_191527/longformer
Created directory: ./experiments/20250903_191527/longformer
Generating sorting dataset with sequence length 20...
Model has 724,874 parameters
Training model...


Epoch 1/10: 100%|█████████████████████████████████████████████████████████████████████| 250/250 [00:17<00:00, 14.12it/s]


Epoch 1/10 | Train Loss: 1.6477 | Val Loss: 0.4469 | Val Accuracy: 0.8673 | Grad Norm: 2.9739 | LR: 0.000667


Epoch 2/10: 100%|█████████████████████████████████████████████████████████████████████| 250/250 [00:16<00:00, 15.44it/s]


Epoch 2/10 | Train Loss: 0.4105 | Val Loss: 0.1580 | Val Accuracy: 0.9625 | Grad Norm: 5.5020 | LR: 0.001000


Epoch 3/10: 100%|█████████████████████████████████████████████████████████████████████| 250/250 [00:18<00:00, 13.26it/s]


Epoch 3/10 | Train Loss: 0.2722 | Val Loss: 0.1531 | Val Accuracy: 0.9371 | Grad Norm: 6.0405 | LR: 0.001000


Epoch 4/10: 100%|█████████████████████████████████████████████████████████████████████| 250/250 [00:17<00:00, 14.07it/s]


Epoch 4/10 | Train Loss: 0.2070 | Val Loss: 0.1014 | Val Accuracy: 0.9707 | Grad Norm: 5.2857 | LR: 0.001000


Epoch 5/10: 100%|█████████████████████████████████████████████████████████████████████| 250/250 [00:19<00:00, 13.10it/s]


Epoch 5/10 | Train Loss: 0.1688 | Val Loss: 0.0678 | Val Accuracy: 0.9823 | Grad Norm: 4.5516 | LR: 0.001000


Epoch 6/10: 100%|█████████████████████████████████████████████████████████████████████| 250/250 [00:16<00:00, 14.99it/s]


Epoch 6/10 | Train Loss: 0.1476 | Val Loss: 0.0494 | Val Accuracy: 0.9891 | Grad Norm: 4.5069 | LR: 0.001000


Epoch 7/10: 100%|█████████████████████████████████████████████████████████████████████| 250/250 [00:16<00:00, 14.84it/s]


Epoch 7/10 | Train Loss: 0.1377 | Val Loss: 0.0386 | Val Accuracy: 0.9929 | Grad Norm: 4.9098 | LR: 0.001000


Epoch 8/10: 100%|█████████████████████████████████████████████████████████████████████| 250/250 [00:16<00:00, 14.81it/s]


Epoch 8/10 | Train Loss: 0.1196 | Val Loss: 0.0293 | Val Accuracy: 0.9958 | Grad Norm: 4.2416 | LR: 0.001000


Epoch 9/10: 100%|█████████████████████████████████████████████████████████████████████| 250/250 [00:16<00:00, 15.57it/s]


Epoch 9/10 | Train Loss: 0.1099 | Val Loss: 0.0214 | Val Accuracy: 0.9967 | Grad Norm: 4.1396 | LR: 0.001000


Epoch 10/10: 100%|████████████████████████████████████████████████████████████████████| 250/250 [00:18<00:00, 13.26it/s]


Epoch 10/10 | Train Loss: 0.1034 | Val Loss: 0.0247 | Val Accuracy: 0.9949 | Grad Norm: 4.2451 | LR: 0.001000
Visualizing attention patterns...
Testing length generalization...
Length: 20, Accuracy: 0.9965
Length: 22, Accuracy: 0.6332
Length: 25, Accuracy: 0.3768
Length: 30, Accuracy: 0.2073
Experiment completed! Results saved to ./experiments/20250903_191527/longformer

Generating comparative analysis...

EXPERIMENT SUMMARY
Experiment    Task Attention Type Position Encoding  Best Val Accuracy  Training Time (s)  Parameters  Trained Length  Length 20 Acc  Length 22 Acc  Length 25 Acc  Length 30 Acc
  baseline sorting           full           learned           0.996950         182.433403      724874              20         0.9955       0.646364         0.3484       0.223000
      rope sorting           full              rope           0.993625         238.043910      200586              20         0.9945       0.910000         0.7800       0.619667
     alibi sorting          alibi    