# 05 YAML Configuration

[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/imewei/NLSQ/blob/main/examples/notebooks/08_workflow_system/05_yaml_configuration.ipynb)

Features demonstrated:
- Create an nlsq.yaml configuration file
- Configure tolerances, memory limits via YAML
- Use yaml config with curve_fit()
- Override YAML settings with environment variables

Run this example:
    python examples/scripts/08_workflow_system/05_yaml_configuration.py

In [None]:
import os
from pathlib import Path

import jax.numpy as jnp
import numpy as np

try:
    import yaml
except ImportError:
    print("pyyaml is required. Install with: pip install pyyaml")
    raise

from nlsq import curve_fit

QUICK = os.environ.get("NLSQ_EXAMPLES_QUICK") == "1"

In [None]:
def exponential_decay(x, a, b, c):
    """Exponential decay model."""
    return a * jnp.exp(-b * x) + c


def load_yaml_config(config_path="nlsq.yaml"):
    """Load YAML configuration file."""
    path = Path(config_path)
    if not path.exists():
        return None
    with open(path) as f:
        return yaml.safe_load(f)


def get_workflow_settings(yaml_config, workflow_name):
    """Get workflow settings from YAML config."""
    if yaml_config is None:
        return {}
    workflows = yaml_config.get("workflows", {})
    return workflows.get(workflow_name, {})


def fit_with_yaml_config(
    f,
    xdata,
    ydata,
    p0=None,
    bounds=(-np.inf, np.inf),
    workflow_name=None,
    config_path="nlsq.yaml",
):
    """Curve fit using YAML-defined workflow configuration.

    Parameters
    ----------
    f : callable
        Model function
    xdata, ydata : array_like
        Data to fit
    p0 : array_like, optional
        Initial parameters
    bounds : tuple, optional
        Parameter bounds
    workflow_name : str, optional
        Name of workflow from nlsq.yaml
    config_path : str, optional
        Path to YAML config file

    Returns
    -------
    popt, pcov : tuple
        Fitted parameters and covariance
    """
    yaml_config = load_yaml_config(config_path)

    if yaml_config is None:
        print("No YAML config found, using defaults")
        return curve_fit(f, xdata, ydata, p0=p0, bounds=bounds)

    if workflow_name is None:
        workflow_name = yaml_config.get("default_workflow", "standard")

    settings = get_workflow_settings(yaml_config, workflow_name)

    return curve_fit(
        f,
        xdata,
        ydata,
        p0=p0,
        bounds=bounds,
        gtol=settings.get("gtol", 1e-8),
        ftol=settings.get("ftol", 1e-8),
        xtol=settings.get("xtol", 1e-8),
        multistart=settings.get("enable_multistart", False),
        n_starts=settings.get("n_starts", 0) if settings.get("enable_multistart", False) else 0,
        sampler=settings.get("sampler", "lhs"),
    )

In [None]:
def main():
    print("=" * 70)
    print("YAML Configuration for NLSQ Workflows")
    print("=" * 70)
    print()

    np.random.seed(42)

    # =========================================================================
    # 1. Create example nlsq.yaml
    # =========================================================================
    print("1. Creating example nlsq.yaml...")

    config = {
        "default_workflow": "standard",
        "memory_limit_gb": 16.0,
        "workflows": {
            "high_precision": {
                "gtol": 1e-10,
                "ftol": 1e-10,
                "xtol": 1e-10,
                "enable_multistart": True,
                "n_starts": 4 if QUICK else 20,
                "sampler": "lhs",
            },
            "quick_explore": {
                "gtol": 1e-5,
                "ftol": 1e-5,
                "xtol": 1e-5,
                "enable_multistart": False,
            },
            "large_data": {
                "gtol": 1e-8,
                "ftol": 1e-8,
                "xtol": 1e-8,
                "memory_limit_gb": 8.0,
                "enable_multistart": True,
                "n_starts": 4 if QUICK else 10,
            },
        },
    }

    config_path = Path("nlsq.yaml")
    with open(config_path, "w") as f:
        yaml.dump(config, f, default_flow_style=False, sort_keys=False)

    print("  Created nlsq.yaml")
    print()
    print("  Contents:")
    print("  " + "-" * 40)
    for line in config_path.read_text().split("\n")[:15]:
        print(f"  {line}")
    print("  ...")

    # =========================================================================
    # 2. Load YAML configuration
    # =========================================================================
    print()
    print("2. Loading YAML configuration...")

    loaded_config = load_yaml_config()

    print(f"  default_workflow: {loaded_config.get('default_workflow')}")
    print(f"  memory_limit_gb: {loaded_config.get('memory_limit_gb')}")
    print(f"  workflows defined: {list(loaded_config.get('workflows', {}).keys())}")

    # =========================================================================
    # 3. Get workflow settings
    # =========================================================================
    print()
    print("3. Getting workflow settings:")

    for wf_name in ["high_precision", "quick_explore", "large_data"]:
        settings = get_workflow_settings(loaded_config, wf_name)
        if settings:
            print(f"\n  {wf_name}:")
            print(f"    gtol: {settings.get('gtol', 'default')}")
            print(f"    enable_multistart: {settings.get('enable_multistart', False)}")
            if settings.get('enable_multistart'):
                print(f"    n_starts: {settings.get('n_starts')}")

    # =========================================================================
    # 4. Using YAML config with curve_fit
    # =========================================================================
    print()
    print("4. Using YAML config with curve_fit:")

    x_data = np.linspace(0, 5, 100 if QUICK else 300)
    true_a, true_b, true_c = 2.5, 1.2, 0.5
    y_true = true_a * np.exp(-true_b * x_data) + true_c
    y_data = y_true + 0.1 * np.random.randn(len(x_data))

    print(f"  True parameters: a={true_a}, b={true_b}, c={true_c}")

    popt, _ = fit_with_yaml_config(
        exponential_decay,
        x_data,
        y_data,
        p0=[1.0, 1.0, 0.0],
        bounds=([0, 0, -1], [10, 5, 2]),
        workflow_name="high_precision",
    )

    print()
    print("  high_precision workflow result:")
    print(f"    a={popt[0]:.6f}, b={popt[1]:.6f}, c={popt[2]:.6f}")

    # =========================================================================
    # 5. Comparing workflows
    # =========================================================================
    print()
    print("5. Comparing workflows:")

    popt1, _ = fit_with_yaml_config(
        exponential_decay,
        x_data,
        y_data,
        p0=[1.0, 1.0, 0.0],
        bounds=([0, 0, -1], [10, 5, 2]),
        workflow_name="high_precision",
    )
    print(f"  high_precision: a={popt1[0]:.4f}, b={popt1[1]:.4f}, c={popt1[2]:.4f}")

    popt2, _ = fit_with_yaml_config(
        exponential_decay,
        x_data,
        y_data,
        p0=[1.0, 1.0, 0.0],
        bounds=([0, 0, -1], [10, 5, 2]),
        workflow_name="quick_explore",
    )
    print(f"  quick_explore:  a={popt2[0]:.4f}, b={popt2[1]:.4f}, c={popt2[2]:.4f}")

    # =========================================================================
    # Cleanup and Summary
    # =========================================================================
    if config_path.exists():
        config_path.unlink()
        print()
        print("  Cleaned up nlsq.yaml")

    print()
    print("=" * 70)
    print("Summary")
    print("=" * 70)
    print()
    print("YAML configuration enables:")
    print("  - Reproducible workflow settings")
    print("  - Easy sharing between collaborators")
    print("  - Project-specific configurations")
    print()
    print("Key pattern:")
    print("  1. Define workflows in nlsq.yaml")
    print("  2. Load config and extract settings")
    print("  3. Pass settings to curve_fit() or fit()")
    print()
    print("Environment variables:")
    print("  - NLSQ_MEMORY_LIMIT_GB")
    print("  - NLSQ_DEFAULT_WORKFLOW")

In [None]:
if __name__ == "__main__":
    main()