# Experiment 14.2: Navier-Stokes Ablations

In [None]:
import jax
import jax.numpy as jnp

jax.config.update("jax_default_matmul_precision", "highest")

from src.equations import nst_res as pde_res
from src.equations import nst_w_func

from src.utils import _get_adam, count_params, _get_colloc_indices, grad_norm
from src.utils import count_rga, count_pirate, count_pikan
from src.rgakan import RGAKAN
from src.kan import KAN
from src.piratenet import PirateNet

from functools import partial

import numpy as np
from jax import device_get

import optax
from flax import nnx

import pickle
import time

import os

# Create the directory if it doesn't exist
results_dir = "results"
os.makedirs(results_dir, exist_ok=True)

result_file = os.path.join(results_dir, "ablations_nst.pkl")

plots_dir = "plots"
os.makedirs(plots_dir, exist_ok=True)

RESULTS = dict()
        
seed = 42

## Basic Functions

In [None]:
# PDE Loss - returns 2 residuals: vorticity transport and continuity
def pde_loss(model, l_E, collocs, use_rba, use_causal, cont_weight):

    residuals = pde_res(model, collocs, Re=Re)  # shape (batch_size, 2): [vorticity, continuity]

    if use_rba:
        # Get new RBA weights (use mean of residuals for weighting)
        abs_res = jnp.abs(residuals).mean(axis=1, keepdims=True)  # shape (batch_size, 1)
        l_E_new = (RBA_gamma*l_E) + (RBA_eta*abs_res/jnp.max(abs_res))
        # Multiply by RBA weights
        w_resids = l_E_new * residuals  # shape (batch_size, 2)
    else:
        l_E_new = l_E
        w_resids = residuals

    if use_causal:
        # Reshape residuals for causal training (along time dimension)
        # Shape: (num_chunks, points_per_chunk, 2)
        resids_chunked = w_resids.reshape(num_chunks, -1, 2)

        # Get average loss per chunk for each residual type
        loss_vort = jnp.mean(resids_chunked[:, :, 0]**2, axis=1)  # shape (num_chunks,)
        loss_cont = jnp.mean(resids_chunked[:, :, 1]**2, axis=1)  # shape (num_chunks,)

        # Get causal weights (use minimum of both for stronger causality)
        weights_vort = jax.lax.stop_gradient(jnp.exp(-causal_tol * (M @ loss_vort)))
        weights_cont = jax.lax.stop_gradient(jnp.exp(-causal_tol * (M @ loss_cont)))
        weights = jnp.minimum(weights_vort, weights_cont)

        # Weighted loss
        weighted_loss = jnp.mean(weights * loss_vort) + cont_weight*jnp.mean(weights * loss_cont)
    else:
        # Standard loss
        weighted_loss = jnp.mean(w_resids[:, 0]**2) + cont_weight*jnp.mean(w_resids[:, 1]**2)

    return weighted_loss, l_E_new


# IC Loss for u, v, w (w is derived from u, v)
def ic_loss(model, l_I_u, l_I_v, l_I_w, ic_collocs, u0_data, v0_data, w0_data, use_rba):

    # Get u, v predictions
    uv_pred = model(ic_collocs)  # shape (batch_size, 2)
    u_pred = uv_pred[:, 0:1]
    v_pred = uv_pred[:, 1:2]
    
    # Get w prediction (derived from u, v)
    w_pred = nst_w_func(model, ic_collocs)  # shape (batch_size, 1)

    # Residuals
    u_res = u_pred - u0_data
    v_res = v_pred - v0_data
    w_res = w_pred - w0_data

    if use_rba:
        # RBA weights for u
        abs_res_u = jnp.abs(u_res)
        l_I_u_new = (RBA_gamma*l_I_u) + (RBA_eta*abs_res_u/jnp.max(abs_res_u))
        
        # RBA weights for v
        abs_res_v = jnp.abs(v_res)
        l_I_v_new = (RBA_gamma*l_I_v) + (RBA_eta*abs_res_v/jnp.max(abs_res_v))
        
        # RBA weights for w
        abs_res_w = jnp.abs(w_res)
        l_I_w_new = (RBA_gamma*l_I_w) + (RBA_eta*abs_res_w/jnp.max(abs_res_w))

        # Weighted residuals
        w_res_u = l_I_u_new * u_res
        w_res_v = l_I_v_new * v_res
        w_res_w = l_I_w_new * w_res
    else:
        l_I_u_new, l_I_v_new, l_I_w_new = l_I_u, l_I_v, l_I_w
        w_res_u, w_res_v, w_res_w = u_res, v_res, w_res

    # Total IC loss
    loss = jnp.mean(w_res_u**2) + jnp.mean(w_res_v**2) + jnp.mean(w_res_w**2)

    return loss, (l_I_u_new, l_I_v_new, l_I_w_new)


@partial(nnx.jit, static_argnums=(7, 8))
def train_step(model, optimizer, collocs, ic_collocs, u0_data, v0_data, w0_data, 
               use_rba, use_causal, cont_weight,
               λ_E, λ_I, l_E, l_I_u, l_I_v, l_I_w):

    # PDE loss
    (loss_E, l_E_new), grads_E = nnx.value_and_grad(pde_loss, has_aux=True)(model, l_E, collocs, use_rba, use_causal, cont_weight)

    # IC loss
    (loss_I, (l_I_u_new, l_I_v_new, l_I_w_new)), grads_I = nnx.value_and_grad(ic_loss, has_aux=True)(
        model, l_I_u, l_I_v, l_I_w, ic_collocs, u0_data, v0_data, w0_data, use_rba
    )
    
    # Compute total loss
    loss = λ_E*loss_E + λ_I*loss_I

    # Compute total gradients
    grads = jax.tree_util.tree_map(lambda g1, g2: λ_E * g1 + λ_I * g2, grads_E, grads_I)

    # Optimizer step
    optimizer.update(grads)

    return loss, grads_E, grads_I, l_E_new, l_I_u_new, l_I_v_new, l_I_w_new


@nnx.jit
def get_RAD_indices(model, collocs_pool, old_indices, l_E, l_E_pool):

    # Apply updates from old indices to pool
    updated_pool = l_E_pool.at[old_indices].set(l_E)

    # Get full residuals
    resids = pde_res(model, collocs_pool, Re=Re)  # shape (pool_size, 2)
    
    # Use mean of residuals for RAD
    resids_mean = jnp.mean(resids**2, axis=1, keepdims=True)  # shape (pool_size, 1)
    
    # Multiply by RBA weights
    w_resids = updated_pool * resids_mean
    
    # Get absolute
    wa_resids = jnp.abs(w_resids)

    # Raise to power rad_a
    ea = jnp.power(wa_resids, rad_a)
    
    # Divide by mean and add rad_c
    px = (ea/jnp.mean(ea)) + rad_c
    
    # Normalize
    px_norm = (px / jnp.sum(px))[:,0]

    sorted_indices = _get_colloc_indices(collocs_pool=collocs_pool, batch_size=batch_size, px=px_norm, seed=seed)

    return sorted_indices, updated_pool


def model_eval_nst(model, coords, u_ref, v_ref, w_ref):
    """Evaluate model and compute L² errors for u, v, w over full trajectory."""
    # Get predictions
    uv_pred = model(coords)
    u_pred = uv_pred[:, 0].reshape(u_ref.shape)
    v_pred = uv_pred[:, 1].reshape(v_ref.shape)
    w_pred = nst_w_func(model, coords).reshape(w_ref.shape)

    # L² errors    

    u_error = jnp.linalg.norm(u_pred - u_ref) / jnp.linalg.norm(u_ref)   
    w_error = jnp.linalg.norm(w_pred - w_ref) / jnp.linalg.norm(w_ref)
    v_error = jnp.linalg.norm(v_pred - v_ref) / jnp.linalg.norm(v_ref)

    return u_error, v_error, w_error, u_pred, v_pred, w_pred

## Data & Grid-Search Parameters

In [None]:
# Custom function for 3D collocation points (t, x, y) on torus
def _get_pde_collocs_3d(t_range, x_range, y_range, sample_size_t, sample_size_xy):
    """Generate collocation points for 3D (t, x, y) domain."""
    t = jnp.linspace(t_range[0], t_range[1], sample_size_t)
    x = jnp.linspace(x_range[0], x_range[1], sample_size_xy)
    y = jnp.linspace(y_range[0], y_range[1], sample_size_xy)
    T, X, Y = jnp.meshgrid(t, x, y, indexing='ij')
    collocs_pool = jnp.stack([T.flatten(), X.flatten(), Y.flatten()], axis=1)
    return collocs_pool


def _get_ic_collocs_2d(x_range, y_range, sample_size):
    """Generate IC collocation points for 2D spatial domain at t=0."""
    t = jnp.array([0.0], dtype=float)
    x = jnp.linspace(x_range[0], x_range[1], sample_size)
    y = jnp.linspace(y_range[0], y_range[1], sample_size)
    T, X, Y = jnp.meshgrid(t, x, y, indexing='ij')
    ic_collocs = jnp.stack([T.flatten(), X.flatten(), Y.flatten()], axis=1)
    return ic_collocs

In [None]:
# Load reference data (new format with full u, v, w solutions)
ref = np.load('data/ns.npy', allow_pickle=True).item()

# Full reference solutions: shape (11, 64, 64)
u_sol = jnp.array(ref['u'])
v_sol = jnp.array(ref['v'])
w_sol = jnp.array(ref['w'])

# Coordinates
t_ref = ref['t']  # shape (11,)
x_ref = ref['x']  # shape (64,)
y_ref = ref['y']  # shape (64,)

# Initial conditions (64x64 grid)
u0_ref = jnp.array(ref['u0'])
v0_ref = jnp.array(ref['v0'])
w0_ref = jnp.array(ref['w0'])

# Reynolds number
Re = 1.0 / ref['viscosity']  # Re = 100

# Domain ranges
t_max = float(t_ref.max())
x_max = float(x_ref.max())
y_max = float(y_ref.max())


# Grid sizeprint(f"Initial conditions shapes: u0={u0_ref.shape}, v0={v0_ref.shape}, w0={w0_ref.shape}")

N_t, N_x, N_y = w_sol.shape
print(f"Spatial grid: {N_x}x{N_y} = {N_x*N_y} IC points")

print(f"Reference solution shapes: u={u_sol.shape}, v={v_sol.shape}, w={w_sol.shape}")

print(f"Domain: t ∈ [0, {t_max:.4f}], x ∈ [0, {x_max:.4f}], y ∈ [0, {y_max:.4f}]")
print(f"Reynolds number: Re = {Re:.0f}")

In [None]:
# Collocation points for PDE residual (t, x, y)
collocs_pool = _get_pde_collocs_3d(t_range=(0, t_max), x_range=(0, x_max), y_range=(0, y_max), 
                                   sample_size_t=32, sample_size_xy=64)

# IC collocation points at t=0 (use same grid as reference for IC)
ic_x = jnp.array(x_ref)
ic_y = jnp.array(y_ref)
IC_X, IC_Y = jnp.meshgrid(ic_x, ic_y, indexing='ij')
ic_collocs = jnp.stack([jnp.zeros_like(IC_X.flatten()), IC_X.flatten(), IC_Y.flatten()], axis=1)

# IC data from reference
u0_data = u0_ref.flatten().reshape(-1, 1)
v0_data = v0_ref.flatten().reshape(-1, 1)
w0_data = w0_ref.flatten().reshape(-1, 1)

# Evaluation coordinates (full trajectory)
T_eval, X_eval, Y_eval = jnp.meshgrid(jnp.array(t_ref), jnp.array(x_ref), jnp.array(y_ref), indexing='ij')
coords = jnp.hstack([T_eval.flatten()[:, None], X_eval.flatten()[:, None], Y_eval.flatten()[:, None]])

print(f"Collocs pool shape: {collocs_pool.shape}")
print(f"IC collocs shape: {ic_collocs.shape}")
print(f"Coords shape: {coords.shape}")

In [None]:
# Training epochs
num_epochs = 100_000

# Scheduler configurations
learning_rate = 1e-3
decay_steps = 2000
decay_rate = 0.9
warmup_steps = 1000

# Define causal training parameters
causal_tol = 1.0
num_chunks = 32
M = jnp.triu(jnp.ones((num_chunks, num_chunks)), k=1).T

# Define Grad Norm parameters
grad_mixing = 0.9
f_grad_norm = 1000

# Define resampling parameters
batch_size = 2**12
f_resample = 2000
rad_a = 1.0
rad_c = 1.0

# Define RBA parameters
RBA_gamma = 0.999
RBA_eta = 0.01

# Define model parameters
n_in = 3  # (t, x, y)
n_out = 2  # (u, v) - w is derived
D = 5

# Periodic embeddings for x and y (period = 2π ≈ 6.28)
# Normalize by dividing domain by 2π
period_axes = {1: 1.0, 2: 1.0}  # will use sin/cos embeddings

sine_D = 5
init_scheme = {'type': 'glorot', 'gain': None, 'norm_pow': 0, 'distribution': 'uniform', 'sample_size': 10000}

### Ablation Runs

In [None]:
cases = [
    {'name': 'No Custom Weights', 'rba': False, 'rad': False, 'causal': False, 'grad_norm': False, 'lambda_I': 1.0, 'cont_weight': 1.0},
    {'name': 'Only RBA', 'rba': True, 'rad': False, 'causal': False, 'grad_norm': False, 'lambda_I': 1e5, 'cont_weight': 100.0},
    {'name': 'No RBA', 'rba': False, 'rad': True, 'causal': True, 'grad_norm': True, 'lambda_I': 1e5, 'cont_weight': 100.0},
    {'name': 'No RBA, No RAD', 'rba': False, 'rad': False, 'causal': True, 'grad_norm': True, 'lambda_I': 1e5, 'cont_weight': 100.0},
    {'name': 'No RBA, No Causal', 'rba': False, 'rad': True, 'causal': False, 'grad_norm': True, 'lambda_I': 1e5, 'cont_weight': 100.0},
    {'name': 'No RBA, No Grad Norm', 'rba': False, 'rad': True, 'causal': True, 'grad_norm': False, 'lambda_I': 1e5, 'cont_weight': 100.0},
]

num_blocks = 6
n_hidden = 16
alpha = 1.0
beta = 0.0

print(f"Training RGAKAN (alpha={alpha}, beta={beta}) for ablation study.")

for case in cases:
    case_name = case['name']
    print(f"\n--- Starting Case: {case_name} ---")
    
    RESULTS[case_name] = dict()

    for idx, run in enumerate([0, 7, 42]):
        RESULTS[case_name][idx] = dict()
        
        # Initialize RBA weights - full pool
        l_E_pool = jnp.ones((collocs_pool.shape[0], 1))
        # Also get RBAs for ICs (separate for u, v, w)
        l_I_u = jnp.ones((ic_collocs.shape[0], 1))
        l_I_v = jnp.ones((ic_collocs.shape[0], 1))
        l_I_w = jnp.ones((ic_collocs.shape[0], 1))
    
        # Get starting collocation points & RBA weights
        sorted_indices = _get_colloc_indices(collocs_pool=collocs_pool, batch_size=batch_size, px=None, seed=seed)
        
        collocs = collocs_pool[sorted_indices]
        l_E = l_E_pool[sorted_indices]
        
        # Get opt_type
        opt_type = _get_adam(learning_rate=learning_rate, decay_steps=decay_steps, decay_rate=decay_rate, warmup_steps=warmup_steps)

        # Define model
        model = RGAKAN(n_in=n_in, n_out=n_out, n_hidden=n_hidden, num_blocks=num_blocks, D=D,
                       init_scheme=init_scheme, alpha=alpha, beta=beta, ref=None,
                       period_axes=period_axes, rff_std=None, sine_D=sine_D, seed=seed+run)

        if idx == 0:
            print(f"Initialized model with {count_params(model)} parameters.")
        
        # Define global loss weights
        λ_E = jnp.array(1.0, dtype=float)
        λ_I = jnp.array(case['lambda_I'], dtype=float)

        # Set optimizer
        optimizer = nnx.Optimizer(model, opt_type)

        tick = time.time()
    
        # Start training
        for epoch in range(num_epochs):
        
            loss, grads_E, grads_I, l_E, l_I_u, l_I_v, l_I_w = train_step(
                model, optimizer, collocs, ic_collocs, u0_data, v0_data, w0_data, 
                case['rba'], case['causal'], case['cont_weight'],
                λ_E, λ_I, l_E, l_I_u, l_I_v, l_I_w
            )
            
            # Perform grad norm
            if case['grad_norm'] and (epoch != 0) and (epoch % f_grad_norm == 0):
                λ_Ε, λ_I = grad_norm(grads_E, grads_I, λ_E, λ_I, grad_mixing)
        
            # Perform RAD
            if case['rad'] and (epoch != 0) and (epoch % f_resample == 0):
                # Get new indices after resampling
                sorted_indices, l_E_pool = get_RAD_indices(model, collocs_pool, sorted_indices, l_E, l_E_pool)
                # Set new batch of collocs and l_E
                collocs = collocs_pool[sorted_indices]
                l_E = l_E_pool[sorted_indices]

        tack = time.time()
        
        # Evaluate on full trajectory
        u_err, v_err, w_err, u_pred, v_pred, w_pred = model_eval_nst(model, coords, u_sol, v_sol, w_sol)

        print(f"\tRun = {idx}\t L²(u)={u_err:.2e}\t L²(v)={v_err:.2e}\t L²(w)={w_err:.2e}\t Loss = {loss:.2e}\t Time = {(tack-tick)/60:.2f} mins")

        RESULTS[case_name][idx]['l2_u'] = np.asarray(device_get(u_err))
        RESULTS[case_name][idx]['l2_v'] = np.asarray(device_get(v_err))
        RESULTS[case_name][idx]['l2_w'] = np.asarray(device_get(w_err))
        RESULTS[case_name][idx]['loss'] = np.asarray(device_get(loss))
        RESULTS[case_name][idx]['time'] = (tack-tick)/num_epochs

        RESULTS[case_name][idx]['u_pred'] = np.asarray(device_get(u_pred))
        RESULTS[case_name][idx]['w_pred'] = np.asarray(device_get(w_pred))
        RESULTS[case_name][idx]['v_pred'] = np.asarray(device_get(v_pred))

    # Save intermediate results
    with open(result_file, "wb") as f:
        pickle.dump(RESULTS, f)

In [None]:
with open(result_file, "wb") as f:
    pickle.dump(RESULTS, f)

## Analysis

In [None]:
with open(result_file, "rb") as f:
    RESULTS = pickle.load(f)

In [None]:
def print_ablation_summary(results):
    """
    Parses the RESULTS dictionary and prints a summary table with 
    Mean +/- Standard Error for L2 errors and average time per iteration.
    """
    
    # Define header format
    header = f"{'Ablation Type':<25} | {'L²(u)':<20} | {'L²(v)':<20} | {'L²(w)':<20} | {'Time/Iter (ms)'}"
    print("-" * len(header))
    print(header)
    print("-" * len(header))

    for case_name, runs in results.items():
        # 1. Aggregate data across the 3 runs for this case
        # We assume the structure is results[case][run_idx][metric]
        u_vals = [runs[i]['l2_u'] for i in runs]
        v_vals = [runs[i]['l2_v'] for i in runs]
        w_vals = [runs[i]['l2_w'] for i in runs]
        t_vals = [runs[i]['time'] for i in runs]

        # 2. Helper to get "Mean +/- SE" string
        def get_stat_str(values):
            arr = np.array(values, dtype=np.float64)
            mean = np.mean(arr)
            # Standard Error = Std Dev / Sqrt(N)
            # ddof=1 for sample standard deviation
            se = np.std(arr, ddof=1) / np.sqrt(len(arr)) 
            return f"{mean:.2e} ± {se:.1e}"

        # 3. Calculate statistics
        u_str = get_stat_str(u_vals)
        v_str = get_stat_str(v_vals)
        w_str = get_stat_str(w_vals)
        
        # Time is just the simple mean across runs
        time_mean = np.mean(t_vals)*1000

        # 4. Print row
        print(f"{case_name:<25} | {u_str:<20} | {v_str:<20} | {w_str:<20} | {time_mean:.4f}")


In [None]:
# --- Usage ---
# Assuming 'RESULTS' is populated from your training loop:
print_ablation_summary(RESULTS)