# MHD Shock Tube in Cylindrical Geometry

**Validation case**: Brio-Wu MHD shock tube adapted to cylindrical coordinates

## Learning Objectives

After completing this notebook, you will understand:

1. How discontinuous initial conditions evolve into shock waves in MHD
2. The characteristic wave structure: fast shock, slow shock, contact, rarefaction
3. Rankine-Hugoniot jump conditions relating upstream/downstream states
4. How cylindrical geometry affects wave propagation compared to Cartesian
5. Numerical challenges in capturing discontinuities

---
## 1. Physics Background

### The Riemann Problem

A **Riemann problem** is an initial value problem with piecewise constant initial data separated by a discontinuity:

$$\mathbf{U}(z, t=0) = \begin{cases} \mathbf{U}_L & z < 0 \\ \mathbf{U}_R & z > 0 \end{cases}$$

where $\mathbf{U} = (\rho, v, p, \mathbf{B})$ is the vector of conserved/primitive MHD variables.

The Riemann problem is fundamental because:
1. It tests the ability of numerical schemes to capture discontinuities
2. It isolates wave physics without the complexity of smooth initial data
3. The solution structure is well-understood analytically

### MHD Wave Types

In ideal MHD, there are three characteristic wave speeds:

| Wave | Speed | Physics |
|------|-------|--------|
| **Fast magnetosonic** | $c_f = \sqrt{\frac{1}{2}\left(c_s^2 + v_A^2 + \sqrt{(c_s^2 + v_A^2)^2 - 4c_s^2 v_{Az}^2}\right)}$ | Compression + magnetic pressure |
| **Alfvén** | $c_A = v_{Az} = B_z/\sqrt{\mu_0 \rho}$ | Field line bending (no compression) |
| **Slow magnetosonic** | $c_s = \sqrt{\frac{1}{2}\left(c_s^2 + v_A^2 - \sqrt{(c_s^2 + v_A^2)^2 - 4c_s^2 v_{Az}^2}\right)}$ | Compression - magnetic pressure |

where $c_s = \sqrt{\gamma p/\rho}$ is the sound speed and $v_A = B/\sqrt{\mu_0 \rho}$ is the Alfvén speed.

Each wave type can appear as:
- **Shock**: Compressive discontinuity (entropy increases across it)
- **Rarefaction**: Smooth expansion fan
- **Contact**: Density/temperature jump with continuous pressure and velocity

### Rankine-Hugoniot Jump Conditions

At a shock moving with speed $s$, conservation laws require:

$$\rho_L (v_L - s) = \rho_R (v_R - s) \quad \text{(mass)}$$

$$\rho_L (v_L - s) v_L + p_L + \frac{B_L^2}{2\mu_0} = \rho_R (v_R - s) v_R + p_R + \frac{B_R^2}{2\mu_0} \quad \text{(momentum)}$$

$$[\![\mathbf{B}_\perp]\!] = \frac{1}{s - v_n} [\![v_n \mathbf{B}_\perp - B_n \mathbf{v}_\perp]\!] \quad \text{(induction)}$$

Plus energy conservation. The notation $[\![X]\!] = X_R - X_L$ denotes the jump.

These relations determine the downstream state given the upstream state and shock speed.

### The Brio-Wu Problem

The Brio-Wu test (1988) is a standard MHD shock tube with:

| State | $\rho$ | $p$ | $B_r$ | $B_z$ |
|-------|--------|-----|-------|-------|
| Left ($z < 0$) | 1.0 | 1.0 | 1.0 | 0.75 |
| Right ($z > 0$) | 0.125 | 0.1 | -1.0 | 0.75 |

With $\gamma = 2$ (not the usual 5/3) for numerical convenience.

The solution develops:
1. **Fast rarefaction** propagating left
2. **Slow compound wave** (shock + rarefaction)
3. **Contact discontinuity** at $z \approx 0$
4. **Slow shock** propagating right
5. **Fast shock** propagating right

### Cylindrical Geometry Effects

In cylindrical coordinates $(r, \phi, z)$, the divergence operator introduces **geometric source terms**:

$$\nabla \cdot \mathbf{v} = \frac{1}{r}\frac{\partial(r v_r)}{\partial r} + \frac{1}{r}\frac{\partial v_\phi}{\partial \phi} + \frac{\partial v_z}{\partial z}$$

For a 1D problem uniform in $r$ and $\phi$, the extra $1/r$ terms vanish, but the coordinate system still affects:

1. **Magnetic field structure**: $\nabla \cdot \mathbf{B} = 0$ has different implications
2. **Wave speeds**: Modified by curvature effects near the axis
3. **Boundary conditions**: Must handle axis ($r=0$) singularity

Our test is designed to be r-independent, so it's effectively 1D physics but tested in the cylindrical solver framework.

---
## 2. Configuration Setup

In [None]:
# Standard imports
import jax
import jax.numpy as jnp
import numpy as np
import matplotlib.pyplot as plt
from IPython.display import HTML

# jax-frc imports
from jax_frc.configurations.brio_wu import BrioWuConfiguration
from jax_frc.solvers.explicit import RK4Solver

# Notebook utilities
import sys
sys.path.insert(0, '.')
from _shared import plot_comparison, compute_metrics, metrics_summary, animate_overlay

# Plotting style
%matplotlib inline
plt.rcParams['figure.dpi'] = 100
plt.rcParams['font.size'] = 11

In [None]:
# === PHYSICS PARAMETERS (Brio-Wu standard) ===

# Left state (z < 0)
rho_L = 1.0      # Density
p_L = 1.0        # Pressure
Br_L = 1.0       # Radial magnetic field

# Right state (z > 0)
rho_R = 0.125    # Lower density (8x less)
p_R = 0.1        # Lower pressure (10x less)
Br_R = -1.0      # Reversed Br (magnetic shear)

# Common parameters
Bz = 0.75        # Guide field (constant across interface)
gamma = 2.0      # Adiabatic index

# === GRID PARAMETERS ===

nz = 512         # Axial resolution (high for shock capturing)
nr = 16          # Minimal radial (r-uniform problem)

# === TIME PARAMETERS ===

t_end = 0.1      # End time (Alfvén units)
dt = 1e-4        # Timestep

# Compute characteristic speeds
c_s_L = np.sqrt(gamma * p_L / rho_L)  # Sound speed left
v_A_L = np.sqrt(Br_L**2 + Bz**2) / np.sqrt(rho_L)  # Alfvén speed left
print(f"Left state: c_s = {c_s_L:.2f}, v_A = {v_A_L:.2f}")

c_s_R = np.sqrt(gamma * p_R / rho_R)
v_A_R = np.sqrt(Br_R**2 + Bz**2) / np.sqrt(rho_R)
print(f"Right state: c_s = {c_s_R:.2f}, v_A = {v_A_R:.2f}")

In [None]:
# Create configuration
config = BrioWuConfiguration(
    nz=nz,
    nx=nr,
    gamma=gamma,
)

# Build simulation components
geometry = config.build_geometry()
initial_state = config.build_initial_state(geometry)
model = config.build_model()
solver = RK4Solver()  # Higher-order for shock problems

print(f"Grid: {geometry.nx} × {geometry.nz} (x × z)")
print(f"Domain: z ∈ [{geometry.z_min:.1f}, {geometry.z_max:.1f}]")
print(f"Resolution: Δz = {geometry.dz:.4f}")

### Visualize Initial Conditions

In [None]:
# Extract 1D profiles at mid-radius
r_mid = geometry.nr // 2
z = np.array(geometry.z_grid[r_mid, :])

# Initial state
rho_init = np.array(initial_state.n[r_mid, :])
p_init = np.array(initial_state.p[r_mid, :])
Br_init = np.array(initial_state.B[r_mid, :, 0])
Bz_init = np.array(initial_state.B[r_mid, :, 2])

fig, axes = plt.subplots(2, 2, figsize=(12, 8))

# Density
axes[0, 0].plot(z, rho_init, 'b-', linewidth=2)
axes[0, 0].set_ylabel('Density ρ')
axes[0, 0].set_title('Density')
axes[0, 0].grid(True, alpha=0.3)

# Pressure
axes[0, 1].plot(z, p_init, 'r-', linewidth=2)
axes[0, 1].set_ylabel('Pressure p')
axes[0, 1].set_title('Pressure')
axes[0, 1].grid(True, alpha=0.3)

# Br (radial magnetic field)
axes[1, 0].plot(z, Br_init, 'g-', linewidth=2)
axes[1, 0].axhline(0, color='gray', linestyle='--', alpha=0.5)
axes[1, 0].set_xlabel('z')
axes[1, 0].set_ylabel('$B_r$')
axes[1, 0].set_title('Radial B-field (reverses at z=0)')
axes[1, 0].grid(True, alpha=0.3)

# Bz (guide field)
axes[1, 1].plot(z, Bz_init, 'm-', linewidth=2)
axes[1, 1].set_xlabel('z')
axes[1, 1].set_ylabel('$B_z$')
axes[1, 1].set_title('Guide field (constant)')
axes[1, 1].set_ylim([0, 1])
axes[1, 1].grid(True, alpha=0.3)

fig.suptitle('Initial Conditions: Brio-Wu Shock Tube', fontsize=14)
plt.tight_layout()
plt.show()

---
## 3. Run Simulation

We evolve the initial discontinuity forward in time. MHD waves will propagate outward from $z=0$.

In [None]:
# Time stepping
n_steps = int(t_end / dt)
save_interval = max(1, n_steps // 50)  # ~50 snapshots

print(f"Running {n_steps} steps with dt = {dt:.2e}")
print(f"Saving every {save_interval} steps")

In [None]:
# Run simulation
state = initial_state

# Store history for animation
history = {
    'rho': [np.array(state.n[r_mid, :])],
    'p': [np.array(state.p[r_mid, :])],
    'Br': [np.array(state.B[r_mid, :, 0])],
    'times': [0.0]
}

for step in range(n_steps):
    state = solver.step(state, dt, model, geometry)
    
    # Check for NaN (shock simulations can be unstable)
    if jnp.any(jnp.isnan(state.n)):
        print(f"WARNING: NaN detected at step {step+1}")
        break
    
    # Save snapshot
    if (step + 1) % save_interval == 0:
        history['rho'].append(np.array(state.n[r_mid, :]))
        history['p'].append(np.array(state.p[r_mid, :]))
        history['Br'].append(np.array(state.B[r_mid, :, 0]))
        history['times'].append(state.time)
        
        if (step + 1) % (n_steps // 5) == 0:
            print(f"  Step {step+1}/{n_steps} (t = {state.time:.3f})")

# Save final state
if history['times'][-1] != state.time:
    history['rho'].append(np.array(state.n[r_mid, :]))
    history['p'].append(np.array(state.p[r_mid, :]))
    history['Br'].append(np.array(state.B[r_mid, :, 0]))
    history['times'].append(state.time)

print(f"Simulation complete: t = {state.time:.3f}")

# Convert to arrays
for key in ['rho', 'p', 'Br', 'times']:
    history[key] = np.array(history[key])

---
## 4. Expected Solution Structure

The Brio-Wu problem has a well-known wave structure. Let's define the expected approximate positions of key features at $t = 0.1$.

In [None]:
# Expected wave positions at t = 0.1 (approximate, from literature)
# These are for the standard Brio-Wu problem

expected_features = {
    'fast_rarefaction_head': -0.45,  # Leftward fast rarefaction front
    'slow_compound': -0.15,          # Slow compound wave region
    'contact': 0.02,                  # Contact discontinuity (density jump)
    'slow_shock': 0.28,              # Slow shock position
    'fast_shock': 0.45,              # Fast shock position
}

print("Expected wave positions at t = 0.1:")
for name, pos in expected_features.items():
    print(f"  {name}: z ≈ {pos:.2f}")

In [None]:
# Plot final state with expected positions
rho_final = history['rho'][-1]
p_final = history['p'][-1]
Br_final = history['Br'][-1]

fig, axes = plt.subplots(3, 1, figsize=(12, 10), sharex=True)

# Density
axes[0].plot(z, rho_init, 'gray', linestyle=':', alpha=0.7, label='Initial')
axes[0].plot(z, rho_final, 'b-', linewidth=2, label=f'Numerical (t={t_end})')
axes[0].set_ylabel('Density ρ', fontsize=12)
axes[0].legend()
axes[0].grid(True, alpha=0.3)

# Mark expected features
for name, pos in expected_features.items():
    axes[0].axvline(pos, color='red', linestyle='--', alpha=0.4)

# Pressure
axes[1].plot(z, p_init, 'gray', linestyle=':', alpha=0.7, label='Initial')
axes[1].plot(z, p_final, 'r-', linewidth=2, label=f'Numerical (t={t_end})')
axes[1].set_ylabel('Pressure p', fontsize=12)
axes[1].legend()
axes[1].grid(True, alpha=0.3)

for name, pos in expected_features.items():
    axes[1].axvline(pos, color='red', linestyle='--', alpha=0.4)

# Magnetic field Br
axes[2].plot(z, Br_init, 'gray', linestyle=':', alpha=0.7, label='Initial')
axes[2].plot(z, Br_final, 'g-', linewidth=2, label=f'Numerical (t={t_end})')
axes[2].axhline(0, color='gray', linestyle='-', alpha=0.3)
axes[2].set_xlabel('z', fontsize=12)
axes[2].set_ylabel('$B_r$', fontsize=12)
axes[2].legend()
axes[2].grid(True, alpha=0.3)

for name, pos in expected_features.items():
    axes[2].axvline(pos, color='red', linestyle='--', alpha=0.4)

# Add feature labels
axes[0].text(-0.45, 0.9, 'Fast\nRaref.', ha='center', fontsize=9, color='red')
axes[0].text(-0.15, 0.9, 'Slow\nWave', ha='center', fontsize=9, color='red')
axes[0].text(0.02, 0.9, 'Contact', ha='center', fontsize=9, color='red')
axes[0].text(0.28, 0.9, 'Slow\nShock', ha='center', fontsize=9, color='red')
axes[0].text(0.45, 0.9, 'Fast\nShock', ha='center', fontsize=9, color='red')

fig.suptitle(f'Brio-Wu Shock Tube at t = {t_end}', fontsize=14)
plt.tight_layout()
plt.show()

---
## 5. Comparison and Metrics

In [None]:
def find_shock_position(z, field, threshold_fraction=0.5):
    """Find approximate shock position by locating the steepest gradient.
    
    For rightward-propagating shocks, finds where the field drops below
    a threshold between left and right states.
    """
    # Compute gradient
    grad = np.gradient(field, z)
    
    # Find location of maximum |gradient|
    max_grad_idx = np.argmax(np.abs(grad))
    return z[max_grad_idx]

def find_discontinuity_positions(z, rho):
    """Find positions of major discontinuities in density profile."""
    from scipy.signal import find_peaks
    
    # Compute absolute gradient
    grad = np.abs(np.gradient(rho, z))
    
    # Find peaks in gradient (discontinuities)
    peaks, properties = find_peaks(grad, prominence=0.1, distance=10)
    
    return z[peaks], grad[peaks]

# Find discontinuities in final state
try:
    disc_positions, disc_strengths = find_discontinuity_positions(z, rho_final)
    print("Detected discontinuities at:")
    for pos, strength in zip(disc_positions, disc_strengths):
        print(f"  z = {pos:.3f} (gradient magnitude: {strength:.2f})")
except ImportError:
    print("scipy not available for peak detection")

In [None]:
# Validation metrics

# 1. Check for numerical stability (no NaN/Inf)
has_nan = np.any(np.isnan(rho_final)) or np.any(np.isnan(p_final))
has_inf = np.any(np.isinf(rho_final)) or np.any(np.isinf(p_final))
numerical_stability = not (has_nan or has_inf)

# 2. Check positivity (density and pressure must be positive)
min_rho = np.min(rho_final)
min_p = np.min(p_final)
positivity = (min_rho > 0) and (min_p > 0)

# 3. Approximate conservation (total mass)
mass_initial = np.trapezoid(rho_init, z)
mass_final = np.trapezoid(rho_final, z)
mass_conservation_error = abs(mass_final - mass_initial) / abs(mass_initial)

# 4. Check fast shock is in expected region
# Fast shock should be around z ≈ 0.45
z_fast_shock_region = (z > 0.35) & (z < 0.55)
grad_in_region = np.abs(np.gradient(rho_final, z))[z_fast_shock_region]
has_fast_shock = np.max(grad_in_region) > 0.5  # Should have steep gradient

print("=" * 50)
print("VALIDATION RESULTS")
print("=" * 50)
print(f"Numerical stability:     {'✓ PASS' if numerical_stability else '✗ FAIL'}")
print(f"Positivity preserved:    {'✓ PASS' if positivity else '✗ FAIL'} (min ρ = {min_rho:.2e}, min p = {min_p:.2e})")
print(f"Mass conservation error: {mass_conservation_error:.2%} {'✓ PASS' if mass_conservation_error < 0.05 else '✗ FAIL'}")
print(f"Fast shock detected:     {'✓ PASS' if has_fast_shock else '✗ FAIL'}")

In [None]:
# Overall result
overall_pass = numerical_stability and positivity and (mass_conservation_error < 0.05)

print(f"\nOverall result: {'✓ PASS' if overall_pass else '✗ FAIL'}")
if overall_pass:
    print("The simulation captures the MHD shock structure without numerical instability.")
else:
    print("The simulation encountered issues:")
    if not numerical_stability:
        print("  - NaN or Inf values detected")
    if not positivity:
        print("  - Negative density or pressure")
    if mass_conservation_error >= 0.05:
        print(f"  - Mass conservation error ({mass_conservation_error:.2%}) exceeds 5%")

---
## 6. Time Evolution Animation

In [None]:
from matplotlib.animation import FuncAnimation

# Create animation showing density evolution
fig, ax = plt.subplots(figsize=(10, 5))

line, = ax.plot(z, history['rho'][0], 'b-', linewidth=2)
ax.plot(z, history['rho'][0], 'gray', linestyle=':', alpha=0.5, label='Initial')

ax.set_xlim(z.min(), z.max())
ax.set_ylim(0, 1.2)
ax.set_xlabel('z', fontsize=12)
ax.set_ylabel('Density ρ', fontsize=12)
ax.legend(loc='upper right')
ax.grid(True, alpha=0.3)

title = ax.set_title('Shock Evolution | t = 0.000', fontsize=14)

def update(frame):
    line.set_ydata(history['rho'][frame])
    title.set_text(f'Shock Evolution | t = {history["times"][frame]:.3f}')
    return line, title

anim = FuncAnimation(fig, update, frames=len(history['times']),
                     interval=100, blit=False)
plt.close(fig)

HTML(anim.to_jshtml())

---
## 7. Interactive Exploration

In [None]:
import ipywidgets as widgets
from IPython.display import display, clear_output

output = widgets.Output()

def explore_time(time_idx):
    """Show solution at a specific snapshot."""
    with output:
        clear_output(wait=True)
        
        fig, axes = plt.subplots(1, 3, figsize=(14, 4))
        
        t = history['times'][time_idx]
        
        # Density
        axes[0].plot(z, history['rho'][0], 'gray', linestyle=':', alpha=0.7, label='t=0')
        axes[0].plot(z, history['rho'][time_idx], 'b-', linewidth=2, label=f't={t:.3f}')
        axes[0].set_xlabel('z')
        axes[0].set_ylabel('ρ')
        axes[0].set_title('Density')
        axes[0].legend()
        axes[0].grid(True, alpha=0.3)
        
        # Pressure
        axes[1].plot(z, history['p'][0], 'gray', linestyle=':', alpha=0.7, label='t=0')
        axes[1].plot(z, history['p'][time_idx], 'r-', linewidth=2, label=f't={t:.3f}')
        axes[1].set_xlabel('z')
        axes[1].set_ylabel('p')
        axes[1].set_title('Pressure')
        axes[1].legend()
        axes[1].grid(True, alpha=0.3)
        
        # Br
        axes[2].plot(z, history['Br'][0], 'gray', linestyle=':', alpha=0.7, label='t=0')
        axes[2].plot(z, history['Br'][time_idx], 'g-', linewidth=2, label=f't={t:.3f}')
        axes[2].axhline(0, color='gray', linestyle='-', alpha=0.3)
        axes[2].set_xlabel('z')
        axes[2].set_ylabel('$B_r$')
        axes[2].set_title('Radial B-field')
        axes[2].legend()
        axes[2].grid(True, alpha=0.3)
        
        plt.tight_layout()
        plt.show()

# Time slider
time_slider = widgets.IntSlider(
    value=0, min=0, max=len(history['times'])-1, step=1,
    description='Snapshot:',
    continuous_update=False
)

ui = widgets.VBox([
    widgets.HTML('<h4>Time Evolution Explorer</h4>'),
    widgets.HTML('<p>Drag the slider to see the solution at different times:</p>'),
    time_slider,
])

interactive_output = widgets.interactive_output(explore_time, {'time_idx': time_slider})

display(ui, interactive_output)

---
## Summary

This notebook demonstrated:

1. **Physics**: MHD shock tubes develop a rich wave structure including fast/slow shocks, rarefactions, and contact discontinuities, governed by the Rankine-Hugoniot jump conditions.

2. **The Brio-Wu problem**: A standard test with magnetic field reversal that exercises all MHD wave families.

3. **Validation approach**: Since exact analytic solutions are complex (requiring iterative Riemann solvers), we validate by checking:
   - Numerical stability (no NaN/Inf)
   - Positivity (ρ > 0, p > 0)
   - Conservation (mass, momentum, energy)
   - Qualitative wave structure (shocks at expected positions)

4. **Cylindrical effects**: For this r-uniform test, the physics is effectively 1D but validates the solver's handling of cylindrical geometry.

### Key Challenges in Shock Capturing

- **Numerical diffusion**: Shocks spread over several grid cells
- **Oscillations**: Godunov's theorem says linear schemes produce oscillations or diffusion
- **Positivity**: Naive schemes can produce negative density/pressure
- **Conservation**: Finite difference schemes may not conserve exactly

### Try Next

- Increase `nz` to see sharper shocks (but more expensive)
- Try different initial pressure ratios
- Compare RK4Solver vs EulerSolver stability