# Shesha Tutorial: Training Drift Monitoring

This notebook demonstrates how to use Shesha to monitor representation stability during model training.

**What you'll learn:**
- How to track geometric stability during training
- How to detect representation collapse or divergence
- How to balance stability vs task alignment

**Requirements:**
```bash
pip install shesha-geometry torch
```

## 1. Setup

In [None]:
# Optional: Install dependencies
# !pip install shesha-geometry torch

In [None]:
import numpy as np
import matplotlib.pyplot as plt

import shesha

SEED = 320
np.random.seed(SEED)

print(f"Shesha version: {shesha.__version__}")

## 2. Simulate Training Dynamics

We'll simulate different training scenarios to show how Shesha captures representation changes.

In [None]:
def simulate_training(n_samples=500, n_features=256, n_epochs=20, scenario='healthy'):
    """
    Simulate representation evolution during training.

    Scenarios:
    - 'healthy': Gradual improvement in task alignment while maintaining stability
    - 'collapse': Representations collapse to low-dimensional manifold
    - 'divergence': Representations become increasingly noisy
    - 'overfit': Initial improvement followed by degradation
    """
    rng = np.random.default_rng(SEED)

    # Create labels (5 classes)
    labels = np.repeat(np.arange(5), n_samples // 5)

    # Base signal: low-rank structure
    latent_dim = 30
    latent = rng.standard_normal((n_samples, latent_dim))
    projection = rng.standard_normal((latent_dim, n_features))
    base_signal = latent @ projection
    base_signal /= np.std(base_signal)

    # Class centroids
    class_directions = rng.standard_normal((5, n_features))
    class_directions /= np.linalg.norm(class_directions, axis=1, keepdims=True)

    embeddings_history = []

    for epoch in range(n_epochs):
        t = epoch / n_epochs

        if scenario == 'healthy':
            # Gradually add class signal while keeping structure
            class_signal = np.array([class_directions[l] for l in labels]) * t * 2
            noise = rng.standard_normal((n_samples, n_features)) * 0.3
            embeddings = base_signal * (1 - 0.3*t) + class_signal + noise

        elif scenario == 'collapse':
            # Collapse to fewer dimensions over time
            effective_dims = max(5, int(n_features * (1 - t * 0.9)))
            embeddings = base_signal.copy()
            embeddings[:, effective_dims:] *= (1 - t)

        elif scenario == 'divergence':
            # Increasing noise
            noise_scale = 0.1 + t * 3
            noise = rng.standard_normal((n_samples, n_features)) * noise_scale
            embeddings = base_signal + noise

        elif scenario == 'overfit':
            # Good until midpoint, then degrade
            if t < 0.5:
                class_signal = np.array([class_directions[l] for l in labels]) * t * 4
                noise = rng.standard_normal((n_samples, n_features)) * 0.2
            else:
                # Overfit: memorize training data, lose structure
                class_signal = np.array([class_directions[l] for l in labels]) * 2
                noise = rng.standard_normal((n_samples, n_features)) * (t - 0.3) * 2
            embeddings = base_signal * (1 - t * 0.5) + class_signal + noise

        embeddings_history.append(embeddings)

    return embeddings_history, labels

## 3. Monitor Training with Shesha

In [None]:
def monitor_training(embeddings_history, labels):
    """Compute Shesha metrics at each epoch."""
    metrics = {'stability': [], 'alignment': [], 'var_ratio': []}

    for embeddings in embeddings_history:
        metrics['stability'].append(shesha.feature_split(embeddings, seed=SEED))
        metrics['alignment'].append(shesha.supervised_alignment(embeddings, labels, seed=SEED))
        metrics['var_ratio'].append(shesha.variance_ratio(embeddings, labels))

    return metrics

# Run all scenarios
scenarios = ['healthy', 'collapse', 'divergence', 'overfit']
all_metrics = {}

for scenario in scenarios:
    print(f"Simulating {scenario} training...")
    history, labels = simulate_training(scenario=scenario)
    all_metrics[scenario] = monitor_training(history, labels)

print("Done!")

## 4. Visualize Training Dynamics

In [None]:
fig, axes = plt.subplots(2, 2, figsize=(12, 10))
epochs = range(20)

titles = {
    'healthy': 'Healthy Training',
    'collapse': 'Representation Collapse',
    'divergence': 'Divergence (Noisy)',
    'overfit': 'Overfitting'
}

for ax, scenario in zip(axes.flat, scenarios):
    m = all_metrics[scenario]

    ax.plot(epochs, m['stability'], 'b-o', label='Stability', markersize=4)
    ax.plot(epochs, m['alignment'], 'r-s', label='Alignment', markersize=4)
    ax.plot(epochs, m['var_ratio'], 'g-^', label='Var Ratio', markersize=4)

    ax.set_xlabel('Epoch')
    ax.set_ylabel('Score')
    ax.set_title(titles[scenario])
    ax.legend(loc='best', fontsize='small')
    ax.grid(True, alpha=0.3)
    ax.set_ylim(-0.1, 1.0)

plt.tight_layout()
plt.savefig('training_dynamics.png', dpi=150)
plt.show()

## 5. Interpretation

| Scenario | Stability | Alignment | Diagnosis |
|----------|-----------|-----------|----------|
| **Healthy** | Stable/slight decrease | Increases | Normal - learning task while maintaining structure |
| **Collapse** | Drops sharply | Variable | Bad - representations losing information |
| **Divergence** | Drops to ~0 | Drops to ~0 | Bad - signal overwhelmed by noise |
| **Overfit** | Drops late | Peaks then drops | Warning - memorizing rather than generalizing |

## 6. Early Stopping with Shesha

In [None]:
def should_stop_early(metrics, patience=3, stability_threshold=0.2):
    """
    Simple early stopping based on stability.

    Stop if stability drops below threshold for 'patience' epochs.
    """
    recent = metrics['stability'][-patience:]
    if len(recent) < patience:
        return False
    return all(s < stability_threshold for s in recent)

# Demo
print("Early stopping analysis:")
for scenario in scenarios:
    m = all_metrics[scenario]
    for epoch in range(5, 20):
        partial = {k: v[:epoch] for k, v in m.items()}
        if should_stop_early(partial):
            print(f"  {scenario}: Stop at epoch {epoch}")
            break
    else:
        print(f"  {scenario}: No early stop triggered")

## 7. Quick Reference

```python
import shesha

# In your training loop:
for epoch in range(n_epochs):
    train_one_epoch(model, data)
    
    # Extract embeddings from validation set
    embeddings = get_embeddings(model, val_data)
    
    # Monitor stability
    stability = shesha.feature_split(embeddings, seed=320)
    
    # Log metrics
    wandb.log({'stability': stability, 'epoch': epoch})
    
    # Early stopping check
    if stability < 0.1:
        print("Warning: Representation collapse detected!")
```

**Red flags:**
- Stability < 0.1: Possible collapse or noise
- Sudden stability drop: Check for bugs or bad hyperparameters
- Alignment drops while loss decreases: Overfitting