[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/imewei/NLSQ/blob/main/examples/notebooks/06_streaming/01_basic_fault_tolerance.ipynb)


In [None]:
# @title Install NLSQ (run once in Colab)
import sys
if 'google.colab' in sys.modules:
    print("Running in Google Colab - installing NLSQ...")
    !pip install -q nlsq
    print("✅ NLSQ installed successfully!")
else:
    print("Not running in Colab - assuming NLSQ is already installed")

In [1]:
# Configure matplotlib for inline plotting in VS Code/Jupyter
# MUST come before importing matplotlib
%matplotlib inline

In [2]:
import jax.numpy as jnp
import numpy as np

from nlsq import StreamingConfig, StreamingOptimizer


def exponential_decay(x, a, b):
    """Exponential decay model: y = a * exp(-b * x)"""
    return a * jnp.exp(-b * x)


def main():
    print("=" * 70)
    print("Streaming Optimizer: Basic Fault Tolerance Example")
    print("=" * 70)
    print()

    np.random.seed(42)
    n_samples = 10000
    x_data = np.linspace(0, 10, n_samples)
    true_a, true_b = 2.5, 0.3
    y_true = exponential_decay(x_data, true_a, true_b)
    y_data = y_true + 0.1 * np.random.randn(n_samples)

    print(f"Dataset: {n_samples} samples")
    print(f"True parameters: a={true_a}, b={true_b}")
    print()

    config = StreamingConfig(
        batch_size=100,
        max_epochs=10,
        learning_rate=0.001,
        enable_fault_tolerance=True,  # Enable fault tolerance features
        validate_numerics=True,  # Check for NaN/Inf
        min_success_rate=0.5,  # Require 50% batch success
        max_retries_per_batch=2,  # Max 2 retry attempts
        checkpoint_dir="checkpoints",
        checkpoint_frequency=100,  # Save every 100 iterations
        enable_checkpoints=True,
    )

    print("Configuration:")
    print(f"  Batch size: {config.batch_size}")
    print(f"  Max epochs: {config.max_epochs}")
    print(f"  Learning rate: {config.learning_rate}")
    print(f"  Fault tolerance: {config.enable_fault_tolerance}")
    print(f"  Validate numerics: {config.validate_numerics}")
    print(f"  Min success rate: {config.min_success_rate:.0%}")
    print(f"  Max retries per batch: {config.max_retries_per_batch}")
    print()

    optimizer = StreamingOptimizer(config)
    p0 = np.array([1.0, 0.1])
    print(f"Initial guess: a={p0[0]}, b={p0[1]}")
    print()

    print("Starting optimization...")
    print("-" * 70)
    result = optimizer.fit(
        (x_data, y_data),  # Data as tuple
        exponential_decay,  # Model function
        p0,  # Initial parameters
        verbose=1,  # Show progress
    )
    print("-" * 70)
    print()

    best_params = result["x"]
    success = result["success"]
    message = result["message"]
    best_loss = result["best_loss"]
    diagnostics = result["streaming_diagnostics"]

    print("RESULTS")
    print("=" * 70)
    print(f"Success: {success}")
    print(f"Message: {message}")
    print()
    print("Best parameters found:")
    print(f"  a = {best_params[0]:.6f} (true: {true_a})")
    print(f"  b = {best_params[1]:.6f} (true: {true_b})")
    print(f"  Best loss = {best_loss:.6e}")
    print()

    print("DIAGNOSTICS")
    print("=" * 70)
    print(f"Batch success rate: {diagnostics['batch_success_rate']:.1%}")
    print(f"Total batches attempted: {diagnostics['total_batches_attempted']}")
    print(f"Total retries: {diagnostics['total_retries']}")
    print(f"Convergence achieved: {diagnostics['convergence_achieved']}")
    print(f"Final epoch: {diagnostics['final_epoch']}")
    print(f"Elapsed time: {diagnostics['elapsed_time']:.2f}s")
    print()

    if diagnostics["failed_batches"]:
        print(f"Failed batches ({len(diagnostics['failed_batches'])}):")
        print(f"  Indices: {diagnostics['failed_batches']}")
        print(f"  Error types: {diagnostics['error_types']}")
        print()

    agg = diagnostics["aggregate_stats"]
    print("Aggregate Statistics (from batch buffer):")
    print(f"  Mean loss: {agg['mean_loss']:.6e}")
    print(f"  Std loss: {agg['std_loss']:.6e}")
    print(f"  Mean gradient norm: {agg['mean_grad_norm']:.6f}")
    print(f"  Mean batch time: {agg['mean_batch_time'] * 1000:.2f}ms")
    print()

    recent_stats = diagnostics["recent_batch_stats"]
    if recent_stats:
        print(f"Recent batch statistics (last {len(recent_stats)} batches):")
        # Convert deque to list for slicing
        recent_list = list(recent_stats)[-5:]
        for i, stats in enumerate(recent_list, 1):
            status = "SUCCESS" if stats["success"] else "FAILED"
            retry_info = (
                f" (retries: {stats['retry_count']})"
                if stats["retry_count"] > 0
                else ""
            )
            print(
                f"  Batch {stats['batch_idx']}: {status}, loss={stats['loss']:.6e}{retry_info}"
            )
        print()

    if diagnostics["checkpoint_info"]:
        cp = diagnostics["checkpoint_info"]
        print("Checkpoint Information:")
        print(f"  Path: {cp['path']}")
        print(f"  Saved at: {cp['saved_at']}")
        print(f"  Batch index: {cp['batch_idx']}")
        print()

    print("=" * 70)
    print("Example complete!")
    print()
    print("Key takeaways:")
    print("  - Fault tolerance enabled by default (no configuration needed)")
    print("  - Best parameters always returned (never initial p0)")
    print("  - NaN/Inf detection at three validation points")
    print("  - Adaptive retry strategies for failed batches")
    print("  - Comprehensive diagnostics for analysis")
    print("  - Checkpoints saved automatically for recovery")


if __name__ == "__main__":
    main()

Streaming Optimizer: Basic Fault Tolerance Example

Dataset: 10000 samples
True parameters: a=2.5, b=0.3

Configuration:
  Batch size: 100
  Max epochs: 10
  Learning rate: 0.001
  Fault tolerance: True
  Validate numerics: True
  Min success rate: 50%
  Max retries per batch: 2

Initial guess: a=1.0, b=0.1

Starting optimization...
----------------------------------------------------------------------


----------------------------------------------------------------------

RESULTS
Success: True
Message: Optimization complete: 1000/1000 batches succeeded (100.0%)

Best parameters found:
  a = 1.189291 (true: 2.5)
  b = 0.134749 (true: 0.3)
  Best loss = 7.611412e-03

DIAGNOSTICS
Batch success rate: 100.0%
Total batches attempted: 1000
Total retries: 0
Convergence achieved: False
Final epoch: 9
Elapsed time: 0.62s

Aggregate Statistics (from batch buffer):
  Mean loss: 1.625896e-01
  Std loss: 3.034042e-01
  Mean gradient norm: 0.948855
  Mean batch time: 0.40ms

Recent batch statistics (last 100 batches):
  Batch 95: SUCCESS, loss=1.803398e-02
  Batch 96: SUCCESS, loss=1.725470e-02
  Batch 97: SUCCESS, loss=1.732153e-02
  Batch 98: SUCCESS, loss=1.525652e-02
  Batch 99: SUCCESS, loss=1.644096e-02

Checkpoint Information:
  Path: checkpoints/checkpoint_iter_1000.h5
  Saved at: 2025-12-18T15:00:26.340074
  Batch index: 99

Example complete!

Key takeaways:
  - Fault tolerance enabled by