# 3D Magnetic Field Diffusion (Rm << 1)

**Validation case**: Magnetic field diffusion in the resistive limit

## Learning Objectives

After completing this notebook, you will understand:

1. The magnetic induction equation and the role of resistivity
2. The magnetic Reynolds number and what Rm << 1 means physically
3. How a Gaussian magnetic field profile spreads due to diffusion
4. The analytic solution for magnetic diffusion in 2D and 3D

---
## 1. Physics Background

### The Magnetic Induction Equation

The evolution of magnetic field in a conducting plasma is governed by the **induction equation**:

$$\frac{\partial \mathbf{B}}{\partial t} = \nabla \times (\mathbf{v} \times \mathbf{B}) + \eta \nabla^2 \mathbf{B}$$

where:
- $\mathbf{v}$ is the plasma velocity
- $\eta = 1/(\mu_0 \sigma)$ is the magnetic diffusivity (related to resistivity)
- $\sigma$ is the electrical conductivity

The two terms represent competing physics:

| Term | Physics | Effect |
|------|---------|--------|
| $\nabla \times (\mathbf{v} \times \mathbf{B})$ | Advection | Field frozen to plasma, moves with flow |
| $\eta \nabla^2 \mathbf{B}$ | Diffusion | Field slips through plasma, spreads out |

### The Magnetic Reynolds Number

The **magnetic Reynolds number** measures the relative importance of advection vs diffusion:

$$R_m = \frac{v L}{\eta} = \frac{\text{advection rate}}{\text{diffusion rate}}$$

where $v$ is a characteristic velocity and $L$ is a characteristic length scale.

| Regime | Condition | Physics |
|--------|-----------|--------|
| **Ideal MHD** | $R_m \gg 1$ | Advection dominates, field frozen to plasma |
| **Resistive MHD** | $R_m \ll 1$ | Diffusion dominates, field slips through plasma |

In this notebook, we study the **resistive limit** ($R_m \ll 1$) where $\mathbf{v} = 0$ and the induction equation reduces to:

$$\frac{\partial \mathbf{B}}{\partial t} = \eta \nabla^2 \mathbf{B}$$

This is the **heat equation** (or diffusion equation) for the magnetic field!

### Analytic Solution: Spreading Gaussian

For a 2D Gaussian initial condition in the x-z plane:

$$B_z(x, z, t=0) = B_{\text{peak}} \exp\left(-\frac{x^2 + z^2}{2\sigma_0^2}\right)$$

The exact solution is a **spreading Gaussian**:

$$B_z(x, z, t) = B_{\text{peak}} \frac{\sigma_0^2}{\sigma_t^2} \exp\left(-\frac{x^2 + z^2}{2\sigma_t^2}\right)$$

where $\sigma_t^2 = \sigma_0^2 + 2Dt$ and $D = \eta/\mu_0$ is the magnetic diffusivity.

Key features:
- **Width grows**: $\sigma(t) = \sqrt{\sigma_0^2 + 2Dt}$
- **Peak decreases**: amplitude $\propto \sigma_0^2/\sigma_t^2$ in 2D
- **Characteristic timescale**: $\tau_{\text{diff}} = \sigma_0^2 / (2D)$

### Physical Intuition

Magnetic diffusion is analogous to heat conduction:

- A localized "hot spot" of magnetic field spreads into surrounding regions
- The field diffuses from high-B to low-B regions
- Higher resistivity ($\eta$) means faster diffusion
- The process is irreversible (entropy increases)

In fusion plasmas, magnetic diffusion is usually slow ($R_m \sim 10^6$ in tokamaks), but it becomes important:
- During magnetic reconnection events
- In resistive instabilities (tearing modes)
- Near plasma boundaries where temperature (and thus conductivity) is lower

---
## 2. Configuration Setup

In [None]:
# Standard imports
import jax
import jax.numpy as jnp
import jax.lax as lax
import numpy as np
import matplotlib.pyplot as plt
from IPython.display import HTML

# jax-frc imports
from jax_frc.configurations import MagneticDiffusionConfiguration
from jax_frc.solvers.explicit import EulerSolver

# Notebook utilities
import sys
sys.path.insert(0, '.')
from _shared import plot_comparison, plot_error, compute_metrics, metrics_summary, animate_overlay

# Plotting style
%matplotlib inline
plt.rcParams['figure.dpi'] = 100
plt.rcParams['font.size'] = 11

In [None]:
# === PHYSICS PARAMETERS ===

B_peak = 1.0       # Peak magnetic field [T]
sigma = 0.1        # Initial Gaussian width [m]

# Resistivity (not diffusivity!) - diffusivity D = eta/mu_0
eta = 1.26e-10     # Magnetic resistivity [Ω·m]
MU0 = 1.2566e-6    # Permeability of free space
diffusivity = eta / MU0  # ~1e-4 m²/s

print(f"Resistivity η = {eta:.2e} Ω·m")
print(f"Diffusivity D = η/μ₀ = {diffusivity:.2e} m²/s")

# === GRID PARAMETERS (3D Cartesian) ===

nx = 64            # X resolution
ny = 4             # Y resolution (thin for 2D-like behavior)
nz = 64            # Z resolution
extent = 1.0       # Domain: [-extent, extent] in each direction

print(f"\n3D Grid: {nx} × {ny} × {nz} (x × y × z)")
print(f"Domain: [{-extent}, {extent}]³")
print(f"Thin y-direction makes this effectively 2D in x-z plane")

# === TIME PARAMETERS ===

# Diffusion timescale
tau_diff = sigma**2 / (2 * diffusivity)
print(f"\nDiffusion timescale: τ_diff = {tau_diff:.2f} s")

t_end = 0.1 * tau_diff    # Run for 0.1 diffusion times

# CFL constraint for diffusion
dx = 2 * extent / nx
dz = 2 * extent / nz
dx_min = min(dx, dz)
dt_max = 0.25 * dx_min**2 / diffusivity
dt = dt_max * 0.5  # Safety factor

print(f"Simulation time: t_end = {t_end:.2f} s")
print(f"Timestep: dt = {dt:.4f} s (CFL: {dt_max:.4f} s)")
print(f"Expected steps: ~{int(t_end/dt)}")

In [None]:
# Create configuration with 3D Cartesian geometry
config = MagneticDiffusionConfiguration(
    B_peak=B_peak,
    sigma=sigma,
    eta=eta,
    nx=nx,
    ny=ny,
    nz=nz,
    extent=extent,
)

# Build simulation components
geometry = config.build_geometry()
initial_state = config.build_initial_state(geometry)
model = config.build_model()
solver = EulerSolver()

print(f"Grid: {geometry.nx} × {geometry.ny} × {geometry.nz} (x × y × z)")
print(f"Domain: x ∈ [{geometry.x_min:.1f}, {geometry.x_max:.1f}]")
print(f"        y ∈ [{geometry.y_min:.1f}, {geometry.y_max:.1f}]")
print(f"        z ∈ [{geometry.z_min:.1f}, {geometry.z_max:.1f}]")
print(f"Resolution: Δx = {geometry.dx:.4f}, Δy = {geometry.dy:.4f}, Δz = {geometry.dz:.4f}")
print(f"\nModel: {type(model).__name__}")
print(f"Magnetic Reynolds number: Rm = {config.magnetic_reynolds_number():.2e} (should be << 1)")

### Visualize Initial Condition

In [None]:
# Get 3D grid coordinates
x_3d = np.array(geometry.x_grid)
y_3d = np.array(geometry.y_grid)
z_3d = np.array(geometry.z_grid)

# Initial B_z profile - take a slice at y=0 (middle of thin y dimension)
y_mid_idx = geometry.ny // 2
Bz_init_2d = np.array(initial_state.B[:, y_mid_idx, :, 2])  # x-z slice

# 2D coordinates for plotting (x-z slice)
x_2d = x_3d[:, y_mid_idx, :]
z_2d = z_3d[:, y_mid_idx, :]

# Create figure with 2D contour and cross-sections
fig, axes = plt.subplots(1, 3, figsize=(15, 4))

# 2D contour plot (x-z slice)
im = axes[0].contourf(z_2d, x_2d, Bz_init_2d, levels=20, cmap='RdBu_r')
axes[0].set_xlabel('z [m]', fontsize=12)
axes[0].set_ylabel('x [m]', fontsize=12)
axes[0].set_title('Initial $B_z$ (x-z slice at y=0)', fontsize=14)
plt.colorbar(im, ax=axes[0], label='$B_z$ [T]')
# Mark center
axes[0].plot(0, 0, 'k+', markersize=15, markeredgewidth=2)

# X slice at z=0
z_mid_idx = geometry.nz // 2
x_1d = x_2d[:, z_mid_idx]
Bz_x = Bz_init_2d[:, z_mid_idx]
axes[1].plot(x_1d, Bz_x, 'b-', linewidth=2)
axes[1].axvline(0, color='gray', linestyle='--', alpha=0.5)
axes[1].set_xlabel('x [m]', fontsize=12)
axes[1].set_ylabel('$B_z$ [T]', fontsize=12)
axes[1].set_title('X slice at z=0', fontsize=14)
axes[1].grid(True, alpha=0.3)

# Z slice at x=0
x_mid_idx = geometry.nx // 2
z_1d = z_2d[x_mid_idx, :]
Bz_z = Bz_init_2d[x_mid_idx, :]
axes[2].plot(z_1d, Bz_z, 'b-', linewidth=2)
axes[2].axvline(0, color='gray', linestyle='--', alpha=0.5)
axes[2].set_xlabel('z [m]', fontsize=12)
axes[2].set_ylabel('$B_z$ [T]', fontsize=12)
axes[2].set_title('Z slice at x=0', fontsize=14)
axes[2].grid(True, alpha=0.3)

plt.tight_layout()
plt.show()

print(f"Peak B_z = {np.max(Bz_init_2d):.3f} T at origin")
print(f"Initial width σ₀ = {sigma} m")

---
## 3. Run Simulation

We evolve the Gaussian profile forward in time. The magnetic field should spread isotropically according to the analytic solution.

In [None]:
# Time stepping setup
n_steps = int(t_end / dt)
save_interval = max(1, n_steps // 50)  # ~50 snapshots

print(f"Running {n_steps} steps with dt = {dt:.2e}")
print(f"Saving every {save_interval} steps")

In [None]:
# Run simulation using lax.scan for efficiency
@jax.jit
def scan_step(state, _):
    new_state = solver.step(state, dt, model, geometry)
    return new_state, new_state.B[:, y_mid_idx, :, 2]  # Save x-z slice of B_z

# Run simulation
print("Running simulation...")
final_state, Bz_history = lax.scan(scan_step, initial_state, None, length=n_steps)

# Convert to numpy and subsample for plotting
Bz_history = np.array(Bz_history)[::save_interval]
times = np.arange(0, n_steps, save_interval) * dt

# Add final state if not included
Bz_final_2d = np.array(final_state.B[:, y_mid_idx, :, 2])

print(f"Simulation complete")
print(f"Saved {len(times)} snapshots")

---
## 4. Compare with Analytic Solution

In [None]:
# Compute 2D analytic solution at final time (x-z plane)
t_final = n_steps * dt
Bz_analytic_2d = np.array(config.analytic_solution_2d(
    jnp.array(x_2d), jnp.array(z_2d), t_final
))

# Expected width at final time
sigma_final = np.sqrt(sigma**2 + 2 * diffusivity * t_final)
print(f"Initial width: σ₀ = {sigma:.4f} m")
print(f"Final width:   σ(t) = {sigma_final:.4f} m")
print(f"Width increase: {(sigma_final/sigma - 1)*100:.1f}%")
print(f"\nPeak decay: σ₀²/σ_t² = {(sigma/sigma_final)**2:.3f} (2D scaling)")

In [None]:
# 2D comparison plots
fig, axes = plt.subplots(2, 3, figsize=(15, 9))

# Top row: Initial, Final Numerical, Final Analytic (2D contours)
vmax = B_peak
levels = np.linspace(0, vmax, 21)

im0 = axes[0, 0].contourf(z_2d, x_2d, Bz_init_2d, levels=levels, cmap='RdBu_r')
axes[0, 0].set_title('Initial $B_z$', fontsize=12)
axes[0, 0].set_ylabel('x [m]')
plt.colorbar(im0, ax=axes[0, 0])

im1 = axes[0, 1].contourf(z_2d, x_2d, Bz_final_2d, levels=levels, cmap='RdBu_r')
axes[0, 1].set_title(f'Numerical $B_z$ at t={t_final:.2f}s', fontsize=12)
plt.colorbar(im1, ax=axes[0, 1])

im2 = axes[0, 2].contourf(z_2d, x_2d, Bz_analytic_2d, levels=levels, cmap='RdBu_r')
axes[0, 2].set_title(f'Analytic $B_z$ at t={t_final:.2f}s', fontsize=12)
plt.colorbar(im2, ax=axes[0, 2])

# Bottom row: Error, X slice, Z slice
error_2d = Bz_final_2d - Bz_analytic_2d
err_max = np.max(np.abs(error_2d))
im3 = axes[1, 0].contourf(z_2d, x_2d, error_2d, levels=20, cmap='coolwarm')
axes[1, 0].set_title('Error (Num - Ana)', fontsize=12)
axes[1, 0].set_xlabel('z [m]')
axes[1, 0].set_ylabel('x [m]')
plt.colorbar(im3, ax=axes[1, 0], label='Error [T]')

# X slice comparison at z=0
axes[1, 1].plot(x_1d, Bz_init_2d[:, z_mid_idx], 'gray', linestyle=':', label='Initial')
axes[1, 1].plot(x_1d, Bz_final_2d[:, z_mid_idx], 'b-', linewidth=2, label='Numerical')
axes[1, 1].plot(x_1d, Bz_analytic_2d[:, z_mid_idx], 'r--', linewidth=2, label='Analytic')
axes[1, 1].set_xlabel('x [m]')
axes[1, 1].set_ylabel('$B_z$ [T]')
axes[1, 1].set_title('X slice at z=0')
axes[1, 1].legend()
axes[1, 1].grid(True, alpha=0.3)

# Z slice comparison at x=0
axes[1, 2].plot(z_1d, Bz_init_2d[x_mid_idx, :], 'gray', linestyle=':', label='Initial')
axes[1, 2].plot(z_1d, Bz_final_2d[x_mid_idx, :], 'b-', linewidth=2, label='Numerical')
axes[1, 2].plot(z_1d, Bz_analytic_2d[x_mid_idx, :], 'r--', linewidth=2, label='Analytic')
axes[1, 2].set_xlabel('z [m]')
axes[1, 2].set_ylabel('$B_z$ [T]')
axes[1, 2].set_title('Z slice at x=0')
axes[1, 2].legend()
axes[1, 2].grid(True, alpha=0.3)

plt.tight_layout()
plt.show()

In [None]:
# Error distribution (use shared utility on flattened arrays)
fig, axes = plot_error(
    np.arange(Bz_final_2d.size), Bz_final_2d.ravel(), Bz_analytic_2d.ravel(),
    xlabel='Grid point index',
    title='Error Analysis (flattened)'
)
plt.show()

In [None]:
# Compute validation metrics on full 2D field
metrics = compute_metrics(Bz_final_2d.ravel(), Bz_analytic_2d.ravel())

# Display with thresholds
thresholds = {
    'l2_error': 0.05,       # 5% relative L2 error
    'max_rel_error': 0.10,  # 10% max relative error
}

print("Validation Metrics:")
print("=" * 50)
for name, value in metrics.items():
    threshold = thresholds.get(name, None)
    if threshold:
        status = '✓ PASS' if value <= threshold else '✗ FAIL'
        print(f"{name:20s}: {value:.4e}  (threshold: {threshold:.2e}) {status}")
    else:
        print(f"{name:20s}: {value:.4e}")

# Overall result
overall_pass = all(metrics[k] <= thresholds[k] for k in thresholds if k in metrics)
print("=" * 50)
print(f"Overall: {'✓ PASS' if overall_pass else '✗ FAIL'}")

---
## 5. Time Evolution Animation

In [None]:
# 2D animation of spreading Gaussian
from matplotlib.animation import FuncAnimation

fig, axes = plt.subplots(1, 2, figsize=(12, 5))

# Initialize plots
levels = np.linspace(0, B_peak, 21)

def animate(frame):
    t = times[frame]
    
    # Clear and replot (contourf doesn't support set_array)
    axes[0].clear()
    axes[0].contourf(z_2d, x_2d, Bz_history[frame], levels=levels, cmap='RdBu_r')
    axes[0].set_xlabel('z [m]')
    axes[0].set_ylabel('x [m]')
    axes[0].set_title(f'Numerical (t={t:.3f}s)')
    
    axes[1].clear()
    Bz_ana_t = np.array(config.analytic_solution_2d(jnp.array(x_2d), jnp.array(z_2d), t))
    axes[1].contourf(z_2d, x_2d, Bz_ana_t, levels=levels, cmap='RdBu_r')
    axes[1].set_xlabel('z [m]')
    axes[1].set_ylabel('x [m]')
    axes[1].set_title(f'Analytic (t={t:.3f}s)')
    
    return []

anim = FuncAnimation(fig, animate, frames=len(times), interval=200, blit=False)
plt.close()

HTML(anim.to_jshtml())

---
## 6. Interactive Exploration

In [None]:
import ipywidgets as widgets
from IPython.display import display, clear_output

def explore_time(time_idx):
    """Show solution at a specific snapshot."""
    fig, axes = plt.subplots(2, 2, figsize=(12, 10))
    
    t = times[time_idx]
    Bz_num = Bz_history[time_idx]
    Bz_ana = np.array(config.analytic_solution_2d(jnp.array(x_2d), jnp.array(z_2d), t))
    
    # 2D numerical
    levels = np.linspace(0, B_peak, 21)
    im0 = axes[0, 0].contourf(z_2d, x_2d, Bz_num, levels=levels, cmap='RdBu_r')
    axes[0, 0].set_xlabel('z [m]')
    axes[0, 0].set_ylabel('x [m]')
    axes[0, 0].set_title(f'Numerical $B_z$ at t = {t:.4f} s')
    plt.colorbar(im0, ax=axes[0, 0])
    
    # 2D analytic
    im1 = axes[0, 1].contourf(z_2d, x_2d, Bz_ana, levels=levels, cmap='RdBu_r')
    axes[0, 1].set_xlabel('z [m]')
    axes[0, 1].set_ylabel('x [m]')
    axes[0, 1].set_title(f'Analytic $B_z$ at t = {t:.4f} s')
    plt.colorbar(im1, ax=axes[0, 1])
    
    # X slice at z=0
    axes[1, 0].plot(x_1d, Bz_init_2d[:, z_mid_idx], 'gray', linestyle=':', alpha=0.7, label='Initial')
    axes[1, 0].plot(x_1d, Bz_num[:, z_mid_idx], 'b-', linewidth=2, label='Numerical')
    axes[1, 0].plot(x_1d, Bz_ana[:, z_mid_idx], 'r--', linewidth=2, label='Analytic')
    axes[1, 0].set_xlabel('x [m]')
    axes[1, 0].set_ylabel('$B_z$ [T]')
    axes[1, 0].set_title('X slice at z=0')
    axes[1, 0].legend()
    axes[1, 0].grid(True, alpha=0.3)
    axes[1, 0].set_ylim(0, B_peak * 1.1)
    
    # Width indicator
    sigma_t = np.sqrt(sigma**2 + 2 * diffusivity * t)
    axes[1, 0].axvline(-sigma_t, color='green', linestyle='--', alpha=0.5)
    axes[1, 0].axvline(sigma_t, color='green', linestyle='--', alpha=0.5, 
                       label=f'σ(t) = {sigma_t:.3f}')
    axes[1, 0].legend()
    
    # Z slice at x=0
    axes[1, 1].plot(z_1d, Bz_init_2d[x_mid_idx, :], 'gray', linestyle=':', alpha=0.7, label='Initial')
    axes[1, 1].plot(z_1d, Bz_num[x_mid_idx, :], 'b-', linewidth=2, label='Numerical')
    axes[1, 1].plot(z_1d, Bz_ana[x_mid_idx, :], 'r--', linewidth=2, label='Analytic')
    axes[1, 1].set_xlabel('z [m]')
    axes[1, 1].set_ylabel('$B_z$ [T]')
    axes[1, 1].set_title('Z slice at x=0')
    axes[1, 1].legend()
    axes[1, 1].grid(True, alpha=0.3)
    axes[1, 1].set_ylim(0, B_peak * 1.1)
    
    plt.tight_layout()
    plt.show()
    
    # Print metrics
    m = compute_metrics(Bz_num.ravel(), Bz_ana.ravel())
    print(f"L2 error: {m['l2_error']:.4e}  |  Max error: {m['linf_error']:.4e}")
    print(f"Current width σ(t) = {sigma_t:.4f} m")

# Time slider
time_slider = widgets.IntSlider(
    value=0, min=0, max=len(times)-1, step=1,
    description='Snapshot:',
    continuous_update=False
)

ui = widgets.VBox([
    widgets.HTML('<h4>Time Evolution Explorer</h4>'),
    widgets.HTML('<p>Drag the slider to see the solution at different times. Watch the Gaussian spread!</p>'),
    time_slider,
])

interactive_output = widgets.interactive_output(explore_time, {'time_idx': time_slider})

display(ui, interactive_output)

---
## Summary

This notebook demonstrated **magnetic field diffusion** using 3D Cartesian coordinates:

1. **Physics**: In the resistive limit ($R_m \ll 1$), the magnetic induction equation reduces to a diffusion equation. Magnetic field spreads through the plasma like heat through a conductor.

2. **Analytic solution**: A Gaussian initial condition evolves as:
   - Width: $\sigma(t) = \sqrt{\sigma_0^2 + 2Dt}$ (isotropic spreading)
   - Peak: $B_{\text{peak}}(t) = B_0 \cdot \sigma_0^2/\sigma_t^2$ (2D amplitude decay)

3. **3D Cartesian coordinates**: The simulation uses native Cartesian (x, y, z) coordinates. For 2D-like behavior, we use a thin y dimension (ny=4) with periodic boundaries.

4. **Validation**: The numerical solution matches the analytic solution, confirming correct implementation of resistive diffusion.

### Physical Significance

Magnetic diffusion is important in:
- **Magnetic reconnection**: Field lines can break and reconnect when diffusion is significant
- **Resistive instabilities**: Tearing modes, resistive wall modes
- **Edge physics**: Lower temperature means higher resistivity and faster diffusion

### Try Next

- Increase resolution (`nx`, `nz`) to reduce numerical error
- Increase `ny` to see full 3D diffusion
- Compare with the **frozen flux** notebook to see the opposite limit ($R_m \gg 1$)