# Cubed-Sphere Advection Benchmark (NumPy vs JAX single/double)
此筆記本將 Cubed-Sphere DG 解法從 SWE 基準轉為純平流問題，並比較 NumPy、JAX Float32、JAX Float64 的效能。

In [None]:
# !pip install git+https://github.com/wcw100168/Cubed-Sphere-DG-Solver.git
import jax
import jax.numpy as jnp
import numpy as np
import matplotlib.pyplot as plt
import pandas as pd
import time
import gc
from typing import Tuple, Dict, Any

from cubed_sphere.solvers import CubedSphereAdvectionSolver, AdvectionConfig

plt.style.use('seaborn-v0_8-darkgrid')
print(f"JAX 預設設備: {jax.devices()[0].platform.upper()}")

In [None]:
# --- 物理參數設定 ---
R_earth = 6371220.0  # 地球半徑 (m)
T_period = 12 * 24 * 3600.0  # 週期 12 天 (s)
u0_velocity = (2 * np.pi * R_earth) / T_period  # 剛體旋轉最大速度
T_final = 1.0 * T_period  # 模擬總時間為一圈

print(f"Earth Radius (R): {R_earth} m")
print(f"Rotation Period (T): {T_period} s")
print(f"Max Velocity (u0): {u0_velocity:.2f} m/s")

# --- 初始條件：高斯分佈 ---
def gaussian_hill(lon, lat):
    """產生位於赤道附近的高斯分佈山丘 (lon, lat 為弧度)。"""
    r = np.sqrt(lon ** 2 + lat ** 2)
    return np.exp(- (r / 0.5) ** 2)

In [None]:
def benchmark_advection_scalability() -> pd.DataFrame:
    N_values = [16, 32, 64, 96, 128]
    backends = [
        {'name': 'numpy', 'backend_str': 'numpy', 'jax_x64': False},
        {'name': 'jax_single', 'backend_str': 'jax', 'jax_x64': False},
        {'name': 'jax_double', 'backend_str': 'jax', 'jax_x64': True}
    ]

    num_steps = 100
    results = []

    print(f"--- 開始 Advection Scalability Benchmark ({num_steps} 步) ---")

    for backend_info in backends:
        b_name = backend_info['name']
        b_str = backend_info['backend_str']
        use_x64 = backend_info['jax_x64']

        if b_str == 'jax':
            jax.config.update('jax_enable_x64', use_x64)

        print(f"\n>> 測試後端: {b_name.upper()} (JAX x64: {jax.config.read('jax_enable_x64')})")
        print(f"{'N':<5} | {'Grid Pts':<10} | {'dt (s)':<10} | {'Time/Step (ms)':<15} | {'Throughput (TPS)':<15}")
        print('-' * 65)

        for N in N_values:
            config = AdvectionConfig(
                N=N,
                CFL=0.5,
                backend=b_str
            )

            solver = CubedSphereAdvectionSolver(config)
            state = solver.get_initial_condition(type='custom', func=gaussian_hill)

            dt_safe = solver.compute_safe_dt(state, cfl=0.5)
            total_sim_time = dt_safe * num_steps

            if b_str == 'jax':
                try:
                    jax.block_until_ready(solver.solve((0, total_sim_time), state, callbacks=None))
                except AttributeError:
                    _ = solver.step(0.0, state, dt_safe)

                start = time.time()
                jax.block_until_ready(solver.solve((0, total_sim_time), state, callbacks=None))
                duration = time.time() - start
            else:
                current_state = state.copy()
                t_current = 0.0
                start = time.time()
                for _ in range(num_steps):
                    current_state = solver.step(t_current, current_state, dt_safe)
                    t_current += dt_safe
                duration = time.time() - start

            avg_time = duration / num_steps
            avg_ms = avg_time * 1000
            tps = 1.0 / avg_time
            grid_points = 6 * N * N

            print(f"{N:<5} | {grid_points:<10} | {dt_safe:<10.3f} | {avg_ms:<15.2f} | {tps:<15.1f}")

            results.append({
                'Backend': b_name,
                'N': N,
                'Grid_Points': grid_points,
                'dt_safe': dt_safe,
                'Time_per_Step_ms': avg_ms,
                'Throughput_TPS': tps
            })

            if b_str == 'jax':
                jax.clear_caches()
            gc.collect()

    return pd.DataFrame(results)

df_results = benchmark_advection_scalability()

In [None]:
df_results.to_csv('advection_scalability_benchmark.csv', index=False)
display(df_results)

def plot_backend_comparison(df):
    fig, ax = plt.subplots(figsize=(10, 6), dpi=150)

    df_np = df[df['Backend'] == 'numpy']
    df_jax_single = df[df['Backend'] == 'jax_single']
    df_jax_double = df[df['Backend'] == 'jax_double']

    ax.plot(df_np['Grid_Points'], df_np['Time_per_Step_ms'],
            marker='o', markersize=8, linestyle='-', linewidth=2, color='#e74c3c', label='NumPy (CPU)')
    ax.plot(df_jax_single['Grid_Points'], df_jax_single['Time_per_Step_ms'],
            marker='s', markersize=8, linestyle='-', linewidth=2, color='#3498db', label='JAX Single (Float32)')
    ax.plot(df_jax_double['Grid_Points'], df_jax_double['Time_per_Step_ms'],
            marker='^', markersize=8, linestyle='-', linewidth=2, color='#2c3e50', label='JAX Double (Float64)')

    ax.set_title('Advection Solver: Numpy vs JAX (Single/Double)', fontsize=14, pad=15, fontweight='bold')
    ax.set_xlabel('Total Grid Points ($6 \times N^2$)', fontsize=12)
    ax.set_ylabel('Wall-clock Time per Step (ms)', fontsize=12)
    ax.set_yscale('log')
    ax.set_xscale('log')

    secax = ax.secondary_xaxis('top', functions=(lambda x: (x / 6) ** 0.5, lambda x: 6 * x ** 2))
    secax.set_xlabel('Cubed-Sphere Resolution ($N$)', fontsize=11, labelpad=10)
    N_ticks = [16, 32, 64, 96, 128]
    secax.set_xticks(N_ticks)
    secax.set_xticklabels([str(n) for n in N_ticks])

    ax.grid(True, alpha=0.6, linestyle='--')
    ax.legend(loc='upper left', fontsize=11)

    plt.tight_layout()
    plt.savefig('advection_scalability_comparison.png', dpi=300, bbox_inches='tight')
    plt.show()

plot_backend_comparison(df_results)