# The Validation Experiment Design

## 1.3 Implementing the confusion detection experiment

In [1]:
import torch
import torch.nn as nn
import numpy as np
import copy
from model_classes import InstrumentedTransformer, ModelConfig, ConfusionDetector
from data_generation import generate_arithmetic_batch, VOCAB

def run_confusion_validation():
    """
    The experiment that proves (or disproves) confusion detection
    """
    
    # Initialize
    config = ModelConfig()
    config.vocab_size = VOCAB.vocab_size  # Use actual vocab size
    model = InstrumentedTransformer(config)
    detector = ConfusionDetector()
    optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)
    
    results = {
        'phase1_addition': [],
        'phase2a_multiplication': [],
        'phase2b_subtraction': []
    }
    
    # Phase 1: Addition Training
    print("="*60)
    print("Phase 1: Training on Addition")
    print("="*60)
    
    for step in range(500):
        batch = generate_arithmetic_batch(operation='add', batch_size=32)
        
        # Forward pass
        logits, metadata = model(batch['input'], return_metadata=True)
        loss = nn.CrossEntropyLoss(ignore_index=VOCAB.pad_id)(
            logits.view(-1, logits.size(-1)), 
            batch['target'].view(-1)
        )
        
        # Compute confusion metrics
        metrics = detector.compute_all_metrics(model, batch, metadata, loss)
        results['phase1_addition'].append(metrics)
        
        # Train
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        
        if step % 100 == 0:
            print(f"Step {step:3d} | Loss: {loss.item():.4f} | "
                  f"Attn Entropy: {metrics['attention_entropy']:.4f}")
    
    print(f"\n✓ Addition baseline established")
    print(f"  Final loss: {results['phase1_addition'][-1]['loss']:.4f}")
    
    # Phase 2a: Multiplication (NO training, just measure confusion)
    print("\n" + "="*60)
    print("Phase 2a: Testing on MULTIPLICATION (Architectural Challenge)")
    print("="*60)
    
    model.eval()  # Important: eval mode for testing
    with torch.no_grad():
        for step in range(200):
            batch = generate_arithmetic_batch(operation='multiply', batch_size=32)
            
            logits, metadata = model(batch['input'], return_metadata=True)
            loss = nn.CrossEntropyLoss(ignore_index=VOCAB.pad_id)(
                logits.view(-1, logits.size(-1)), 
                batch['target'].view(-1)
            )
            
            metrics = detector.compute_all_metrics(model, batch, metadata, loss)
            results['phase2a_multiplication'].append(metrics)
            
            if step % 50 == 0:
                print(f"Step {step:3d} | Loss: {loss.item():.4f} | "
                      f"Attn Entropy: {metrics['attention_entropy']:.4f}")
    
    # Phase 2b: Subtraction (NO training, just measure confusion)
    print("\n" + "="*60)
    print("Phase 2b: Testing on SUBTRACTION (Parametric Challenge)")
    print("="*60)
    
    with torch.no_grad():
        for step in range(200):
            batch = generate_arithmetic_batch(operation='subtract', batch_size=32)
            
            logits, metadata = model(batch['input'], return_metadata=True)
            loss = nn.CrossEntropyLoss(ignore_index=VOCAB.pad_id)(
                logits.view(-1, logits.size(-1)), 
                batch['target'].view(-1)
            )
            
            metrics = detector.compute_all_metrics(model, batch, metadata, loss)
            results['phase2b_subtraction'].append(metrics)
            
            if step % 50 == 0:
                print(f"Step {step:3d} | Loss: {loss.item():.4f} | "
                      f"Attn Entropy: {metrics['attention_entropy']:.4f}")
    
    print("\n" + "="*60)
    print("Experiment Complete!")
    print("="*60)
    
    return results

In [6]:
import torch
import importlib
import model_classes
importlib.reload(model_classes)
from model_classes import InstrumentedTransformer, ModelConfig, ConfusionDetector

results = run_confusion_validation()

Phase 1: Training on Addition
Step   0 | Loss: 3.1543 | Attn Entropy: 2.1979
Step 100 | Loss: 1.2235 | Attn Entropy: 1.7412
Step 200 | Loss: 1.0045 | Attn Entropy: 1.6309
Step 300 | Loss: 0.8446 | Attn Entropy: 1.5665
Step 400 | Loss: 0.8942 | Attn Entropy: 1.5330

✓ Addition baseline established
  Final loss: 0.7900

Phase 2a: Testing on MULTIPLICATION (Architectural Challenge)
Step   0 | Loss: 2.9072 | Attn Entropy: 1.4525
Step  50 | Loss: 2.7322 | Attn Entropy: 1.4705
Step 100 | Loss: 3.2209 | Attn Entropy: 1.4417
Step 150 | Loss: 3.1024 | Attn Entropy: 1.4484

Phase 2b: Testing on SUBTRACTION (Parametric Challenge)
Step   0 | Loss: 4.0218 | Attn Entropy: 1.4939
Step  50 | Loss: 3.8147 | Attn Entropy: 1.4842
Step 100 | Loss: 4.8311 | Attn Entropy: 1.5326
Step 150 | Loss: 3.8786 | Attn Entropy: 1.5384

Experiment Complete!
