# SAE Training Loss Visualization

This notebook visualizes the training progress and loss curves for the Sparse Autoencoder (SAE).

**Purpose:**
- Monitor training convergence
- Analyze reconstruction vs. sparsity loss trade-off
- Verify target sparsity is achieved

**Input Data:**
- Training logs CSV: step, total_loss, recon_loss, sparsity_loss, sparsity_l0

**Output:**
- Training loss curves (total, reconstruction, sparsity)
- Sparsity (L0) progression
- Convergence analysis

In [None]:
import os
import sys
import warnings
from pathlib import Path

import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
import pandas as pd

# Add project root to path (works from notebooks/visualizations/)
NOTEBOOK_DIR = Path(os.getcwd())
PROJECT_ROOT = NOTEBOOK_DIR.parent.parent
if str(PROJECT_ROOT) not in sys.path:
    sys.path.insert(0, str(PROJECT_ROOT))

from src.visualization import (
    setup_korean_font,
    ensure_korean_font,
    load_training_logs,
    plot_training_loss
)

warnings.filterwarnings('ignore')

print(f"Project root: {PROJECT_ROOT}")
print(f"Notebook dir: {NOTEBOOK_DIR}")

In [None]:
# Setup Korean font for matplotlib (improved version with auto-detection)
font_name = ensure_korean_font()

# Seaborn style
sns.set_style('whitegrid')

In [None]:
# Configuration
RESULTS_DIR = PROJECT_ROOT / "results"
ASSETS_DIR = PROJECT_ROOT / "notebooks" / "visualizations" / "assets"
ASSETS_DIR.mkdir(exist_ok=True, parents=True)

# Stage: 'pilot', 'medium', or 'full'
STAGE = "pilot"

# SAE configuration
SAE_TYPE = "gated"  # 'standard' or 'gated'
LAYER_QUANTILE = "q2"  # 'q1', 'q2', or 'q3'

# Target sparsity (for reference)
TARGET_SPARSITY_L0 = 0.05  # Target L0 sparsity (5% active features)

print(f"Results directory: {RESULTS_DIR}")
print(f"\nStage: {STAGE}")
print(f"SAE type: {SAE_TYPE}")
print(f"Layer quantile: {LAYER_QUANTILE}")
print(f"Target sparsity (L0): {TARGET_SPARSITY_L0}")

## Load Training Logs

In [None]:
# Load training logs
training_logs = load_training_logs(
    RESULTS_DIR, 
    stage=STAGE, 
    sae_type=SAE_TYPE,
    layer_quantile=LAYER_QUANTILE
)

print(f"Loaded {len(training_logs)} training steps")
print(f"\nColumns: {list(training_logs.columns)}")
print(f"\nFirst few rows:")
print(training_logs.head())

## Plot Training Loss Curves

In [None]:
fig = plot_training_loss(
    training_logs=training_logs,
    save_path=ASSETS_DIR / f"sae_training_loss_{STAGE}_{SAE_TYPE}_{LAYER_QUANTILE}.png",
    figsize=(14, 8)
)

plt.show()

## Detailed Loss Analysis

In [None]:
# Enhanced loss visualization with smoothing
fig, axes = plt.subplots(2, 2, figsize=(16, 10))
axes = axes.flatten()

# Window for smoothing
window_size = max(1, len(training_logs) // 50)

def smooth(series, window=window_size):
    return series.rolling(window=window, min_periods=1).mean()

# Total loss
if 'total_loss' in training_logs.columns:
    ax = axes[0]
    ax.plot(training_logs['step'], training_logs['total_loss'], alpha=0.3, color='blue', label='Raw')
    ax.plot(training_logs['step'], smooth(training_logs['total_loss']), color='blue', linewidth=2, label='Smoothed')
    ax.set_xlabel('Training Step', fontsize=12)
    ax.set_ylabel('Loss', fontsize=12)
    ax.set_title('Total Loss', fontsize=14, pad=10)
    ax.legend()
    ax.grid(alpha=0.3)

# Reconstruction loss
if 'recon_loss' in training_logs.columns:
    ax = axes[1]
    ax.plot(training_logs['step'], training_logs['recon_loss'], alpha=0.3, color='green', label='Raw')
    ax.plot(training_logs['step'], smooth(training_logs['recon_loss']), color='green', linewidth=2, label='Smoothed')
    ax.set_xlabel('Training Step', fontsize=12)
    ax.set_ylabel('Loss', fontsize=12)
    ax.set_title('Reconstruction Loss', fontsize=14, pad=10)
    ax.legend()
    ax.grid(alpha=0.3)

# Sparsity loss
if 'sparsity_loss' in training_logs.columns:
    ax = axes[2]
    ax.plot(training_logs['step'], training_logs['sparsity_loss'], alpha=0.3, color='red', label='Raw')
    ax.plot(training_logs['step'], smooth(training_logs['sparsity_loss']), color='red', linewidth=2, label='Smoothed')
    ax.set_xlabel('Training Step', fontsize=12)
    ax.set_ylabel('Loss', fontsize=12)
    ax.set_title('Sparsity Loss', fontsize=14, pad=10)
    ax.legend()
    ax.grid(alpha=0.3)

# Sparsity L0 with target line
if 'sparsity_l0' in training_logs.columns:
    ax = axes[3]
    ax.plot(training_logs['step'], training_logs['sparsity_l0'], alpha=0.3, color='purple', label='Raw')
    ax.plot(training_logs['step'], smooth(training_logs['sparsity_l0']), color='purple', linewidth=2, label='Smoothed')
    ax.axhline(TARGET_SPARSITY_L0, color='orange', linestyle='--', linewidth=2, label=f'Target: {TARGET_SPARSITY_L0}')
    ax.set_xlabel('Training Step', fontsize=12)
    ax.set_ylabel('L0 (Active Features Ratio)', fontsize=12)
    ax.set_title('Sparsity (L0)', fontsize=14, pad=10)
    ax.legend()
    ax.grid(alpha=0.3)

plt.suptitle(f'SAE Training Progress ({SAE_TYPE}, {STAGE}, {LAYER_QUANTILE})', fontsize=16, y=1.02)
plt.tight_layout()
plt.savefig(ASSETS_DIR / f"sae_training_detailed_{STAGE}_{SAE_TYPE}_{LAYER_QUANTILE}.png", 
            dpi=300, bbox_inches='tight')
plt.show()

## Training Statistics

In [None]:
# Compute training statistics
print("\nTraining Statistics:")
print("=" * 80)

# Final values (average of last 10%)
last_n = max(1, len(training_logs) // 10)
final_metrics = training_logs.tail(last_n)

print(f"\nTotal training steps: {len(training_logs)}")
print(f"\nFinal metrics (average of last {last_n} steps):")

metric_names = {
    'total_loss': 'Total Loss',
    'recon_loss': 'Reconstruction Loss',
    'sparsity_loss': 'Sparsity Loss',
    'sparsity_l0': 'Sparsity (L0)'
}

for col, name in metric_names.items():
    if col in training_logs.columns:
        final_mean = final_metrics[col].mean()
        final_std = final_metrics[col].std()
        initial = training_logs[col].iloc[0] if len(training_logs) > 0 else 0
        
        if col == 'sparsity_l0':
            print(f"\n{name}:")
            print(f"  Initial:   {initial:.4f}")
            print(f"  Final:     {final_mean:.4f} +/- {final_std:.4f}")
            print(f"  Target:    {TARGET_SPARSITY_L0}")
            if TARGET_SPARSITY_L0 > 0:
                deviation = abs(final_mean - TARGET_SPARSITY_L0) / TARGET_SPARSITY_L0 * 100
                print(f"  Deviation: {deviation:.1f}%")
        else:
            print(f"\n{name}:")
            print(f"  Initial:   {initial:.6f}")
            print(f"  Final:     {final_mean:.6f} +/- {final_std:.6f}")
            if initial > 0:
                reduction = (1 - final_mean / initial) * 100
                print(f"  Reduction: {reduction:+.1f}%")

## Convergence Analysis

In [None]:
# Check convergence
if 'total_loss' in training_logs.columns and len(training_logs) > 10:
    # Divide training into segments
    n_segments = 5
    segment_size = len(training_logs) // n_segments
    
    print("\nConvergence Analysis (Total Loss):")
    print("=" * 80)
    
    segment_means = []
    for i in range(n_segments):
        start = i * segment_size
        end = (i + 1) * segment_size if i < n_segments - 1 else len(training_logs)
        segment = training_logs.iloc[start:end]
        mean_loss = segment['total_loss'].mean()
        segment_means.append(mean_loss)
        
        # Calculate percentage of training
        pct_start = start / len(training_logs) * 100
        pct_end = end / len(training_logs) * 100
        print(f"  Segment {i+1} ({pct_start:.0f}%-{pct_end:.0f}%): Loss = {mean_loss:.6f}")
    
    # Check if converging (loss should be decreasing or stable)
    final_vs_initial = (segment_means[-1] - segment_means[0]) / segment_means[0] * 100
    final_vs_mid = (segment_means[-1] - segment_means[n_segments//2]) / segment_means[n_segments//2] * 100
    
    print(f"\n  Total change: {final_vs_initial:+.2f}%")
    print(f"  Change in 2nd half: {final_vs_mid:+.2f}%")
    
    if abs(final_vs_mid) < 5:
        print("\n  Status: Converged (stable)")
    elif final_vs_mid < 0:
        print("\n  Status: Still improving")
    else:
        print("\n  Status: Potential overfitting (loss increasing)")

## Loss Component Balance

In [None]:
# Analyze balance between reconstruction and sparsity
if 'recon_loss' in training_logs.columns and 'sparsity_loss' in training_logs.columns:
    fig, ax = plt.subplots(figsize=(12, 5))
    
    # Normalize losses for comparison
    recon_norm = training_logs['recon_loss'] / training_logs['recon_loss'].max()
    sparse_norm = training_logs['sparsity_loss'] / training_logs['sparsity_loss'].max()
    
    ax.plot(training_logs['step'], smooth(recon_norm), color='green', linewidth=2, label='Reconstruction (normalized)')
    ax.plot(training_logs['step'], smooth(sparse_norm), color='red', linewidth=2, label='Sparsity (normalized)')
    
    ax.set_xlabel('Training Step', fontsize=12)
    ax.set_ylabel('Normalized Loss', fontsize=12)
    ax.set_title('Reconstruction vs Sparsity Loss Balance', fontsize=14, pad=10)
    ax.legend(fontsize=11)
    ax.grid(alpha=0.3)
    
    plt.tight_layout()
    plt.savefig(ASSETS_DIR / f"loss_balance_{STAGE}_{SAE_TYPE}_{LAYER_QUANTILE}.png", 
                dpi=300, bbox_inches='tight')
    plt.show()
    
    # Compute ratio
    final_recon = final_metrics['recon_loss'].mean()
    final_sparse = final_metrics['sparsity_loss'].mean()
    if final_sparse > 0:
        ratio = final_recon / final_sparse
        print(f"\nFinal recon/sparsity ratio: {ratio:.2f}")
        print(f"  (>1 means reconstruction dominates, <1 means sparsity dominates)")

## Sparsity Evolution

In [None]:
if 'sparsity_l0' in training_logs.columns:
    fig, axes = plt.subplots(1, 2, figsize=(14, 5))
    
    # Sparsity over time
    ax = axes[0]
    ax.plot(training_logs['step'], training_logs['sparsity_l0'], linewidth=2, color='purple')
    ax.axhline(TARGET_SPARSITY_L0, color='red', linestyle='--', 
               label=f'Target ({TARGET_SPARSITY_L0:.1%})', alpha=0.7)
    ax.set_xlabel('Step', fontsize=12)
    ax.set_ylabel('Sparsity (L0)', fontsize=12)
    ax.set_title('Sparsity Evolution', fontsize=14, pad=10)
    ax.legend(fontsize=11)
    ax.grid(alpha=0.3)
    
    # Distribution of sparsity
    ax = axes[1]
    ax.hist(training_logs['sparsity_l0'], bins=50, color='purple', alpha=0.7, edgecolor='black')
    ax.axvline(training_logs['sparsity_l0'].mean(), color='blue', linestyle='--', 
               linewidth=2, label=f"Mean: {training_logs['sparsity_l0'].mean():.4f}")
    ax.axvline(TARGET_SPARSITY_L0, color='red', linestyle='--', 
               linewidth=2, label=f'Target: {TARGET_SPARSITY_L0}')
    ax.set_xlabel('Sparsity (L0)', fontsize=12)
    ax.set_ylabel('Frequency', fontsize=12)
    ax.set_title('Sparsity Distribution', fontsize=14, pad=10)
    ax.legend(fontsize=11)
    ax.grid(alpha=0.3)
    
    plt.tight_layout()
    plt.savefig(ASSETS_DIR / f"sae_sparsity_analysis_{STAGE}_{SAE_TYPE}_{LAYER_QUANTILE}.png", 
                dpi=300, bbox_inches='tight')
    plt.show()
    
    print(f"\nSparsity Statistics:")
    print(f"  Final sparsity: {training_logs['sparsity_l0'].iloc[-1]:.4f}")
    print(f"  Mean sparsity:  {training_logs['sparsity_l0'].mean():.4f}")
    print(f"  Target:         {TARGET_SPARSITY_L0:.4f}")
    print(f"  Deviation:      {abs(training_logs['sparsity_l0'].iloc[-1] - TARGET_SPARSITY_L0):.4f}")

## Interpretation

### Key Metrics:

1. **Total Loss:** Overall training objective
   - Should decrease and stabilize
   - Final value indicates reconstruction quality

2. **Reconstruction Loss:** How well SAE reconstructs inputs
   - Lower is better
   - Trade-off with sparsity

3. **Sparsity Loss:** Penalty for non-sparse activations
   - Controls feature selectivity
   - Higher lambda = sparser features

4. **Sparsity (L0):** Fraction of active features per input
   - Target: ~5% (adjustable)
   - Too low = information loss
   - Too high = less interpretable

### What to Look For:

- **Convergence:** Loss should stabilize, not oscillate
- **Balance:** Recon and sparsity losses should be similar magnitude
- **L0 Target:** Should be close to target sparsity
- **No Divergence:** Loss shouldn't increase over training

### Next Steps:

1. If L0 too high: Increase sparsity lambda
2. If recon loss too high: Decrease sparsity lambda
3. If not converged: Train longer
4. If oscillating: Reduce learning rate