# Interactive Storyworld Training Notebook

This notebook provides an interactive interface for training and analyzing the SAE + RL system.

## Setup

In [None]:
import sys
import numpy as np
import torch
import matplotlib.pyplot as plt
import json
from pathlib import Path

# Import our modules
from sae_narrative_features import (
    SparseAutoencoder,
    StoryWorldStateExtractor,
    train_sae_on_rollouts,
    FeatureAffordanceAnalyzer
)

from rl_training_infrastructure import (
    StoryWorldVerifiers,
    StoryWorldRLTrainer,
    RLConfig
)

from integrated_training_pipeline import (
    IterativeTrainingPipeline,
    FeatureAwareVerifiers
)

print("Imports successful!")
print(f"PyTorch version: {torch.__version__}")
print(f"Device: {'GPU' if torch.cuda.is_available() else 'CPU'}")

## Part 1: Generate Synthetic Data

In [None]:
# Load synthetic generation functions from quickstart_demo
from quickstart_demo import generate_synthetic_storyworld, generate_synthetic_rollout

# Generate storyworlds
n_storyworlds = 20
storyworlds = [generate_synthetic_storyworld() for _ in range(n_storyworlds)]

print(f"Generated {len(storyworlds)} storyworlds")
print(f"\nExample storyworld structure:")
print(json.dumps(storyworlds[0], indent=2)[:500] + "...")

In [None]:
# Generate rollouts
rollouts = []
for sw in storyworlds:
    for _ in range(5):  # 5 rollouts per storyworld
        rollout = generate_synthetic_rollout(sw)
        if len(rollout) > 0:
            rollouts.append(rollout)

print(f"Generated {len(rollouts)} rollouts")
print(f"Average rollout length: {np.mean([len(r) for r in rollouts]):.1f} steps")

# Visualize rollout length distribution
plt.figure(figsize=(10, 4))
plt.hist([len(r) for r in rollouts], bins=20, edgecolor='black')
plt.xlabel('Rollout Length (steps)')
plt.ylabel('Count')
plt.title('Distribution of Rollout Lengths')
plt.show()

## Part 2: Train Sparse Autoencoder

In [None]:
# Train SAE
sae, dataset, history = train_sae_on_rollouts(
    rollouts,
    latent_dim=128,
    sparsity_coef=0.05,
    n_epochs=30,
    batch_size=16
)

print("\nTraining complete!")

In [None]:
# Visualize training curves
fig, axes = plt.subplots(2, 2, figsize=(12, 8))

axes[0, 0].plot(history['total_loss'])
axes[0, 0].set_title('Total Loss')
axes[0, 0].set_xlabel('Epoch')
axes[0, 0].grid(True)

axes[0, 1].plot(history['mse_loss'])
axes[0, 1].set_title('Reconstruction Loss (MSE)')
axes[0, 1].set_xlabel('Epoch')
axes[0, 1].grid(True)

axes[1, 0].plot(history['sparsity_loss'])
axes[1, 0].set_title('Sparsity Loss (L1)')
axes[1, 0].set_xlabel('Epoch')
axes[1, 0].grid(True)

axes[1, 1].plot(history['l0_norm'])
axes[1, 1].set_title('Feature Sparsity (L0)')
axes[1, 1].set_xlabel('Epoch')
axes[1, 1].axhline(y=20, color='r', linestyle='--', label='Target (~20)')
axes[1, 1].legend()
axes[1, 1].grid(True)

plt.tight_layout()
plt.show()

print(f"Final MSE: {history['mse_loss'][-1]:.4f}")
print(f"Final L0: {history['l0_norm'][-1]:.2f} active features")

## Part 3: Analyze Features

In [None]:
# Analyze feature-affordance coupling
analyzer = FeatureAffordanceAnalyzer(sae, dataset)

# Compute correlations
correlations = analyzer.compute_feature_affordance_correlation()
top_features = analyzer.identify_top_features(n_top=20, method='correlation')

print("Top 10 Features by Affordance Correlation:")
for idx, score in top_features[:10]:
    print(f"  Feature {idx:3d}: {score:.3f}")

# Visualize correlation distribution
plt.figure(figsize=(10, 5))
plt.hist(np.abs(correlations), bins=30, edgecolor='black')
plt.xlabel('|Correlation| with Affordance Count')
plt.ylabel('Number of Features')
plt.title('Distribution of Feature-Affordance Correlations')
plt.axvline(x=0.3, color='r', linestyle='--', label='Strong correlation (>0.3)')
plt.legend()
plt.grid(True)
plt.show()

print(f"\nFeatures with |correlation| > 0.3: {np.sum(np.abs(correlations) > 0.3)}")

In [None]:
# Examine specific states and their features
sample_idx = 10
sample = dataset[sample_idx]

state_tensor = sample['state'].unsqueeze(0)
with torch.no_grad():
    features = sae.encode(state_tensor).squeeze().numpy()
    reconstruction = sae.decode(sae.encode(state_tensor)).squeeze().numpy()

print(f"Sample State {sample_idx}:")
print(f"  Affordance cardinality: {sample['cardinality'].item():.0f}")
print(f"  Active features (>0.01): {np.sum(features > 0.01)}")
print(f"  Top 5 feature activations:")
top_5_idx = np.argsort(features)[-5:][::-1]
for idx in top_5_idx:
    print(f"    Feature {idx}: {features[idx]:.3f}")

# Plot original vs reconstruction
plt.figure(figsize=(12, 4))
x = np.arange(len(sample['state']))
plt.plot(x, sample['state'].numpy(), label='Original', alpha=0.7)
plt.plot(x, reconstruction, label='Reconstruction', alpha=0.7)
plt.xlabel('State Dimension')
plt.ylabel('Value')
plt.title('State Reconstruction Quality')
plt.legend()
plt.grid(True)
plt.show()

## Part 4: Test RL Verifiers

In [None]:
# Initialize verifiers
verifiers = StoryWorldVerifiers(
    min_encounters=3,
    min_characters=2,
    min_endings=2
)

# Test on synthetic storyworlds
test_rewards = []
for sw in storyworlds[:10]:
    text = json.dumps(sw)
    rewards = verifiers.compute_total_reward(text)
    test_rewards.append(rewards)

# Visualize reward components
component_names = ['valid_json', 'schema', 'structure', 'effects', 'secrets', 'endings']
component_means = {k: np.mean([r[k] for r in test_rewards]) for k in component_names}

plt.figure(figsize=(10, 6))
plt.bar(component_names, component_means.values())
plt.ylabel('Average Score')
plt.title('Verifier Component Scores on Synthetic Data')
plt.xticks(rotation=45)
plt.ylim(0, 1)
plt.grid(True, axis='y')
plt.tight_layout()
plt.show()

print("\nAverage Scores:")
for k, v in component_means.items():
    print(f"  {k}: {v:.3f}")
print(f"  Total: {np.mean([r['total'] for r in test_rewards]):.3f}")

## Part 5: Mini RL Training (Optional)

In [None]:
# Warning: This cell requires GPU and will take several minutes
# Uncomment to run

# config = RLConfig(
#     model_name="gpt2",
#     max_length=512,
#     batch_size=2,
#     n_epochs=2,
#     n_samples_per_epoch=10,
#     learning_rate=5e-6
# )

# trainer = StoryWorldRLTrainer(config)
# trainer.train()

# # Test generation
# sample = trainer.generate_storyworld()
# print(sample[:300])

## Part 6: Feature-Aware Rewards

In [None]:
# Test feature-aware verifiers
extractor = StoryWorldStateExtractor()
feature_verifiers = FeatureAwareVerifiers(
    sae=sae,
    extractor=extractor,
    min_encounters=3,
    min_characters=2
)

# Compare base vs feature-aware rewards
base_rewards = []
feature_rewards = []

for sw in storyworlds[:5]:
    text = json.dumps(sw)
    
    base_r = verifiers.compute_total_reward(text)
    feat_r = feature_verifiers.compute_total_reward(text)
    
    base_rewards.append(base_r['total'])
    feature_rewards.append(feat_r['total'])
    
    print(f"Storyworld {len(base_rewards)}:")
    print(f"  Base reward: {base_r['total']:.3f}")
    print(f"  Feature-aware: {feat_r['total']:.3f}")
    print(f"  Feature quality: {feat_r['feature_quality']:.3f}")
    print()

# Visualize comparison
x = np.arange(len(base_rewards))
width = 0.35

plt.figure(figsize=(10, 6))
plt.bar(x - width/2, base_rewards, width, label='Base Reward')
plt.bar(x + width/2, feature_rewards, width, label='Feature-Aware Reward')
plt.xlabel('Storyworld')
plt.ylabel('Reward')
plt.title('Base vs Feature-Aware Rewards')
plt.legend()
plt.grid(True, axis='y')
plt.show()

## Part 7: Save Models and Results

In [None]:
# Save SAE
save_dir = Path("./notebook_outputs")
save_dir.mkdir(exist_ok=True)

torch.save(sae.state_dict(), save_dir / "sae_notebook.pt")

# Save metrics
import pickle
with open(save_dir / "training_history.pkl", 'wb') as f:
    pickle.dump(history, f)

with open(save_dir / "feature_correlations.pkl", 'wb') as f:
    pickle.dump(correlations, f)

print(f"Models and results saved to {save_dir}")

## Summary

This notebook demonstrated:

1. ✅ Synthetic storyworld generation
2. ✅ Rollout collection
3. ✅ SAE training on narrative states
4. ✅ Feature-affordance analysis
5. ✅ Multi-component verifier testing
6. ✅ Feature-aware reward computation

Next steps:
- Scale to larger datasets
- Run full RL training
- Integrate with QFT-MCP corpus
- Deploy production pipeline