# The Validation Experiment Design

## 1.3 Implementing the confusion detection experiment

In [2]:
def run_confusion_validation():
    """
    the experiment that either proves
    or disapproves confusion detection:

    - phase 1: train on addition -> low confusion
    within architecture capability
    - phase 2a: introduce multiplication -> high confusion
    proving necessity for new architecture
    - phase 2b: introduce subtraction -> low confusion
    proving same architecture works

    if we can't distinguish 2a from 2b, 
    the whole project sucks
    """

    model = InstrumentedTransformer(ModelConfig())
    detector = ConfusionDetector()
    optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)

    results = {
        'phase1_addition': [],
        'phase2a_multiplication': [],
        'phase2b_subtraction': []
    }

    # phase 1: addition training to establish baseline
    print("Phase 1: Training on addition...")
    for step in range(500):
        batch = generate_arithmetic_batch(operation='add', batch_size=32)

        # forward w/ instrumentation
        logits, metadata = model(batch['input'], return_metadata=True)
        loss = nn.CrossEntropyLoss()(logits.view(-1, logits.size(-1)),
                                     batch['target'].view(-1))

        # compute confusion metrics
        metrics = detector.compute_all_metrics(model, batch, metadata)
        results['phase1_addition'].append(metrics)

        # train!
        optimizer.zero_grad()
        loss_backward()
        optimizer.step()

        if step % 50 == 0:
            print(f"Step {step}, Loss: {loss.item():.4f}, "
                  f"Entropy: {metrics['attention_entropy']:.4f}")

    print("Addition learned. Confusion metrics baseline established.\n")

    # phase 2a: introduce multiplication
    # which should spike confusion
    print("Phase 2a: introducing multiplication...")
    for step in range(500, 700):
        batch = generate_arithmetic_batch(operation='multiply', batch_size=32)
        
        metrics = detector.compute_all_metrics(model, batch, metadata)
        results['phase2a_multiplication'].append(metrics)
        
        # don't train yet, just measure confusion
        if step % 20 == 0:
            print(f"Step {step}, Multiplication Confusion: "
                  f"Entropy: {metrics['attention_entropy']:.4f}, "
                  f"Loss: {metrics['loss']:.4f}")

    # phase 2b: introduce subtraction (should NOT spike confusion much)
    print("\nPhase 2b: Introducing SUBTRACTION...")
    model_copy = copy.deepcopy(model)  # start from same point
    for step in range(500, 700):
        batch = generate_arithmetic_batch(operation='subtract', batch_size=32)
        
        metrics = detector.compute_all_metrics(model_copy, batch, metadata)
        results['phase2b_subtraction'].append(metrics)
        
        if step % 20 == 0:
            print(f"Step {step}, Subtraction Confusion: "
                  f"Entropy: {metrics['attention_entropy']:.4f}, "
                  f"Loss: {metrics['loss']:.4f}")
    
    return results

In [20]:
import torch
import importlib
import model_classes
importlib.reload(model_classes)
from model_classes import InstrumentedTransformer, ModelConfig, ConfusionDetector

results = run_confusion_validation()

Phase 1: Training on addition...


NameError: name 'generate_arithmetic_batch' is not defined