# Experiment 12.1: Poisson Benchmarks

In [None]:
import jax
import jax.numpy as jnp

jax.config.update("jax_default_matmul_precision", "highest")

from src.utils import _get_adam, _get_pde_collocs, _get_ic_collocs, model_eval, count_params, _get_colloc_indices, grad_norm
from src.utils import count_rga, count_pirate, count_pikan
from src.wrappers import PoissonKAN, PoissonModel, PoissonPirate

import numpy as np
from jax import device_get

import optax
from flax import nnx

import pickle
import time

import os

# Create the directory if it doesn't exist
results_dir = "results"
os.makedirs(results_dir, exist_ok=True)

result_file = os.path.join(results_dir, "benchmarks_poisson.pkl")

plots_dir = "plots"
os.makedirs(plots_dir, exist_ok=True)

RESULTS = dict()

RESULTS[1] = dict()
RESULTS[2] = dict()
RESULTS[4] = dict()
        
seed = 42

In [None]:
def metric_stats(RESULTS, model_idx, metric='l2'):
    vals = []
    runs = RESULTS.get(model_idx, {})
    for i in (0, 1, 2):
        try:
            v = runs[i][metric]
        except (KeyError, TypeError):
            continue
        v = np.array(v, dtype=float).squeeze()
        vals.append(float(v))

    if len(vals) == 0:
        return np.nan, np.nan
    if len(vals) == 1:
        return float(vals[0]), np.nan

    mean = float(np.mean(vals))
    se = float(np.std(vals, ddof=1) / np.sqrt(len(vals)))
    return mean, se

def _collect_metric_vals(runs_dict, metric='l2', run_ids=(0,1,2)):
    vals = []
    for i in run_ids:
        try:
            v = runs_dict[i][metric]
        except (KeyError, TypeError):
            continue
        v = np.array(v, dtype=float).squeeze()
        vals.append(float(v))
    return vals

def pick_best_rgakan(RESULTS, metric='l2', configs=None, run_ids=(0,1,2)):
    if configs is None:
        configs = [(0,0), (0,1), (1,0), (1,1)]

    # 1) Find best config by mean metric
    config_stats = {}
    for cfg in configs:
        runs = RESULTS.get(cfg, {})
        vals = _collect_metric_vals(runs, metric=metric, run_ids=run_ids)
        if len(vals) == 0:
            mean, se = np.inf, np.nan
        elif len(vals) == 1:
            mean, se = float(vals[0]), np.nan
        else:
            mean = float(np.mean(vals))
            se = float(np.std(vals, ddof=1) / np.sqrt(len(vals)))
        config_stats[cfg] = (mean, se)

    # choose lowest mean; tie-break by smaller SE, then lexicographic cfg
    best_config = min(configs, key=lambda c: (config_stats[c][0], np.nan_to_num(config_stats[c][1], nan=np.inf), c))
    best_mean, best_se = config_stats[best_config]

    # 2) Within best config, pick best run by lowest metric
    runs = RESULTS.get(best_config, {})
    best_run_idx, best_run_value = None, np.inf
    for i in run_ids:
        try:
            v = runs[i][metric]
            v = float(np.array(v, dtype=float).squeeze())
        except (KeyError, TypeError, ValueError):
            continue
        if v < best_run_value:
            best_run_value = v
            best_run_idx = i

    best_run_time = None
    best_run_output = None
    if best_run_idx is not None:
        best_entry = runs.get(best_run_idx, {})
        best_run_time = best_entry.get('time', None)
        # 'output' was stored only for RGAKAN in your training loop
        best_run_output = best_entry.get('output', None)
        if best_run_output is not None:
            best_run_output = np.array(best_run_output)

    return {
        'best_config': best_config,
        'config_mean': best_mean,
        'config_se': best_se,
        'best_run_idx': best_run_idx,
        'best_run_value': best_run_value,
        'best_run_time': best_run_time,
        'best_run_output': best_run_output,
    }

In [None]:
@nnx.jit
def get_RAD_indices(model, collocs_pool, old_indices, l_E, l_E_pool):

    # Apply updates from old indices to pool
    updated_pool = l_E_pool.at[old_indices].set(l_E)

    # Get full residuals
    resids = pde_res(model, collocs_pool)
    
    # Multiply by RBA weights
    w_resids = updated_pool * resids
    
    # Get absolute
    wa_resids = jnp.abs(w_resids)

    # Raise to power rad_a
    ea = jnp.power(wa_resids, rad_a)
    
    # Divide by mean and add rad_c
    px = (ea/jnp.mean(ea)) + rad_c
    
    # Normalize
    px_norm = (px / jnp.sum(px))[:,0]

    sorted_indices = _get_colloc_indices(collocs_pool=collocs_pool, batch_size=batch_size, px=px_norm, seed=seed)

    return sorted_indices, updated_pool

## Parameters

In [None]:
# Training epochs
num_epochs = 100_000

# Scheduler configurations
learning_rate = 1e-3
decay_steps = 2000
decay_rate = 0.9
warmup_steps = 1000

# Define resampling parameters
batch_size = 2**12
f_resample = 2000
rad_a = 1.0
rad_c = 1.0

# Define RBA parameters
RBA_gamma = 0.999
RBA_eta = 0.01

# Define model parameters
n_in = 2
n_out = 1
D = 5
period_axes = None
sine_D = 5
init_scheme = {'type': 'glorot', 'gain': None, 'norm_pow': 0, 'distribution': 'uniform', 'sample_size': 10000}

In [None]:
# Architecture parameters
archs = ['RGA KAN', 'PirateNet', 'cPIKAN']

arch_params = {'RGA KAN' : {'n_hidden' : 16, 'num_blocks' : 6},
               'PirateNet' : {'n_hidden' : 36, 'num_blocks' : 4},
               'cPIKAN' : {'n_hidden' : 18, 'num_layers' : 12}}

print("Expected to train models with following number of parameters:")

rga_width = arch_params['RGA KAN']['n_hidden']
rga_blocks = arch_params['RGA KAN']['num_blocks']
rga_params = count_rga(n_in, period_axes, n_out, rga_width, rga_blocks, D, sine_D)
print(f"RGA KAN: {rga_params} parameters")

pirate_width = arch_params['PirateNet']['n_hidden']
pirate_blocks = arch_params['PirateNet']['num_blocks']
pirate_params = count_pirate(n_in, period_axes, n_out, pirate_width, pirate_blocks)
print(f"PirateNet: {pirate_params} parameters")

pikan_width = arch_params['cPIKAN']['n_hidden']
pikan_depth = arch_params['cPIKAN']['num_layers']
pikan_params = count_pikan(n_in, period_axes, n_out, pikan_width, pikan_depth, D)
print(f"cPIKAN: {pikan_params} parameters")

## $\omega = 1$

In [None]:
from src.equations import poisson_1_res as pde_res

In [None]:
# Collocation points and ICs
collocs_pool = _get_pde_collocs(ranges = [(-1,1), (-1,1)], sample_size = 400)

# Reference solution
ref = np.load('data/poisson_1.npz')
refsol = jnp.array(ref['usol'])

N_x, N_y = ref['usol'].shape
x, y = ref['x'].flatten(), ref['y'].flatten()
X, Y = jnp.meshgrid(x, y, indexing='ij')
coords = jnp.hstack([X.flatten()[:, None], Y.flatten()[:, None]])

In [None]:
# PDE Loss
def pde_loss(model, l_E, collocs):

    residuals = pde_res(model, collocs) # shape (batch_size, 1)

    # Get new RBA weights
    abs_res = jnp.abs(residuals)
    l_E_new = (RBA_gamma*l_E) + (RBA_eta*abs_res/jnp.max(abs_res)) # shape (batch_size, 1)

    # Multiply by RBA weights
    w_resids = l_E_new * residuals # shape (batch_size, 1)

    # Weighted loss
    weighted_loss = jnp.mean(w_resids**2)

    return weighted_loss, l_E_new


@nnx.jit
def train_step(model, optimizer, collocs, l_E):

    # PDE loss
    (loss, l_E_new), grads = nnx.value_and_grad(pde_loss, has_aux=True)(model, l_E, collocs)

    # Optimizer step
    optimizer.update(grads)

    return loss, l_E_new

### cPIKAN / Pirate Runs

In [None]:
for arch in ['PirateNet', 'cPIKAN']:

    RESULTS[1][arch] = dict()

    n_hidden = arch_params[arch]['n_hidden']

    if arch == 'PirateNet':
        num_blocks = arch_params[arch]['num_blocks']
        depth = int(3*num_blocks)
    elif arch == 'cPIKAN':
        num_layers = arch_params[arch]['num_layers']
        depth = num_layers
        
    print(f"Training {arch} with depth = {depth} and width = {n_hidden}.")

    for idx, run in enumerate([0, 7, 42]):

        RESULTS[1][arch][idx] = dict()
        
        # Initialize RBA weights - full pool
        l_E_pool = jnp.ones((collocs_pool.shape[0], 1))
    
        # Get starting collocation points & RBA weights
        sorted_indices = _get_colloc_indices(collocs_pool=collocs_pool, batch_size=batch_size, px=None, seed=seed)
        
        collocs = collocs_pool[sorted_indices]
        l_E = l_E_pool[sorted_indices]
        
        # Get opt_type
        opt_type = _get_adam(learning_rate=learning_rate, decay_steps=decay_steps, decay_rate=decay_rate, warmup_steps=warmup_steps)

        # Define model
        if arch == 'PirateNet':
            model = PoissonPirate(n_in = n_in, n_out = n_out, n_hidden = n_hidden, num_blocks = num_blocks,
                                  alpha = 0.0, ref = None, period_axes = period_axes, rff_std = 1.0,
                                  RWF={"mean": 1.0, "std": 0.1}, seed=seed+run)
        elif arch == 'cPIKAN':
            model = PoissonKAN(n_in = n_in, n_out = n_out, n_hidden = n_hidden, num_layers = num_layers, D = D,
                               init_scheme = init_scheme, period_axes = period_axes, rff_std = None,
                               seed = seed+run)

        if idx == 0:
            print(f"Initialized model with {count_params(model)} parameters.")

        # Set optimizer
        optimizer = nnx.Optimizer(model, opt_type)

        tick = time.time()
    
        # Start training
        for epoch in range(num_epochs):
        
            loss, l_E = train_step(model, optimizer, collocs, l_E)
        
            # Perform RAD
            if (epoch != 0) and (epoch % f_resample == 0):
    
                # Get new indices after resampling
                sorted_indices, l_E_pool = get_RAD_indices(model, collocs_pool, sorted_indices, l_E, l_E_pool)
                # Set new batch of collocs and l_E
                collocs = collocs_pool[sorted_indices]
                l_E = l_E_pool[sorted_indices]

        tack = time.time()
        final_error = model_eval(model, coords, refsol)

        print(f"\tRun = {idx}\t L^2 = {final_error:.2e}\t Loss = {loss:.2e}\t Time = {(tack-tick)/60:.2f} mins")

        RESULTS[1][arch][idx]['l2'] = np.asarray(device_get(final_error))
        RESULTS[1][arch][idx]['loss'] = np.asarray(device_get(loss))
        RESULTS[1][arch][idx]['time'] = (tack-tick)/num_epochs

### RGA Runs

In [None]:
num_blocks = arch_params['RGA KAN']['num_blocks']
n_hidden = arch_params['RGA KAN']['n_hidden']
depth = int(2*num_blocks)

print(f"Training RGA KAN with depth = {depth} and width = {n_hidden}.")

for alpha in [0.0, 1.0]:
    for beta in [0.0, 1.0]:

        RESULTS[1][(alpha,beta)] = dict()
        print(f"Training for alpha = {alpha} and beta = {beta}.")
    
        for idx, run in enumerate([0, 7, 42]):
    
            RESULTS[1][(alpha,beta)][idx] = dict()
            
            # Initialize RBA weights - full pool
            l_E_pool = jnp.ones((collocs_pool.shape[0], 1))
        
            # Get starting collocation points & RBA weights
            sorted_indices = _get_colloc_indices(collocs_pool=collocs_pool, batch_size=batch_size, px=None, seed=seed)
            
            collocs = collocs_pool[sorted_indices]
            l_E = l_E_pool[sorted_indices]
            
            # Get opt_type
            opt_type = _get_adam(learning_rate=learning_rate, decay_steps=decay_steps, decay_rate=decay_rate, warmup_steps=warmup_steps)
    
            # Define model
            model = PoissonModel(n_in = n_in, n_out = n_out, n_hidden = n_hidden, num_blocks = num_blocks, D = D,
                                 init_scheme = init_scheme, alpha = alpha, beta = beta, ref = None,
                                 period_axes = period_axes, rff_std = None, sine_D = sine_D, seed = seed+run)
    
            if idx == 0:
                print(f"Initialized model with {count_params(model)} parameters.")
    
            # Set optimizer
            optimizer = nnx.Optimizer(model, opt_type)
    
            tick = time.time()
        
            # Start training
            for epoch in range(num_epochs):
            
                loss, l_E = train_step(model, optimizer, collocs, l_E)
            
                # Perform RAD
                if (epoch != 0) and (epoch % f_resample == 0):
        
                    # Get new indices after resampling
                    sorted_indices, l_E_pool = get_RAD_indices(model, collocs_pool, sorted_indices, l_E, l_E_pool)
                    # Set new batch of collocs and l_E
                    collocs = collocs_pool[sorted_indices]
                    l_E = l_E_pool[sorted_indices]
    
            tack = time.time()
            final_output = model(coords).reshape(refsol.shape)
            final_error = model_eval(model, coords, refsol)
    
            print(f"\tRun = {idx}\t L^2 = {final_error:.2e}\t Loss = {loss:.2e}\t Time = {(tack-tick)/60:.2f} mins")
    
            RESULTS[1][(alpha,beta)][idx]['l2'] = np.asarray(device_get(final_error))
            RESULTS[1][(alpha,beta)][idx]['loss'] = np.asarray(device_get(loss))
            RESULTS[1][(alpha,beta)][idx]['time'] = (tack-tick)/num_epochs
            RESULTS[1][(alpha,beta)][idx]['output'] = np.asarray(device_get(final_output))

In [None]:
# Checkpoint 1
with open(result_file, "wb") as f:
    pickle.dump(RESULTS, f)

In [None]:
with open(result_file, "rb") as f:
    RESULTS = pickle.load(f)

this_res = RESULTS[1]

In [None]:
print("RESULTS")
print("------------------------------------------------------------------")
m, s = metric_stats(this_res, 'PirateNet', 'l2')
tt, _ = metric_stats(this_res, 'PirateNet', 'time')
print(f"PirateNet:\t\t L^2 = {m:.3e}\t Error = {s:.3e}\t Time = {tt*1000:.2f} ms.")
m, s = metric_stats(this_res, 'cPIKAN', 'l2')
tt, _ = metric_stats(this_res, 'cPIKAN', 'time')
print(f"cPIKAN:\t\t\t L^2 = {m:.3e}\t Error = {s:.3e}\t Time = {tt*1000:.2f} ms.")
m, s = metric_stats(this_res, (0,0), 'l2')
tt, _ = metric_stats(this_res, (0,0), 'time')
print(f"RGAKAN (α = 0, β = 0):\t L^2 = {m:.3e}\t Error = {s:.3e}\t Time = {tt*1000:.2f} ms.")
m, s = metric_stats(this_res, (1,0), 'l2')
tt, _ = metric_stats(this_res, (1,0), 'time')
print(f"RGAKAN (α = 1, β = 0):\t L^2 = {m:.3e}\t Error = {s:.3e}\t Time = {tt*1000:.2f} ms.")
m, s = metric_stats(this_res, (0,1), 'l2')
tt, _ = metric_stats(this_res, (0,1), 'time')
print(f"RGAKAN (α = 0, β = 1):\t L^2 = {m:.3e}\t Error = {s:.3e}\t Time = {tt*1000:.2f} ms.")
m, s = metric_stats(this_res, (1,1), 'l2')
tt, _ = metric_stats(this_res, (1,1), 'time')
print(f"RGAKAN (α = 1, β = 1):\t L^2 = {m:.3e}\t Error = {s:.3e}\t Time = {tt*1000:.2f} ms.")
print("------------------------------------------------------------------")
m, s = metric_stats(this_res, 'PirateNet', 'loss')
print(f"PirateNet:\t\t Loss = {m:.3e}\t Error = {s:.3e}")
m, s = metric_stats(this_res, 'cPIKAN', 'loss')
print(f"cPIKAN:\t\t\t Loss = {m:.3e}\t Error = {s:.3e}")
m, s = metric_stats(this_res, (0,0), 'loss')
print(f"RGAKAN (α = 0, β = 0):\t Loss = {m:.3e}\t Error = {s:.3e}")
m, s = metric_stats(this_res, (1,0), 'loss')
print(f"RGAKAN (α = 1, β = 0):\t Loss = {m:.3e}\t Error = {s:.3e}")
m, s = metric_stats(this_res, (0,1), 'loss')
print(f"RGAKAN (α = 0, β = 1):\t Loss = {m:.3e}\t Error = {s:.3e}")
m, s = metric_stats(this_res, (1,1), 'loss')
print(f"RGAKAN (α = 1, β = 1):\t Loss = {m:.3e}\t Error = {s:.3e}")

In [None]:
res_1 = pick_best_rgakan(this_res, metric='l2')
a, b = res_1['best_config']
print(f"The lowest average L^2 error is obtained for alpha = {a} and beta = {b}.")
print(f"Among the runs with this configuration, the lowest L^2 error is {res_1['best_run_value']:.2e}")

## $\omega = 2$

In [None]:
from src.equations import poisson_2_res as pde_res

In [None]:
# Collocation points and ICs
collocs_pool = _get_pde_collocs(ranges = [(-1,1), (-1,1)], sample_size = 400)

# Reference solution
ref = np.load('data/poisson_2.npz')
refsol = jnp.array(ref['usol'])

N_x, N_y = ref['usol'].shape
x, y = ref['x'].flatten(), ref['y'].flatten()
X, Y = jnp.meshgrid(x, y, indexing='ij')
coords = jnp.hstack([X.flatten()[:, None], Y.flatten()[:, None]])

In [None]:
# PDE Loss
def pde_loss(model, l_E, collocs):

    residuals = pde_res(model, collocs) # shape (batch_size, 1)

    # Get new RBA weights
    abs_res = jnp.abs(residuals)
    l_E_new = (RBA_gamma*l_E) + (RBA_eta*abs_res/jnp.max(abs_res)) # shape (batch_size, 1)

    # Multiply by RBA weights
    w_resids = l_E_new * residuals # shape (batch_size, 1)

    # Weighted loss
    weighted_loss = jnp.mean(w_resids**2)

    return weighted_loss, l_E_new


@nnx.jit
def train_step(model, optimizer, collocs, l_E):

    # PDE loss
    (loss, l_E_new), grads = nnx.value_and_grad(pde_loss, has_aux=True)(model, l_E, collocs)

    # Optimizer step
    optimizer.update(grads)

    return loss, l_E_new

### cPIKAN / Pirate Runs

In [None]:
for arch in ['PirateNet', 'cPIKAN']:

    RESULTS[2][arch] = dict()

    n_hidden = arch_params[arch]['n_hidden']

    if arch == 'PirateNet':
        num_blocks = arch_params[arch]['num_blocks']
        depth = int(3*num_blocks)
    elif arch == 'cPIKAN':
        num_layers = arch_params[arch]['num_layers']
        depth = num_layers
        
    print(f"Training {arch} with depth = {depth} and width = {n_hidden}.")

    for idx, run in enumerate([0, 7, 42]):

        RESULTS[2][arch][idx] = dict()
        
        # Initialize RBA weights - full pool
        l_E_pool = jnp.ones((collocs_pool.shape[0], 1))
    
        # Get starting collocation points & RBA weights
        sorted_indices = _get_colloc_indices(collocs_pool=collocs_pool, batch_size=batch_size, px=None, seed=seed)
        
        collocs = collocs_pool[sorted_indices]
        l_E = l_E_pool[sorted_indices]
        
        # Get opt_type
        opt_type = _get_adam(learning_rate=learning_rate, decay_steps=decay_steps, decay_rate=decay_rate, warmup_steps=warmup_steps)

        # Define model
        if arch == 'PirateNet':
            model = PoissonPirate(n_in = n_in, n_out = n_out, n_hidden = n_hidden, num_blocks = num_blocks,
                                  alpha = 0.0, ref = None, period_axes = period_axes, rff_std = 1.0,
                                  RWF={"mean": 1.0, "std": 0.1}, seed=seed+run)
        elif arch == 'cPIKAN':
            model = PoissonKAN(n_in = n_in, n_out = n_out, n_hidden = n_hidden, num_layers = num_layers, D = D,
                               init_scheme = init_scheme, period_axes = period_axes, rff_std = None,
                               seed = seed+run)

        if idx == 0:
            print(f"Initialized model with {count_params(model)} parameters.")

        # Set optimizer
        optimizer = nnx.Optimizer(model, opt_type)

        tick = time.time()
    
        # Start training
        for epoch in range(num_epochs):
        
            loss, l_E = train_step(model, optimizer, collocs, l_E)
        
            # Perform RAD
            if (epoch != 0) and (epoch % f_resample == 0):
    
                # Get new indices after resampling
                sorted_indices, l_E_pool = get_RAD_indices(model, collocs_pool, sorted_indices, l_E, l_E_pool)
                # Set new batch of collocs and l_E
                collocs = collocs_pool[sorted_indices]
                l_E = l_E_pool[sorted_indices]

        tack = time.time()
        final_error = model_eval(model, coords, refsol)

        print(f"\tRun = {idx}\t L^2 = {final_error:.2e}\t Loss = {loss:.2e}\t Time = {(tack-tick)/60:.2f} mins")

        RESULTS[2][arch][idx]['l2'] = np.asarray(device_get(final_error))
        RESULTS[2][arch][idx]['loss'] = np.asarray(device_get(loss))
        RESULTS[2][arch][idx]['time'] = (tack-tick)/num_epochs

### RGA Runs

In [None]:
num_blocks = arch_params['RGA KAN']['num_blocks']
n_hidden = arch_params['RGA KAN']['n_hidden']
depth = int(2*num_blocks)

print(f"Training RGA KAN with depth = {depth} and width = {n_hidden}.")

for alpha in [0.0, 1.0]:
    for beta in [0.0, 1.0]:

        RESULTS[2][(alpha,beta)] = dict()
        print(f"Training for alpha = {alpha} and beta = {beta}.")
    
        for idx, run in enumerate([0, 7, 42]):
    
            RESULTS[2][(alpha,beta)][idx] = dict()
            
            # Initialize RBA weights - full pool
            l_E_pool = jnp.ones((collocs_pool.shape[0], 1))
        
            # Get starting collocation points & RBA weights
            sorted_indices = _get_colloc_indices(collocs_pool=collocs_pool, batch_size=batch_size, px=None, seed=seed)
            
            collocs = collocs_pool[sorted_indices]
            l_E = l_E_pool[sorted_indices]
            
            # Get opt_type
            opt_type = _get_adam(learning_rate=learning_rate, decay_steps=decay_steps, decay_rate=decay_rate, warmup_steps=warmup_steps)
    
            # Define model
            model = PoissonModel(n_in = n_in, n_out = n_out, n_hidden = n_hidden, num_blocks = num_blocks, D = D,
                                 init_scheme = init_scheme, alpha = alpha, beta = beta, ref = None,
                                 period_axes = period_axes, rff_std = None, sine_D = sine_D, seed = seed+run)
    
            if idx == 0:
                print(f"Initialized model with {count_params(model)} parameters.")
    
            # Set optimizer
            optimizer = nnx.Optimizer(model, opt_type)
    
            tick = time.time()
        
            # Start training
            for epoch in range(num_epochs):
            
                loss, l_E = train_step(model, optimizer, collocs, l_E)
            
                # Perform RAD
                if (epoch != 0) and (epoch % f_resample == 0):
        
                    # Get new indices after resampling
                    sorted_indices, l_E_pool = get_RAD_indices(model, collocs_pool, sorted_indices, l_E, l_E_pool)
                    # Set new batch of collocs and l_E
                    collocs = collocs_pool[sorted_indices]
                    l_E = l_E_pool[sorted_indices]
    
            tack = time.time()
            final_output = model(coords).reshape(refsol.shape)
            final_error = model_eval(model, coords, refsol)
    
            print(f"\tRun = {idx}\t L^2 = {final_error:.2e}\t Loss = {loss:.2e}\t Time = {(tack-tick)/60:.2f} mins")
    
            RESULTS[2][(alpha,beta)][idx]['l2'] = np.asarray(device_get(final_error))
            RESULTS[2][(alpha,beta)][idx]['loss'] = np.asarray(device_get(loss))
            RESULTS[2][(alpha,beta)][idx]['time'] = (tack-tick)/num_epochs
            RESULTS[2][(alpha,beta)][idx]['output'] = np.asarray(device_get(final_output))

In [None]:
# Checkpoint 2
with open(result_file, "wb") as f:
    pickle.dump(RESULTS, f)

In [None]:
with open(result_file, "rb") as f:
    RESULTS = pickle.load(f)

this_res = RESULTS[2]

In [None]:
print("RESULTS")
print("------------------------------------------------------------------")
m, s = metric_stats(this_res, 'PirateNet', 'l2')
tt, _ = metric_stats(this_res, 'PirateNet', 'time')
print(f"PirateNet:\t\t L^2 = {m:.3e}\t Error = {s:.3e}\t Time = {tt*1000:.2f} ms.")
m, s = metric_stats(this_res, 'cPIKAN', 'l2')
tt, _ = metric_stats(this_res, 'cPIKAN', 'time')
print(f"cPIKAN:\t\t\t L^2 = {m:.3e}\t Error = {s:.3e}\t Time = {tt*1000:.2f} ms.")
m, s = metric_stats(this_res, (0,0), 'l2')
tt, _ = metric_stats(this_res, (0,0), 'time')
print(f"RGAKAN (α = 0, β = 0):\t L^2 = {m:.3e}\t Error = {s:.3e}\t Time = {tt*1000:.2f} ms.")
m, s = metric_stats(this_res, (1,0), 'l2')
tt, _ = metric_stats(this_res, (1,0), 'time')
print(f"RGAKAN (α = 1, β = 0):\t L^2 = {m:.3e}\t Error = {s:.3e}\t Time = {tt*1000:.2f} ms.")
m, s = metric_stats(this_res, (0,1), 'l2')
tt, _ = metric_stats(this_res, (0,1), 'time')
print(f"RGAKAN (α = 0, β = 1):\t L^2 = {m:.3e}\t Error = {s:.3e}\t Time = {tt*1000:.2f} ms.")
m, s = metric_stats(this_res, (1,1), 'l2')
tt, _ = metric_stats(this_res, (1,1), 'time')
print(f"RGAKAN (α = 1, β = 1):\t L^2 = {m:.3e}\t Error = {s:.3e}\t Time = {tt*1000:.2f} ms.")
print("------------------------------------------------------------------")
m, s = metric_stats(this_res, 'PirateNet', 'loss')
print(f"PirateNet:\t\t Loss = {m:.3e}\t Error = {s:.3e}")
m, s = metric_stats(this_res, 'cPIKAN', 'loss')
print(f"cPIKAN:\t\t\t Loss = {m:.3e}\t Error = {s:.3e}")
m, s = metric_stats(this_res, (0,0), 'loss')
print(f"RGAKAN (α = 0, β = 0):\t Loss = {m:.3e}\t Error = {s:.3e}")
m, s = metric_stats(this_res, (1,0), 'loss')
print(f"RGAKAN (α = 1, β = 0):\t Loss = {m:.3e}\t Error = {s:.3e}")
m, s = metric_stats(this_res, (0,1), 'loss')
print(f"RGAKAN (α = 0, β = 1):\t Loss = {m:.3e}\t Error = {s:.3e}")
m, s = metric_stats(this_res, (1,1), 'loss')
print(f"RGAKAN (α = 1, β = 1):\t Loss = {m:.3e}\t Error = {s:.3e}")

In [None]:
res_2 = pick_best_rgakan(this_res, metric='l2')
a, b = res_2['best_config']
print(f"The lowest average L^2 error is obtained for alpha = {a} and beta = {b}.")
print(f"Among the runs with this configuration, the lowest L^2 error is {res_2['best_run_value']:.2e}")

## $\omega = 4$

In [None]:
from src.equations import poisson_4_res as pde_res

In [None]:
# Collocation points and ICs
collocs_pool = _get_pde_collocs(ranges = [(-1,1), (-1,1)], sample_size = 400)

# Reference solution
ref = np.load('data/poisson_4.npz')
refsol = jnp.array(ref['usol'])

N_x, N_y = ref['usol'].shape
x, y = ref['x'].flatten(), ref['y'].flatten()
X, Y = jnp.meshgrid(x, y, indexing='ij')
coords = jnp.hstack([X.flatten()[:, None], Y.flatten()[:, None]])

In [None]:
# PDE Loss
def pde_loss(model, l_E, collocs):

    residuals = pde_res(model, collocs) # shape (batch_size, 1)

    # Get new RBA weights
    abs_res = jnp.abs(residuals)
    l_E_new = (RBA_gamma*l_E) + (RBA_eta*abs_res/jnp.max(abs_res)) # shape (batch_size, 1)

    # Multiply by RBA weights
    w_resids = l_E_new * residuals # shape (batch_size, 1)

    # Weighted loss
    weighted_loss = jnp.mean(w_resids**2)

    return weighted_loss, l_E_new


@nnx.jit
def train_step(model, optimizer, collocs, l_E):

    # PDE loss
    (loss, l_E_new), grads = nnx.value_and_grad(pde_loss, has_aux=True)(model, l_E, collocs)

    # Optimizer step
    optimizer.update(grads)

    return loss, l_E_new

### cPIKAN / Pirate Runs

In [None]:
for arch in ['PirateNet', 'cPIKAN']:

    RESULTS[4][arch] = dict()

    n_hidden = arch_params[arch]['n_hidden']

    if arch == 'PirateNet':
        num_blocks = arch_params[arch]['num_blocks']
        depth = int(3*num_blocks)
    elif arch == 'cPIKAN':
        num_layers = arch_params[arch]['num_layers']
        depth = num_layers
        
    print(f"Training {arch} with depth = {depth} and width = {n_hidden}.")

    for idx, run in enumerate([0, 7, 42]):

        RESULTS[4][arch][idx] = dict()
        
        # Initialize RBA weights - full pool
        l_E_pool = jnp.ones((collocs_pool.shape[0], 1))
    
        # Get starting collocation points & RBA weights
        sorted_indices = _get_colloc_indices(collocs_pool=collocs_pool, batch_size=batch_size, px=None, seed=seed)
        
        collocs = collocs_pool[sorted_indices]
        l_E = l_E_pool[sorted_indices]
        
        # Get opt_type
        opt_type = _get_adam(learning_rate=learning_rate, decay_steps=decay_steps, decay_rate=decay_rate, warmup_steps=warmup_steps)

        # Define model
        if arch == 'PirateNet':
            model = PoissonPirate(n_in = n_in, n_out = n_out, n_hidden = n_hidden, num_blocks = num_blocks,
                                  alpha = 0.0, ref = None, period_axes = period_axes, rff_std = 1.0,
                                  RWF={"mean": 1.0, "std": 0.1}, seed=seed+run)
        elif arch == 'cPIKAN':
            model = PoissonKAN(n_in = n_in, n_out = n_out, n_hidden = n_hidden, num_layers = num_layers, D = D,
                               init_scheme = init_scheme, period_axes = period_axes, rff_std = None,
                               seed = seed+run)

        if idx == 0:
            print(f"Initialized model with {count_params(model)} parameters.")

        # Set optimizer
        optimizer = nnx.Optimizer(model, opt_type)

        tick = time.time()
    
        # Start training
        for epoch in range(num_epochs):
        
            loss, l_E = train_step(model, optimizer, collocs, l_E)
        
            # Perform RAD
            if (epoch != 0) and (epoch % f_resample == 0):
    
                # Get new indices after resampling
                sorted_indices, l_E_pool = get_RAD_indices(model, collocs_pool, sorted_indices, l_E, l_E_pool)
                # Set new batch of collocs and l_E
                collocs = collocs_pool[sorted_indices]
                l_E = l_E_pool[sorted_indices]

        tack = time.time()
        final_error = model_eval(model, coords, refsol)

        print(f"\tRun = {idx}\t L^2 = {final_error:.2e}\t Loss = {loss:.2e}\t Time = {(tack-tick)/60:.2f} mins")

        RESULTS[4][arch][idx]['l2'] = np.asarray(device_get(final_error))
        RESULTS[4][arch][idx]['loss'] = np.asarray(device_get(loss))
        RESULTS[4][arch][idx]['time'] = (tack-tick)/num_epochs

### RGA Runs

In [None]:
num_blocks = arch_params['RGA KAN']['num_blocks']
n_hidden = arch_params['RGA KAN']['n_hidden']
depth = int(2*num_blocks)

print(f"Training RGA KAN with depth = {depth} and width = {n_hidden}.")

for alpha in [0.0, 1.0]:
    for beta in [0.0, 1.0]:

        RESULTS[4][(alpha,beta)] = dict()
        print(f"Training for alpha = {alpha} and beta = {beta}.")
    
        for idx, run in enumerate([0, 7, 42]):
    
            RESULTS[4][(alpha,beta)][idx] = dict()
            
            # Initialize RBA weights - full pool
            l_E_pool = jnp.ones((collocs_pool.shape[0], 1))
        
            # Get starting collocation points & RBA weights
            sorted_indices = _get_colloc_indices(collocs_pool=collocs_pool, batch_size=batch_size, px=None, seed=seed)
            
            collocs = collocs_pool[sorted_indices]
            l_E = l_E_pool[sorted_indices]
            
            # Get opt_type
            opt_type = _get_adam(learning_rate=learning_rate, decay_steps=decay_steps, decay_rate=decay_rate, warmup_steps=warmup_steps)
    
            # Define model
            model = PoissonModel(n_in = n_in, n_out = n_out, n_hidden = n_hidden, num_blocks = num_blocks, D = D,
                                 init_scheme = init_scheme, alpha = alpha, beta = beta, ref = None,
                                 period_axes = period_axes, rff_std = None, sine_D = sine_D, seed = seed+run)
    
            if idx == 0:
                print(f"Initialized model with {count_params(model)} parameters.")
    
            # Set optimizer
            optimizer = nnx.Optimizer(model, opt_type)
    
            tick = time.time()
        
            # Start training
            for epoch in range(num_epochs):
            
                loss, l_E = train_step(model, optimizer, collocs, l_E)
            
                # Perform RAD
                if (epoch != 0) and (epoch % f_resample == 0):
        
                    # Get new indices after resampling
                    sorted_indices, l_E_pool = get_RAD_indices(model, collocs_pool, sorted_indices, l_E, l_E_pool)
                    # Set new batch of collocs and l_E
                    collocs = collocs_pool[sorted_indices]
                    l_E = l_E_pool[sorted_indices]
    
            tack = time.time()
            final_output = model(coords).reshape(refsol.shape)
            final_error = model_eval(model, coords, refsol)
    
            print(f"\tRun = {idx}\t L^2 = {final_error:.2e}\t Loss = {loss:.2e}\t Time = {(tack-tick)/60:.2f} mins")
    
            RESULTS[4][(alpha,beta)][idx]['l2'] = np.asarray(device_get(final_error))
            RESULTS[4][(alpha,beta)][idx]['loss'] = np.asarray(device_get(loss))
            RESULTS[4][(alpha,beta)][idx]['time'] = (tack-tick)/num_epochs
            RESULTS[4][(alpha,beta)][idx]['output'] = np.asarray(device_get(final_output))

In [None]:
# Checkpoint 3
with open(result_file, "wb") as f:
    pickle.dump(RESULTS, f)

In [None]:
with open(result_file, "rb") as f:
    RESULTS = pickle.load(f)

this_res = RESULTS[4]

In [None]:
print("RESULTS")
print("------------------------------------------------------------------")
m, s = metric_stats(this_res, 'PirateNet', 'l2')
tt, _ = metric_stats(this_res, 'PirateNet', 'time')
print(f"PirateNet:\t\t L^2 = {m:.3e}\t Error = {s:.3e}\t Time = {tt*1000:.2f} ms.")
m, s = metric_stats(this_res, 'cPIKAN', 'l2')
tt, _ = metric_stats(this_res, 'cPIKAN', 'time')
print(f"cPIKAN:\t\t\t L^2 = {m:.3e}\t Error = {s:.3e}\t Time = {tt*1000:.2f} ms.")
m, s = metric_stats(this_res, (0,0), 'l2')
tt, _ = metric_stats(this_res, (0,0), 'time')
print(f"RGAKAN (α = 0, β = 0):\t L^2 = {m:.3e}\t Error = {s:.3e}\t Time = {tt*1000:.2f} ms.")
m, s = metric_stats(this_res, (1,0), 'l2')
tt, _ = metric_stats(this_res, (1,0), 'time')
print(f"RGAKAN (α = 1, β = 0):\t L^2 = {m:.3e}\t Error = {s:.3e}\t Time = {tt*1000:.2f} ms.")
m, s = metric_stats(this_res, (0,1), 'l2')
tt, _ = metric_stats(this_res, (0,1), 'time')
print(f"RGAKAN (α = 0, β = 1):\t L^2 = {m:.3e}\t Error = {s:.3e}\t Time = {tt*1000:.2f} ms.")
m, s = metric_stats(this_res, (1,1), 'l2')
tt, _ = metric_stats(this_res, (1,1), 'time')
print(f"RGAKAN (α = 1, β = 1):\t L^2 = {m:.3e}\t Error = {s:.3e}\t Time = {tt*1000:.2f} ms.")
print("------------------------------------------------------------------")
m, s = metric_stats(this_res, 'PirateNet', 'loss')
print(f"PirateNet:\t\t Loss = {m:.3e}\t Error = {s:.3e}")
m, s = metric_stats(this_res, 'cPIKAN', 'loss')
print(f"cPIKAN:\t\t\t Loss = {m:.3e}\t Error = {s:.3e}")
m, s = metric_stats(this_res, (0,0), 'loss')
print(f"RGAKAN (α = 0, β = 0):\t Loss = {m:.3e}\t Error = {s:.3e}")
m, s = metric_stats(this_res, (1,0), 'loss')
print(f"RGAKAN (α = 1, β = 0):\t Loss = {m:.3e}\t Error = {s:.3e}")
m, s = metric_stats(this_res, (0,1), 'loss')
print(f"RGAKAN (α = 0, β = 1):\t Loss = {m:.3e}\t Error = {s:.3e}")
m, s = metric_stats(this_res, (1,1), 'loss')
print(f"RGAKAN (α = 1, β = 1):\t Loss = {m:.3e}\t Error = {s:.3e}")

In [None]:
res_4 = pick_best_rgakan(this_res, metric='l2')
a, b = res_4['best_config']
print(f"The lowest average L^2 error is obtained for alpha = {a} and beta = {b}.")
print(f"Among the runs with this configuration, the lowest L^2 error is {res_4['best_run_value']:.2e}")

## Final Plot

In [None]:
ref = np.load('data/poisson_1.npz')
refsol1 = jnp.array(ref['usol'])

ref = np.load('data/poisson_2.npz')
refsol2 = jnp.array(ref['usol'])

ref = np.load('data/poisson_4.npz')
refsol4 = jnp.array(ref['usol'])

pred1 = res_1['best_run_output']
pred2 = res_2['best_run_output']
pred4 = res_4['best_run_output']

In [None]:
import matplotlib.pyplot as plt
import seaborn as sns
from mpl_toolkits.axes_grid1 import make_axes_locatable
from matplotlib import ticker

LABEL_FS = 16
TITLE_FS = 18
TICK_FS = 14
CBAR_FS = 16

def plot_pred_ref_diff_grid(u_preds, refsols, x, t, save_fig=False,
                            clim_pred=None, clim_ref=None, clim_diff=None,
                            row_titles=None, fig_name="final_poisson.pdf"):

    # normalize inputs
    u_preds  = [np.array(u) for u in u_preds]
    refsols  = [np.array(r) for r in refsols]
    assert len(u_preds) == len(refsols), "u_preds and refsols must have same length"
    nrows = len(u_preds)

    extent = [np.min(x), np.max(x), np.min(t), np.max(t)]
    cmap = sns.color_palette("Spectral", as_cmap=True)

    def _clim_for(clim, i):
        if clim is None:
            return None
        # per-row list/tuple of tuples
        if isinstance(clim, (list, tuple)) and len(clim) == nrows and isinstance(clim[0], (list, tuple)):
            return clim[i]
        # single tuple applied to all rows
        if isinstance(clim, (list, tuple)) and len(clim) == 2 and np.isscalar(clim[0]) and np.isscalar(clim[1]):
            return clim
        return None

    fig, axs = plt.subplots(nrows, 3, figsize=(14, 2.5*nrows), constrained_layout=False)
    if nrows == 1:
        axs = np.expand_dims(axs, axis=0)  # unify indexing

    col_titles = ["Reference", "Prediction", "Absolute Error"]
    cbar_labels = [r"$u_{\mathrm{ref}}$", r"$u_{\mathrm{pred}}$", r"$|u_{\mathrm{pred}}-u_{\mathrm{ref}}|$"]

    for i in range(nrows):
        ref = refsols[i]
        pred = u_preds[i]
        diff = np.abs(pred - ref)

        panels = [ref.T, pred.T, diff.T]
        for j in range(3):
            ax = axs[i, j]
            img = ax.imshow(panels[j], origin='lower', aspect='auto', extent=extent, cmap=cmap)

            # column titles only on top row
            if i == 0:
                ax.set_title(col_titles[j], fontsize=TITLE_FS)

            if j == 2 and row_titles is not None:
                ax.annotate(
                    row_titles[i],
                    xy=(1.70, 0.5), xycoords='axes fraction',   # just to the right of the panel
                    ha='left', va='center',
                    rotation=90,
                    fontsize=TITLE_FS
                )

            if i == nrows - 1:
                ax.set_xlabel(r"$x$", fontsize=LABEL_FS)
            else:
                ax.set_xlabel("")
            if j == 0:
                ax.set_ylabel(r"$y$", fontsize=LABEL_FS)
            else:
                ax.set_ylabel("")
    
            ax.tick_params(axis='both', which='major', labelsize=TICK_FS)

            # apply clims
            if j == 0:
                clim = _clim_for(clim_ref, i)
            elif j == 1:
                clim = _clim_for(clim_pred, i)
            else:
                clim = _clim_for(clim_diff, i)
            if clim is not None:
                img.set_clim(*clim)

            # colorbar below each subplot
            divider = make_axes_locatable(ax)
            cax = divider.append_axes("right", size="5%", pad=0.15)  # wider cbar
            cbar = fig.colorbar(img, cax=cax, orientation='vertical')
            
            # formatter only for 3rd column; fixed ticks for others
            if j == 2:
                cbar.formatter = ticker.ScalarFormatter(useMathText=True)
                cbar.formatter.set_scientific(True)
                cbar.formatter.set_powerlimits((0, 0))
                cbar.formatter.set_useOffset(False)
                cbar.update_ticks()
                offset_text = cbar.ax.yaxis.get_offset_text()
                offset_text.set_fontsize(TICK_FS)
                offset_text.set_x(7.5)

                cbar.ax.tick_params(labelsize=TICK_FS)
                cbar.set_label(cbar_labels[j], fontsize=LABEL_FS, labelpad=10)
            else:
                cbar.set_ticks([-1, 0, 1])
                #cbar.set_ticks([-1, -0.5, 0, 0.5, 1])
            
                cbar.ax.tick_params(labelsize=TICK_FS)
                cbar.set_label(cbar_labels[j], fontsize=LABEL_FS, labelpad=6)


    plt.subplots_adjust(left=0.12, wspace=0.65, hspace=0.3, bottom=0.0)

    if save_fig:
        plt.savefig(f"{plots_dir}/final_poisson.pdf", format="pdf", bbox_inches="tight")

    plt.show()


In [None]:
plot_pred_ref_diff_grid(
    u_preds=[pred1, pred2, pred4],
    refsols=[refsol1, refsol2, refsol4],
    x=x, t=y,
    row_titles=[r"$\omega=1$", r"$\omega=2$", r"$\omega=4$"],
    clim_ref=None, clim_pred=None, clim_diff=None,
    save_fig=True
)
