# SAE Training Loss Visualization

This notebook visualizes training dynamics for the Gated Sparse Autoencoder (gSAE).

**Purpose:**
- Monitor training progress and convergence
- Compare loss components (reconstruction, sparsity)
- Verify target sparsity is achieved
- Identify training issues or anomalies

**Input Data:**
- Training logs: CSV with step, total_loss, recon_loss, sparsity_loss, sparsity_l0

**Output:**
- Multi-panel loss curves
- Sparsity evolution
- Training statistics

In [None]:
import os
import sys
from pathlib import Path

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns

PROJECT_ROOT = Path(os.getcwd()).parent.parent
sys.path.append(str(PROJECT_ROOT))

from src.visualization import (
    setup_korean_font,
    load_training_logs,
    plot_training_loss
)

print(f"Project root: {PROJECT_ROOT}")

In [None]:
# Setup
setup_korean_font()
sns.set_style('whitegrid')

In [None]:
# Configuration
RESULTS_DIR = PROJECT_ROOT / "results"
ASSETS_DIR = PROJECT_ROOT / "notebooks" / "visualizations" / "assets"
ASSETS_DIR.mkdir(exist_ok=True, parents=True)

STAGE = "mock"

print(f"Stage: {STAGE}")

## Load Training Logs

In [None]:
# Load logs
logs = load_training_logs(RESULTS_DIR, stage=STAGE)

print(f"Training logs shape: {logs.shape}")
print(f"Total steps: {len(logs)}")
print(f"\nColumns: {list(logs.columns)}")
print(f"\nFirst few rows:")
print(logs.head())

## Plot Training Loss Curves

In [None]:
fig = plot_training_loss(
    training_logs=logs,
    save_path=ASSETS_DIR / f"sae_training_loss_{STAGE}.png",
    figsize=(14, 8)
)

plt.show()

## Detailed Loss Analysis

In [None]:
# Smoothed curves
window = min(50, len(logs) // 10)

fig, axes = plt.subplots(2, 1, figsize=(14, 10))

# Total loss (original + smoothed)
ax = axes[0]
ax.plot(logs['step'], logs['total_loss'], alpha=0.3, color='blue', label='Original')
ax.plot(logs['step'], logs['total_loss'].rolling(window=window, center=True).mean(),
        linewidth=2, color='blue', label=f'Smoothed (window={window})')
ax.set_xlabel('Step', fontsize=12)
ax.set_ylabel('Total Loss', fontsize=12)
ax.set_title('전체 손실 (평활화)\nTotal Loss (Smoothed)', fontsize=14, pad=10)
ax.legend(fontsize=11)
ax.grid(alpha=0.3)

# Loss components
ax = axes[1]
if 'recon_loss' in logs.columns:
    ax.plot(logs['step'], logs['recon_loss'].rolling(window=window, center=True).mean(),
            linewidth=2, color='green', label='Reconstruction Loss')
if 'sparsity_loss' in logs.columns:
    ax.plot(logs['step'], logs['sparsity_loss'].rolling(window=window, center=True).mean(),
            linewidth=2, color='red', label='Sparsity Loss')
ax.set_xlabel('Step', fontsize=12)
ax.set_ylabel('Loss', fontsize=12)
ax.set_title('손실 구성 요소\nLoss Components', fontsize=14, pad=10)
ax.legend(fontsize=11)
ax.grid(alpha=0.3)

plt.tight_layout()
plt.savefig(ASSETS_DIR / f"sae_training_detailed_{STAGE}.png", dpi=300, bbox_inches='tight')
plt.show()

## Sparsity Evolution

In [None]:
if 'sparsity_l0' in logs.columns:
    fig, axes = plt.subplots(1, 2, figsize=(14, 5))
    
    # Sparsity over time
    ax = axes[0]
    ax.plot(logs['step'], logs['sparsity_l0'], linewidth=2, color='purple')
    ax.axhline(0.05, color='red', linestyle='--', label='Target (5%)', alpha=0.7)
    ax.set_xlabel('Step', fontsize=12)
    ax.set_ylabel('Sparsity (L0)', fontsize=12)
    ax.set_title('희소성 진화\nSparsity Evolution', fontsize=14, pad=10)
    ax.legend(fontsize=11)
    ax.grid(alpha=0.3)
    
    # Distribution of sparsity
    ax = axes[1]
    ax.hist(logs['sparsity_l0'], bins=50, color='purple', alpha=0.7, edgecolor='black')
    ax.axvline(logs['sparsity_l0'].mean(), color='blue', linestyle='--', 
               linewidth=2, label=f"Mean: {logs['sparsity_l0'].mean():.3f}")
    ax.axvline(0.05, color='red', linestyle='--', linewidth=2, label='Target: 0.05')
    ax.set_xlabel('Sparsity (L0)', fontsize=12)
    ax.set_ylabel('Frequency', fontsize=12)
    ax.set_title('희소성 분포\nSparsity Distribution', fontsize=14, pad=10)
    ax.legend(fontsize=11)
    ax.grid(alpha=0.3)
    
    plt.tight_layout()
    plt.savefig(ASSETS_DIR / f"sae_sparsity_analysis_{STAGE}.png", dpi=300, bbox_inches='tight')
    plt.show()
    
    print(f"\nSparsity Statistics:")
    print(f"  Final sparsity: {logs['sparsity_l0'].iloc[-1]:.4f}")
    print(f"  Mean sparsity:  {logs['sparsity_l0'].mean():.4f}")
    print(f"  Target:         0.0500")
    print(f"  Deviation:      {abs(logs['sparsity_l0'].iloc[-1] - 0.05):.4f}")

## Training Statistics

In [None]:
# Compute statistics
stats = {}

for col in ['total_loss', 'recon_loss', 'sparsity_loss', 'sparsity_l0']:
    if col in logs.columns:
        stats[col] = {
            'Initial': logs[col].iloc[0],
            'Final': logs[col].iloc[-1],
            'Min': logs[col].min(),
            'Max': logs[col].max(),
            'Mean': logs[col].mean(),
            'Std': logs[col].std(),
            'Change': logs[col].iloc[-1] - logs[col].iloc[0],
            'Change (%)': (logs[col].iloc[-1] - logs[col].iloc[0]) / logs[col].iloc[0] * 100
        }

df_stats = pd.DataFrame(stats).T

print("\nTraining Statistics:")
print("=" * 100)
print(df_stats.to_string())

# Save to CSV
df_stats.to_csv(ASSETS_DIR / f"training_statistics_{STAGE}.csv")
print(f"\nSaved to {ASSETS_DIR / f'training_statistics_{STAGE}.csv'}")

## Convergence Analysis

In [None]:
# Check convergence based on recent stability
window_size = min(100, len(logs) // 5)
recent_window = logs.tail(window_size)

print(f"\nConvergence Analysis (last {window_size} steps):")
print("=" * 80)

for col in ['total_loss', 'recon_loss', 'sparsity_loss']:
    if col in logs.columns:
        mean = recent_window[col].mean()
        std = recent_window[col].std()
        cv = (std / mean) * 100  # Coefficient of variation
        
        converged = cv < 5.0  # <5% variation = converged
        status = "✓ Converged" if converged else "⚠ Not converged"
        
        print(f"{col:20s}: Mean={mean:.4f}, Std={std:.4f}, CV={cv:.2f}% {status}")

# Sparsity convergence
if 'sparsity_l0' in logs.columns:
    final_sparsity = logs['sparsity_l0'].iloc[-1]
    target_sparsity = 0.05
    deviation = abs(final_sparsity - target_sparsity)
    
    within_target = deviation < 0.01  # Within 1% of target
    status = "✓ Target achieved" if within_target else "⚠ Target not achieved"
    
    print(f"\nSparsity Target:")
    print(f"  Final:  {final_sparsity:.4f}")
    print(f"  Target: {target_sparsity:.4f}")
    print(f"  Status: {status}")

## Loss Rate Analysis

In [None]:
# Compute loss gradient (rate of change)
if len(logs) > 1:
    fig, ax = plt.subplots(figsize=(14, 5))
    
    # Compute gradient
    gradient = np.gradient(logs['total_loss'].values)
    smoothed_gradient = pd.Series(gradient).rolling(window=window, center=True).mean()
    
    ax.plot(logs['step'], smoothed_gradient, linewidth=2, color='orange')
    ax.axhline(0, color='black', linestyle='--', linewidth=1, alpha=0.5)
    ax.set_xlabel('Step', fontsize=12)
    ax.set_ylabel('Loss Gradient (rate of change)', fontsize=12)
    ax.set_title('손실 변화율\nLoss Rate of Change', fontsize=14, pad=10)
    ax.grid(alpha=0.3)
    
    plt.tight_layout()
    plt.savefig(ASSETS_DIR / f"loss_gradient_{STAGE}.png", dpi=300, bbox_inches='tight')
    plt.show()
    
    print(f"\nLoss Gradient Statistics:")
    print(f"  Mean:   {np.mean(gradient):.6f}")
    print(f"  Median: {np.median(gradient):.6f}")
    print(f"  Std:    {np.std(gradient):.6f}")

## Interpretation

### What to Look For:

1. **Total Loss:**
   - Should decrease monotonically
   - Converge to stable minimum
   - No sudden spikes or divergence

2. **Reconstruction Loss:**
   - Measures how well SAE reconstructs inputs
   - Should stabilize at low value
   - Balance with sparsity constraint

3. **Sparsity Loss:**
   - Encourages sparse activations
   - Should decrease as model learns sparsity
   - Target: <5% active features (L0 norm)

4. **Sparsity (L0):**
   - Should converge to ~0.05 (5% active)
   - Too high: Features not sparse enough
   - Too low: May lose reconstruction quality

### Training Quality Indicators:

- ✓ **Good:** Smooth decrease, stable convergence, sparsity on target
- ⚠ **Warning:** High variance, slow convergence, sparsity off target
- ✗ **Bad:** Loss increase, divergence, training instability

### Next Steps:

1. If not converged: Continue training or adjust hyperparameters
2. If sparsity off: Adjust sparsity coefficient
3. If reconstruction poor: Reduce sparsity constraint
4. If training good: Proceed to IG² computation and bias analysis