# JAX Integration in NetKet: High-Performance Automatic Differentiation

This notebook explores how NetKet leverages JAX for high-performance automatic differentiation and just-in-time compilation, making quantum many-body calculations efficient and scalable.

## Table of Contents
1. [Introduction to JAX](#introduction)
2. [Automatic Differentiation Basics](#autodiff)
3. [JIT Compilation](#jit)
4. [Vectorization with vmap](#vmap)
5. [Parallelization with pmap](#pmap)
6. [Custom Gradients](#custom-gradients)
7. [Memory and Performance Optimization](#optimization)
8. [NetKet-specific JAX Usage](#netket-jax)
9. [Advanced Topics](#advanced)
10. [Conclusion](#conclusion)

## 1. Introduction to JAX {#introduction}

JAX is NumPy-compatible library that provides:
- **Automatic differentiation**: Forward and reverse mode
- **JIT compilation**: XLA-powered optimization
- **Vectorization**: Automatic batching with `vmap`
- **Parallelization**: Multi-device computation with `pmap`

These features make JAX ideal for machine learning and scientific computing, especially for quantum many-body problems where gradients and performance are crucial.

In [None]:
# Import necessary libraries
import jax
import jax.numpy as jnp
import numpy as np
import netket as nk
import flax.linen as nn
import matplotlib.pyplot as plt
from functools import partial
import time
import warnings
warnings.filterwarnings('ignore')

# Configure JAX
jax.config.update('jax_platform_name', 'cpu')  # Use CPU for this tutorial
print(f"JAX version: {jax.__version__}")
print(f"JAX backend: {jax.lib.xla_bridge.get_backend().platform}")
print(f"Available devices: {jax.devices()}")

# Set random seeds
key = jax.random.PRNGKey(42)
np.random.seed(42)

## 2. Automatic Differentiation Basics {#autodiff}

Automatic differentiation is fundamental to neural quantum states optimization. JAX provides both forward and reverse mode AD.

In [None]:
# Basic automatic differentiation examples

# Simple function
def f(x):
    return x**3 + 2*x**2 - x + 1

# Compute derivative
f_grad = jax.grad(f)

x = 2.0
print(f"f({x}) = {f(x)}")
print(f"f'({x}) = {f_grad(x)}")
print(f"Analytical f'({x}) = {3*x**2 + 4*x - 1}")

In [None]:
# Vector function example
def vector_func(params):
    """Function that takes a parameter vector"""
    a, b, c = params
    return a**2 + b*c + jnp.sin(a*b)

# Gradient with respect to all parameters
vector_grad = jax.grad(vector_func)

params = jnp.array([1.0, 2.0, 3.0])
print(f"Function value: {vector_func(params)}")
print(f"Gradient: {vector_grad(params)}")

# Verify with finite differences
eps = 1e-7
finite_diff_grad = jnp.array([
    (vector_func(params + eps * jnp.array([1, 0, 0])) - 
     vector_func(params - eps * jnp.array([1, 0, 0]))) / (2 * eps),
    (vector_func(params + eps * jnp.array([0, 1, 0])) - 
     vector_func(params - eps * jnp.array([0, 1, 0]))) / (2 * eps),
    (vector_func(params + eps * jnp.array([0, 0, 1])) - 
     vector_func(params - eps * jnp.array([0, 0, 1]))) / (2 * eps)
])
print(f"Finite difference: {finite_diff_grad}")
print(f"Difference: {jnp.abs(vector_grad(params) - finite_diff_grad)}")

In [None]:
# Higher-order derivatives
def quadratic(x):
    return x**4 - 2*x**3 + x**2

# First derivative
first_deriv = jax.grad(quadratic)
# Second derivative
second_deriv = jax.grad(first_deriv)
# Third derivative
third_deriv = jax.grad(second_deriv)

x = 1.5
print(f"f({x}) = {quadratic(x)}")
print(f"f'({x}) = {first_deriv(x)}")
print(f"f''({x}) = {second_deriv(x)}")
print(f"f'''({x}) = {third_deriv(x)}")

# Compare with analytical results
print(f"\nAnalytical:")
print(f"f'({x}) = {4*x**3 - 6*x**2 + 2*x}")
print(f"f''({x}) = {12*x**2 - 12*x + 2}")
print(f"f'''({x}) = {24*x - 12}")

## 3. JIT Compilation {#jit}

Just-in-time compilation with XLA can dramatically speed up computations by optimizing the computational graph.

In [None]:
# JIT compilation example
def slow_function(x):
    """A function with many operations"""
    y = x
    for i in range(100):
        y = jnp.sin(y) + jnp.cos(y) * jnp.exp(-y/10)
    return y

# JIT-compiled version
fast_function = jax.jit(slow_function)

# Test input
x = jnp.array([1.0, 2.0, 3.0, 4.0, 5.0])

# Time the functions
print("Timing comparison:")

# Regular function (first call)
start = time.time()
result_slow = slow_function(x)
time_slow = time.time() - start
print(f"Regular function: {time_slow:.4f} seconds")

# JIT function (first call - includes compilation)
start = time.time()
result_fast = fast_function(x)
time_fast_first = time.time() - start
print(f"JIT function (first call): {time_fast_first:.4f} seconds")

# JIT function (subsequent calls)
start = time.time()
result_fast = fast_function(x)
time_fast_subsequent = time.time() - start
print(f"JIT function (subsequent): {time_fast_subsequent:.4f} seconds")

print(f"\nSpeedup: {time_slow / time_fast_subsequent:.1f}x")
print(f"Results match: {jnp.allclose(result_slow, result_fast)}")

In [None]:
# Matrix operations with JIT
def matrix_operations(A, B):
    """Complex matrix operations"""
    C = jnp.dot(A, B)
    D = jnp.linalg.inv(C + jnp.eye(C.shape[0]) * 1e-6)
    E = jnp.dot(D, A.T)
    return jnp.trace(E)

# JIT-compiled version
matrix_operations_jit = jax.jit(matrix_operations)

# Test matrices
key, subkey1, subkey2 = jax.random.split(key, 3)
A = jax.random.normal(subkey1, (100, 100))
B = jax.random.normal(subkey2, (100, 100))

# Benchmark
n_runs = 10

# Regular function
start = time.time()
for _ in range(n_runs):
    result_regular = matrix_operations(A, B)
time_regular = (time.time() - start) / n_runs

# JIT function (warm up first)
_ = matrix_operations_jit(A, B)
start = time.time()
for _ in range(n_runs):
    result_jit = matrix_operations_jit(A, B)
time_jit = (time.time() - start) / n_runs

print(f"Matrix operations benchmark ({n_runs} runs):")
print(f"Regular: {time_regular:.4f} seconds per run")
print(f"JIT: {time_jit:.4f} seconds per run")
print(f"Speedup: {time_regular / time_jit:.1f}x")
print(f"Results match: {jnp.allclose(result_regular, result_jit)}")

## 4. Vectorization with vmap {#vmap}

`vmap` automatically vectorizes functions, eliminating the need for explicit loops and improving performance.

In [None]:
# Basic vmap example
def single_computation(x):
    """Function that works on a single input"""
    return jnp.sin(x**2) + jnp.cos(x**3)

# Vectorized version
vectorized_computation = jax.vmap(single_computation)

# Test data
x_single = 2.0
x_batch = jnp.array([1.0, 2.0, 3.0, 4.0, 5.0])

print(f"Single input: {single_computation(x_single)}")
print(f"Batch input: {vectorized_computation(x_batch)}")

# Compare with manual loop
manual_result = jnp.array([single_computation(x) for x in x_batch])
print(f"Manual loop: {manual_result}")
print(f"Results match: {jnp.allclose(vectorized_computation(x_batch), manual_result)}")

In [None]:
# Advanced vmap with multiple arguments
def pairwise_distance(x1, x2):
    """Compute distance between two points"""
    return jnp.sqrt(jnp.sum((x1 - x2)**2))

# Vectorize over first argument only
distances_from_point = jax.vmap(pairwise_distance, in_axes=(0, None))

# Vectorize over both arguments
elementwise_distances = jax.vmap(pairwise_distance, in_axes=(0, 0))

# Test data
key, subkey = jax.random.split(key)
points = jax.random.normal(subkey, (5, 3))  # 5 points in 3D
reference_point = jnp.array([0.0, 0.0, 0.0])

print("Points:")
print(points)
print(f"\nDistances from origin: {distances_from_point(points, reference_point)}")
print(f"Pairwise distances: {elementwise_distances(points[:3], points[2:])}")

In [None]:
# Performance comparison: vmap vs loops
def complex_function(params):
    """A more complex function for benchmarking"""
    a, b, c = params
    result = 0.0
    for i in range(10):
        result += jnp.sin(a * i) * jnp.cos(b * i) * jnp.exp(-c * i / 10)
    return result

# Create test data
key, subkey = jax.random.split(key)
param_batch = jax.random.normal(subkey, (1000, 3))

# Manual loop version
def manual_batch(param_batch):
    results = []
    for params in param_batch:
        results.append(complex_function(params))
    return jnp.array(results)

# Vectorized version
vectorized_batch = jax.vmap(complex_function)

# JIT-compiled versions
manual_batch_jit = jax.jit(manual_batch)
vectorized_batch_jit = jax.jit(vectorized_batch)

# Warm up
_ = manual_batch_jit(param_batch)
_ = vectorized_batch_jit(param_batch)

# Benchmark
n_runs = 10

# Manual loop
start = time.time()
for _ in range(n_runs):
    result_manual = manual_batch_jit(param_batch)
time_manual = (time.time() - start) / n_runs

# Vectorized
start = time.time()
for _ in range(n_runs):
    result_vectorized = vectorized_batch_jit(param_batch)
time_vectorized = (time.time() - start) / n_runs

print(f"Batch processing benchmark ({param_batch.shape[0]} items, {n_runs} runs):")
print(f"Manual loop (JIT): {time_manual:.4f} seconds per run")
print(f"Vectorized (JIT): {time_vectorized:.4f} seconds per run")
print(f"Speedup: {time_manual / time_vectorized:.1f}x")
print(f"Results match: {jnp.allclose(result_manual, result_vectorized)}")

## 5. Parallelization with pmap {#pmap}

`pmap` enables parallel computation across multiple devices (CPUs or GPUs).

In [None]:
# Check available devices
devices = jax.devices()
print(f"Available devices: {len(devices)}")
for i, device in enumerate(devices):
    print(f"  Device {i}: {device}")

if len(devices) > 1:
    print("\nMultiple devices available - demonstrating pmap")
    
    def compute_sum_of_squares(x):
        """Simple computation for parallel execution"""
        return jnp.sum(x**2)
    
    # Parallel version
    parallel_compute = jax.pmap(compute_sum_of_squares)
    
    # Create data that can be split across devices
    key, subkey = jax.random.split(key)
    # Shape: (n_devices, batch_size_per_device, features)
    data = jax.random.normal(subkey, (len(devices), 100, 50))
    
    # Sequential computation
    start = time.time()
    sequential_results = jnp.array([compute_sum_of_squares(data[i]) for i in range(len(devices))])
    time_sequential = time.time() - start
    
    # Parallel computation
    start = time.time()
    parallel_results = parallel_compute(data)
    time_parallel = time.time() - start
    
    print(f"Sequential time: {time_sequential:.4f} seconds")
    print(f"Parallel time: {time_parallel:.4f} seconds")
    print(f"Speedup: {time_sequential / time_parallel:.1f}x")
    print(f"Results match: {jnp.allclose(sequential_results, parallel_results)}")
    
else:
    print("\nOnly single device available - pmap would not provide speedup")
    print("pmap is most beneficial with multiple GPUs or CPU cores")

## 6. Custom Gradients {#custom-gradients}

Sometimes we need custom gradient rules for numerical stability or efficiency.

In [None]:
# Example: Custom gradient for numerical stability
@jax.custom_jvp
def safe_log(x):
    """Logarithm with custom gradient for numerical stability"""
    return jnp.log(jnp.maximum(x, 1e-10))

@safe_log.defjvp
def safe_log_jvp(primals, tangents):
    x, = primals
    x_dot, = tangents
    # Custom gradient: 1/max(x, eps) instead of 1/x
    primal_out = safe_log(x)
    tangent_out = x_dot / jnp.maximum(x, 1e-8)  # More stable denominator
    return primal_out, tangent_out

# Test the custom gradient
x_values = jnp.array([1e-12, 1e-8, 1e-4, 1.0, 10.0])

print("Comparing regular log vs safe_log:")
for x in x_values:
    # Regular log and gradient
    try:
        regular_val = jnp.log(x)
        regular_grad = jax.grad(lambda y: jnp.log(y))(x)
    except:
        regular_val = float('nan')
        regular_grad = float('nan')
    
    # Safe log and gradient
    safe_val = safe_log(x)
    safe_grad = jax.grad(safe_log)(x)
    
    print(f"x = {x:.2e}:")
    print(f"  Regular: val = {regular_val:.6f}, grad = {regular_grad:.6f}")
    print(f"  Safe:    val = {safe_val:.6f}, grad = {safe_grad:.6f}")

In [None]:
# Example: Custom gradient for computational efficiency
@jax.custom_vjp
def matrix_sqrt(A):
    """Matrix square root with custom backward pass"""
    # Forward pass: compute matrix square root via eigendecomposition
    eigenvals, eigenvecs = jnp.linalg.eigh(A)
    sqrt_eigenvals = jnp.sqrt(jnp.maximum(eigenvals, 0))
    return eigenvecs @ jnp.diag(sqrt_eigenvals) @ eigenvecs.T

def matrix_sqrt_fwd(A):
    # Forward pass and save intermediate values
    eigenvals, eigenvecs = jnp.linalg.eigh(A)
    sqrt_eigenvals = jnp.sqrt(jnp.maximum(eigenvals, 0))
    sqrt_A = eigenvecs @ jnp.diag(sqrt_eigenvals) @ eigenvecs.T
    return sqrt_A, (eigenvals, eigenvecs, sqrt_eigenvals)

def matrix_sqrt_bwd(res, g):
    # Custom backward pass using the stored intermediate values
    eigenvals, eigenvecs, sqrt_eigenvals = res
    # Simplified gradient computation (for demonstration)
    # In practice, this would use the Sylvester equation
    g_proj = eigenvecs.T @ g @ eigenvecs
    sqrt_eigenvals_expanded = sqrt_eigenvals[:, None]
    denom = sqrt_eigenvals_expanded + sqrt_eigenvals_expanded.T
    denom = jnp.where(denom == 0, 1, denom)  # Avoid division by zero
    grad_eigenvals = g_proj / denom
    grad_A = eigenvecs @ grad_eigenvals @ eigenvecs.T
    return (grad_A,)

matrix_sqrt.defvjp(matrix_sqrt_fwd, matrix_sqrt_bwd)

# Test the custom matrix square root
key, subkey = jax.random.split(key)
A = jax.random.normal(subkey, (4, 4))
A = A @ A.T + 0.1 * jnp.eye(4)  # Make positive definite

sqrt_A = matrix_sqrt(A)
reconstructed = sqrt_A @ sqrt_A

print(f"Matrix square root test:")
print(f"Original matrix A shape: {A.shape}")
print(f"sqrt(A) @ sqrt(A) ≈ A: {jnp.allclose(reconstructed, A, atol=1e-6)}")
print(f"Max reconstruction error: {jnp.max(jnp.abs(reconstructed - A))}")

# Test gradient
def loss(A):
    sqrt_A = matrix_sqrt(A)
    return jnp.sum(sqrt_A**2)

grad_A = jax.grad(loss)(A)
print(f"Gradient computed successfully: {grad_A.shape}")

## 7. Memory and Performance Optimization {#optimization}

JAX provides several tools for optimizing memory usage and computational performance.

In [None]:
# Memory-efficient computation with gradient checkpointing
from jax.experimental import host_callback

def memory_intensive_function(x):
    """Function that requires a lot of intermediate memory"""
    # Simulate memory-intensive computation
    for i in range(10):
        x = jnp.tanh(x @ x.T + x)
        x = x / jnp.linalg.norm(x, axis=-1, keepdims=True)
    return jnp.sum(x)

# Regular gradient computation
regular_grad = jax.grad(memory_intensive_function)

# Gradient checkpointing version
checkpointed_grad = jax.remat(jax.grad(memory_intensive_function))

# Test with smaller matrix for demonstration
key, subkey = jax.random.split(key)
test_matrix = jax.random.normal(subkey, (20, 20))

print("Memory optimization with gradient checkpointing:")

# Regular gradient
start = time.time()
grad_regular = regular_grad(test_matrix)
time_regular = time.time() - start

# Checkpointed gradient
start = time.time()
grad_checkpointed = checkpointed_grad(test_matrix)
time_checkpointed = time.time() - start

print(f"Regular gradient time: {time_regular:.4f} seconds")
print(f"Checkpointed gradient time: {time_checkpointed:.4f} seconds")
print(f"Results match: {jnp.allclose(grad_regular, grad_checkpointed, atol=1e-6)}")
print(f"Note: Checkpointing trades computation for memory (useful for large models)")

In [None]:
# Static vs dynamic shapes
def dynamic_computation(x, n):
    """Computation with dynamic control flow"""
    result = x
    for i in range(n):  # Dynamic loop
        result = result + jnp.sin(result)
    return result

def static_computation(x):
    """Equivalent computation with static control flow"""
    result = x
    for i in range(5):  # Static loop
        result = result + jnp.sin(result)
    return result

# JIT compilation works better with static shapes
static_jit = jax.jit(static_computation)

# For dynamic shapes, we can use partial JIT
dynamic_jit = jax.jit(dynamic_computation, static_argnums=(1,))

# Test data
x = jnp.array([1.0, 2.0, 3.0])
n = 5

print("Static vs dynamic shape optimization:")

# Static computation
start = time.time()
for _ in range(100):
    result_static = static_jit(x)
time_static = time.time() - start

# Dynamic computation (with static argument)
start = time.time()
for _ in range(100):
    result_dynamic = dynamic_jit(x, n)
time_dynamic = time.time() - start

print(f"Static computation: {time_static:.4f} seconds")
print(f"Dynamic computation: {time_dynamic:.4f} seconds")
print(f"Results match: {jnp.allclose(result_static, result_dynamic)}")
print(f"Speedup from static: {time_dynamic / time_static:.1f}x")

## 8. NetKet-specific JAX Usage {#netket-jax}

Let's see how NetKet leverages JAX features for quantum many-body calculations.

In [None]:
# Setup a simple quantum system
N = 6
hilbert = nk.hilbert.Spin(s=1/2, N=N)
lattice = nk.graph.Chain(length=N, pbc=True)
hamiltonian = nk.operator.Ising(hilbert=hilbert, graph=lattice, J=1.0, h=0.5)

# Create RBM model
model = nk.models.RBM(alpha=2, dtype=jnp.complex128)

print(f"Quantum system: {N} spins, Hilbert space dimension = {hilbert.n_states}")

In [None]:
# Demonstrate JAX transformations in NetKet context

# Initialize model parameters
key, subkey = jax.random.split(key)
sample_input = hilbert.random_state(subkey, size=1)
params = model.init(subkey, sample_input)

print("Model parameters:")
for name, param in params['params'].items():
    print(f"  {name}: {param.shape}")

# Function to compute log-amplitude
def log_psi(params, sigma):
    return model.apply(params, sigma)

# Vectorized version for batch evaluation
log_psi_batched = jax.vmap(log_psi, in_axes=(None, 0))

# Generate batch of configurations
key, subkey = jax.random.split(key)
batch_configs = hilbert.random_state(subkey, size=100)

# Evaluate log-amplitudes
log_amplitudes = log_psi_batched(params, batch_configs)
print(f"\nBatch evaluation:")
print(f"  Input shape: {batch_configs.shape}")
print(f"  Output shape: {log_amplitudes.shape}")
print(f"  Mean log|ψ|: {jnp.mean(jnp.real(log_amplitudes)):.4f}")
print(f"  Std log|ψ|: {jnp.std(jnp.real(log_amplitudes)):.4f}")

In [None]:
# Local energy computation with JAX
def local_energy(params, sigma, hamiltonian):
    """Compute local energy E_loc = <σ|H|ψ>/<σ|ψ>"""
    # This is a simplified version - NetKet handles this more efficiently
    log_psi_sigma = log_psi(params, sigma)
    
    # Get matrix elements and connected configurations
    mel = hamiltonian.get_conn_flattened(sigma.reshape(1, -1))
    configs, matrix_elements = mel[0], mel[1]
    
    # Compute log-amplitudes for connected configurations
    log_psi_connected = log_psi_batched(params, configs)
    
    # Compute local energy
    ratios = jnp.exp(log_psi_connected - log_psi_sigma)
    local_energy_val = jnp.sum(matrix_elements * ratios)
    
    return local_energy_val

# Vectorized local energy for batch
local_energy_batched = jax.vmap(local_energy, in_axes=(None, 0, None))

# Compute local energies for batch
local_energies = local_energy_batched(params, batch_configs, hamiltonian)

print(f"Local energy computation:")
print(f"  Mean E_loc: {jnp.mean(local_energies):.6f}")
print(f"  Std E_loc: {jnp.std(local_energies):.6f}")
print(f"  Energy estimate: {jnp.mean(local_energies):.6f} ± {jnp.std(local_energies)/jnp.sqrt(len(local_energies)):.6f}")

In [None]:
# Gradient computation for variational optimization
def energy_function(params, configs, hamiltonian):
    """Energy function for optimization"""
    local_energies = local_energy_batched(params, configs, hamiltonian)
    return jnp.mean(local_energies)

# Compute gradient
energy_grad = jax.grad(energy_function)

# Evaluate gradient
gradient = energy_grad(params, batch_configs, hamiltonian)

print(f"Gradient computation:")
print(f"  Parameter gradients computed successfully")
for name, grad in gradient['params'].items():
    grad_norm = jnp.linalg.norm(grad.flatten())
    print(f"  {name}: gradient norm = {grad_norm:.6f}")

# JIT-compiled version for performance
energy_and_grad_jit = jax.jit(jax.value_and_grad(energy_function))

# Benchmark
n_runs = 10

# Regular computation
start = time.time()
for _ in range(n_runs):
    energy_val = energy_function(params, batch_configs, hamiltonian)
    grad_val = energy_grad(params, batch_configs, hamiltonian)
time_regular = time.time() - start

# JIT-compiled computation (warm up first)
_ = energy_and_grad_jit(params, batch_configs, hamiltonian)
start = time.time()
for _ in range(n_runs):
    energy_val_jit, grad_val_jit = energy_and_grad_jit(params, batch_configs, hamiltonian)
time_jit = time.time() - start

print(f"\nPerformance comparison ({n_runs} runs):")
print(f"  Regular: {time_regular:.4f} seconds")
print(f"  JIT: {time_jit:.4f} seconds")
print(f"  Speedup: {time_regular / time_jit:.1f}x")

## 9. Advanced Topics {#advanced}

Advanced JAX features useful for quantum many-body calculations.

In [None]:
# Scan for iterative computations
def rk4_step(state, dt_and_force):
    """Single RK4 step for time evolution"""
    y, t = state
    dt, force_fn = dt_and_force
    
    k1 = dt * force_fn(y, t)
    k2 = dt * force_fn(y + k1/2, t + dt/2)
    k3 = dt * force_fn(y + k2/2, t + dt/2)
    k4 = dt * force_fn(y + k3, t + dt)
    
    y_new = y + (k1 + 2*k2 + 2*k3 + k4) / 6
    t_new = t + dt
    
    return (y_new, t_new), y_new

def simulate_system(initial_state, times, force_fn):
    """Simulate a dynamical system using RK4"""
    dt = times[1] - times[0]
    dt_and_force = (dt, force_fn)
    
    # Use scan for efficient iteration
    final_state, trajectory = jax.lax.scan(
        rk4_step, 
        (initial_state, times[0]), 
        jnp.repeat(dt_and_force, len(times) - 1, axis=0)
    )
    
    return jnp.concatenate([initial_state[None], trajectory])

# Example: harmonic oscillator
def harmonic_force(state, t):
    """Force for harmonic oscillator: d²x/dt² = -ω²x"""
    x, v = state
    omega = 2.0
    return jnp.array([v, -omega**2 * x])

# Simulate
initial_state = jnp.array([1.0, 0.0])  # x=1, v=0
times = jnp.linspace(0, 2*jnp.pi, 100)

trajectory = simulate_system(initial_state, times, harmonic_force)

print(f"Time evolution simulation:")
print(f"  Time points: {len(times)}")
print(f"  Trajectory shape: {trajectory.shape}")
print(f"  Energy conservation (should be ~constant):")
energy = 0.5 * (trajectory[:, 1]**2 + 4 * trajectory[:, 0]**2)
print(f"    Initial energy: {energy[0]:.6f}")
print(f"    Final energy: {energy[-1]:.6f}")
print(f"    Energy drift: {abs(energy[-1] - energy[0])/energy[0]*100:.2e}%")

# Plot trajectory
plt.figure(figsize=(10, 4))
plt.subplot(1, 2, 1)
plt.plot(times, trajectory[:, 0], label='Position')
plt.plot(times, trajectory[:, 1], label='Velocity')
plt.xlabel('Time')
plt.ylabel('Value')
plt.title('Harmonic Oscillator Trajectory')
plt.legend()
plt.grid(True, alpha=0.3)

plt.subplot(1, 2, 2)
plt.plot(trajectory[:, 0], trajectory[:, 1])
plt.xlabel('Position')
plt.ylabel('Velocity')
plt.title('Phase Space')
plt.grid(True, alpha=0.3)
plt.axis('equal')

plt.tight_layout()
plt.show()

In [None]:
# Conditional computations with jax.lax.cond
def adaptive_step_size(current_error, target_error, current_dt):
    """Adaptive step size control"""
    safety_factor = 0.9
    max_factor = 2.0
    min_factor = 0.5
    
    def increase_step(args):
        current_dt, factor = args
        return jnp.minimum(current_dt * factor, current_dt * max_factor)
    
    def decrease_step(args):
        current_dt, factor = args
        return jnp.maximum(current_dt * factor, current_dt * min_factor)
    
    error_ratio = current_error / target_error
    factor = safety_factor * (1.0 / error_ratio)**0.2
    
    new_dt = jax.lax.cond(
        error_ratio < 1.0,
        increase_step,
        decrease_step,
        (current_dt, factor)
    )
    
    return new_dt

# Test adaptive step size
errors = jnp.array([0.001, 0.01, 0.1, 1.0, 10.0])
target_error = 0.01
current_dt = 0.01

print(f"Adaptive step size control:")
print(f"Target error: {target_error}")
print(f"Current dt: {current_dt}")

for error in errors:
    new_dt = adaptive_step_size(error, target_error, current_dt)
    print(f"  Error {error:.3f} -> dt {new_dt:.6f} (factor: {new_dt/current_dt:.2f})")

## 10. Conclusion {#conclusion}

This notebook has demonstrated the key JAX features that make NetKet powerful and efficient:

### JAX Features Covered:

1. **Automatic Differentiation**:
   - Essential for variational optimization
   - Forward and reverse mode AD
   - Higher-order derivatives

2. **JIT Compilation**:
   - Dramatic performance improvements
   - XLA optimization
   - Static vs dynamic shapes

3. **Vectorization with vmap**:
   - Automatic batching
   - Eliminates explicit loops
   - Better hardware utilization

4. **Parallelization with pmap**:
   - Multi-device computation
   - Scales to large systems
   - Efficient for Monte Carlo sampling

5. **Advanced Features**:
   - Custom gradients for stability
   - Memory optimization
   - Control flow primitives

### NetKet Benefits:

- **Performance**: JIT compilation and vectorization
- **Scalability**: Automatic parallelization
- **Flexibility**: Custom neural architectures
- **Numerical Stability**: Custom gradient rules
- **Memory Efficiency**: Gradient checkpointing

### Best Practices:

1. **Use JIT compilation** for computational hot spots
2. **Vectorize operations** with vmap instead of loops
3. **Leverage static shapes** when possible
4. **Use custom gradients** for numerical stability
5. **Profile and optimize** memory usage for large systems

### Impact on Quantum Many-Body Physics:

JAX's features enable:
- Efficient neural quantum state optimization
- Large-scale Monte Carlo simulations
- Real-time quantum dynamics
- Custom quantum algorithms
- Scalable quantum machine learning

The combination of JAX's computational power and NetKet's quantum-specific abstractions makes it possible to tackle previously intractable quantum many-body problems with neural network ansatzes.