# Heat Diffusion in a Slab Geometry

**Validation case**: 1D heat diffusion with analytic Gaussian solution

## Learning Objectives

After completing this notebook, you will understand:

1. How the heat equation governs thermal transport in plasmas
2. Why a Gaussian temperature profile is a self-similar solution
3. How to set up and run thermal diffusion simulations in jax-frc
4. How numerical solutions compare to analytic predictions
5. What controls numerical accuracy (grid resolution, timestep)

---
## 1. Physics Background

### The Heat Equation

Heat conduction in a plasma is governed by the energy equation. In the simplest case of pure diffusion (no convection, no sources), the temperature $T$ evolves according to:

$$\frac{\partial T}{\partial t} = \kappa \nabla^2 T$$

where $\kappa$ is the **thermal diffusivity** (units: m²/s). This is the classic **heat equation** or **diffusion equation**.

### Physical Meaning

The heat equation describes how temperature gradients drive heat flow:

1. **Heat flux**: $\mathbf{q} = -\kappa \nabla T$ (Fourier's law) — heat flows from hot to cold
2. **Energy conservation**: $\frac{\partial T}{\partial t} = -\nabla \cdot \mathbf{q}$ — temperature rises where heat accumulates
3. **Combining**: $\frac{\partial T}{\partial t} = \nabla \cdot (\kappa \nabla T) = \kappa \nabla^2 T$ (for constant $\kappa$)

In a plasma, the thermal diffusivity depends on collisionality and magnetic field geometry. For this test case, we use constant $\kappa$ to enable exact analytic comparison.

### Gaussian Solution

Consider the 1D heat equation along $z$:

$$\frac{\partial T}{\partial t} = \kappa \frac{\partial^2 T}{\partial z^2}$$

We seek solutions of the form $T(z,t) = A(t) \exp\left(-\frac{z^2}{2\sigma(t)^2}\right) + T_\text{base}$.

**Step 1: Compute the spatial derivative**

$$\frac{\partial T}{\partial z} = A(t) \cdot \left(-\frac{z}{\sigma^2}\right) \exp\left(-\frac{z^2}{2\sigma^2}\right)$$

$$\frac{\partial^2 T}{\partial z^2} = A(t) \left(\frac{z^2}{\sigma^4} - \frac{1}{\sigma^2}\right) \exp\left(-\frac{z^2}{2\sigma^2}\right)$$

**Step 2: Substitute into heat equation**

For the equation to be satisfied, we need the time evolution of $A(t)$ and $\sigma(t)$ to match the diffusion term.

**Step 3: Self-similar solution**

The key insight is that if we choose:

$$\sigma(t)^2 = \sigma_0^2 + 2\kappa t$$

and require total heat conservation (integral of $T - T_\text{base}$ is constant), then:

$$A(t) = T_\text{peak} \sqrt{\frac{\sigma_0^2}{\sigma_0^2 + 2\kappa t}}$$

**The complete analytic solution is:**

$$\boxed{T(z,t) = T_\text{peak} \sqrt{\frac{\sigma_0^2}{\sigma_0^2 + 2\kappa t}} \exp\left(-\frac{z^2}{2(\sigma_0^2 + 2\kappa t)}\right) + T_\text{base}}$$

### Physical Interpretation

The solution describes a **spreading Gaussian**:

| Property | Behavior | Formula |
|----------|----------|---------|
| Width | Grows with time | $\sigma(t) = \sqrt{\sigma_0^2 + 2\kappa t}$ |
| Peak | Decreases with time | $T_\text{max}(t) = T_\text{peak} \sqrt{\sigma_0^2 / (\sigma_0^2 + 2\kappa t)}$ |
| Total heat | Conserved | $\int (T - T_\text{base}) dz = \text{const}$ |

The width grows as $\sqrt{t}$ — this is the characteristic **diffusive scaling**. At early times, the profile spreads rapidly; at late times, the spreading slows as gradients become weaker.

**Diffusion timescale**: The characteristic time for the profile to spread significantly is:

$$\tau_D \sim \frac{\sigma_0^2}{2\kappa}$$

When $t \ll \tau_D$, the profile is nearly unchanged. When $t \gg \tau_D$, it has spread substantially.

---
## 2. Configuration Setup

Let's set up the simulation with parameters chosen to see clear diffusion within a reasonable runtime.

In [None]:
# Standard imports
import jax
import jax.numpy as jnp
import numpy as np
import matplotlib.pyplot as plt
from IPython.display import HTML

# jax-frc imports
from jax_frc.configurations import SlabDiffusionConfiguration
from jax_frc.solvers.semi_implicit import SemiImplicitSolver

# Notebook utilities
import sys
sys.path.insert(0, '.')
from _shared import plot_comparison, plot_error, animate_overlay, compute_metrics, metrics_summary

# Plotting style
%matplotlib inline
plt.rcParams['figure.dpi'] = 100
plt.rcParams['font.size'] = 11

In [None]:
# === PHYSICS PARAMETERS ===

T_peak = 200.0    # Peak temperature [eV] - initial maximum above baseline
T_base = 50.0     # Base temperature [eV] - background level
sigma = 0.3       # Initial Gaussian width [m] - how localized the heat is
kappa = 1e-3      # Thermal diffusivity [m²/s] - controls spreading rate

# === GRID PARAMETERS ===

nz = 128          # Axial grid points - higher = better resolution
z_extent = 2.0    # Domain: z ∈ [-z_extent, z_extent]

# === TIME PARAMETERS ===

t_end = 1e-4      # End time [s]
dt = 1e-6         # Timestep [s]

# Compute characteristic diffusion time
tau_D = sigma**2 / (2 * kappa)
print(f"Diffusion timescale τ_D = {tau_D:.2e} s")
print(f"Simulation runs for t_end/τ_D = {t_end/tau_D:.2f} diffusion times")

In [None]:
# Create configuration object
config = SlabDiffusionConfiguration(
    T_peak=T_peak,
    T_base=T_base,
    sigma=sigma,
    kappa=kappa,
    nz=nz,
    z_extent=z_extent,
)

# Build simulation components
geometry = config.build_geometry()
initial_state = config.build_initial_state(geometry)
model = config.build_model()
solver = SemiImplicitSolver()

print(f"Grid: {geometry.nr} × {geometry.nz} (r × z)")
print(f"Domain: z ∈ [{geometry.z_min:.1f}, {geometry.z_max:.1f}] m")
print(f"Resolution: Δz = {geometry.dz:.4f} m")

### Visualize Initial Condition

The initial temperature profile is a Gaussian centered at $z=0$ with width $\sigma_0$.

In [None]:
# Extract z coordinates and initial temperature
# The grid is 2D (r, z) - take a slice at mid-radius for 1D analysis
r_mid = geometry.nr // 2
z = geometry.z_grid[r_mid, :]
T_initial = initial_state.T[r_mid, :]

fig, ax = plt.subplots(figsize=(10, 5))
ax.plot(z, T_initial, 'b-', linewidth=2, label='Initial T(z)')
ax.axhline(T_base, color='gray', linestyle='--', label=f'$T_{{base}} = {T_base}$ eV')
ax.axhline(T_base + T_peak, color='red', linestyle=':', alpha=0.5, label=f'$T_{{peak}} = {T_base + T_peak}$ eV')

# Mark the initial width
ax.axvline(-sigma, color='green', linestyle=':', alpha=0.5)
ax.axvline(sigma, color='green', linestyle=':', alpha=0.5, label=f'$\\pm\\sigma_0 = \\pm{sigma}$ m')

ax.set_xlabel('z [m]', fontsize=12)
ax.set_ylabel('Temperature [eV]', fontsize=12)
ax.set_title('Initial Temperature Profile', fontsize=14)
ax.legend(loc='upper right')
ax.grid(True, alpha=0.3)
plt.tight_layout()
plt.show()

---
## 3. Run Simulation

We advance the state forward in time using the semi-implicit solver. We save snapshots at regular intervals for visualization.

In [None]:
# Time stepping parameters
n_steps = int(t_end / dt)
save_interval = max(1, n_steps // 50)  # Save ~50 snapshots for animation

print(f"Running {n_steps} steps with dt = {dt:.2e} s")
print(f"Saving every {save_interval} steps ({n_steps // save_interval} snapshots)")

In [None]:
# Run simulation with progress tracking
state = initial_state
history_T = [np.array(state.T[r_mid, :])]  # Store temperature at mid-radius
history_times = [0.0]

for step in range(n_steps):
    state = solver.step(state, dt, model, geometry)
    
    # Save snapshot
    if (step + 1) % save_interval == 0:
        history_T.append(np.array(state.T[r_mid, :]))
        history_times.append(state.time)
        
        # Progress indicator
        if (step + 1) % (n_steps // 5) == 0:
            print(f"  Step {step+1}/{n_steps} (t = {state.time:.2e} s)")

# Always save final state
if history_times[-1] != state.time:
    history_T.append(np.array(state.T[r_mid, :]))
    history_times.append(state.time)

print(f"Simulation complete: final time = {state.time:.2e} s")

# Convert to arrays for plotting
history_T = np.array(history_T)
history_times = np.array(history_times)

---
## 4. Analytic Solution

Now we compute the exact analytic solution at the same grid points and times.

In [None]:
def analytic_solution(z, t):
    """Compute exact Gaussian diffusion solution.
    
    T(z,t) = T_peak * sqrt(σ₀² / (σ₀² + 2κt)) * exp(-z² / (2(σ₀² + 2κt))) + T_base
    """
    sigma_eff_sq = sigma**2 + 2 * kappa * t
    amplitude = T_peak * np.sqrt(sigma**2 / sigma_eff_sq)
    return amplitude * np.exp(-z**2 / (2 * sigma_eff_sq)) + T_base

# Compute analytic solution at final time
z_np = np.array(z)
T_analytic_final = analytic_solution(z_np, t_end)
T_numerical_final = history_T[-1]

In [None]:
# Compare initial vs final (both numerical and analytic)
fig, ax = plt.subplots(figsize=(10, 6))

# Initial
ax.plot(z_np, history_T[0], 'gray', linestyle=':', linewidth=1.5, label='Initial')

# Final numerical
ax.plot(z_np, T_numerical_final, 'b-', linewidth=2, label=f'Numerical (t = {t_end:.0e} s)')

# Final analytic
ax.plot(z_np, T_analytic_final, 'r--', linewidth=2, label=f'Analytic (t = {t_end:.0e} s)')

ax.set_xlabel('z [m]', fontsize=12)
ax.set_ylabel('Temperature [eV]', fontsize=12)
ax.set_title('Diffusion: Initial vs Final State', fontsize=14)
ax.legend()
ax.grid(True, alpha=0.3)
plt.tight_layout()
plt.show()

# Show spreading
sigma_final = np.sqrt(sigma**2 + 2 * kappa * t_end)
print(f"\nWidth evolution:")
print(f"  Initial: σ₀ = {sigma:.3f} m")
print(f"  Final:   σ  = {sigma_final:.3f} m")
print(f"  Ratio:   σ/σ₀ = {sigma_final/sigma:.3f}")

---
## 5. Comparison and Metrics

Let's quantify how well the numerical solution matches the analytic one.

In [None]:
# Overlay comparison
fig, ax = plot_comparison(
    z_np, T_numerical_final, T_analytic_final,
    xlabel='z [m]', ylabel='Temperature [eV]',
    title=f'Numerical vs Analytic at t = {t_end:.0e} s',
    initial=history_T[0]
)
plt.show()

In [None]:
# Error distribution
fig, axes = plot_error(
    z_np, T_numerical_final, T_analytic_final,
    xlabel='z [m]',
    title='Error Analysis'
)
plt.show()

In [None]:
# Compute validation metrics
metrics = compute_metrics(T_numerical_final, T_analytic_final)

# Check conservation: total heat should be constant
# ∫(T - T_base) dz ∝ A(t) * σ(t) = const for Gaussian
total_heat_initial = np.trapezoid(history_T[0] - T_base, z_np)
total_heat_final = np.trapezoid(T_numerical_final - T_base, z_np)
conservation_error = abs(total_heat_final - total_heat_initial) / abs(total_heat_initial)
metrics['heat_conservation'] = conservation_error

# Define acceptance thresholds
thresholds = {
    'l2_error': 0.1,           # 10% relative L2 error
    'heat_conservation': 0.01,  # 1% heat conservation
}

print("=" * 50)
print("VALIDATION RESULTS")
print("=" * 50)
metrics_summary(metrics, thresholds)

In [None]:
# Overall pass/fail
l2_pass = metrics['l2_error'] <= thresholds['l2_error']
conservation_pass = metrics['heat_conservation'] <= thresholds['heat_conservation']
overall_pass = l2_pass and conservation_pass

print(f"\nOverall result: {'✓ PASS' if overall_pass else '✗ FAIL'}")
if overall_pass:
    print("The numerical solution matches the analytic solution within acceptable tolerances.")
else:
    print("The numerical solution does not meet acceptance criteria.")
    if not l2_pass:
        print(f"  - L2 error ({metrics['l2_error']:.2%}) exceeds threshold ({thresholds['l2_error']:.0%})")
    if not conservation_pass:
        print(f"  - Conservation error ({metrics['heat_conservation']:.2%}) exceeds threshold ({thresholds['heat_conservation']:.0%})")

---
## 6. Time Evolution Animation

Watch the temperature profile evolve over time, comparing numerical and analytic solutions.

In [None]:
# Create animation
anim = animate_overlay(
    z_np, history_T, analytic_solution, history_times,
    xlabel='z [m]', ylabel='Temperature [eV]',
    title='Heat Diffusion',
    initial=history_T[0],
    interval=100  # ms between frames
)

# Display in notebook
HTML(anim.to_jshtml())

---
## 7. Interactive Exploration

Use the widgets below to explore how different parameters affect the simulation.

**Note**: For interactive exploration, we run a shorter simulation for responsiveness.

In [None]:
import ipywidgets as widgets
from IPython.display import display, clear_output

# Output area for results
output = widgets.Output()

def run_exploration(kappa_val, sigma_val, nz_val):
    """Run a quick simulation with given parameters."""
    with output:
        clear_output(wait=True)
        
        # Create configuration with new parameters
        cfg = SlabDiffusionConfiguration(
            T_peak=T_peak,
            T_base=T_base,
            sigma=sigma_val,
            kappa=kappa_val,
            nz=int(nz_val),
            z_extent=z_extent,
        )
        
        # Build and run (shorter simulation for responsiveness)
        geom = cfg.build_geometry()
        state = cfg.build_initial_state(geom)
        mdl = cfg.build_model()
        slv = SemiImplicitSolver()
        
        # Quick run: 50 steps
        quick_dt = 1e-6
        quick_steps = 50
        quick_t_end = quick_dt * quick_steps
        
        for _ in range(quick_steps):
            state = slv.step(state, quick_dt, mdl, geom)
        
        # Extract results
        r_mid = geom.nr // 2
        z_vals = np.array(geom.z_grid[r_mid, :])
        T_num = np.array(state.T[r_mid, :])
        
        # Analytic solution with exploration parameters
        sigma_eff_sq = sigma_val**2 + 2 * kappa_val * quick_t_end
        amplitude = T_peak * np.sqrt(sigma_val**2 / sigma_eff_sq)
        T_ana = amplitude * np.exp(-z_vals**2 / (2 * sigma_eff_sq)) + T_base
        
        # Initial condition
        T_init = T_peak * np.exp(-z_vals**2 / (2 * sigma_val**2)) + T_base
        
        # Compute error
        err = compute_metrics(T_num, T_ana)
        
        # Plot
        fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(12, 4))
        
        # Temperature profiles
        ax1.plot(z_vals, T_init, 'gray', linestyle=':', label='Initial')
        ax1.plot(z_vals, T_num, 'b-', linewidth=2, label='Numerical')
        ax1.plot(z_vals, T_ana, 'r--', linewidth=2, label='Analytic')
        ax1.set_xlabel('z [m]')
        ax1.set_ylabel('Temperature [eV]')
        ax1.set_title(f't = {quick_t_end:.2e} s')
        ax1.legend()
        ax1.grid(True, alpha=0.3)
        
        # Error
        ax2.plot(z_vals, np.abs(T_num - T_ana), 'purple', linewidth=2)
        ax2.set_xlabel('z [m]')
        ax2.set_ylabel('|Error| [eV]')
        ax2.set_title(f'L2 Error: {err["l2_error"]:.2%}')
        ax2.grid(True, alpha=0.3)
        
        plt.tight_layout()
        plt.show()
        
        # Show parameters
        tau_D_exp = sigma_val**2 / (2 * kappa_val)
        print(f"Parameters: κ = {kappa_val:.1e}, σ₀ = {sigma_val:.2f}, nz = {int(nz_val)}")
        print(f"Diffusion time: τ_D = {tau_D_exp:.2e} s")
        print(f"Simulation: {quick_steps} steps, t/τ_D = {quick_t_end/tau_D_exp:.3f}")

# Create widgets
kappa_slider = widgets.FloatLogSlider(
    value=1e-3, base=10, min=-5, max=-1, step=0.5,
    description='κ [m²/s]:', readout_format='.1e'
)
sigma_slider = widgets.FloatSlider(
    value=0.3, min=0.1, max=0.8, step=0.1,
    description='σ₀ [m]:'
)
nz_slider = widgets.IntSlider(
    value=128, min=32, max=256, step=32,
    description='nz:'
)

# Interactive widget
ui = widgets.VBox([
    widgets.HTML('<h4>Parameter Exploration</h4>'),
    widgets.HTML('<p>Adjust parameters and see how the solution changes:</p>'),
    kappa_slider,
    sigma_slider,
    nz_slider,
])

interactive_output = widgets.interactive_output(
    run_exploration,
    {'kappa_val': kappa_slider, 'sigma_val': sigma_slider, 'nz_val': nz_slider}
)

display(ui, interactive_output)

---
## Summary

This notebook demonstrated:

1. **Physics**: The heat equation describes diffusive spreading of temperature profiles. A Gaussian initial condition remains Gaussian but spreads as $\sigma(t) = \sqrt{\sigma_0^2 + 2\kappa t}$.

2. **Simulation setup**: Using `SlabDiffusionConfiguration` to create a 1D diffusion test case with known analytic solution.

3. **Validation**: Comparing numerical results to the exact solution using L2 error and conservation metrics.

4. **Key insight**: The numerical scheme accurately captures diffusive dynamics when:
   - Grid resolution is sufficient to resolve the initial profile
   - Timestep satisfies stability requirements
   - Boundary conditions don't interfere (domain large enough)

### Try next

- Increase `kappa` to see faster diffusion
- Reduce `nz` to see resolution effects
- Run longer (`t_end` larger) to see more spreading
- Compare different solvers (EulerSolver vs SemiImplicitSolver)