# 03 Optimization Goals

[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/imewei/NLSQ/blob/main/examples/notebooks/08_workflow_system/03_optimization_goals.ipynb)

Features demonstrated:
- OptimizationGoal enum values and their behaviors
- Adaptive tolerances based on dataset size
- Using workflow presets with fit()
- Combining workflow with custom tolerances

Run this example:
    python examples/scripts/08_workflow_system/03_optimization_goals.py

In [None]:
# @title Install NLSQ (run once in Colab)
import sys

if 'google.colab' in sys.modules:
    print("Running in Google Colab - installing NLSQ...")
    !pip install -q nlsq
    print("NLSQ installed successfully!")
else:
    print("Not running in Colab - assuming NLSQ is already installed")

In [None]:
import time
from pathlib import Path

import jax.numpy as jnp
import matplotlib.pyplot as plt
import numpy as np

from nlsq import OptimizationGoal, fit
from nlsq.core.workflow import calculate_adaptive_tolerances

FIG_DIR = Path.cwd() / "figures"
FIG_DIR.mkdir(parents=True, exist_ok=True)

In [None]:
def exponential_decay(x, a, b, c):
    """Exponential decay: y = a * exp(-b * x) + c"""
    return a * jnp.exp(-b * x) + c


def main():
    print("=" * 70)
    print("Optimization Goals and Adaptive Tolerances")
    print("=" * 70)
    print()

    np.random.seed(42)

    # =========================================================================
    # 1. OptimizationGoal Overview
    # =========================================================================
    print("1. OptimizationGoal Values:")
    print("-" * 60)

    for goal in OptimizationGoal:
        print(f"  {goal.name:<20} = {goal.value}")

    goal_info = {
        OptimizationGoal.FAST: {
            "description": "Prioritize speed with local optimization only",
            "tolerances": "One tier looser",
            "multistart": "Disabled",
            "use_case": "Quick exploration, well-conditioned problems",
        },
        OptimizationGoal.ROBUST: {
            "description": "Standard tolerances with multi-start",
            "tolerances": "Dataset-appropriate",
            "multistart": "Enabled",
            "use_case": "Production use, unknown problem conditioning",
        },
        OptimizationGoal.QUALITY: {
            "description": "Highest precision/accuracy as TOP PRIORITY",
            "tolerances": "One tier tighter",
            "multistart": "Enabled + validation passes",
            "use_case": "Publication-quality results",
        },
    }

    print("\nGoal Details:")
    print("-" * 80)

    for goal, info in goal_info.items():
        print(f"\n  {goal.name}:")
        print(f"    Description:  {info['description']}")
        print(f"    Tolerances:   {info['tolerances']}")
        print(f"    Multi-start:  {info['multistart']}")
        print(f"    Use case:     {info['use_case']}")

    # =========================================================================
    # 2. Adaptive Tolerances
    # =========================================================================
    print()
    print("2. Adaptive Tolerances by Dataset Size and Goal:")
    print("-" * 70)
    print(f"{'Dataset Size':<15} {'FAST':<15} {'ROBUST':<15} {'QUALITY':<15}")
    print("-" * 70)

    dataset_sizes = [500, 5_000, 50_000, 500_000, 5_000_000]
    goals_to_compare = [
        OptimizationGoal.FAST,
        OptimizationGoal.ROBUST,
        OptimizationGoal.QUALITY,
    ]

    for n_points in dataset_sizes:
        tols = {}
        for goal in goals_to_compare:
            tols[goal.name] = calculate_adaptive_tolerances(n_points, goal)["gtol"]

        print(
            f"{n_points:>12,}   {tols['FAST']:<15.0e} "
            f"{tols['ROBUST']:<15.0e} {tols['QUALITY']:<15.0e}"
        )

    # =========================================================================
    # 3. Practical Comparison
    # =========================================================================
    print()
    print("3. Testing Workflows on Exponential Decay Problem:")
    print("-" * 70)

    n_samples = 1000
    x_data = np.linspace(0, 5, n_samples)

    true_a, true_b, true_c = 3.0, 1.2, 0.5

    y_true = true_a * np.exp(-true_b * x_data) + true_c
    noise = 0.1 * np.random.randn(n_samples)
    y_data = y_true + noise

    p0 = [1.0, 0.5, 0.0]
    bounds = ([0.1, 0.1, -1.0], [10.0, 5.0, 2.0])

    print(f"  True parameters: a={true_a}, b={true_b}, c={true_c}")
    print(f"  Dataset size: {n_samples} points")

    results = {}
    workflows_to_test = ["fast", "standard", "quality"]

    for workflow_name in workflows_to_test:
        start_time = time.time()

        popt, pcov = fit(
            exponential_decay,
            x_data,
            y_data,
            p0=p0,
            bounds=bounds,
            workflow=workflow_name,
        )

        elapsed = time.time() - start_time

        y_pred = exponential_decay(x_data, *popt)
        ssr = float(jnp.sum((y_data - y_pred) ** 2))

        param_errors = [abs(popt[i] - [true_a, true_b, true_c][i]) for i in range(3)]

        results[workflow_name] = {
            "popt": popt,
            "ssr": ssr,
            "time": elapsed,
            "errors": param_errors,
        }

        print(f"\n  {workflow_name.upper()}:")
        print(f"    Time:       {elapsed:.4f}s")
        print(f"    SSR:        {ssr:.6f}")
        print(f"    Parameters: a={popt[0]:.4f}, b={popt[1]:.4f}, c={popt[2]:.4f}")
        print(
            f"    Errors:     a_err={param_errors[0]:.4f}, "
            f"b_err={param_errors[1]:.4f}, c_err={param_errors[2]:.4f}"
        )

    # =========================================================================
    # 4. Custom Tolerances with Workflow
    # =========================================================================
    print()
    print("4. Custom Tolerances with Workflow:")
    print("-" * 60)

    print("\n  Combining workflow='standard' with custom tolerances:")
    popt_custom, _ = fit(
        exponential_decay,
        x_data,
        y_data,
        p0=p0,
        bounds=bounds,
        workflow="standard",
        gtol=1e-12,
        ftol=1e-12,
        xtol=1e-12,
    )
    print(f"    Parameters: a={popt_custom[0]:.6f}, b={popt_custom[1]:.6f}, c={popt_custom[2]:.6f}")

    # =========================================================================
    # 5. Visualization
    # =========================================================================
    print()
    print("5. Saving visualizations...")

    fig, axes = plt.subplots(2, 2, figsize=(14, 10))

    colors = {"fast": "blue", "standard": "green", "quality": "red"}

    # Top left: Tolerance comparison across dataset sizes
    ax1 = axes[0, 0]
    sizes = np.logspace(2, 8, 50).astype(int)

    for goal in [
        OptimizationGoal.FAST,
        OptimizationGoal.ROBUST,
        OptimizationGoal.QUALITY,
    ]:
        tols = [calculate_adaptive_tolerances(n, goal)["gtol"] for n in sizes]
        ax1.loglog(sizes, tols, label=goal.name, linewidth=2)

    ax1.set_xlabel("Dataset Size (points)")
    ax1.set_ylabel("gtol")
    ax1.set_title("Adaptive Tolerances by Goal")
    ax1.legend()
    ax1.grid(True, alpha=0.3)

    # Top right: SSR comparison
    ax2 = axes[0, 1]
    workflow_names = list(results.keys())
    ssrs = [results[g]["ssr"] for g in workflow_names]
    bars = ax2.bar(workflow_names, ssrs, color=[colors[g] for g in workflow_names])
    ax2.set_xlabel("Workflow")
    ax2.set_ylabel("Sum of Squared Residuals")
    ax2.set_title("Fit Quality by Workflow")
    for bar, ssr in zip(bars, ssrs, strict=False):
        ax2.text(
            bar.get_x() + bar.get_width() / 2,
            bar.get_height(),
            f"{ssr:.4f}",
            ha="center",
            va="bottom",
            fontsize=9,
        )

    # Bottom left: Time comparison
    ax3 = axes[1, 0]
    times = [results[g]["time"] for g in workflow_names]
    bars = ax3.bar(workflow_names, times, color=[colors[g] for g in workflow_names])
    ax3.set_xlabel("Workflow")
    ax3.set_ylabel("Time (seconds)")
    ax3.set_title("Computation Time by Workflow")
    for bar, t in zip(bars, times, strict=False):
        ax3.text(
            bar.get_x() + bar.get_width() / 2,
            bar.get_height(),
            f"{t:.3f}s",
            ha="center",
            va="bottom",
            fontsize=9,
        )

    # Bottom right: Parameter errors
    ax4 = axes[1, 1]
    x_pos = np.arange(len(workflow_names))
    width = 0.25

    for i, param in enumerate(["a", "b", "c"]):
        errors = [results[g]["errors"][i] for g in workflow_names]
        ax4.bar(x_pos + i * width, errors, width, label=f"{param} error")

    ax4.set_xlabel("Workflow")
    ax4.set_ylabel("Absolute Error")
    ax4.set_title("Parameter Errors by Workflow")
    ax4.set_xticks(x_pos + width)
    ax4.set_xticklabels(workflow_names)
    ax4.legend()

    plt.tight_layout()
    plt.savefig(FIG_DIR / "03_goal_comparison.png", dpi=300, bbox_inches="tight")
    plt.close()
    print(f"  Saved: {FIG_DIR / '03_goal_comparison.png'}")

    # =========================================================================
    # Summary
    # =========================================================================
    print()
    print("=" * 70)
    print("Summary")
    print("=" * 70)
    print(f"True parameters: a={true_a}, b={true_b}, c={true_c}")
    print()
    print("Workflow recommendations:")
    print("  - Exploratory analysis:    workflow='fast'")
    print("  - Production fitting:      workflow='standard'")
    print("  - Publication quality:     workflow='quality'")
    print()
    print("Key behaviors:")
    print("  - fast: Looser tolerances, no multi-start")
    print("  - standard: Balanced tolerances, optional multi-start")
    print("  - quality: Tighter tolerances, multi-start enabled")
    print()
    print("Usage:")
    print("  popt, pcov = fit(model, x, y, workflow='quality')")
    print("  popt, pcov = fit(model, x, y, workflow='standard', gtol=1e-12)")

In [None]:
if __name__ == "__main__":
    main()