# Interactive Workbook: Time-Splitting Spectral Method for Scalar BEC (JAX)

Welcome. This notebook is designed to build **deep intuition + implementation skill** for the split-step Fourier method applied to the Gross–Pitaevskii equation (GPE).

You will find:
- **Conceptual checkpoints** (explain in your own words)
- **Fill-in code blocks** (complete TODOs)
- **Try-it-yourself experiments** (change parameters and reason about outcomes)
- **Caveat/performance labs** (where methods fail or slow down)

> Target equation (dimensionless):

$$
i\partial_t\psi = \left[-\frac{1}{2}\nabla^2 + V(x,y) + g|\psi|^2\right]\psi
$$


## 0) Learning objectives

By the end, you should be able to:
1. Derive and implement Strang splitting for GPE.
2. Explain why FFT diagonalizes the kinetic operator.
3. Quantify accuracy via temporal/grid convergence tests.
4. Diagnose stability/aliasing/normalization issues.
5. Reason about performance tradeoffs (JIT, grid size, backend).


In [None]:
# If needed, run once:
# !pip install jax jaxlib numpy matplotlib pandas scipy

import time
import numpy as np
import jax
import jax.numpy as jnp
import matplotlib.pyplot as plt

from scalar_bec.solver import SolverConfig, run_simulation, make_grid, harmonic_potential, kspace_operators, normalize
from scalar_bec.diagnostics import norm, energy, l2_error

print('JAX backend:', jax.default_backend())
print('Devices:', jax.devices())


## 1) Concept checkpoint: operator splitting

We separate the Hamiltonian into:
- $A = V + g|\psi|^2$ (local in real space)
- $B = -\frac{1}{2}\nabla^2$ (diagonal in Fourier space)

Strang step:
$$
\psi^{n+1} \approx e^{-i\frac{\Delta t}{2}A} e^{-i\Delta t B} e^{-i\frac{\Delta t}{2}A} \psi^n
$$

### Questions
1. Why is this second-order accurate in time?
2. Why is the kinetic step cheap with FFT?
3. What breaks if $\Delta t$ is too large?

_Write your answers in a markdown cell below._


### ✍️ Your answers (insert markdown cell below this one)


### Solutions (Concept checkpoint)

1. **Why second-order?** Strang splitting is symmetric: half-step/full-step/half-step. The Baker–Campbell–Hausdorff expansion cancels odd-order commutator terms, leaving local truncation error $O(\Delta t^3)$ and global error $O(\Delta t^2)$.

2. **Why FFT makes kinetic step cheap?** In Fourier space, $-\nabla^2$ becomes multiplication by $k^2$. So the kinetic propagator is pointwise multiplication:
$$
\hat{\psi}(k) \leftarrow e^{-i\frac{\Delta t}{2}k^2}\hat{\psi}(k),
$$
with cost dominated by FFTs: $O(N\log N)$.

3. **What breaks for large $\Delta t$?** Splitting error grows, phase errors accumulate, invariants drift more, and dynamics can become qualitatively wrong (even if numerically stable).


## 2) Fill-in code: build grid, potential, initial condition

Complete the TODOs.


In [None]:
nx, ny = 128, 128
lx, ly = 20.0, 20.0
g = 100.0
dt = 5e-4
steps = 200

cfg = SolverConfig(nx=nx, ny=ny, lx=lx, ly=ly, g=g, dt=dt, steps=steps)

x, y, X, Y, dx, dy = make_grid(cfg)

V = harmonic_potential(X, Y, omega=1.0)

psi0 = jnp.exp(-(X**2 + Y**2)/2.0).astype(jnp.complex64)
psi0 = normalize(psi0, dx, dy)

print('dx, dy =', float(dx), float(dy))
print('Initial norm =', float(norm(psi0, dx, dy)))


## 3) Fill-in code: one Strang step manually

Implement a manual one-step update to see each operation explicitly.


In [None]:
# Manual one-step implementation exercise
kinetic_phase = kspace_operators(cfg)
half = 0.5 * dt

phase1 = jnp.exp(-1j * half * (V + g * jnp.abs(psi0)**2))
psi_half = phase1 * psi0

psi_k = jnp.fft.fft2(psi_half)
psi_k = kinetic_phase * psi_k
psi_full = jnp.fft.ifft2(psi_k)

phase2 = jnp.exp(-1j * half * (V + g * jnp.abs(psi_full)**2))
psi1 = phase2 * psi_full

print('Norm after one step =', float(norm(psi1, dx, dy)))


### Concept check
- Is norm exactly conserved numerically? Why/why not?
- Which part introduces most floating-point error?


### Solutions

- **Norm conservation:** In exact arithmetic, each substep is unitary, so norm is conserved. In floating-point arithmetic, tiny drift appears from FFT roundoff and finite precision.
- **Largest error source:** Typically FFT/ifft finite-precision roundoff plus repeated phase multiplications over many steps.


## 4) Run full simulation + diagnostics


In [None]:
cfg = SolverConfig(nx=256, ny=256, lx=20.0, ly=20.0, g=100.0, dt=5e-4, steps=500)

start = time.perf_counter()
out = run_simulation(cfg)
out['psi'].block_until_ready()
elapsed = time.perf_counter() - start

psi = out['psi']
V = out['V']

dN = float(norm(psi, out['dx'], out['dy']))
E = float(energy(psi, V, cfg.g, out['dx'], out['dy']))

print(f'backend={out["backend"]}')
print(f'elapsed_s={elapsed:.4f}')
print(f'norm={dN:.8f}')
print(f'energy={E:.8f}')


In [None]:
# Visualize density
rho = np.array(jnp.abs(psi)**2)

plt.figure(figsize=(5,4))
plt.imshow(rho.T, origin='lower', cmap='magma', extent=[float(out['x'][0]), float(out['x'][-1]), float(out['y'][0]), float(out['y'][-1])])
plt.colorbar(label='|psi|^2')
plt.title('Final Density')
plt.xlabel('x')
plt.ylabel('y')
plt.tight_layout()
plt.show()


## 5) Try-it-yourself block: explore nonlinearity strength $g$

Predict first, then run:
1. What happens to density shape as $g$ increases?
2. How does runtime change (if at all)?


In [None]:
g_values = [0.0, 10.0, 100.0, 300.0]
results = []

for gtest in g_values:
    cfg = SolverConfig(nx=256, ny=256, lx=20.0, ly=20.0, g=gtest, dt=5e-4, steps=300)
    t0 = time.perf_counter()
    out = run_simulation(cfg)
    out['psi'].block_until_ready()
    t1 = time.perf_counter()
    results.append((gtest, t1-t0, float(norm(out['psi'], out['dx'], out['dy']))))

for row in results:
    print(f'g={row[0]:6.1f}  elapsed={row[1]:.4f}s  norm={row[2]:.8f}')


## 6) Convergence study: error vs grid size

We compare each grid to a high-resolution reference (subsampled).


In [None]:
def resample_ref(ref, n):
    stride = ref.shape[0] // n
    return ref[::stride, ::stride]

ref_cfg = SolverConfig(nx=512, ny=512, steps=500, dt=2.5e-4)
ref_out = run_simulation(ref_cfg)
ref_psi = ref_out['psi']

grid_list = [64, 128, 256]
errs = []

for n in grid_list:
    cfg = SolverConfig(nx=n, ny=n, steps=500, dt=2.5e-4)
    out = run_simulation(cfg)
    e = float(l2_error(out['psi'], resample_ref(ref_psi, n), out['dx'], out['dy']))
    errs.append((n, e))
    print(f'n={n:4d}, L2 error={e:.6e}')


In [None]:
# Plot grid convergence
nvals = np.array([e[0] for e in errs])
errvals = np.array([e[1] for e in errs])

plt.figure(figsize=(5,4))
plt.loglog(nvals, errvals, 'o-')
plt.title('Grid Convergence')
plt.xlabel('N')
plt.ylabel('L2 error vs reference')
plt.grid(True, which='both', alpha=0.3)
plt.tight_layout()
plt.show()


### Concept question
If the method is spectral in space, why might you *not* observe ideal exponential convergence in this test?

Hints:
- reference solution quality
- subsampling strategy
- finite precision
- non-smooth features


### Solutions

You may miss ideal exponential convergence because:
- the "reference" is not exact,
- subsampling can introduce mismatch/aliasing,
- finite precision limits asymptotic improvement,
- nonlinear dynamics can generate sharper features that need higher resolution,
- domain truncation/boundary effects contaminate error.


## 7) Temporal convergence: error vs $\Delta t$


In [None]:
ref_cfg = SolverConfig(nx=256, ny=256, steps=2400, dt=1.25e-4)
ref = run_simulation(ref_cfg)['psi']

tests = [(1e-3, 300), (5e-4, 600), (2.5e-4, 1200)]
rows = []
for dt_test, steps_test in tests:
    cfg = SolverConfig(nx=256, ny=256, dt=dt_test, steps=steps_test)
    out = run_simulation(cfg)
    rmse = float(jnp.sqrt(jnp.mean(jnp.abs(out['psi'] - ref)**2)))
    rows.append((dt_test, rmse))
    print(f'dt={dt_test:.2e}, rmse={rmse:.6e}')


In [None]:
dts = np.array([r[0] for r in rows])
erm = np.array([r[1] for r in rows])

plt.figure(figsize=(5,4))
plt.loglog(dts, erm, 'o-')
plt.gca().invert_xaxis()
plt.title('Temporal Convergence')
plt.xlabel('dt')
plt.ylabel('RMSE vs reference')
plt.grid(True, which='both', alpha=0.3)
plt.tight_layout()
plt.show()


### Fill-in: estimate convergence order
For two runs $(\Delta t_1, e_1), (\Delta t_2, e_2)$, the empirical order is:
$$
p \approx \frac{\log(e_1/e_2)}{\log(\Delta t_1/\Delta t_2)}
$$


In [None]:
(dt1, e1), (dt2, e2) = rows[0], rows[-1]
p = np.log(e1/e2)/np.log(dt1/dt2)
print('Empirical temporal order p =', p)


## 8) Performance lab: warmup, JIT, and measurement pitfalls

Important caveat: first JAX call includes compile overhead.


In [None]:
cfg = SolverConfig(nx=256, ny=256, steps=300)

# First run (compile + execute)
t0 = time.perf_counter()
out = run_simulation(cfg)
out['psi'].block_until_ready()
t1 = time.perf_counter()

# Second run (mostly execute)
t2 = time.perf_counter()
out2 = run_simulation(cfg)
out2['psi'].block_until_ready()
t3 = time.perf_counter()

print(f'first_run_s  = {t1-t0:.4f}')
print(f'second_run_s = {t3-t2:.4f}')
print('speedup from removing compile overhead:', (t1-t0)/(t3-t2))


### Try-it-yourself
1. Change grid to 384 and 512. How does scaling behave?
2. Repeat 3 times and report median runtime.
3. On GPU, compare FFT-heavy workload vs CPU.


## 9) Optional C++ bridge (discussion + exercise)

You have `scalar_bec_cpp` available. In this project, JAX JIT/XLA often beats naive Python loops and can fuse ops efficiently.

### Conceptual question
When would a custom C++ (or CUDA) kernel still help?
- unsupported operation patterns in XLA
- custom memory layout control
- specialized fused kernels not produced by compiler

### Exercise
Try replacing nonlinear phase update with a host callback to C++ and benchmark. Does it help or hurt? Why?


### Suggested answer

A custom C++/CUDA kernel can help when:
- XLA cannot fuse the needed sequence efficiently,
- memory traffic dominates and custom fusion reduces reads/writes,
- you need specialized numerics/layouts not expressible in JAX primitives,
- integrating with existing HPC kernels/libraries.

It can hurt when host-device transfers or callback boundaries break JIT fusion.


## 10) Caveats and failure modes checklist

- **Aliasing:** nonlinear term can populate high-k modes.
- **Boundary artifacts:** FFT implies periodic boundaries.
- **Time-step too large:** phase errors accumulate; wrong dynamics.
- **Energy drift:** monitor when changing precision/backend.
- **Reference error contamination:** convergence studies depend on trustworthy reference.
- **Precision choice:** complex64 is faster but less accurate than complex128.


## 11) Capstone mini-project

Implement and test one of:
1. **Imaginary-time propagation** for ground state.
2. **Rotating frame + vortex initial condition.**
3. **Dealiasing filter** and show effect on stability/error.

For your report include:
- equation modifications,
- implementation details,
- runtime + accuracy comparison,
- at least one plot and one caveat.


---

## Bonus prompt (for your own notes)

In 5–10 bullets, explain *why* the split-step Fourier method is a strong default for nonlinear Schrödinger/GPE problems, and where you’d avoid it.
