# 06 Auto Selection

[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/imewei/NLSQ/blob/main/examples/notebooks/08_workflow_system/06_auto_selection.ipynb)

Features demonstrated:
- MemoryBudgetSelector decision algorithm
- MemoryBudget computation for dataset sizing
- Strategy selection: standard, chunked, streaming
- Memory detection with get_available_memory_gb()
- Adaptive tolerance calculation

Run this example:
    python examples/scripts/08_workflow_system/06_auto_selection.py

In [None]:
# @title Install NLSQ (run once in Colab)
import sys

if 'google.colab' in sys.modules:
    print("Running in Google Colab - installing NLSQ...")
    !pip install -q nlsq
    print("NLSQ installed successfully!")
else:
    print("Not running in Colab - assuming NLSQ is already installed")

In [None]:
from pathlib import Path

import matplotlib.pyplot as plt
import numpy as np

from nlsq import OptimizationGoal
from nlsq.core.workflow import (
    MemoryBudget,
    MemoryBudgetSelector,
    calculate_adaptive_tolerances,
)
from nlsq.streaming.large_dataset import MemoryEstimator

FIG_DIR = Path.cwd() / "figures"
FIG_DIR.mkdir(parents=True, exist_ok=True)

In [None]:
def main():
    print("=" * 70)
    print("Automatic Workflow Selection Deep Dive")
    print("=" * 70)
    print()

    np.random.seed(42)

    # =========================================================================
    # 1. Display the selection decision tree
    # =========================================================================
    print("1. MemoryBudgetSelector Decision Tree")
    print("-" * 80)
    print()
    print("The selector uses a simple decision tree based on memory requirements:")
    print()
    print("  1. Compute MemoryBudget (data_gb, jacobian_gb, peak_gb)")
    print("  2. If data_gb > threshold -> STREAMING")
    print("  3. Else if peak_gb > threshold -> CHUNKED")
    print("  4. Else -> STANDARD")
    print()
    print("Available strategies:")
    print("  standard  - Full in-memory computation")
    print("  chunked   - Memory-managed chunk processing")
    print("  streaming - Mini-batch gradient descent")

    # =========================================================================
    # 2. MemoryBudget Computation
    # =========================================================================
    print()
    print()
    print("2. MemoryBudget Computation")
    print("-" * 70)
    print(f"{'Dataset':<12} {'Data GB':<12} {'Jacobian GB':<15} {'Peak GB':<12} {'Fits?':<8}")
    print("-" * 70)

    dataset_configs = [
        (10_000, 5, "10K"),
        (100_000, 5, "100K"),
        (1_000_000, 5, "1M"),
        (10_000_000, 5, "10M"),
        (100_000_000, 5, "100M"),
    ]

    for n_points, n_params, label in dataset_configs:
        budget = MemoryBudget.compute(
            n_points=n_points, n_params=n_params, safety_factor=0.75
        )
        fits = "Yes" if budget.fits_in_memory else "No"
        print(
            f"{label:<12} {budget.data_gb:<12.4f} {budget.jacobian_gb:<15.4f} "
            f"{budget.peak_gb:<12.4f} {fits:<8}"
        )

    # =========================================================================
    # 3. MemoryBudgetSelector in Action
    # =========================================================================
    print()
    print()
    print("3. MemoryBudgetSelector Decisions")
    print("-" * 70)

    available_memory = MemoryEstimator.get_available_memory_gb()
    selector = MemoryBudgetSelector(safety_factor=0.75)

    print(f"Available memory: {available_memory:.1f} GB")
    print(f"Threshold (75%): {available_memory * 0.75:.1f} GB")
    print()

    test_sizes = [10_000, 100_000, 1_000_000, 10_000_000, 100_000_000]
    n_params = 5

    print(f"{'Dataset Size':<15} {'Strategy':<15} {'Config Type':<20}")
    print("-" * 50)

    for n_points in test_sizes:
        strategy, config = selector.select(n_points=n_points, n_params=n_params)
        config_type = type(config).__name__ if config else "None"
        size_str = f"{n_points:,}"
        print(f"{size_str:<15} {strategy:<15} {config_type:<20}")

    # =========================================================================
    # 4. Selection with Different Memory Limits
    # =========================================================================
    print()
    print()
    print("4. Selection with Different Memory Limits")
    print("-" * 70)

    memory_limits = [8.0, 32.0, 64.0, 128.0]  # GB
    n_points = 5_000_000  # 5M points
    n_params = 5

    for mem_limit in memory_limits:
        strategy, config = selector.select(
            n_points=n_points, n_params=n_params, memory_limit_gb=mem_limit
        )
        config_type = type(config).__name__ if config else "None"
        print(f"  Memory: {mem_limit:>6.0f} GB -> {strategy:<12} ({config_type})")

    # =========================================================================
    # 5. Adaptive Tolerances
    # =========================================================================
    print()
    print()
    print("5. Adaptive Tolerance Calculation")
    print("-" * 60)

    test_configs = [
        (1_000, None),
        (1_000_000, None),
        (1_000_000, OptimizationGoal.FAST),
        (1_000_000, OptimizationGoal.QUALITY),
        (100_000_000, OptimizationGoal.ROBUST),
    ]

    print(f"{'n_points':<12s} | {'Goal':<12s} | {'gtol':<12s} | {'ftol':<12s}")
    print("-" * 60)

    for n_points, goal in test_configs:
        tolerances = calculate_adaptive_tolerances(n_points, goal)

        if n_points >= 1_000_000:
            size_str = f"{n_points / 1_000_000:.0f}M"
        else:
            size_str = f"{n_points / 1_000:.0f}K"

        goal_str = goal.name if goal else "None"

        print(
            f"{size_str:<12s} | {goal_str:<12s} | {tolerances['gtol']:.0e} | {tolerances['ftol']:.0e}"
        )

    # =========================================================================
    # 6. Defense Layer Awareness
    # =========================================================================
    print()
    print()
    print("6. Defense Layer Awareness for Streaming")
    print("-" * 70)
    print()
    print("When MemoryBudgetSelector chooses STREAMING strategy,")
    print("the returned HybridStreamingConfig includes 4-layer defense.")
    print()
    print("Defense presets for streaming workflows:")
    print("  defense_strict()     - Warm-start refinement (previous fit as p0)")
    print("  defense_relaxed()    - Exploration (rough initial guesses)")
    print("  scientific_default() - Production scientific computing (default)")
    print("  defense_disabled()   - Pre-0.3.6 behavior (no protection)")
    print()
    print("The 4-layer defense strategy protects against L-BFGS warmup divergence:")
    print("  Layer 1: Warm Start Detection - Skip warmup if near optimal")
    print("  Layer 2: Adaptive Step Size - Scale step size based on fit quality")
    print("  Layer 3: Cost-Increase Guard - Abort if loss increases > 5%")
    print("  Layer 4: Step Clipping - Limit parameter update magnitude")

    # =========================================================================
    # 7. Visualization
    # =========================================================================
    print()
    print()
    print("7. Saving selection algorithm visualization...")

    fig, ax = plt.subplots(figsize=(12, 8))

    # Strategy boundaries visualization
    dataset_sizes = np.logspace(4, 9, 100)  # 10K to 1B
    memory_limits = np.linspace(4, 128, 50)

    n_params = 5
    strategy_map = np.zeros((len(memory_limits), len(dataset_sizes)))

    for i, mem_limit in enumerate(memory_limits):
        for j, n_points in enumerate(dataset_sizes):
            strategy, _ = selector.select(
                n_points=int(n_points), n_params=n_params, memory_limit_gb=mem_limit
            )
            if strategy == "streaming":
                strategy_map[i, j] = 2
            elif strategy == "chunked":
                strategy_map[i, j] = 1
            else:
                strategy_map[i, j] = 0

    cmap = plt.cm.RdYlGn_r
    im = ax.imshow(
        strategy_map,
        aspect="auto",
        origin="lower",
        cmap=cmap,
        extent=[4, 9, 4, 128],
    )

    ax.set_xlabel("Dataset Size (log10)")
    ax.set_ylabel("Memory Limit (GB)")
    ax.set_title("Strategy Selection Boundaries (5 parameters)")

    cbar = plt.colorbar(im, ax=ax, ticks=[0, 1, 2])
    cbar.ax.set_yticklabels(["Standard", "Chunked", "Streaming"])

    ax.axhline(y=available_memory, color="white", linestyle="--", linewidth=2)
    ax.text(
        9.05,
        available_memory,
        f"Current: {available_memory:.0f} GB",
        color="white",
        va="center",
    )

    plt.tight_layout()
    plt.savefig(FIG_DIR / "06_selection_algorithm.png", dpi=300, bbox_inches="tight")
    plt.close()
    print(f"  Saved: {FIG_DIR / '06_selection_algorithm.png'}")

    # =========================================================================
    # Summary
    # =========================================================================
    print()
    print()
    print("=" * 70)
    print("Summary")
    print("=" * 70)
    print()
    print("Key Functions:")
    print("  MemoryBudget.compute(n_points, n_params)")
    print("  MemoryBudgetSelector().select(n_points, n_params)")
    print("  calculate_adaptive_tolerances(n_points, goal)")
    print("  MemoryEstimator.get_available_memory_gb()")
    print()
    print(f"Current System Memory: {available_memory:.1f} GB")
    print()
    print("Key Takeaways:")
    print("  - MemoryBudgetSelector uses a decision tree based on memory requirements")
    print("  - Decision: data_gb > threshold -> streaming; peak_gb > threshold -> chunked")
    print("  - Override with memory_limit_gb for reproducible behavior")
    print("  - Streaming configs include 4-layer defense by default")
    print()
    print("Usage with fit():")
    print("  popt, pcov = fit(model, x, y, workflow='auto')  # Automatic selection")
    print("  popt, pcov = fit(model, x, y, workflow='standard')  # Force standard")

In [None]:
if __name__ == "__main__":
    main()