In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import IterableDataset, DataLoader
import math
import random
import numpy as np
from tqdm import tqdm

# --- CONFIGURATION ---
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
VOCAB_SIZE = 512       
D_MODEL = 128          
N_LAYERS = 4
DROPOUT = 0.1
LR = 1e-3
BATCH_SIZE = 64
TRAIN_STEPS = 3000     # Slightly longer training for the harder task

# ==========================================
# 1. DATA GENERATORS
# ==========================================

class SyntheticDataset(IterableDataset):
    def __init__(self, batch_size, seq_len):
        self.batch_size = batch_size
        self.seq_len = seq_len

    def generate_batch(self):
        raise NotImplementedError

    def __iter__(self):
        while True:
            yield self.generate_batch()

class HybridArithmeticDataset(SyntheticDataset):
    """
    NOVEL TASK: Hybrid Recall + Arithmetic
    Context: 'a=3', 'b=5', 'c=2' ...
    Query: 'a + 1 = ?' -> Target: '4'
    
    Why this is novel: 
    It forces the model to retrieve a specific symbol ('3') associated with a key ('a'),
    move it to working memory, and then perform an algorithmic operation (+1) on it.
    """
    def __init__(self, batch_size, seq_len, vocab_size=VOCAB_SIZE, num_vars=8):
        super().__init__(batch_size, seq_len)
        self.vocab_size = vocab_size
        self.num_vars = num_vars

    def generate_batch(self):
        input_ids = torch.zeros((self.batch_size, self.seq_len), dtype=torch.long)
        targets = torch.zeros((self.batch_size, self.seq_len), dtype=torch.long)
        
        # Token Mapping:
        # 0: Pad
        # 1: '='
        # 2: '+'
        # 3: '?' (The "Transform/Solve" token)
        # 10-19: Variables ('a', 'b', etc.)
        # 20-50: Integer Values (0 to 30)
        
        VAR_START = 10
        VAL_START = 20
        
        for i in range(self.batch_size):
            # 1. Create Variable Assignments
            # Map random variables to random integer values (0-9)
            vars_idx = list(range(self.num_vars))
            random.shuffle(vars_idx)
            
            memory = {} # Truth Map: Token_ID -> Integer Value
            seq = []
            
            for v_idx in vars_idx:
                var_token = VAR_START + v_idx
                val_int = random.randint(0, 9) 
                val_token = VAL_START + val_int
                
                memory[var_token] = val_int
                
                # Context: "var = val" -> [VAR, 1, VAL]
                seq.extend([var_token, 1, val_token])
            
            # 2. Generate Query
            # Pick a variable we defined
            if not memory: # Safety check
                input_ids[i] = torch.zeros(self.seq_len)
                continue
                
            target_var_token = random.choice(list(memory.keys()))
            stored_int = memory[target_var_token]
            
            # Operation: +1 (Keep simple for now to test pure mechanism)
            op_token = 2 # '+'
            operand = 1  # We always add 1 in this specific task version
            
            # Calculate Answer
            answer_int = stored_int + operand
            answer_token = VAL_START + answer_int
            
            # Query Sequence: "var + 1 ?" -> [VAR, 2, VAL_1, 3]
            # Let's signify "1" as VAL_START + 1
            one_token = VAL_START + 1
            
            seq.extend([target_var_token, op_token, one_token, 3]) 
            
            # 3. Pad/Crop
            if len(seq) > self.seq_len:
                seq = seq[-(self.seq_len):] # Keep the end (query) if too long
            else:
                # Pad at the start (left padding) or end? 
                # For causal models, right padding is standard if using masks,
                # but simple 0-padding works here.
                seq = seq + [0]*(self.seq_len - len(seq))
            
            input_ids[i] = torch.tensor(seq)
            
            # 4. Target Generation
            # Standard next-token prediction
            t = torch.roll(input_ids[i], -1)
            
            # CRITICAL: We must ensure the token AFTER the '?' (3) is the ANSWER
            try:
                # Find the '?' token
                query_pos = seq.index(3) 
                if query_pos < self.seq_len - 1:
                    t[query_pos] = answer_token
            except ValueError:
                pass
            
            targets[i] = t
            
        return input_ids.to(DEVICE), targets.to(DEVICE)

class MQARDataset(SyntheticDataset):
    """ Baseline Retrieval Task (for comparison) """
    def __init__(self, batch_size, seq_len, vocab_size=VOCAB_SIZE, num_pairs=8):
        super().__init__(batch_size, seq_len)
        self.vocab_size = vocab_size
        self.num_pairs = num_pairs

    def generate_batch(self):
        input_ids = torch.zeros((self.batch_size, self.seq_len), dtype=torch.long)
        targets = torch.zeros((self.batch_size, self.seq_len), dtype=torch.long)
        
        for i in range(self.batch_size):
            keys = torch.randint(0, self.vocab_size, (self.num_pairs,))
            values = torch.randint(0, self.vocab_size, (self.num_pairs,))
            seq = torch.randint(0, self.vocab_size, (self.seq_len,))
            
            for j in range(self.num_pairs):
                seq[2*j] = keys[j]
                seq[2*j+1] = values[j]
            
            query_idx = random.randint(0, self.num_pairs - 1)
            seq[-2] = keys[query_idx]
            seq[-1] = values[query_idx]
            
            input_ids[i] = seq
            targets[i] = torch.roll(seq, -1)
            
        return input_ids.to(DEVICE), targets.to(DEVICE)

# ==========================================
# 2. MODEL ARCHITECTURES
# ==========================================

class BaselineTransformer(nn.Module):
    def __init__(self):
        super().__init__()
        self.embed = nn.Embedding(VOCAB_SIZE, D_MODEL)
        self.pos_embed = nn.Embedding(512, D_MODEL)
        self.decoder_layer = nn.TransformerDecoderLayer(d_model=D_MODEL, nhead=4, batch_first=True)
        self.transformer = nn.TransformerDecoder(self.decoder_layer, num_layers=N_LAYERS)
        self.head = nn.Linear(D_MODEL, VOCAB_SIZE)

    def forward(self, x):
        b, t = x.shape
        pos = torch.arange(t, device=x.device).unsqueeze(0)
        emb = self.embed(x) + self.pos_embed(pos)
        mask = nn.Transformer.generate_square_subsequent_mask(t).to(x.device)
        out = self.transformer(emb, memory=emb, tgt_mask=mask)
        return self.head(out)

class BaselineLSTM(nn.Module):
    def __init__(self):
        super().__init__()
        self.embed = nn.Embedding(VOCAB_SIZE, D_MODEL)
        self.lstm = nn.LSTM(input_size=D_MODEL, hidden_size=D_MODEL, num_layers=N_LAYERS, batch_first=True)
        self.head = nn.Linear(D_MODEL, VOCAB_SIZE)

    def forward(self, x):
        emb = self.embed(x)
        out, _ = self.lstm(emb)
        return self.head(out)

# --- PURE PYTORCH MAMBA ---
class MambaBlock(nn.Module):
    def __init__(self, d_model, d_state=16, d_conv=4, expand=2):
        super().__init__()
        self.d_inner = int(expand * d_model)
        self.dt_rank = math.ceil(d_model / 16)
        
        self.in_proj = nn.Linear(d_model, self.d_inner * 2, bias=False)
        self.conv1d = nn.Conv1d(
            in_channels=self.d_inner, out_channels=self.d_inner,
            bias=True, kernel_size=d_conv, groups=self.d_inner, padding=d_conv - 1,
        )
        self.x_proj = nn.Linear(self.d_inner, self.dt_rank + d_state * 2, bias=False)
        self.dt_proj = nn.Linear(self.dt_rank, self.d_inner, bias=True)
        
        A = torch.arange(1, d_state + 1, dtype=torch.float32).repeat(self.d_inner, 1)
        self.A_log = nn.Parameter(torch.log(A))
        self.D = nn.Parameter(torch.ones(self.d_inner))
        self.out_proj = nn.Linear(self.d_inner, d_model, bias=False)
        self.act = nn.SiLU()

    def parallel_scan(self, u, delta, A, B, C, D):
        batch, seq_len, d_inner = u.shape
        d_state = A.shape[1]
        deltaA = torch.exp(torch.einsum('b l d, d n -> b l d n', delta, A))
        deltaB_u = torch.einsum('b l d, b l n, b l d -> b l d n', delta, B, u)
        
        x = torch.zeros((batch, d_inner, d_state), device=u.device)
        ys = []
        for t in range(seq_len):
            x = deltaA[:, t] * x + deltaB_u[:, t]
            y = torch.einsum('b d n, b l n -> b d', x, C[:, t].unsqueeze(1))
            ys.append(y)
        y = torch.stack(ys, dim=1)
        y = y + u * D
        return y

    def forward(self, x):
        batch, seq_len, _ = x.shape
        xz = self.in_proj(x)
        x_in, z = xz.chunk(2, dim=-1)
        x_in = x_in.transpose(1, 2)
        x_conv = self.conv1d(x_in)[:, :, :seq_len]
        x_conv = self.act(x_conv).transpose(1, 2)
        x_dbl = self.x_proj(x_conv)
        delta, B, C = torch.split(x_dbl, [self.dt_rank, 16, 16], dim=-1)
        delta = F.softplus(self.dt_proj(delta))
        A = -torch.exp(self.A_log)
        y = self.parallel_scan(x_conv, delta, A, B, C, self.D)
        return self.out_proj(y * self.act(z))

class MambaWrapper(nn.Module):
    def __init__(self):
        super().__init__()
        self.embed = nn.Embedding(VOCAB_SIZE, D_MODEL)
        self.layers = nn.ModuleList([
            MambaBlock(d_model=D_MODEL, d_state=16, expand=2)
            for _ in range(N_LAYERS) 
        ])
        self.norm_f = nn.LayerNorm(D_MODEL)
        self.head = nn.Linear(D_MODEL, VOCAB_SIZE)

    def forward(self, x):
        x = self.embed(x)
        for layer in self.layers:
            x = x + layer(x)
        x = self.norm_f(x)
        return self.head(x)

# ==========================================
# 3. RUNNER
# ==========================================

def get_accuracy(logits, targets):
    preds = torch.argmax(logits, dim=-1)
    correct = (preds == targets).float()
    return correct.mean().item()

def run_experiment(model_type, task_name):
    print(f"\n--- Running {task_name} with {model_type} ---")
    
    # Dataset Selection
    if task_name == "MQAR":
        dataset = MQARDataset(BATCH_SIZE, seq_len=64) 
    elif task_name == "Hybrid":
        dataset = HybridArithmeticDataset(BATCH_SIZE, seq_len=64)
    else:
        raise ValueError("Unknown Task")
    
    loader = DataLoader(dataset, batch_size=None)
    
    # Model Selection
    if model_type == "Transformer":
        model = BaselineTransformer().to(DEVICE)
    elif model_type == "LSTM":
        model = BaselineLSTM().to(DEVICE)
    elif model_type == "Mamba":
        model = MambaWrapper().to(DEVICE)
    
    optimizer = torch.optim.AdamW(model.parameters(), lr=LR)
    criterion = nn.CrossEntropyLoss()
    
    model.train()
    iterator = iter(loader)
    
    for step in tqdm(range(TRAIN_STEPS)):
        inputs, targets = next(iterator)
        optimizer.zero_grad()
        logits = model(inputs)
        loss = criterion(logits.view(-1, VOCAB_SIZE), targets.view(-1))
        
        # FIX: Gradient Clipping for Mamba Stability
        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
        optimizer.step()
        
        if step % 500 == 0:
            print(f"Step {step}: Loss {loss.item():.4f}")

    # Final Eval (Last Token)
    model.eval()
    eval_acc = 0
    with torch.no_grad():
        for _ in range(50):
            inputs, targets = next(iterator)
            logits = model(inputs)
            
            # For Hybrid, we want to check the token AFTER the '?'
            # In our dataset generation, we ensure '?' is at a specific spot or end
            # We will just check the last predicted token against the last target token
            last_pred = torch.argmax(logits[:, -1, :], dim=-1)
            last_target = targets[:, -1]
            
            eval_acc += (last_pred == last_target).float().mean().item()
            
    print(f"Final Test Accuracy: {eval_acc / 50:.4f}")

# ==========================================
# 4. EXECUTION
# ==========================================

if __name__ == "__main__":
    # 1. Run Baseline Retrieval (MQAR)
    # run_experiment("Transformer", "MQAR")
    # run_experiment("Mamba", "MQAR")
    
    # 2. Run Novel Hybrid Task
    run_experiment("Transformer", "Hybrid")
    run_experiment("Mamba", "Hybrid")
    
    # 3. LSTM (Negative Control)
    run_experiment("LSTM", "Hybrid")


--- Running Hybrid with Transformer ---


  0%|          | 6/3000 [00:00<02:53, 17.22it/s]

Step 0: Loss 5.7731


 17%|█▋        | 506/3000 [00:12<01:05, 38.31it/s]

Step 500: Loss 0.0315


 34%|███▎      | 1005/3000 [00:26<00:53, 37.37it/s]

Step 1000: Loss 0.0314


 50%|█████     | 1505/3000 [00:39<00:38, 39.25it/s]

Step 1500: Loss 0.0275


 67%|██████▋   | 2005/3000 [00:51<00:22, 43.69it/s]

Step 2000: Loss 0.0288


 84%|████████▎ | 2510/3000 [01:02<00:10, 44.89it/s]

Step 2500: Loss 0.0275


100%|██████████| 3000/3000 [01:12<00:00, 41.57it/s]


Final Test Accuracy: 1.0000

--- Running Hybrid with Mamba ---


  0%|          | 1/3000 [00:00<15:11,  3.29it/s]

Step 0: Loss 6.3830


 17%|█▋        | 501/3000 [02:18<12:43,  3.27it/s]

Step 500: Loss 0.5328


 33%|███▎      | 1001/3000 [04:39<08:03,  4.13it/s]

Step 1000: Loss 0.5231


 50%|█████     | 1501/3000 [06:59<07:52,  3.17it/s]

Step 1500: Loss 0.5178


 67%|██████▋   | 2001/3000 [09:29<04:46,  3.48it/s]

Step 2000: Loss 0.5096


 83%|████████▎ | 2501/3000 [11:34<02:01,  4.11it/s]

Step 2500: Loss 0.5096


100%|██████████| 3000/3000 [13:45<00:00,  3.63it/s]


Final Test Accuracy: 0.1234

--- Running Hybrid with LSTM ---


  0%|          | 14/3000 [00:00<00:52, 56.55it/s]

Step 0: Loss 6.2252


 17%|█▋        | 515/3000 [00:04<00:20, 119.27it/s]

Step 500: Loss 1.2391


 34%|███▍      | 1020/3000 [00:08<00:16, 117.35it/s]

Step 1000: Loss 1.1745


 51%|█████     | 1523/3000 [00:12<00:12, 117.91it/s]

Step 1500: Loss 1.1312


 67%|██████▋   | 2014/3000 [00:17<00:08, 119.06it/s]

Step 2000: Loss 0.7013


 84%|████████▍ | 2524/3000 [00:21<00:04, 118.84it/s]

Step 2500: Loss 0.6205


100%|██████████| 3000/3000 [00:25<00:00, 118.26it/s]


Final Test Accuracy: 0.1244
