In [None]:
# Install the package directly from GitHub
!pip install git+https://github.com/wcw100168/Cubed-Sphere-DG-Solver.git

# Advanced Physics: Shallow Water Equations

This tutorial runs a full physics simulation: **Williamson Case 2** (Global Steady State Zonal Flow).
We will use the **JAX** backend for this demonstration to show how to enable hardware acceleration.

In [None]:
# Enable JAX x64 mode BEFORE other imports
import os
os.environ["JAX_Enable_x64"] = "True"

import jax
jax.config.update("jax_enable_x64", True)

import numpy as np
import matplotlib.pyplot as plt
from cubed_sphere.solvers import CubedSphereSWE, SWEConfig

In [None]:
# 1. Stability & Config
N = 16
R = 6.37122e6
H_avg = 10000.0
g = 9.80616

# Calculate max wave speed (Gravity Wave + Flow)
c_wave = np.sqrt(g * H_avg)
u_flow_max = 40.0 # approx for Case 2
v_max = c_wave + u_flow_max

target_cfl = 0.5
dt = target_cfl * R / (v_max * N**2)

print(f"Wavespeed: {c_wave:.1f} m/s")
print(f"Stable dt: {dt:.2f} s")

config = SWEConfig(
    N=N, 
    backend='jax',
    H_avg=H_avg,
    gravity=g,
    dt=dt
)
solver = CubedSphereSWE(config)

In [None]:
# Setup Williamson Case 2 Initial Condition
# (Manual setup for tutorial demonstration)
grid_size = config.N + 1
state = np.zeros((3, 6, grid_size, grid_size))

u0_vel = 2.0 * np.pi * config.R / (12.0 * 24.0 * 3600.0) # ~38 m/s
R = config.R
Omega = config.Omega
g = config.gravity

# NOTE: Accessing internal implementation details for tutorial setup
faces = solver._impl.faces
topo = solver._impl.topology

for i, fname in enumerate(topo.FACE_MAP):
    fg = faces[fname]
    theta = fg.lat
    
    h = config.H_avg - (R * Omega * u0_vel + 0.5 * u0_vel**2) * (np.sin(theta)**2) / g
    
    state[0, i] = h * fg.sqrt_g
    # Momentum approx zero/steady for tutorial setup simplicity
    state[1, i] = 0.0
    state[2, i] = 0.0

In [None]:
# Run Simulation
t_accumulated = 0.0
n_steps = 20

print("Starting Integration...")
initial_mass = np.sum(state[0])

for step in range(n_steps):
    state = solver.step(t_accumulated, state, dt)
    t_accumulated += dt
    
print("Integration Complete.")

In [None]:
# Conservation Check
final_mass = np.sum(state[0])
diff = abs(final_mass - initial_mass)
print(f"Initial Mass: {initial_mass:.5e}")
print(f"Final Mass:   {final_mass:.5e}")
print(f"Mass Error:   {diff:.5e}")
assert diff < 1e-9, "Mass should be conserved!"

In [None]:
# Plot Height Field
h_field = state[0, 0] / faces['Face1'].sqrt_g # Unwrap Face 1

plt.figure(figsize=(8, 6))
plt.imshow(h_field, extent=[-1, 1, -1, 1])
plt.colorbar(label='Geopotential Height (m)')
plt.title(f"Block 0 Height Field at t={t_accumulated:.1f}s")
plt.show()