# JAX Optimization Tutorial

This notebook demonstrates gradient-based optimization through BIMFx's JAX-native
MFS solver using implicit differentiation for linear solves. We match a target
field by adjusting a torus boundary parameterization.

In [None]:
import os
os.environ.setdefault('JAX_ENABLE_X64', '1')

import jax
import jax.numpy as jnp
import numpy as np

from bimfx import solve_mfs_jax
from bimfx.objectives import boundary_residual_objective


## Parameterized torus boundary
We parameterize a torus by major radius `R` and minor radius `r` and build a
boundary point cloud and normals.

In [None]:
def torus_points_normals(R, r, nphi=6, ntheta=6):
    phi = jnp.linspace(0.0, 2.0 * jnp.pi, nphi, endpoint=False)
    theta = jnp.linspace(0.0, 2.0 * jnp.pi, ntheta, endpoint=False)
    phi2, theta2 = jnp.meshgrid(phi, theta, indexing='ij')
    x = (R + r * jnp.cos(theta2)) * jnp.cos(phi2)
    y = (R + r * jnp.cos(theta2)) * jnp.sin(phi2)
    z = r * jnp.sin(theta2)
    P = jnp.stack([x, y, z], axis=-1).reshape(-1, 3)
    nx = jnp.cos(theta2) * jnp.cos(phi2)
    ny = jnp.cos(theta2) * jnp.sin(phi2)
    nz = jnp.sin(theta2)
    N = jnp.stack([nx, ny, nz], axis=-1).reshape(-1, 3)
    N = N / jnp.maximum(1e-30, jnp.linalg.norm(N, axis=1, keepdims=True))
    return P, N


## Target field from a reference geometry
We generate a target field using a reference torus and then fit a perturbed
torus to match it at probe points.

In [None]:
R_ref, r_ref = 3.0, 1.0
P_ref, N_ref = torus_points_normals(R_ref, r_ref)
field_ref = solve_mfs_jax(P_ref, N_ref, harmonic_coeffs=(1.0, 0.0))
probe = P_ref - 0.05 * N_ref
B_target = jax.lax.stop_gradient(field_ref.B(probe))


## Objective and optimization
We minimize the mismatch to the target field at probe points.
The linear solve is implicitly differentiated, so gradients propagate through
the MFS solver to geometry parameters.

In [None]:
def objective(params):
    R, r = params
    P, N = torus_points_normals(R, r)
    field = solve_mfs_jax(P, N, harmonic_coeffs=(1.0, 0.0))
    B = field.B(probe)
    return jnp.mean((B - B_target) ** 2)

objective_jit = jax.jit(objective)
grad_jit = jax.jit(jax.grad(objective))

params = jnp.array([2.6, 0.7])  # initial guess (R, r)
lr = 0.5
history = []
for k in range(15):
    loss = objective_jit(params)
    history.append(float(loss))
    params = params - lr * grad_jit(params)

print('history:', history)
print('final params:', np.array(params))
