In [None]:
    params = jnp.array([3.0, 0.5])
    def loss(p):
        return jnp.sum((y_single - exponential_model(x_batch_data, *p)) ** 2)
    for _ in range(20):
        g = jax.grad(loss)(params)
        params = params - 0.05 * g
    return params
fit_batch = jit(vmap(fit_one_dataset))
_ = fit_batch(y_batch_data[:10])
start = time.time()
results_batch = fit_batch(y_batch_data)
results_batch[0].block_until_ready()
time_batch = time.time() - start
print(
    f"  Time for {n_datasets} datasets: {time_batch * 1000:.0f} ms ({time_batch * 1000 / n_datasets:.3f} ms/fit)"
)
print(f"  Throughput: {n_datasets / time_batch:.0f} fits/second")
print()
estimated_sequential_time = time_sequential * n_datasets / 100
speedup = estimated_sequential_time / time_batch
print(f"Speedup: {speedup:.0f}x faster with vmap + JIT ✓")
print()
print("Key insight: vmap parallelizes across datasets, JIT compiles once")


Part 4: Memory Optimization

Avoiding out-of-memory (OOM) errors with large datasets.


In [None]:
print("Memory Optimization Strategies:")
print("=" * 60)
print()
print("1. Use float32 instead of float64:")
x_f64 = jnp.array([1.0, 2.0, 3.0], dtype=jnp.float64)
x_f32 = jnp.array([1.0, 2.0, 3.0], dtype=jnp.float32)
print(f"   float64 memory: {x_f64.nbytes} bytes per element")
print(f"   float32 memory: {x_f32.nbytes} bytes per element")
print(f"   Savings: {(1 - x_f32.nbytes / x_f64.nbytes) * 100:.0f}%")
print("   → Use float32 unless high precision is critical\n")
print("2. Process data in chunks (streaming):")
print("   # For very large datasets (millions of points)")
print("   chunk_size = 100000")
print("   for i in range(0, len(data), chunk_size):")
print("       chunk = data[i:i+chunk_size]")
print("       result = fit(chunk)")
print("       results.append(result)\n")
print("3. Clear JAX cache if needed:")
print("   from jax import clear_caches")
print("   clear_caches()  # Frees compilation cache\n")
print("4. Monitor memory usage:")
def get_array_memory_mb(arr):
    return arr.nbytes / (1024**2)
large_array = jnp.ones((10000, 1000), dtype=jnp.float32)
print(
    f"   Example: {large_array.shape} array uses {get_array_memory_mb(large_array):.1f} MB"
)
print()
print("5. Typical memory requirements:")
print("   10K points:     ~0.1 MB (negligible)")
print("   1M points:      ~10 MB (easy)")
print("   100M points:    ~1 GB (manageable)")
print("   1B points:      ~10 GB (need chunking or distributed)")
print()
print("→ For datasets >100M points, use chunked processing or streaming")


Part 5: Performance Benchmarking

Systematic performance measurement and optimization.


In [None]:
def benchmark_nlsq(n_points_list, n_params=2, n_runs=5):
    """Benchmark NLSQ across different problem sizes.
    Parameters
    ----------
    n_points_list : list
        List of dataset sizes to test
    n_params : int
        Number of parameters to fit
    n_runs : int
        Number of runs to average
    Returns
    -------
    results : dict
        Benchmark results
    """
    results = {"n_points": [], "mean_time_ms": [], "std_time_ms": []}
    cf_bench = CurveFit()
    for n_points in n_points_list:
        x = jnp.linspace(0, 5, n_points)
        y = 3.0 * jnp.exp(-0.5 * x) + np.random.normal(0, 0.1, n_points)
        _ = cf_bench.curve_fit(exponential_model, x, y, p0=[2.0, 0.3], maxiter=20)
        times = []
        for _ in range(n_runs):
            start = time.time()
            popt, _ = cf_bench.curve_fit(
                exponential_model, x, y, p0=[2.0, 0.3], maxiter=20
            )
            times.append((time.time() - start) * 1000)
        results["n_points"].append(n_points)
        results["mean_time_ms"].append(np.mean(times))
        results["std_time_ms"].append(np.std(times))
    return results
print("Running comprehensive benchmark...")
print("(This may take 30-60 seconds)")
print()
sizes = [100, 500, 1000, 5000, 10000]
bench_results = benchmark_nlsq(sizes, n_runs=5)
print("Benchmark Results:")
print("=" * 60)
print(f"{'N Points':<12} {'Mean Time (ms)':<20} {'Throughput (fits/s)'}")
print("-" * 60)
for i, n in enumerate(bench_results["n_points"]):
    mean_t = bench_results["mean_time_ms"][i]
    std_t = bench_results["std_time_ms"][i]
    throughput = 1000 / mean_t
    print(f"{n:<12} {mean_t:>8.2f} ± {std_t:<8.2f} {throughput:>12.1f}")
print()
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(14, 5))
ax1.errorbar(
    bench_results["n_points"],
    bench_results["mean_time_ms"],
    yerr=bench_results["std_time_ms"],
    marker="o",
    capsize=5,
    label="NLSQ",
)
ax1.set_xlabel("Number of Data Points")
ax1.set_ylabel("Time (ms)")
ax1.set_title("Performance Scaling")
ax1.legend()
ax1.grid(alpha=0.3)
ax2.loglog(bench_results["n_points"], bench_results["mean_time_ms"], "o-", label="NLSQ")
ax2.set_xlabel("Number of Data Points")
ax2.set_ylabel("Time (ms)")
ax2.set_title("Scaling Behavior (log-log)")
ax2.legend()
ax2.grid(alpha=0.3, which="both")
plt.tight_layout()
fig_dir = Path(__file__).parent / "figures" / "gpu_optimization_deep_dive"
fig_dir.mkdir(parents=True, exist_ok=True)
plt.savefig(fig_dir / "fig_01.png", dpi=300, bbox_inches="tight")
plt.close()
print("Interpretation:")
print("  - Nearly flat scaling: Well-optimized (GPU benefits)")
print("  - Linear scaling: Expected for iterative optimization")
print("  - Superlinear scaling: May indicate memory issues or poor caching")


Summary and Best Practices

Performance Optimization Checklist

**For Maximum Speed:**

1. ✅ **Use GPU** if available (5-50x speedup for large problems)
2. ✅ **Keep array shapes consistent** to avoid recompilation
3. ✅ **Use float32** unless high precision is needed (2x memory savings)
4. ✅ **Batch process** with `vmap` for multiple datasets (10-100x faster)
5. ✅ **Warm up JIT** with small dataset before benchmarking
6. ✅ **Use `block_until_ready()`** when timing (JAX is async)

**For Large Datasets:**

1. ✅ **Chunk data** if >100M points
2. ✅ **Monitor memory** usage
3. ✅ **Consider downsampling** for smooth, oversampled data
4. ✅ **Use streaming** for datasets that don't fit in memory

Performance Expectations

| **Scenario** | **Typical Time** | **Optimization** |
|--------------|------------------|------------------|
| First call (cold start) | 0.5-2 seconds | Expected (JIT compilation) |
| Subsequent calls (warm) | 1-50 ms | Cached compilation |
| Large dataset (10K points) | 5-100 ms | Use GPU if available |
| Batch (1000 fits) | 100-5000 ms | Use vmap for parallelization |
| Huge dataset (1M points) | 50-500 ms | GPU + chunking |

Troubleshooting Performance Issues

**Problem**: First call is slow (>5 seconds)
- **Solution**: Normal for JIT. Subsequent calls will be fast.

**Problem**: All calls are slow (>1 second for small data)
- **Solution**: Check if recompiling each time (varying shapes/dtypes)

**Problem**: Out of memory errors
- **Solution**: Use float32, chunk data, or downsample

**Problem**: GPU not being used
- **Solution**: Check `jax.devices()`, install jax[cuda] or jax[rocm]

**Problem**: Batch processing not faster than sequential
- **Solution**: Problem may be too small, try larger batches or datasets

Advanced Profiling

For detailed profiling:

```python
JAX profiling (requires jax[profiling])
import jax.profiler

Profile a code block
with jax.profiler.trace("/tmp/jax-trace", create_perfetto_link=True):
Your NLSQ code here
popt, pcov = cf.curve_fit(model, x, y, p0=...)

Opens profiling UI in browser
```

Production Recommendations

```python
Example: Optimized production setup
import jax
import jax.numpy as jnp
from nlsq import CurveFit

Configure JAX for production
jax.config.update('jax_enable_x64', False)  # Use float32

Pre-warm JIT cache at startup
cf = CurveFit()
x_dummy = jnp.linspace(0, 1, 100)
y_dummy = jnp.ones(100)
_ = cf.curve_fit(model, x_dummy, y_dummy, p0=initial_guess)

Now ready for fast production fitting
```

Next Steps

- **Scale up**: Try batch processing 10,000+ datasets with vmap
- **Optimize models**: Simplify model functions for faster evaluation
- **Profile**: Use JAX profiler to identify bottlenecks
- **Distribute**: For massive scale, consider JAX's `pmap` for multi-GPU

References

1. **JAX Performance**: https://jax.readthedocs.io/en/latest/notebooks/thinking_in_jax.html
2. **JAX Profiling**: https://jax.readthedocs.io/en/latest/profiling.html
3. **GPU Acceleration**: https://jax.readthedocs.io/en/latest/gpu_performance_tips.html
4. **Related examples**:
- `custom_algorithms_advanced.ipynb` - vmap for batch fitting
- `troubleshooting_guide.ipynb` - Performance debugging

---

**Remember**: Premature optimization is the root of all evil. Profile first, optimize what matters!
