# Experiment 10.1: Advection Benchmarks

In [None]:
import jax
import jax.numpy as jnp

jax.config.update("jax_default_matmul_precision", "highest")

from src.equations import advection_res as pde_res

from src.utils import _get_adam, _get_pde_collocs, _get_ic_collocs, model_eval, count_params, _get_colloc_indices, grad_norm
from src.utils import count_rga, count_pirate, count_pikan
from src.kan import KAN
from src.rgakan import RGAKAN
from src.piratenet import PirateNet

import numpy as np
from jax import device_get

import optax
from flax import nnx

import pickle
import time

import os

# Create the directory if it doesn't exist
results_dir = "results"
os.makedirs(results_dir, exist_ok=True)

result_file = os.path.join(results_dir, "benchmarks_adv.pkl")

plots_dir = "plots"
os.makedirs(plots_dir, exist_ok=True)

RESULTS = dict()
        
seed = 42

### Basic Functions

In [None]:
# PDE Loss
def pde_loss(model, l_E, collocs):

    residuals = pde_res(model, collocs) # shape (batch_size, 1)

    # Get new RBA weights
    abs_res = jnp.abs(residuals)
    l_E_new = (RBA_gamma*l_E) + (RBA_eta*abs_res/jnp.max(abs_res)) # shape (batch_size, 1)

    # Multiply by RBA weights
    w_resids = l_E_new * residuals # shape (batch_size, 1)

    # Reshape residuals for causal training
    residuals = w_resids.reshape(num_chunks, -1) # shape (num_chunks, points)

    # Get average loss per chunk
    loss = jnp.mean(residuals**2, axis=1)

    # Get causal weights
    weights = jax.lax.stop_gradient(jnp.exp(-causal_tol * (M @ loss)))

    # Weighted loss
    weighted_loss = jnp.mean(weights * loss)

    return weighted_loss, l_E_new


# IC Loss
def ic_loss(model, l_I, ic_collocs, ic_data):

    # Residual
    ic_res = model(ic_collocs) - ic_data

    # Get new RBA weights
    abs_res = jnp.abs(ic_res)
    l_I_new = (RBA_gamma*l_I) + (RBA_eta*abs_res/jnp.max(abs_res))

    # Multiply by RBA weights
    w_resids = l_I_new * ic_res

    # Loss
    loss = jnp.mean(w_resids**2)

    return loss, l_I_new


@nnx.jit
def train_step(model, optimizer, collocs, ic_collocs, ic_data, λ_E, λ_I, l_E, l_I):

    # PDE loss
    (loss_E, l_E_new), grads_E = nnx.value_and_grad(pde_loss, has_aux=True)(model, l_E, collocs)

    # IC loss
    (loss_I, l_I_new), grads_I = nnx.value_and_grad(ic_loss, has_aux=True)(model, l_I, ic_collocs, ic_data)
    
    # Compute total loss
    loss = λ_E*loss_E + λ_I*loss_I

    # Compute total gradients
    grads = jax.tree_util.tree_map(lambda g1, g2: λ_E * g1 + λ_I * g2, grads_E, grads_I)

    # Optimizer step
    optimizer.update(grads)

    return loss, grads_E, grads_I, l_E_new, l_I_new


@nnx.jit
def get_RAD_indices(model, collocs_pool, old_indices, l_E, l_E_pool):

    # Apply updates from old indices to pool
    updated_pool = l_E_pool.at[old_indices].set(l_E)

    # Get full residuals
    resids = pde_res(model, collocs_pool)
    
    # Multiply by RBA weights
    w_resids = updated_pool * resids
    
    # Get absolute
    wa_resids = jnp.abs(w_resids)

    # Raise to power rad_a
    ea = jnp.power(wa_resids, rad_a)
    
    # Divide by mean and add rad_c
    px = (ea/jnp.mean(ea)) + rad_c
    
    # Normalize
    px_norm = (px / jnp.sum(px))[:,0]

    sorted_indices = _get_colloc_indices(collocs_pool=collocs_pool, batch_size=batch_size, px=px_norm, seed=seed)

    return sorted_indices, updated_pool

### Data & Grid-Search Parameters

In [None]:
# Collocation points and ICs
collocs_pool = _get_pde_collocs(ranges = [(0,1), (0, 2.0*np.pi)], sample_size = 400)
ic_collocs = _get_ic_collocs(x_range = (0, 2.0*np.pi), sample_size = 2**6)
ic_data = jnp.sin(ic_collocs[:,1]).reshape(-1,1)

# Reference solution
ref = np.load('data/advection.npz')
refsol = jnp.array(ref['usol'])

N_t, N_x = ref['usol'].shape
t, x = ref['t'].flatten(), ref['x'].flatten()
T, X = jnp.meshgrid(t, x, indexing='ij')
coords = jnp.hstack([T.flatten()[:, None], X.flatten()[:, None]])

In [None]:
# Training epochs
num_epochs = 100_000

# Scheduler configurations
learning_rate = 1e-3
decay_steps = 2000
decay_rate = 0.9
warmup_steps = 1000

# Define causal training parameters
causal_tol = 1.0
num_chunks = 32
M = jnp.triu(jnp.ones((num_chunks, num_chunks)), k=1).T

# Define Grad Norm parameters
grad_mixing = 0.9
f_grad_norm = 1000

# Define resampling parameters
batch_size = 2**12
f_resample = 2000
rad_a = 1.0
rad_c = 1.0

# Define RBA parameters
RBA_gamma = 0.999
RBA_eta = 0.01

# Define model parameters
n_in = collocs_pool.shape[1]
n_out = 1
D = 5
period_axes = {1 : 1.0}
sine_D = 5
init_scheme = {'type': 'glorot', 'gain': None, 'norm_pow': 0, 'distribution': 'uniform', 'sample_size': 10000}

In [None]:
# Architecture parameters
archs = ['RGA KAN', 'PirateNet', 'cPIKAN']

arch_params = {'RGA KAN' : {'n_hidden' : 16, 'num_blocks' : 6},
               'PirateNet' : {'n_hidden' : 36, 'num_blocks' : 4},
               'cPIKAN' : {'n_hidden' : 18, 'num_layers' : 12}}

print("Expected to train models with following number of parameters:")

rga_width = arch_params['RGA KAN']['n_hidden']
rga_blocks = arch_params['RGA KAN']['num_blocks']
rga_params = count_rga(n_in, period_axes, n_out, rga_width, rga_blocks, D, sine_D)
print(f"RGA KAN: {rga_params} parameters")

pirate_width = arch_params['PirateNet']['n_hidden']
pirate_blocks = arch_params['PirateNet']['num_blocks']
pirate_params = count_pirate(n_in, period_axes, n_out, pirate_width, pirate_blocks)
print(f"PirateNet: {pirate_params} parameters")

pikan_width = arch_params['cPIKAN']['n_hidden']
pikan_depth = arch_params['cPIKAN']['num_layers']
pikan_params = count_pikan(n_in, period_axes, n_out, pikan_width, pikan_depth, D)
print(f"cPIKAN: {pikan_params} parameters")

### cPIKAN / Pirate Runs

In [None]:
for arch in ['PirateNet', 'cPIKAN']:

    RESULTS[arch] = dict()

    n_hidden = arch_params[arch]['n_hidden']

    if arch == 'PirateNet':
        num_blocks = arch_params[arch]['num_blocks']
        depth = int(3*num_blocks)
    elif arch == 'cPIKAN':
        num_layers = arch_params[arch]['num_layers']
        depth = num_layers
        
    print(f"Training {arch} with depth = {depth} and width = {n_hidden}.")

    for idx, run in enumerate([0, 7, 42]):

        RESULTS[arch][idx] = dict()
        
        # Initialize RBA weights - full pool
        l_E_pool = jnp.ones((collocs_pool.shape[0], 1))
        # Also get RBAs for ICs
        l_I = jnp.ones((ic_collocs.shape[0], 1))
    
        # Get starting collocation points & RBA weights
        sorted_indices = _get_colloc_indices(collocs_pool=collocs_pool, batch_size=batch_size, px=None, seed=seed)
        
        collocs = collocs_pool[sorted_indices]
        l_E = l_E_pool[sorted_indices]
        
        # Get opt_type
        opt_type = _get_adam(learning_rate=learning_rate, decay_steps=decay_steps, decay_rate=decay_rate, warmup_steps=warmup_steps)

        # Define model
        if arch == 'PirateNet':
            model = PirateNet(n_in = n_in, n_out = n_out, n_hidden = n_hidden, num_blocks = num_blocks,
                              alpha = 0.0, ref = ref, period_axes = period_axes, rff_std = 1.0,
                              RWF={"mean": 1.0, "std": 0.1}, seed=seed+run)
        elif arch == 'cPIKAN':
            model = KAN(n_in = n_in, n_out = n_out, n_hidden = n_hidden, num_layers = num_layers, D = D,
                        init_scheme = init_scheme, period_axes = period_axes, rff_std = None,
                        seed = seed+run)

        if idx == 0:
            print(f"Initialized model with {count_params(model)} parameters.")
        
        # Define global loss weights
        λ_E = jnp.array(1.0, dtype=float)
        λ_I = jnp.array(1.0, dtype=float)

        # Set optimizer
        optimizer = nnx.Optimizer(model, opt_type)

        tick = time.time()
    
        # Start training
        for epoch in range(num_epochs):
        
            loss, grads_E, grads_I, l_E, l_I = train_step(model, optimizer, collocs, ic_collocs, ic_data, λ_E, λ_I, l_E, l_I)
            
            # Perform grad norm
            if (epoch != 0) and (epoch % f_grad_norm == 0):
        
                λ_Ε, λ_I = grad_norm(grads_E, grads_I, λ_E, λ_I, grad_mixing)
        
            # Perform RAD
            if (epoch != 0) and (epoch % f_resample == 0):
    
                # Get new indices after resampling
                sorted_indices, l_E_pool = get_RAD_indices(model, collocs_pool, sorted_indices, l_E, l_E_pool)
                # Set new batch of collocs and l_E
                collocs = collocs_pool[sorted_indices]
                l_E = l_E_pool[sorted_indices]

        tack = time.time()
        final_error = model_eval(model, coords, refsol)

        print(f"\tRun = {idx}\t L^2 = {final_error:.2e}\t Loss = {loss:.2e}\t Time = {(tack-tick)/60:.2f} mins")

        RESULTS[arch][idx]['l2'] = np.asarray(device_get(final_error))
        RESULTS[arch][idx]['loss'] = np.asarray(device_get(loss))
        RESULTS[arch][idx]['time'] = (tack-tick)/num_epochs

### RGA Runs

In [None]:
num_blocks = arch_params['RGA KAN']['num_blocks']
n_hidden = arch_params['RGA KAN']['n_hidden']
depth = int(2*num_blocks)

print(f"Training RGA KAN with depth = {depth} and width = {n_hidden}.")

for alpha in [0.0, 1.0]:
    for beta in [0.0, 1.0]:

        RESULTS[(alpha,beta)] = dict()
        print(f"Training for alpha = {alpha} and beta = {beta}.")
    
        for idx, run in enumerate([0, 7, 42]):
    
            RESULTS[(alpha,beta)][idx] = dict()
            
            # Initialize RBA weights - full pool
            l_E_pool = jnp.ones((collocs_pool.shape[0], 1))
            # Also get RBAs for ICs
            l_I = jnp.ones((ic_collocs.shape[0], 1))
        
            # Get starting collocation points & RBA weights
            sorted_indices = _get_colloc_indices(collocs_pool=collocs_pool, batch_size=batch_size, px=None, seed=seed)
            
            collocs = collocs_pool[sorted_indices]
            l_E = l_E_pool[sorted_indices]
            
            # Get opt_type
            opt_type = _get_adam(learning_rate=learning_rate, decay_steps=decay_steps, decay_rate=decay_rate, warmup_steps=warmup_steps)
    
            # Define model
            model = RGAKAN(n_in = n_in, n_out = n_out, n_hidden = n_hidden, num_blocks = num_blocks, D = D,
                           init_scheme = init_scheme, alpha = alpha, beta = beta, ref = ref,
                           period_axes = period_axes, rff_std = None, sine_D = sine_D, seed = seed+run)
    
            if idx == 0:
                print(f"Initialized model with {count_params(model)} parameters.")
            
            # Define global loss weights
            λ_E = jnp.array(1.0, dtype=float)
            λ_I = jnp.array(1.0, dtype=float)
    
            # Set optimizer
            optimizer = nnx.Optimizer(model, opt_type)
    
            tick = time.time()
        
            # Start training
            for epoch in range(num_epochs):
            
                loss, grads_E, grads_I, l_E, l_I = train_step(model, optimizer, collocs, ic_collocs, ic_data, λ_E, λ_I, l_E, l_I)
                
                # Perform grad norm
                if (epoch != 0) and (epoch % f_grad_norm == 0):
            
                    λ_Ε, λ_I = grad_norm(grads_E, grads_I, λ_E, λ_I, grad_mixing)
            
                # Perform RAD
                if (epoch != 0) and (epoch % f_resample == 0):
        
                    # Get new indices after resampling
                    sorted_indices, l_E_pool = get_RAD_indices(model, collocs_pool, sorted_indices, l_E, l_E_pool)
                    # Set new batch of collocs and l_E
                    collocs = collocs_pool[sorted_indices]
                    l_E = l_E_pool[sorted_indices]
    
            tack = time.time()
            final_output = model(coords).reshape(refsol.shape)
            final_error = model_eval(model, coords, refsol)
    
            print(f"\tRun = {idx}\t L^2 = {final_error:.2e}\t Loss = {loss:.2e}\t Time = {(tack-tick)/60:.2f} mins")
    
            RESULTS[(alpha,beta)][idx]['l2'] = np.asarray(device_get(final_error))
            RESULTS[(alpha,beta)][idx]['loss'] = np.asarray(device_get(loss))
            RESULTS[(alpha,beta)][idx]['time'] = (tack-tick)/num_epochs
            RESULTS[(alpha,beta)][idx]['output'] = np.asarray(device_get(final_output))

In [None]:
with open(result_file, "wb") as f:
    pickle.dump(RESULTS, f)

## Plots & Analysis

In [None]:
with open(result_file, "rb") as f:
    RESULTS = pickle.load(f)

In [None]:
def metric_stats(RESULTS, model_idx, metric='l2'):
    vals = []
    runs = RESULTS.get(model_idx, {})
    for i in (0, 1, 2):
        try:
            v = runs[i][metric]
        except (KeyError, TypeError):
            continue
        v = np.array(v, dtype=float).squeeze()
        vals.append(float(v))

    if len(vals) == 0:
        return np.nan, np.nan
    if len(vals) == 1:
        return float(vals[0]), np.nan

    mean = float(np.mean(vals))
    se = float(np.std(vals, ddof=1) / np.sqrt(len(vals)))
    return mean, se

def _collect_metric_vals(runs_dict, metric='l2', run_ids=(0,1,2)):
    vals = []
    for i in run_ids:
        try:
            v = runs_dict[i][metric]
        except (KeyError, TypeError):
            continue
        v = np.array(v, dtype=float).squeeze()
        vals.append(float(v))
    return vals

def pick_best_rgakan(RESULTS, metric='l2', configs=None, run_ids=(0,1,2)):
    if configs is None:
        configs = [(0,0), (0,1), (1,0), (1,1)]

    # 1) Find best config by mean metric
    config_stats = {}
    for cfg in configs:
        runs = RESULTS.get(cfg, {})
        vals = _collect_metric_vals(runs, metric=metric, run_ids=run_ids)
        if len(vals) == 0:
            mean, se = np.inf, np.nan
        elif len(vals) == 1:
            mean, se = float(vals[0]), np.nan
        else:
            mean = float(np.mean(vals))
            se = float(np.std(vals, ddof=1) / np.sqrt(len(vals)))
        config_stats[cfg] = (mean, se)

    # choose lowest mean; tie-break by smaller SE, then lexicographic cfg
    best_config = min(configs, key=lambda c: (config_stats[c][0], np.nan_to_num(config_stats[c][1], nan=np.inf), c))
    best_mean, best_se = config_stats[best_config]

    # 2) Within best config, pick best run by lowest metric
    runs = RESULTS.get(best_config, {})
    best_run_idx, best_run_value = None, np.inf
    for i in run_ids:
        try:
            v = runs[i][metric]
            v = float(np.array(v, dtype=float).squeeze())
        except (KeyError, TypeError, ValueError):
            continue
        if v < best_run_value:
            best_run_value = v
            best_run_idx = i

    best_run_time = None
    best_run_output = None
    if best_run_idx is not None:
        best_entry = runs.get(best_run_idx, {})
        best_run_time = best_entry.get('time', None)
        # 'output' was stored only for RGAKAN in your training loop
        best_run_output = best_entry.get('output', None)
        if best_run_output is not None:
            best_run_output = np.array(best_run_output)

    return {
        'best_config': best_config,
        'config_mean': best_mean,
        'config_se': best_se,
        'best_run_idx': best_run_idx,
        'best_run_value': best_run_value,
        'best_run_time': best_run_time,
        'best_run_output': best_run_output,
    }

In [None]:
print("RESULTS")
print("------------------------------------------------------------------")
m, s = metric_stats(RESULTS, 'PirateNet', 'l2')
tt, _ = metric_stats(RESULTS, 'PirateNet', 'time')
print(f"PirateNet:\t\t L^2 = {m:.3e}\t Error = {s:.3e}\t Time = {tt*1000:.2f} ms.")
m, s = metric_stats(RESULTS, 'cPIKAN', 'l2')
tt, _ = metric_stats(RESULTS, 'cPIKAN', 'time')
print(f"cPIKAN:\t\t\t L^2 = {m:.3e}\t Error = {s:.3e}\t Time = {tt*1000:.2f} ms.")
m, s = metric_stats(RESULTS, (0,0), 'l2')
tt, _ = metric_stats(RESULTS, (0,0), 'time')
print(f"RGAKAN (α = 0, β = 0):\t L^2 = {m:.3e}\t Error = {s:.3e}\t Time = {tt*1000:.2f} ms.")
m, s = metric_stats(RESULTS, (1,0), 'l2')
tt, _ = metric_stats(RESULTS, (1,0), 'time')
print(f"RGAKAN (α = 1, β = 0):\t L^2 = {m:.3e}\t Error = {s:.3e}\t Time = {tt*1000:.2f} ms.")
m, s = metric_stats(RESULTS, (0,1), 'l2')
tt, _ = metric_stats(RESULTS, (0,1), 'time')
print(f"RGAKAN (α = 0, β = 1):\t L^2 = {m:.3e}\t Error = {s:.3e}\t Time = {tt*1000:.2f} ms.")
m, s = metric_stats(RESULTS, (1,1), 'l2')
tt, _ = metric_stats(RESULTS, (1,1), 'time')
print(f"RGAKAN (α = 1, β = 1):\t L^2 = {m:.3e}\t Error = {s:.3e}\t Time = {tt*1000:.2f} ms.")
print("------------------------------------------------------------------")
m, s = metric_stats(RESULTS, 'PirateNet', 'loss')
print(f"PirateNet:\t\t Loss = {m:.3e}\t Error = {s:.3e}")
m, s = metric_stats(RESULTS, 'cPIKAN', 'loss')
print(f"cPIKAN:\t\t\t Loss = {m:.3e}\t Error = {s:.3e}")
m, s = metric_stats(RESULTS, (0,0), 'loss')
print(f"RGAKAN (α = 0, β = 0):\t Loss = {m:.3e}\t Error = {s:.3e}")
m, s = metric_stats(RESULTS, (1,0), 'loss')
print(f"RGAKAN (α = 1, β = 0):\t Loss = {m:.3e}\t Error = {s:.3e}")
m, s = metric_stats(RESULTS, (0,1), 'loss')
print(f"RGAKAN (α = 0, β = 1):\t Loss = {m:.3e}\t Error = {s:.3e}")
m, s = metric_stats(RESULTS, (1,1), 'loss')
print(f"RGAKAN (α = 1, β = 1):\t Loss = {m:.3e}\t Error = {s:.3e}")

In [None]:
res = pick_best_rgakan(RESULTS, metric='l2')
a, b = res['best_config']
print(f"The lowest average L^2 error is obtained for alpha = {a} and beta = {b}.")
print(f"Among the runs with this configuration, the lowest L^2 error is {res['best_run_value']:.2e}")

In [None]:
import matplotlib.pyplot as plt
import seaborn as sns
from mpl_toolkits.axes_grid1 import make_axes_locatable
from matplotlib.ticker import ScalarFormatter
from matplotlib import ticker

LABEL_FS = 16
TITLE_FS = 18
TICK_FS = 14
CBAR_FS = 16

def plot_pred_ref_diff(u_pred, refsol, x, t, save_fig=False,
                       clim_pred=None, clim_ref=None, clim_diff=None):

    # allow jax arrays
    u_pred = np.array(u_pred)
    refsol = np.array(refsol)
    diff = np.abs(u_pred - refsol)

    fig, axs = plt.subplots(1, 3, figsize=(14, 3.5), constrained_layout=False)

    extent = [np.min(x), np.max(x), np.min(t), np.max(t)]
    cmap = sns.color_palette("Spectral", as_cmap=True)

    imgs = []
    imgs.append(axs[0].imshow(refsol.T, origin='lower', aspect='auto', extent=extent, cmap=cmap))
    imgs.append(axs[1].imshow(u_pred.T, origin='lower', aspect='auto', extent=extent, cmap=cmap))
    imgs.append(axs[2].imshow(diff.T, origin='lower', aspect='auto', extent=extent, cmap=cmap))
    
    titles = ["Reference", "Prediction", "Absolute Error"]
    cbar_labels = [r"$u_{\mathrm{ref}}$", r"$u_{\mathrm{pred}}$", r"$|u_{\mathrm{pred}}-u_{\mathrm{ref}}|$"]


    if clim_pred is not None: imgs[1].set_clim(*clim_pred)
    if clim_ref  is not None: imgs[0].set_clim(*clim_ref)
    if clim_diff is not None: imgs[2].set_clim(*clim_diff)

    for ax, title, img, cblab in zip(axs, titles, imgs, cbar_labels):
        ax.set_title(title, fontsize=TITLE_FS)
        
        ax.set_xlabel(r"$t$", fontsize=LABEL_FS)
        ax.set_ylabel(r"$x$", fontsize=LABEL_FS) 
            
        ax.tick_params(axis='both', which='major', labelsize=TICK_FS)

        
        divider = make_axes_locatable(ax)
        cax = divider.append_axes("bottom", size="7%", pad=0.65)
        
        cbar = fig.colorbar(img, cax=cax, orientation='horizontal')
        
        cbar.formatter = ticker.ScalarFormatter(useMathText=True)
        cbar.formatter.set_scientific(True)
        cbar.formatter.set_powerlimits((0, 0))   # always sci notation
        cbar.formatter.set_useOffset(False)
        cbar.update_ticks()
        cbar.ax.xaxis.get_offset_text().set_fontsize(TICK_FS)
        
        cbar.ax.tick_params(labelsize=CBAR_FS)
        cbar.set_label(cblab, fontsize=LABEL_FS, labelpad=6)

    plt.subplots_adjust(wspace=0.35, bottom=0.0)

    if save_fig:
        plt.savefig(f"{plots_dir}/final_adv.pdf", format="pdf", bbox_inches="tight")

    plt.show()


In [None]:
final_output = res['best_run_output']

plot_pred_ref_diff(final_output, refsol, t, x, save_fig=True, clim_pred=(-1,1), clim_ref=(-1,1))