# Adaptive learning rate schedulers for improved training

Note, you can find the original version of this notebook at [docs/advanced_tutorials/22_adaptive_learning_rates.ipynb](https://github.com/sbi-dev/sbi/blob/main/docs/advanced_tutorials/22_adaptive_learning_rates.ipynb) in the `sbi` repository.

Learning rate schedulers can significantly improve neural network training by automatically adjusting the learning rate during training. This tutorial demonstrates how to use PyTorch's learning rate schedulers with all `sbi` inference methods (NPE, NLE, and NRE).

We'll show how different schedulers affect training dynamics and final performance using a simple but illustrative example.

In [None]:
import matplotlib.pyplot as plt
import numpy as np
import torch
from torch.distributions import MultivariateNormal

from sbi.analysis import pairplot
from sbi.inference import NLE_A, NPE, NRE_A
from sbi.simulators.linear_gaussian import linear_gaussian

# Set random seed for reproducibility
torch.manual_seed(42)
np.random.seed(42)

## Setup: Linear Gaussian model

We'll use a simple 2D linear Gaussian model where we can easily visualize the results:

In [None]:
# Define prior and simulator
num_dim = 2
prior = MultivariateNormal(loc=torch.zeros(num_dim), covariance_matrix=torch.eye(num_dim))

# Simulator parameters
likelihood_shift = -0.5 * torch.ones(num_dim)
likelihood_cov = 0.3 * torch.eye(num_dim)

def simulator(theta):
    return linear_gaussian(theta, likelihood_shift, likelihood_cov)

# Generate training data
num_simulations = 1000
theta = prior.sample((num_simulations,))
x = simulator(theta)

print(f"Generated {num_simulations} training samples")
print(f"Parameter space dimensionality: {theta.shape[1]}")
print(f"Data space dimensionality: {x.shape[1]}")

## Comparing different learning rate schedulers

Let's train the same NPE model with different schedulers and compare their training dynamics:

In [None]:
def train_with_scheduler(scheduler_name, scheduler_kwargs, learning_rate=1e-3, max_epochs=100):
    """Train NPE with a specific scheduler and return training history."""

    inference = NPE(prior=prior, show_progress_bars=False)
    inference.append_simulations(theta, x)

    # Train with scheduler
    posterior = inference.train(
        lr_scheduler=scheduler_name,
        lr_scheduler_kwargs=scheduler_kwargs,
        learning_rate=learning_rate,
        max_num_epochs=max_epochs,
        validation_fraction=0.1,
        training_batch_size=50
    )

    # Extract training history
    summary = inference._summary
    return {
        'epochs': summary['epochs_trained'],
        'train_losses': summary['train_log_probs'],
        'val_losses': summary['validation_log_probs'],
        'learning_rates': summary['learning_rates'],
        'posterior': posterior,
        'inference': inference
    }

# Define scheduler configurations to test
scheduler_configs = {
    'No Scheduler': (None, {}),
    'ReduceLROnPlateau': ('plateau', {'factor': 0.5, 'patience': 10}),
    'ExponentialLR': ('exponential', {'gamma': 0.95}),
    'CosineAnnealingLR': ('cosine', {'T_max': 50, 'eta_min': 1e-6}),
    'StepLR': ('step', {'step_size': 25, 'gamma': 0.3})
}

# Train with each scheduler
results = {}
for name, (scheduler, kwargs) in scheduler_configs.items():
    print(f"Training with {name}...")
    results[name] = train_with_scheduler(scheduler, kwargs, learning_rate=5e-3, max_epochs=150)
    final_epoch = results[name]['epochs'][-1]
    final_val_loss = results[name]['val_losses'][-1]
    print(f"  Converged after {final_epoch} epochs, final validation loss: {final_val_loss:.3f}")

## Visualizing training dynamics

Let's compare how different schedulers affect training:

In [None]:
fig, axes = plt.subplots(2, 2, figsize=(12, 10))
axes = axes.flatten()

colors = plt.cm.tab10(np.linspace(0, 1, len(results)))

# Plot training loss
for i, (name, result) in enumerate(results.items()):
    epochs = result['epochs']
    train_losses = result['train_losses']
    axes[0].plot(epochs, train_losses, label=name, color=colors[i], alpha=0.8)

axes[0].set_xlabel('Epoch')
axes[0].set_ylabel('Training Loss')
axes[0].set_title('Training Loss vs Epoch')
axes[0].legend()
axes[0].grid(True, alpha=0.3)

# Plot validation loss
for i, (name, result) in enumerate(results.items()):
    epochs = result['epochs']
    val_losses = result['val_losses']
    axes[1].plot(epochs, val_losses, label=name, color=colors[i], alpha=0.8)

axes[1].set_xlabel('Epoch')
axes[1].set_ylabel('Validation Loss')
axes[1].set_title('Validation Loss vs Epoch')
axes[1].legend()
axes[1].grid(True, alpha=0.3)

# Plot learning rate schedules
for i, (name, result) in enumerate(results.items()):
    if name != 'No Scheduler':  # Skip constant LR
        epochs = result['epochs']
        lrs = result['learning_rates']
        axes[2].plot(epochs, lrs, label=name, color=colors[i], alpha=0.8)

axes[2].set_xlabel('Epoch')
axes[2].set_ylabel('Learning Rate')
axes[2].set_title('Learning Rate Schedule')
axes[2].set_yscale('log')
axes[2].legend()
axes[2].grid(True, alpha=0.3)

# Training efficiency: epochs to convergence
names = list(results.keys())
epochs_to_converge = [results[name]['epochs'][-1] for name in names]
final_val_losses = [results[name]['val_losses'][-1] for name in names]

bars = axes[3].bar(range(len(names)), epochs_to_converge, color=colors[:len(names)], alpha=0.7)
axes[3].set_xlabel('Scheduler Type')
axes[3].set_ylabel('Epochs to Convergence')
axes[3].set_title('Training Efficiency')
axes[3].set_xticks(range(len(names)))
axes[3].set_xticklabels(names, rotation=45, ha='right')

# Add final validation loss as text on bars
for bar, val_loss in zip(bars, final_val_losses, strict=False):
    height = bar.get_height()
    axes[3].text(bar.get_x() + bar.get_width()/2., height + 1,
                f'Val: {val_loss:.3f}', ha='center', va='bottom', fontsize=8)

plt.tight_layout()
plt.show()

## Examining posterior quality

Let's compare the posterior quality for different schedulers by sampling from the trained models:

In [None]:
# Set observation point
x_o = torch.tensor([[0.5, -0.2]])

# Sample from posteriors
fig, axes = plt.subplots(1, len(results), figsize=(4*len(results), 4))
if len(results) == 1:
    axes = [axes]

for i, (name, result) in enumerate(results.items()):
    # Build posterior and sample
    posterior = result['inference'].build_posterior()
    samples = posterior.sample((1000,), x=x_o)

    # Plot
    pairplot(samples,
             limits=[[-3, 3], [-3, 3]],
             figsize=(4, 4),
             fig=fig,
             subplot_idx=(0, i),
             points=x_o,
             points_colors=['red'])

    axes[i].set_title(f'{name}\n(Final val loss: {result["val_losses"][-1]:.3f})')

plt.tight_layout()
plt.show()

## Advanced scheduler configuration

### Using dictionary configuration for complex schedulers

In [None]:
# Complex scheduler configuration using dictionary
advanced_scheduler_config = {
    "type": "plateau",
    "factor": 0.3,        # Reduce LR by 70%
    "patience": 8,        # Wait 8 epochs
    "threshold": 1e-4,    # Minimum improvement threshold
    "cooldown": 5,        # Wait 5 epochs after reduction
    "min_lr": 1e-6,       # Don't reduce below this
}

inference_advanced = NPE(prior=prior, show_progress_bars=False)
inference_advanced.append_simulations(theta, x)

posterior_advanced = inference_advanced.train(
    lr_scheduler=advanced_scheduler_config,
    learning_rate=1e-2,    # Higher initial LR
    max_num_epochs=200,
    validation_fraction=0.1,
    min_lr_threshold=2e-6  # Stop if LR gets too small
)

print(f"Advanced scheduler training completed after {inference_advanced._summary['epochs_trained'][-1]} epochs")
print(f"Final learning rate: {inference_advanced._summary['learning_rates'][-1]:.2e}")
print(f"Final validation loss: {inference_advanced._summary['validation_log_probs'][-1]:.4f}")

### Learning rate scheduling with early stopping

In [None]:
# Demonstrate early stopping based on minimum learning rate
inference_early_stop = NPE(prior=prior, show_progress_bars=False)
inference_early_stop.append_simulations(theta, x)

posterior_early_stop = inference_early_stop.train(
    lr_scheduler="exponential",
    lr_scheduler_kwargs={"gamma": 0.85},  # Fast decay
    learning_rate=5e-3,
    max_num_epochs=300,                   # High max epochs
    min_lr_threshold=1e-4,                # Stop when LR < 1e-4
    stop_after_epochs=100,                # Also stop if no improvement
    validation_fraction=0.1
)

# Plot the learning rate decay and stopping point
epochs = inference_early_stop._summary['epochs_trained']
lrs = inference_early_stop._summary['learning_rates']
val_losses = inference_early_stop._summary['validation_log_probs']

fig, (ax1, ax2) = plt.subplots(2, 1, figsize=(10, 8))

# Learning rate schedule
ax1.plot(epochs, lrs, 'b-', linewidth=2)
ax1.axhline(y=1e-4, color='r', linestyle='--', alpha=0.7, label='min_lr_threshold')
ax1.set_ylabel('Learning Rate')
ax1.set_yscale('log')
ax1.set_title('Learning Rate Schedule with Early Stopping')
ax1.legend()
ax1.grid(True, alpha=0.3)

# Validation loss
ax2.plot(epochs, val_losses, 'g-', linewidth=2)
ax2.set_xlabel('Epoch')
ax2.set_ylabel('Validation Loss')
ax2.set_title('Validation Loss During Training')
ax2.grid(True, alpha=0.3)

plt.tight_layout()
plt.show()

print(f"Training stopped after {epochs[-1]} epochs due to learning rate threshold")
print(f"Final learning rate: {lrs[-1]:.2e}")

## Schedulers with NLE and NRE

Learning rate schedulers work identically across all `sbi` inference methods:

In [None]:
# Test schedulers with NLE and NRE
methods = {
    'NPE': NPE,
    'NLE': NLE_A,
    'NRE': NRE_A
}

scheduler_config = {
    "type": "plateau",
    "factor": 0.5,
    "patience": 5
}

method_results = {}

for name, Method in methods.items():
    print(f"Training {name} with plateau scheduler...")

    inference = Method(prior=prior, show_progress_bars=False)
    inference.append_simulations(theta, x)

    estimator = inference.train(
        lr_scheduler=scheduler_config,
        learning_rate=1e-3,
        max_num_epochs=100,
        validation_fraction=0.1,
        training_batch_size=50
    )

    epochs_trained = inference._summary['epochs_trained'][-1]
    final_lr = inference._summary['learning_rates'][-1]

    method_results[name] = {
        'epochs': epochs_trained,
        'final_lr': final_lr,
        'inference': inference
    }

    print(f"  Converged after {epochs_trained} epochs, final LR: {final_lr:.2e}")

# Compare learning rate evolution across methods
plt.figure(figsize=(10, 6))

for name, result in method_results.items():
    epochs = result['inference']._summary['epochs_trained']
    lrs = result['inference']._summary['learning_rates']
    plt.plot(epochs, lrs, label=f'{name}', linewidth=2, alpha=0.8)

plt.xlabel('Epoch')
plt.ylabel('Learning Rate')
plt.title('Learning Rate Schedule Across Different Inference Methods')
plt.yscale('log')
plt.legend()
plt.grid(True, alpha=0.3)
plt.show()

## Best practices and recommendations

Based on the experiments above, here are key takeaways:

1. **ReduceLROnPlateau is often the best choice**: It adapts to your specific training dynamics and typically provides good results across different problems.

2. **Start with higher learning rates**: Schedulers allow you to begin with more aggressive learning rates that would otherwise cause instability.

3. **Use min_lr_threshold**: This prevents wasted computation when learning rates become ineffectively small.

4. **Monitor training curves**: The learning rate schedule should align with improvements in validation loss.

5. **Scheduler choice depends on your problem**: 
   - Plateau: Good default choice, adapts to training progress
   - Exponential: Simple and predictable decay
   - Cosine: Good for longer training runs
   - Cyclic: Can help escape local minima

6. **Combine with early stopping**: Use both scheduler-based and patience-based stopping criteria for robust training.

## Summary

Learning rate schedulers provide a powerful way to improve `sbi` training:

- **Easy to use**: Just add `lr_scheduler` and `lr_scheduler_kwargs` parameters to `.train()`
- **Universal compatibility**: Works with NPE, NLE, NRE, and all their variants
- **Flexible configuration**: Use string shortcuts or detailed dictionary configs
- **Automatic monitoring**: Learning rates are tracked in training summaries
- **Early stopping**: Prevent over-training with `min_lr_threshold`

Experiment with different schedulers to find what works best for your specific inference problem!