# SAE Training on Conditional Features (Colab GPU)

This notebook should be run on Google Colab with GPU enabled.

**Steps:**
1. Clone repository from GitHub
2. Install dependencies
3. Load prepared dataset
4. Extract model activations
5. Train Sparse Autoencoder
6. Save results to Google Drive

## 1. Setup Environment

In [None]:
# Check if running in Colab
import os
IN_COLAB = 'COLAB_GPU' in os.environ

if IN_COLAB:
    print("Running in Google Colab")
    !nvidia-smi
else:
    print("Not in Colab - make sure to upload this to Colab for GPU access")

In [None]:
# Clone repository
if IN_COLAB:
    !git clone https://github.com/ychleee/SAE_Functional.git
    %cd SAE_Functional
else:
    # If running locally for testing
    import sys
    sys.path.append('..')

In [None]:
# Install dependencies
if IN_COLAB:
    !pip install -q -r requirements_colab.txt
    print("Dependencies installed")

In [None]:
# Fix NumPy compatibility issues
import subprocess
import sys

# Restart runtime after installing to avoid conflicts
if IN_COLAB:
    # Uninstall and reinstall numpy/pandas with compatible versions
    !pip uninstall -y numpy pandas numexpr
    !pip install numpy==1.23.5 pandas==1.5.3
    
    print("NumPy and Pandas reinstalled with compatible versions")
    print("You may need to restart the runtime: Runtime -> Restart runtime")
    print("Then run from the next cell onwards (skip the setup cells)")

In [None]:
# Mount Google Drive for saving results
if IN_COLAB:
    from google.colab import drive
    drive.mount('/content/drive')
    
    # Create results directory in Drive
    DRIVE_DIR = '/content/drive/MyDrive/sae_conditionals_results'
    !mkdir -p {DRIVE_DIR}
    print(f"Results will be saved to: {DRIVE_DIR}")

## 2. Load Configuration and Data

In [None]:
import yaml
import torch
import pandas as pd
import numpy as np
from pathlib import Path

# Load configuration
with open('configs/training_config.yaml', 'r') as f:
    config = yaml.safe_load(f)

print("Configuration loaded:")
print(f"Model: {config['model']['name']}")
print(f"SAE features: {config['sae']['hidden_dim']}")
print(f"Sparsity coefficient: {config['sae']['sparsity_coeff']}")

In [None]:
# Load prepared dataset
# Note: You need to have run 01_data_preparation.ipynb first and pushed to GitHub
data_path = Path(config['paths']['data_dir']) / 'conditionals_dataset.csv'

if data_path.exists():
    df = pd.read_csv(data_path)
    print(f"Loaded {len(df)} sentences")
    print(f"Conditionals: {df['has_conditional'].sum()}")
    print(f"Controls: {(~df['has_conditional']).sum()}")
else:
    print(f"Dataset not found at {data_path}")
    print("Creating a small synthetic dataset for testing...")
    
    # Import data generation utilities
    from src.data_utils import ConditionalDatasetGenerator
    
    generator = ConditionalDatasetGenerator()
    df = generator.generate_dataset(n_samples=500)
    print(f"Generated {len(df)} sentences")

## 3. Extract Model Activations

In [None]:
from src.activation_extraction import ActivationExtractor, prepare_activations_for_sae

# Initialize extractor
device = 'cuda' if torch.cuda.is_available() else 'cpu'
print(f"Using device: {device}")

extractor = ActivationExtractor(
    model_name=config['model']['name'],
    device=device
)

In [None]:
# Extract activations for the dataset
print("Extracting activations...")
activations_dict = extractor.extract_dataset_activations(
    df,
    text_column='text',
    layer_idx=config['model']['layer_idx'],
    batch_size=config['data']['batch_size'],
    max_samples=config['data']['n_samples']
)

print(f"Activations shape: {activations_dict['shape']}")
print(f"Extracted from layer: {activations_dict['layer_idx']}")

In [None]:
# Prepare activations for SAE training
pooled_activations = prepare_activations_for_sae(
    activations_dict,
    pool_method=config['data']['pool_method']
)

print(f"Pooled activations shape: {pooled_activations.shape}")
print(f"Input dimension for SAE: {pooled_activations.shape[1]}")

# Update config with actual input dimension
actual_input_dim = pooled_activations.shape[1]
config['sae']['input_dim'] = actual_input_dim

## 4. Train Sparse Autoencoder

In [None]:
# Split data into train/validation
n_train = int(len(pooled_activations) * config['data']['train_val_split'])
train_data = pooled_activations[:n_train]
val_data = pooled_activations[n_train:]

print(f"Training samples: {len(train_data)}")
print(f"Validation samples: {len(val_data)}")

In [None]:
from src.sae_training import SparseAutoencoder, SAETrainer

# Create SAE model
sae_model = SparseAutoencoder(
    input_dim=config['sae']['input_dim'],
    hidden_dim=config['sae']['hidden_dim'],
    sparsity_coeff=config['sae']['sparsity_coeff'],
    use_bias=config['sae']['use_bias']
)

print(f"SAE Model:")
print(f"  Input dimension: {config['sae']['input_dim']}")
print(f"  Hidden dimension: {config['sae']['hidden_dim']}")
print(f"  Sparsity coefficient: {config['sae']['sparsity_coeff']}")
print(f"  Total parameters: {sum(p.numel() for p in sae_model.parameters())}")

In [None]:
# Initialize trainer
trainer = SAETrainer(sae_model, device=device)

# Train the model
print("\nStarting training...")
history = trainer.train(
    train_data=train_data,
    val_data=val_data,
    epochs=config['training']['epochs'],
    batch_size=config['training']['batch_size'],
    learning_rate=config['training']['learning_rate'],
    weight_decay=config['training']['weight_decay'],
    patience=config['training']['patience'],
    verbose=True
)

print("\nTraining complete!")

In [None]:
# Plot training history
import matplotlib.pyplot as plt

fig = trainer.plot_history()
plt.show()

# Final metrics
final_train = history['train'][-1]
print("\nFinal Training Metrics:")
print(f"  Total Loss: {final_train['total_loss']:.4f}")
print(f"  Reconstruction Loss: {final_train['recon_loss']:.4f}")
print(f"  Sparsity Loss: {final_train['sparsity_loss']:.4f}")
print(f"  Active Features: {final_train['active_features']:.1f}")

## 5. Analyze Learned Features

In [None]:
from src.sae_training import analyze_features

# Analyze which features correspond to conditionals
analysis = analyze_features(
    model=sae_model,
    activations=pooled_activations,
    metadata=activations_dict['metadata'],
    top_k=config['analysis']['top_k_features']
)

print("Top Conditional Features:")
for i, (feat_idx, score) in enumerate(zip(
    analysis['conditional_features'][:5],
    analysis['conditional_scores'][:5]
)):
    print(f"  {i+1}. Feature {feat_idx}: differential score = {score:.3f}")

print("\nFeature Statistics:")
for key, value in analysis['feature_stats'].items():
    print(f"  {key}: {value:.2f}")

In [None]:
from src.feature_analysis import FeatureInterpreter, create_feature_report

# Create feature interpreter
interpreter = FeatureInterpreter(
    sae_model=sae_model,
    activations_dict={'activations': pooled_activations, 'metadata': activations_dict['metadata']},
    device=device
)

# Analyze conditional features
conditional_analysis = interpreter.analyze_conditional_features(top_k=10)

print("\nConditional Feature Analysis:")
print(f"Top 'if' features: {conditional_analysis['if_features']['indices'][:5]}")
print(f"Top 'then' features: {conditional_analysis['then_features']['indices'][:5]}")
print(f"Top conditional features: {conditional_analysis['conditional_features']['indices'][:5]}")

In [None]:
# Generate report for top conditional feature
top_feature = analysis['conditional_features'][0]
report = create_feature_report(interpreter, top_feature, n_examples=5)
print(report)

## 6. Save Results to Google Drive

In [None]:
import torch
import json
from datetime import datetime

if IN_COLAB:
    # Create timestamp for this run
    timestamp = datetime.now().strftime('%Y%m%d_%H%M%S')
    run_dir = f"{DRIVE_DIR}/run_{timestamp}"
    !mkdir -p {run_dir}
    
    # Save trained model
    model_path = f"{run_dir}/sae_model.pt"
    torch.save({
        'model_state_dict': sae_model.state_dict(),
        'config': config,
        'history': history,
        'analysis': analysis
    }, model_path)
    print(f"Model saved to: {model_path}")
    
    # Save activations
    activations_path = f"{run_dir}/activations.pt"
    torch.save(activations_dict, activations_path)
    print(f"Activations saved to: {activations_path}")
    
    # Save analysis results
    analysis_path = f"{run_dir}/analysis.json"
    with open(analysis_path, 'w') as f:
        json.dump(conditional_analysis, f, indent=2)
    print(f"Analysis saved to: {analysis_path}")
    
    # Save training plot
    fig = trainer.plot_history()
    fig.savefig(f"{run_dir}/training_history.png")
    print(f"Training plot saved to: {run_dir}/training_history.png")
    
    print(f"\nAll results saved to: {run_dir}")
else:
    print("Not in Colab - results not saved to Drive")

## 7. Download Results for Local Analysis

In [None]:
# Instructions for downloading results
if IN_COLAB:
    print("To download results to your local machine:")
    print(f"1. Navigate to Google Drive: {DRIVE_DIR}")
    print("2. Download the latest run_TIMESTAMP folder")
    print("3. Place in your local project's models/checkpoints directory")
    print("\nAlternatively, use rclone or Google Drive sync on your local machine")
else:
    print("Upload this notebook to Colab to train with GPU and save results")

## Next Steps

1. **Download results** from Google Drive to your local machine
2. **Run analysis notebook** (03_analysis.ipynb) locally to:
   - Perform detailed feature interpretation
   - Create visualizations
   - Test causal interventions
3. **Experiment with hyperparameters** by modifying training_config.yaml
4. **Try different models** (Pythia-410M, GPT-2-medium) for comparison