# Interactive SAE Pipeline Tutorial

This notebook provides an interactive walkthrough of the Sparse Autoencoder (SAE) training pipeline, replicating the functionality of `SAETrainingPipeline.run_complete_pipeline()` from `sae_pipeline.py`.

## Overview

The SAE pipeline consists of several key steps:
1. **Embedding Generation**: Extract embeddings from HelicalmRNA model
2. **Data Preparation**: Create training and validation dataloaders
3. **Model Setup**: Initialize the Sparse Autoencoder
4. **Training**: Train the model with sparsity constraints
5. **Visualization**: Plot training progress and results

Let's explore each step interactively!

## Setup and Imports

In [None]:
import sys
import os
from pathlib import Path

# Add the src directory to the path
notebook_dir = Path.cwd()
src_path = notebook_dir.parent / "src"
sys.path.append(str(src_path))

import numpy as np
import torch
import matplotlib.pyplot as plt
import seaborn as sns
from tqdm.notebook import tqdm
import pandas as pd

# Import SAE components
from sae.pipeline.sae_pipeline import SAETrainingPipeline
from sae.models.sae import SAE
from sae.training.trainer import SAETrainer, TrainingConfig
from sae.losses.losses import SAELoss

# Set up plotting style
plt.style.use('seaborn-v0_8')
sns.set_palette("husl")

print("✅ Imports completed successfully!")
print(f"📁 Working directory: {notebook_dir}")
print(f"📁 Source path: {src_path}")

## Configuration

Set up the pipeline parameters. You can modify these values to experiment with different configurations.

In [None]:
# Pipeline configuration
config = {
    'refseq_file': '../data/vertebrate_mammalian.1.rna.gbff',  # Path to your RefSeq file
    'max_samples': 200,  # Number of samples to process (smaller for faster demo)
    'hidden_dim': 500,   # Number of SAE features to learn
    'epochs': 15,        # Number of training epochs
    'batch_size': 8,     # Batch size for training
    'sparsity_weight': 0.01,  # Weight for sparsity penalty
    'learning_rate': 0.001,   # Learning rate
    'layer_idx': None,   # Layer index to extract embeddings from (None for final layer)
    'layer_name': 'final'  # Layer name for identification
}

print("🔧 Pipeline Configuration:")
for key, value in config.items():
    print(f"   {key}: {value}")

# Check if RefSeq file exists
refseq_path = Path(config['refseq_file'])
if refseq_path.exists():
    print(f"\n✅ RefSeq file found: {refseq_path}")
else:
    print(f"\n⚠️  RefSeq file not found: {refseq_path}")
    print("   Please update the 'refseq_file' path in the config above.")

## Step 1: Initialize Pipeline

Create the SAE training pipeline with the specified configuration.

In [None]:
# Initialize the SAE training pipeline
pipeline = SAETrainingPipeline(
    embedding_dim=None,  # Will be auto-detected from embeddings
    hidden_dim=config['hidden_dim'],
    sparsity_weight=config['sparsity_weight'],
    learning_rate=config['learning_rate'],
    layer_idx=config['layer_idx'],
    layer_name=config['layer_name']
)

print("✅ SAE Training Pipeline initialized!")
print(f"   Hidden dimension: {config['hidden_dim']}")
print(f"   Sparsity weight: {config['sparsity_weight']}")
print(f"   Learning rate: {config['learning_rate']}")
print(f"   Layer: {config['layer_name']}")

## Step 2: Setup Embedding Generator

Initialize the embedding generator to extract embeddings from the HelicalmRNA model.

In [None]:
print("🔧 Setting up embedding generator...")

# Setup embedding generator
pipeline.setup_embedding_generator()

print("✅ Embedding generator setup complete!")
print(f"   Model: {pipeline.embedding_generator.model_name}")
print(f"   Wrapper: {pipeline.embedding_generator.wrapper_name}")

## Step 3: Prepare Training Data

Generate embeddings from the RefSeq file and create training/validation dataloaders.

In [None]:
print("🔧 Preparing training data...")

# Prepare data
train_loader, val_loader = pipeline.prepare_data(
    refseq_file=config['refseq_file'],
    max_samples=config['max_samples'],
    batch_size=config['batch_size'],
    filter_by_type="mRNA",
    use_cds=True,
    dataset_name="Interactive_Demo"
)

print("✅ Data preparation complete!")
print(f"   Embedding dimension: {pipeline.embedding_dim}")
print(f"   Training batches: {len(train_loader)}")
print(f"   Validation batches: {len(val_loader)}")
print(f"   Batch size: {config['batch_size']}")

# Show sample data shape
for batch_idx, (data, target) in enumerate(train_loader):
    print(f"\n📊 Sample batch {batch_idx + 1}:")
    print(f"   Input shape: {data.shape}")
    print(f"   Target shape: {target.shape}")
    print(f"   Data type: {data.dtype}")
    print(f"   Value range: [{data.min():.3f}, {data.max():.3f}]")
    break

## Step 4: Setup SAE Model

Initialize the Sparse Autoencoder model with the specified architecture.

In [None]:
print("🔧 Setting up SAE model...")

# Setup SAE model
pipeline.setup_sae_model()

print("✅ SAE model setup complete!")
print(f"   Input size: {pipeline.embedding_dim}")
print(f"   Hidden size: {pipeline.hidden_dim}")
print(f"   Model parameters: {sum(p.numel() for p in pipeline.sae_model.parameters()):,}")

# Test forward pass with sample data
sample_data, _ = next(iter(train_loader))
with torch.no_grad():
    reconstructed, encoded = pipeline.sae_model(sample_data)
    
print(f"\n🧪 Forward pass test:")
print(f"   Input shape: {sample_data.shape}")
print(f"   Encoded shape: {encoded.shape}")
print(f"   Reconstructed shape: {reconstructed.shape}")
print(f"   Reconstruction MSE: {torch.nn.functional.mse_loss(reconstructed, sample_data):.6f}")

# Calculate sparsity in encoded representation
sparsity = (encoded == 0).float().mean()
print(f"   Initial sparsity: {sparsity:.3f} ({sparsity*100:.1f}%)")

## Step 5: Setup Trainer

Configure the training setup with optimizer, loss function, and training configuration.

In [None]:
print("🔧 Setting up trainer...")

# Setup trainer
pipeline.setup_trainer(train_loader, val_loader)

print("✅ Trainer setup complete!")
print(f"   Optimizer: {type(pipeline.trainer.optimizer).__name__}")
print(f"   Learning rate: {pipeline.trainer.optimizer.param_groups[0]['lr']}")
print(f"   Loss function: {type(pipeline.trainer.loss_fn).__name__}")
print(f"   Sparsity weight: {pipeline.trainer.loss_fn.sparsity_weight}")
print(f"   Sparsity target: {pipeline.trainer.loss_fn.sparsity_target}")

# Test loss computation
sample_data, _ = next(iter(train_loader))
with torch.no_grad():
    loss, loss_dict = pipeline.trainer.loss_fn(
        pipeline.sae_model(sample_data)[0],  # reconstructed
        sample_data,  # target
        pipeline.sae_model(sample_data)[1]   # encoded
    )
    
print(f"\n🧪 Loss computation test:")
print(f"   Total loss: {loss:.6f}")
for key, value in loss_dict.items():
    print(f"   {key}: {value:.6f}")

## Step 6: Training

Train the SAE model and monitor the training progress with detailed metrics.

In [None]:
print(f"🚀 Starting training for {config['epochs']} epochs...")
print("=" * 60)

# Train the model
history = pipeline.train(epochs=config['epochs'])

print("\n✅ Training completed!")
print(f"   Total epochs: {len(history)}")
print(f"   Final training loss: {history[-1]['total_loss']:.6f}")
print(f"   Final validation loss: {history[-1].get('val_total_loss', 'N/A')}")
if 'val_l0_sparsity' in history[-1]:
    print(f"   Final L0 sparsity: {history[-1]['val_l0_sparsity']:.1f}")

## Step 7: Training Visualization

Create comprehensive plots to visualize the training progress and results.

In [None]:
# Convert history to DataFrame for easier plotting
history_df = pd.DataFrame(history)
history_df['epoch'] = range(1, len(history_df) + 1)

print("📊 Training History Summary:")
print(history_df[['epoch', 'total_loss', 'reconstruction_loss', 'sparsity_loss']].tail())

# Create comprehensive training plots
fig, axes = plt.subplots(2, 3, figsize=(18, 12))
fig.suptitle(f'SAE Training Progress - {config["hidden_dim"]} Features, {config["epochs"]} Epochs', fontsize=16)

# 1. Total Loss
ax1 = axes[0, 0]
ax1.plot(history_df['epoch'], history_df['total_loss'], 'b-', linewidth=2, label='Training')
if 'val_total_loss' in history_df.columns:
    ax1.plot(history_df['epoch'], history_df['val_total_loss'], 'r--', linewidth=2, label='Validation')
ax1.set_xlabel('Epoch')
ax1.set_ylabel('Total Loss')
ax1.set_title('Total Loss')
ax1.legend()
ax1.grid(True, alpha=0.3)

# 2. Reconstruction Loss
ax2 = axes[0, 1]
ax2.plot(history_df['epoch'], history_df['reconstruction_loss'], 'g-', linewidth=2, label='Training')
if 'val_reconstruction_loss' in history_df.columns:
    ax2.plot(history_df['epoch'], history_df['val_reconstruction_loss'], 'm--', linewidth=2, label='Validation')
ax2.set_xlabel('Epoch')
ax2.set_ylabel('Reconstruction Loss')
ax2.set_title('Reconstruction Loss')
ax2.legend()
ax2.grid(True, alpha=0.3)

# 3. Sparsity Loss
ax3 = axes[0, 2]
ax3.plot(history_df['epoch'], history_df['sparsity_loss'], 'orange', linewidth=2, label='Training')
if 'val_sparsity_loss' in history_df.columns:
    ax3.plot(history_df['epoch'], history_df['val_sparsity_loss'], 'brown', linewidth=2, label='Validation')
ax3.set_xlabel('Epoch')
ax3.set_ylabel('Sparsity Loss')
ax3.set_title('Sparsity Loss')
ax3.legend()
ax3.grid(True, alpha=0.3)

# 4. L0 Sparsity (if available)
ax4 = axes[1, 0]
if 'val_l0_sparsity' in history_df.columns:
    ax4.plot(history_df['epoch'], history_df['val_l0_sparsity'], 'purple', linewidth=2, marker='o')
    ax4.set_xlabel('Epoch')
    ax4.set_ylabel('L0 Sparsity')
    ax4.set_title('L0 Sparsity (Non-zero features per token)')
    ax4.grid(True, alpha=0.3)
else:
    ax4.text(0.5, 0.5, 'L0 Sparsity
Not Available', ha='center', va='center', transform=ax4.transAxes)
    ax4.set_title('L0 Sparsity')

# 5. Loss Components Comparison
ax5 = axes[1, 1]
ax5.plot(history_df['epoch'], history_df['reconstruction_loss'], 'g-', linewidth=2, label='Reconstruction')
ax5.plot(history_df['epoch'], history_df['sparsity_loss'], 'orange', linewidth=2, label='Sparsity')
ax5.set_xlabel('Epoch')
ax5.set_ylabel('Loss')
ax5.set_title('Loss Components')
ax5.legend()
ax5.grid(True, alpha=0.3)

# 6. Final Loss Comparison
ax6 = axes[1, 2]
final_epoch = history_df.iloc[-1]
loss_types = ['reconstruction_loss', 'sparsity_loss']
loss_values = [final_epoch['reconstruction_loss'], final_epoch['sparsity_loss']]
colors = ['green', 'orange']

bars = ax6.bar(loss_types, loss_values, color=colors, alpha=0.7)
ax6.set_ylabel('Loss Value')
ax6.set_title('Final Loss Breakdown')
ax6.grid(True, alpha=0.3)

# Add value labels on bars
for bar, value in zip(bars, loss_values):
    ax6.text(bar.get_x() + bar.get_width()/2, bar.get_height() + 0.01, 
             f'{value:.4f}', ha='center', va='bottom')

plt.tight_layout()
plt.show()

## Step 8: Model Analysis

Analyze the trained model to understand the learned representations and sparsity patterns.

In [None]:
# Analyze the trained model
print("🔍 Model Analysis")
print("=" * 40)

# Test the model on validation data
pipeline.sae_model.eval()
val_reconstruction_losses = []
val_sparsity_values = []
val_l0_values = []

with torch.no_grad():
    for data, target in val_loader:
        reconstructed, encoded = pipeline.sae_model(data)
        
        # Calculate reconstruction loss
        recon_loss = torch.nn.functional.mse_loss(reconstructed, target)
        val_reconstruction_losses.append(recon_loss.item())
        
        # Calculate sparsity
        sparsity = (encoded == 0).float().mean()
        val_sparsity_values.append(sparsity.item())
        
        # Calculate L0 sparsity (non-zero features per token)
        if len(encoded.shape) == 3:
            # (batch_size, seq_len, hidden_dim)
            non_zero_per_token = (encoded != 0).sum(dim=2).float()  # (batch_size, seq_len)
            l0_val = non_zero_per_token.mean()
        else:
            # (batch_size, hidden_dim)
            l0_val = (encoded != 0).sum(dim=1).float().mean()
        val_l0_values.append(l0_val.item())

print(f"📊 Validation Results:")
print(f"   Average reconstruction loss: {np.mean(val_reconstruction_losses):.6f}")
print(f"   Average sparsity: {np.mean(val_sparsity_values):.3f} ({np.mean(val_sparsity_values)*100:.1f}%)")
print(f"   Average L0 sparsity: {np.mean(val_l0_values):.1f} features per token")

# Visualize learned features
fig, axes = plt.subplots(2, 2, figsize=(15, 12))
fig.suptitle('Learned SAE Features Analysis', fontsize=16)

# 1. Feature activation distribution
ax1 = axes[0, 0]
sample_data, _ = next(iter(val_loader))
with torch.no_grad():
    _, encoded = pipeline.sae_model(sample_data)
    encoded_flat = encoded.flatten().cpu().numpy()
    
ax1.hist(encoded_flat, bins=50, alpha=0.7, color='blue', edgecolor='black')
ax1.set_xlabel('Feature Activation Value')
ax1.set_ylabel('Frequency')
ax1.set_title('Feature Activation Distribution')
ax1.grid(True, alpha=0.3)
ax1.axvline(x=0, color='red', linestyle='--', alpha=0.7, label='Zero threshold')
ax1.legend()

# 2. Sparsity per sample
ax2 = axes[0, 1]
with torch.no_grad():
    _, encoded = pipeline.sae_model(sample_data)
    if len(encoded.shape) == 3:
        # Calculate sparsity per sample
        sparsity_per_sample = (encoded == 0).float().mean(dim=(1, 2)).cpu().numpy()
    else:
        sparsity_per_sample = (encoded == 0).float().mean(dim=1).cpu().numpy()
    
ax2.hist(sparsity_per_sample, bins=20, alpha=0.7, color='green', edgecolor='black')
ax2.set_xlabel('Sparsity per Sample')
ax2.set_ylabel('Frequency')
ax2.set_title('Sparsity Distribution Across Samples')
ax2.grid(True, alpha=0.3)

# 3. Feature usage heatmap (top 20 most active features)
ax3 = axes[1, 0]
with torch.no_grad():
    _, encoded = pipeline.sae_model(sample_data)
    if len(encoded.shape) == 3:
        # Average across sequence dimension
        feature_activity = encoded.mean(dim=1)  # (batch_size, hidden_dim)
    else:
        feature_activity = encoded
    
    # Get top 20 most active features
    feature_usage = feature_activity.abs().mean(dim=0)  # (hidden_dim,)
    top_features = torch.topk(feature_usage, min(20, len(feature_usage))).indices
    
    # Create heatmap
    heatmap_data = feature_activity[:, top_features].cpu().numpy()
    im = ax3.imshow(heatmap_data.T, aspect='auto', cmap='viridis')
    ax3.set_xlabel('Sample Index')
    ax3.set_ylabel('Feature Index')
    ax3.set_title('Top 20 Most Active Features')
    plt.colorbar(im, ax=ax3)

# 4. L0 sparsity over time (if sequence data)
ax4 = axes[1, 1]
with torch.no_grad():
    _, encoded = pipeline.sae_model(sample_data)
    if len(encoded.shape) == 3:
        # Calculate L0 per sequence position
        l0_per_position = (encoded != 0).sum(dim=2).float().mean(dim=0).cpu().numpy()
        positions = range(len(l0_per_position))
        ax4.plot(positions, l0_per_position, 'purple', linewidth=2, marker='o')
        ax4.set_xlabel('Sequence Position')
        ax4.set_ylabel('L0 Sparsity')
        ax4.set_title('L0 Sparsity Across Sequence Positions')
        ax4.grid(True, alpha=0.3)
    else:
        ax4.text(0.5, 0.5, 'Sequence data
not available', ha='center', va='center', transform=ax4.transAxes)
        ax4.set_title('L0 Sparsity Across Sequence')

plt.tight_layout()
plt.show()

## Step 9: Save Model and Results

Save the trained model and training history for future use.

In [None]:
# Save the trained model
save_dir = Path("../outputs/interactive_demo")
save_dir.mkdir(parents=True, exist_ok=True)

model_path = save_dir / "sae_model_interactive.pth"
history_path = save_dir / "training_history_interactive.json"
config_path = save_dir / "model_config_interactive.json"

# Save model
torch.save({
    'model_state_dict': pipeline.sae_model.state_dict(),
    'embedding_dim': pipeline.embedding_dim,
    'hidden_dim': pipeline.hidden_dim,
    'sparsity_weight': config['sparsity_weight'],
    'layer_idx': config['layer_idx'],
    'layer_name': config['layer_name']
}, model_path)

# Save training history
import json
with open(history_path, 'w') as f:
    json.dump(history, f, indent=2)

# Save configuration
with open(config_path, 'w') as f:
    json.dump(config, f, indent=2)

print("💾 Model and results saved!")
print(f"   Model: {model_path}")
print(f"   History: {history_path}")
print(f"   Config: {config_path}")

# Print final summary
print("\n🎉 Interactive SAE Pipeline Complete!")
print("=" * 50)
print(f"📊 Final Results:")
print(f"   Model: {pipeline.embedding_dim} → {pipeline.hidden_dim} features")
print(f"   Training epochs: {len(history)}")
print(f"   Final reconstruction loss: {history[-1]['reconstruction_loss']:.6f}")
print(f"   Final sparsity loss: {history[-1]['sparsity_loss']:.6f}")
if 'val_l0_sparsity' in history[-1]:
    print(f"   Final L0 sparsity: {history[-1]['val_l0_sparsity']:.1f} features per token")
print(f"   Model parameters: {sum(p.numel() for p in pipeline.sae_model.parameters()):,}")

## Summary

This interactive notebook has demonstrated the complete SAE training pipeline:

1. **Data Loading**: Successfully loaded and processed RefSeq data
2. **Embedding Generation**: Extracted embeddings from HelicalmRNA model
3. **Model Architecture**: Set up a sparse autoencoder with tied weights
4. **Training**: Trained the model with sparsity constraints
5. **Analysis**: Visualized training progress and learned features
6. **Results**: Achieved sparse representations with good reconstruction quality

### Key Insights:
- The model learned sparse representations with controlled sparsity
- L0 sparsity shows how many features are active per token
- Training history reveals the trade-off between reconstruction and sparsity
- Feature analysis shows which dimensions are most important

### Next Steps:
- Experiment with different `hidden_dim` values
- Adjust `sparsity_weight` to control sparsity levels
- Try different layers of the HelicalmRNA model
- Compare with BatchTopK SAE for different sparsity patterns

The saved model can be loaded and used for feature extraction or further analysis!