# 4-Layer Defense Strategy for L-BFGS Warmup Divergence Prevention

[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/imewei/NLSQ/blob/main/examples/notebooks/05_feature_demos/defense_layers_demo.ipynb)

**Level**: Intermediate | **Time**: 25-30 minutes | **Version**: 0.3.6+

---

## Overview

This notebook demonstrates NLSQ's **4-layer defense strategy** that prevents Adam optimizer divergence during the warmup phase when initial parameters are already near optimal.

### What You'll Learn

- Understanding each defense layer and when it activates
- Using telemetry to monitor defense layer behavior
- Configuring defense sensitivity for different scenarios
- Using preset configurations (strict, relaxed, scientific)
- Troubleshooting common defense layer issues

### Why Defense Layers?

The hybrid streaming optimizer uses Adam for initial warmup before switching to Gauss-Newton. However, if initial parameters are already close to optimal, Adam's aggressive updates can push parameters **away** from the optimum, causing divergence.

The 4-layer defense strategy prevents this by:
1. **Layer 1**: Detecting warm starts and skipping warmup
2. **Layer 2**: Adapting learning rate based on initial fit quality
3. **Layer 3**: Aborting if loss increases beyond tolerance
4. **Layer 4**: Clipping update magnitude to prevent large jumps

In [None]:
# @title Install NLSQ (run once in Colab)
import sys

if 'google.colab' in sys.modules:
    print("Running in Google Colab - installing NLSQ...")
    !pip install -q nlsq
    print("NLSQ installed successfully!")
else:
    print("Not running in Colab - assuming NLSQ is already installed")

In [None]:
# Configure matplotlib for inline plotting
%matplotlib inline

In [None]:
import jax.numpy as jnp
import matplotlib.pyplot as plt
import numpy as np

from nlsq import (
    HybridStreamingConfig,
    curve_fit,
    get_defense_telemetry,
    reset_defense_telemetry,
)

np.random.seed(42)
print("Setup complete!")

---

## 1. Understanding the 4 Defense Layers

Let's set up a test model and explore each defense layer.

In [None]:
def exponential_decay(x, a, b, c):
    """Three-parameter exponential decay: y = a * exp(-b * x) + c"""
    return a * jnp.exp(-b * x) + c

# Generate synthetic data
true_params = np.array([5.0, 0.5, 1.0])
x = np.linspace(0, 10, 500)  # Reduced for faster execution
y_true = exponential_decay(x, *true_params)
y = y_true + np.random.normal(0, 0.1, len(x))

print(f"Dataset: {len(x)} samples")
print(f"True parameters: a={true_params[0]}, b={true_params[1]}, c={true_params[2]}")

### Layer 1: Warm Start Detection

**Purpose**: Skip L-BFGS warmup entirely when initial parameters are already near optimal.

**How it works**: Computes `relative_loss = initial_loss / y_variance`. If `relative_loss < warm_start_threshold` (default 1%), warmup is skipped.

**When it helps**: Refinement workflows, iterative fitting, warm starts from previous fits.

In [None]:
# Demonstrate Layer 1: Near-optimal initial guess
print("=" * 60)
print("Layer 1: Warm Start Detection Demo")
print("=" * 60)

# Reset telemetry to track this specific fit
reset_defense_telemetry()

# Use parameters very close to true values (simulating a refinement scenario)
near_optimal_p0 = true_params + np.random.normal(0, 0.01, 3)
print(f"\nNear-optimal p0: {near_optimal_p0}")
print(f"True params:     {true_params}")

# Fit with hybrid_streaming
popt, pcov = curve_fit(
    exponential_decay,
    x, y,
    p0=near_optimal_p0,
    method='hybrid_streaming',
    verbose=0,
)

# Check telemetry
telemetry = get_defense_telemetry()
summary = telemetry.get_summary()
layer1_triggers = summary['layer1']['triggers']

print(f"\nFitted params: {popt}")
print("\n--- Telemetry ---")
print(f"Layer 1 triggers: {layer1_triggers}")
print(f"Total warmup calls: {summary['total_warmup_calls']}")

if layer1_triggers > 0:
    print("\n-> Layer 1 detected warm start and skipped L-BFGS warmup!")
else:
    print("\n-> Initial guess not close enough for warm start detection")

### Layer 2: Adaptive Step Size

**Purpose**: Automatically select learning rate based on initial loss quality.

**Learning rate tiers**:
- **Refinement** (1e-6): `relative_loss < 0.1` - ultra-conservative for excellent fits
- **Careful** (1e-5): `0.1 <= relative_loss < 1.0` - conservative for good fits
- **Exploration** (0.001): `relative_loss >= 1.0` - standard rate for poor fits

**When it helps**: Multi-scale parameter problems, mixed starting point quality.

In [None]:
# Demonstrate Layer 2: Different LR modes
print("=" * 60)
print("Layer 2: Adaptive Step Size Demo")
print("=" * 60)

# Test with different initial guess qualities
test_cases = [
    ("Excellent fit (refinement LR)", true_params * np.array([1.02, 1.01, 0.99])),
    ("Good fit (careful LR)", true_params * np.array([1.2, 0.8, 1.3])),
    ("Poor fit (exploration LR)", np.array([1.0, 0.1, 5.0])),
]

for name, p0 in test_cases:
    reset_defense_telemetry()

    # Disable Layer 1 to see Layer 2 in action
    config = HybridStreamingConfig(
        enable_warm_start_detection=False,  # Force warmup to run
        warmup_iterations=50,  # Short warmup for demo
    )

    popt, _ = curve_fit(
        exponential_decay, x, y, p0=p0,
        method='hybrid_streaming',
        config=config,
        verbose=0,
    )

    telemetry = get_defense_telemetry()
    lr_counts = telemetry.layer2_lr_mode_counts

    print(f"\n{name}")
    print(f"  p0: {p0}")
    print(f"  LR mode counts: {lr_counts}")

### Layer 3: Cost-Increase Guard

**Purpose**: Abort warmup immediately if loss increases beyond tolerance.

**How it works**: After each Adam step, checks if `current_loss > initial_loss * (1 + cost_increase_tolerance)`. If triggered, returns the **best parameters found** (not the diverged ones).

**Default tolerance**: 5% (configurable)

**When it helps**: Preventing divergence from near-optimal starting points.

In [None]:
# Demonstrate Layer 3: Cost Guard Protection
print("=" * 60)
print("Layer 3: Cost-Increase Guard Demo")
print("=" * 60)

# Scenario: Aggressive LR that would cause divergence
# Layer 3 should catch this and return best params
reset_defense_telemetry()

# Use a config that might cause divergence (high LR, disabled other defenses)
config = HybridStreamingConfig(
    enable_warm_start_detection=False,
    enable_adaptive_warmup_lr=False,  # Use fixed (high) LR
    warmup_learning_rate=0.1,  # Very aggressive LR
    enable_cost_guard=True,  # But cost guard is ON
    cost_increase_tolerance=0.05,  # Abort if loss increases >5%
    enable_step_clipping=False,
    warmup_iterations=100,
)

# Start from excellent initial guess
excellent_p0 = true_params * np.array([1.01, 0.99, 1.005])
print(f"Starting from near-optimal p0: {excellent_p0}")
print(f"With aggressive LR={config.warmup_learning_rate}")

popt, _ = curve_fit(
    exponential_decay, x, y, p0=excellent_p0,
    method='hybrid_streaming',
    config=config,
    verbose=0,
)

telemetry = get_defense_telemetry()
summary = telemetry.get_summary()
layer3_triggers = summary['layer3']['triggers']

print(f"\nFinal params: {popt}")
print("\n--- Telemetry ---")
print(f"Layer 3 (cost guard) triggers: {layer3_triggers}")

if layer3_triggers > 0:
    print("\n-> Layer 3 detected loss increase and aborted warmup!")
    print("   Returned best parameters found, not diverged values.")

### Layer 4: Trust Region Constraint (Step Clipping)

**Purpose**: Limit parameter update magnitude to prevent large destabilizing jumps.

**How it works**: Clips Adam parameter updates to max L2 norm of `max_warmup_step_size` (default 0.1).

**When it helps**: Multi-scale parameters, ill-conditioned problems, preventing overshooting.

In [None]:
# Demonstrate Layer 4: Step Clipping
print("=" * 60)
print("Layer 4: Trust Region Constraint Demo")
print("=" * 60)

reset_defense_telemetry()

# Config with step clipping
config = HybridStreamingConfig(
    enable_warm_start_detection=False,
    enable_adaptive_warmup_lr=False,
    warmup_learning_rate=0.01,  # High LR that would produce large steps
    enable_cost_guard=False,
    enable_step_clipping=True,  # Step clipping ON
    max_warmup_step_size=0.1,  # Max L2 norm of update
    warmup_iterations=100,
)

# Poor initial guess that needs exploration
poor_p0 = np.array([10.0, 0.1, 5.0])
print(f"Starting from poor p0: {poor_p0}")

popt, _ = curve_fit(
    exponential_decay, x, y, p0=poor_p0,
    method='hybrid_streaming',
    config=config,
    verbose=0,
)

telemetry = get_defense_telemetry()
summary = telemetry.get_summary()
layer4_triggers = summary['layer4']['triggers']

print(f"\nFinal params: {popt}")
print(f"True params:  {true_params}")
print("\n--- Telemetry ---")
print(f"Layer 4 (step clipping) triggers: {layer4_triggers}")

if layer4_triggers > 0:
    print(f"\n-> Layer 4 clipped {layer4_triggers} large parameter updates")

---

## 2. Using Defense Layer Telemetry

Telemetry helps monitor defense layer behavior in production.

In [None]:
# Run multiple fits and collect telemetry
print("=" * 60)
print("Defense Layer Telemetry Demo")
print("=" * 60)

reset_defense_telemetry()

# Simulate a batch of fits with varying starting points
n_fits = 5  # Reduced for faster execution
for i in range(n_fits):
    # Vary initial guess quality
    noise_scale = 0.01 if i < 2 else (0.3 if i < 4 else 1.0)
    p0 = true_params * (1 + np.random.uniform(-noise_scale, noise_scale, 3))

    popt, _ = curve_fit(
        exponential_decay, x, y, p0=p0,
        method='hybrid_streaming',
        verbose=0,
    )

# Get comprehensive telemetry report
telemetry = get_defense_telemetry()

print(f"\n--- Summary after {n_fits} fits ---")
summary = telemetry.get_summary()
for key, value in summary.items():
    print(f"  {key}: {value}")

print("\n--- Trigger Rates ---")
rates = telemetry.get_trigger_rates()
for key, value in rates.items():
    print(f"  {key}: {value:.1f}%")

In [None]:
# View recent events
print("\n--- Recent Events (last 5) ---")
events = telemetry.get_recent_events(5)
for event in events:
    print(f"  {event['timestamp']}: {event['type']} - {event.get('data', {})}")

In [None]:
# Export metrics (Prometheus/Grafana compatible)
print("\n--- Prometheus-Compatible Metrics ---")
metrics = telemetry.export_metrics()
for metric_name, value in metrics.items():
    print(f"  {metric_name}: {value}")

---

## 3. Configuration Presets

NLSQ provides preset configurations for common scenarios.

In [None]:
# Available presets
print("=" * 60)
print("Configuration Presets")
print("=" * 60)

presets = [
    ("Default", HybridStreamingConfig()),
    ("defense_strict()", HybridStreamingConfig.defense_strict()),
    ("defense_relaxed()", HybridStreamingConfig.defense_relaxed()),
    ("defense_disabled()", HybridStreamingConfig.defense_disabled()),
    ("scientific_default()", HybridStreamingConfig.scientific_default()),
]

print("\n{:<22} {:>8} {:>10} {:>12} {:>8}".format(
    "Preset", "Layer1", "Layer2", "Layer3", "Layer4"
))
print("-" * 60)

for name, config in presets:
    print("{:<22} {:>8} {:>10} {:>12} {:>8}".format(
        name,
        "ON" if config.enable_warm_start_detection else "OFF",
        "ON" if config.enable_adaptive_warmup_lr else "OFF",
        "ON" if config.enable_cost_guard else "OFF",
        "ON" if config.enable_step_clipping else "OFF",
    ))

In [None]:
# Compare preset behavior
print("\n--- Preset Comparison: Near-Optimal Start ---")

near_optimal_p0 = true_params * np.array([1.005, 0.995, 1.002])

for name, config in presets[:4]:  # Skip scientific_default for brevity
    reset_defense_telemetry()

    popt, _ = curve_fit(
        exponential_decay, x, y, p0=near_optimal_p0,
        method='hybrid_streaming',
        config=config,
        verbose=0,
    )

    telemetry = get_defense_telemetry()
    rates = telemetry.get_trigger_rates()

    error = np.linalg.norm(popt - true_params)
    print(f"\n{name}:")
    print(f"  Parameter error: {error:.6f}")
    print(f"  Layer 1 rate: {rates.get('layer1_warm_start_rate', 0):.0f}%")

---

## 4. Custom Configuration

Fine-tune defense layers for your specific needs.

In [None]:
# Custom configuration example
print("=" * 60)
print("Custom Configuration Example")
print("=" * 60)

# Scenario: Scientific computing with multi-scale parameters
custom_config = HybridStreamingConfig(
    # Layer 1: Stricter warm start detection
    enable_warm_start_detection=True,
    warm_start_threshold=0.005,  # 0.5% instead of 1%

    # Layer 2: Conservative learning rates
    enable_adaptive_warmup_lr=True,
    warmup_lr_refinement=1e-7,  # Ultra-conservative
    warmup_lr_careful=1e-6,
    warmup_learning_rate=1e-4,  # Lower than default 0.001

    # Layer 3: Tighter cost tolerance
    enable_cost_guard=True,
    cost_increase_tolerance=0.02,  # 2% instead of 5%

    # Layer 4: Smaller step limit
    enable_step_clipping=True,
    max_warmup_step_size=0.05,  # Half the default

    # Other settings
    precision='float64',  # Full precision for scientific work
    warmup_iterations=300,
)

print("Custom config created with:")
print(f"  warm_start_threshold: {custom_config.warm_start_threshold}")
print(f"  warmup_lr_refinement: {custom_config.warmup_lr_refinement}")
print(f"  cost_increase_tolerance: {custom_config.cost_increase_tolerance}")
print(f"  max_warmup_step_size: {custom_config.max_warmup_step_size}")

# Use the custom config
reset_defense_telemetry()
popt, pcov = curve_fit(
    exponential_decay, x, y,
    p0=np.array([4.5, 0.45, 1.1]),
    method='hybrid_streaming',
    config=custom_config,
    verbose=0,
)

print(f"\nFitted params: {popt}")
print(f"Std errors: {np.sqrt(np.diag(pcov))}")

---

## 5. Practical Scenarios

### Scenario A: Warm Start Refinement

Refining parameters from a previous fit.

In [None]:
# Warm Start Refinement Scenario
print("=" * 60)
print("Scenario A: Warm Start Refinement")
print("=" * 60)

# Step 1: Initial fit on first batch of data
x1 = x[:500]
y1 = y[:500]

popt_v1, _ = curve_fit(
    exponential_decay, x1, y1,
    p0=np.array([3.0, 0.3, 0.5]),
    method='hybrid_streaming',
    verbose=0,
)
print(f"Initial fit (v1): {popt_v1}")

# Step 2: Refinement with full data (using v1 as starting point)
reset_defense_telemetry()

popt_v2, pcov_v2 = curve_fit(
    exponential_decay, x, y,
    p0=popt_v1,  # Use previous result as initial guess
    method='hybrid_streaming',
    config=HybridStreamingConfig.defense_strict(),  # Use strict for refinement
    verbose=0,
)

telemetry = get_defense_telemetry()
rates = telemetry.get_trigger_rates()

print(f"Refined fit (v2): {popt_v2}")
print(f"True params:      {true_params}")
print(f"\nLayer 1 (warm start) triggered: {rates.get('layer1_warm_start_rate', 0):.0f}%")
print("-> Defense layers protected against divergence from good initial guess")

### Scenario B: Production Monitoring

Monitoring defense layer activations in a batch processing pipeline.

In [None]:
# Production Monitoring Scenario
print("=" * 60)
print("Scenario B: Production Monitoring")
print("=" * 60)

reset_defense_telemetry()

# Simulate production batch with varying data quality
results = []
for i in range(3):  # Reduced from 20 for faster execution
    # Simulate different starting point qualities
    if i < 2:
        p0 = true_params * (1 + np.random.uniform(-0.01, 0.01, 3))  # Excellent
    elif i < 7:
        p0 = true_params * (1 + np.random.uniform(-0.2, 0.2, 3))  # Good
    else:
        p0 = np.array([1.0, 0.1, 5.0]) + np.random.uniform(-0.5, 0.5, 3)  # Poor

    # Add some noise to data
    y_noisy = y + np.random.normal(0, 0.05 * (i % 3 + 1), len(y))

    popt, _ = curve_fit(
        exponential_decay, x, y_noisy, p0=p0,
        method='hybrid_streaming',
        verbose=0,
    )
    results.append(popt)

# Production monitoring report
telemetry = get_defense_telemetry()
rates = telemetry.get_trigger_rates()

print(f"\n--- Production Report ({len(results)} fits) ---")
print("\nDefense Layer Activation Rates:")
print(f"  Layer 1 (Warm Start):     {rates.get('layer1_warm_start_rate', 0):.1f}%")
print(f"  Layer 2 (Refinement LR):  {rates.get('layer2_refinement_rate', 0):.1f}%")
print(f"  Layer 2 (Careful LR):     {rates.get('layer2_careful_rate', 0):.1f}%")
print(f"  Layer 2 (Exploration LR): {rates.get('layer2_exploration_rate', 0):.1f}%")
print(f"  Layer 3 (Cost Guard):     {rates.get('layer3_cost_guard_rate', 0):.1f}%")
print(f"  Layer 4 (Step Clipping):  {rates.get('layer4_clip_rate', 0):.1f}%")

# Alerts based on rates
print("\n--- Alerts ---")
if rates.get('layer1_warm_start_rate', 0) > 50:
    print("INFO: >50% warm starts - consider using defense_strict()")
if rates.get('layer3_cost_guard_rate', 0) > 20:
    print("WARNING: >20% cost guard triggers - review initial guess quality")
if rates.get('layer3_cost_guard_rate', 0) == 0 and rates.get('layer1_warm_start_rate', 0) == 0:
    print("OK: Defense layers active but not frequently triggered")

---

## 6. Troubleshooting

### Common Issues and Solutions

In [None]:
# Troubleshooting guide
print("=" * 60)
print("Troubleshooting Guide")
print("=" * 60)

issues = [
    {
        "problem": "Warmup always skipped (Layer 1 always triggers)",
        "cause": "warm_start_threshold too high for your use case",
        "solution": "config = HybridStreamingConfig(warm_start_threshold=0.001)",
    },
    {
        "problem": "Convergence too slow after upgrading to 0.3.6",
        "cause": "Layer 2 using ultra-conservative LR",
        "solution": "config = HybridStreamingConfig.defense_relaxed()",
    },
    {
        "problem": "Cost guard aborts warmup too early",
        "cause": "cost_increase_tolerance too strict",
        "solution": "config = HybridStreamingConfig(cost_increase_tolerance=0.2)",
    },
    {
        "problem": "Results different from pre-0.3.6",
        "cause": "Defense layers preventing previous divergence",
        "solution": "config = HybridStreamingConfig.defense_disabled() # For testing only",
    },
    {
        "problem": "Need pre-0.3.6 behavior for regression tests",
        "cause": "Defense layers change optimization path",
        "solution": "config = HybridStreamingConfig.defense_disabled()",
    },
]

for i, issue in enumerate(issues, 1):
    print(f"\n{i}. Problem: {issue['problem']}")
    print(f"   Cause: {issue['cause']}")
    print(f"   Solution: {issue['solution']}")

---

## Summary

### Key Takeaways

1. **4 Defense Layers** protect against L-BFGS warmup divergence:
   - Layer 1: Warm start detection (skip warmup if near optimal)
   - Layer 2: Adaptive step size (scale step size by fit quality)
   - Layer 3: Cost-increase guard (abort on loss increase)
   - Layer 4: Step clipping (limit update magnitude)

2. **Enabled by default** - no code changes required for most users

3. **Telemetry** helps monitor defense behavior in production:
   - `get_defense_telemetry()` - access telemetry singleton
   - `reset_defense_telemetry()` - reset counters
   - `.get_summary()` - comprehensive report
   - `.get_trigger_rates()` - percentage rates
   - `.export_metrics()` - Prometheus-compatible format

4. **Presets** for common scenarios:
   - `defense_strict()` - for refinement/warm starts
   - `defense_relaxed()` - for exploration
   - `defense_disabled()` - pre-0.3.6 behavior
   - `scientific_default()` - optimized for physics/scientific computing

### When to Customize

| Scenario | Recommendation |
|----------|----------------|
| Default usage | Use defaults (all layers ON) |
| Warm start refinement | `defense_strict()` |
| Exploration from poor guess | `defense_relaxed()` |
| Regression testing | `defense_disabled()` |
| Scientific computing | `scientific_default()` |
| Custom requirements | Create `HybridStreamingConfig(...)` |

### Next Steps

- [Hybrid Streaming API](../06_streaming/05_hybrid_streaming_api.ipynb) - Complete hybrid streaming guide
- [Troubleshooting Guide](../03_advanced/troubleshooting_guide.ipynb) - General troubleshooting
- [Migration Guide](https://nlsq.readthedocs.io/en/latest/migration/v0.3.6_defense_layers.html) - v0.3.6 migration details