# Experiment 4: Deep PDE Solving with RGA KAN

## Allen-Cahn Equation

In [None]:
import jax
import jax.numpy as jnp

jax.config.update("jax_default_matmul_precision", "highest")

from src.equations import ac_res as pde_res

from src.utils import _get_adam, _get_pde_collocs, _get_ic_collocs, model_eval, count_params, _get_colloc_indices, grad_norm
from src.rgakan import RGAKAN

import numpy as np

import optax
from flax import nnx

import os

# Create the directory if it doesn't exist
results_dir = "results"
os.makedirs(results_dir, exist_ok=True)

plots_dir = "plots"
os.makedirs(plots_dir, exist_ok=True)

# Define the experiment
experiment_name = "rgakan_prelim"
results_file = os.path.join(results_dir, f"{experiment_name}.csv")

# Define the file header
header = "pde, width, depth, run, l2"

# Check if the file exists and write the header if it doesn't
if not os.path.exists(results_file):
    with open(results_file, "w") as file:
        file.write(header + "\n")
        
seed = 42

### Basic Functions

In [None]:
# PDE Loss
def pde_loss(model, l_E, collocs):

    residuals = pde_res(model, collocs) # shape (batch_size, 1)

    # Get new RBA weights
    abs_res = jnp.abs(residuals)
    l_E_new = (RBA_gamma*l_E) + (RBA_eta*abs_res/jnp.max(abs_res)) # shape (batch_size, 1)

    # Multiply by RBA weights
    w_resids = l_E_new * residuals # shape (batch_size, 1)

    # Reshape residuals for causal training
    residuals = w_resids.reshape(num_chunks, -1) # shape (num_chunks, points)

    # Get average loss per chunk
    loss = jnp.mean(residuals**2, axis=1)

    # Get causal weights
    weights = jax.lax.stop_gradient(jnp.exp(-causal_tol * (M @ loss)))

    # Weighted loss
    weighted_loss = jnp.mean(weights * loss)

    return weighted_loss, l_E_new


# IC Loss
def ic_loss(model, l_I, ic_collocs, ic_data):

    # Residual
    ic_res = model(ic_collocs) - ic_data

    # Get new RBA weights
    abs_res = jnp.abs(ic_res)
    l_I_new = (RBA_gamma*l_I) + (RBA_eta*abs_res/jnp.max(abs_res))

    # Multiply by RBA weights
    w_resids = l_I_new * ic_res

    # Loss
    loss = jnp.mean(w_resids**2)

    return loss, l_I_new


@nnx.jit
def train_step(model, optimizer, collocs, ic_collocs, ic_data, λ_E, λ_I, l_E, l_I):

    # PDE loss
    (loss_E, l_E_new), grads_E = nnx.value_and_grad(pde_loss, has_aux=True)(model, l_E, collocs)

    # IC loss
    (loss_I, l_I_new), grads_I = nnx.value_and_grad(ic_loss, has_aux=True)(model, l_I, ic_collocs, ic_data)
    
    # Compute total loss
    loss = λ_E*loss_E + λ_I*loss_I

    # Compute total gradients
    grads = jax.tree_util.tree_map(lambda g1, g2: λ_E * g1 + λ_I * g2, grads_E, grads_I)

    # Optimizer step
    optimizer.update(grads)

    return loss, grads_E, grads_I, l_E_new, l_I_new


@nnx.jit
def get_RAD_indices(model, collocs_pool, old_indices, l_E, l_E_pool):

    # Apply updates from old indices to pool
    updated_pool = l_E_pool.at[old_indices].set(l_E)

    # Get full residuals
    resids = pde_res(model, collocs_pool)
    
    # Multiply by RBA weights
    w_resids = updated_pool * resids
    
    # Get absolute
    wa_resids = jnp.abs(w_resids)

    # Raise to power rad_a
    ea = jnp.power(wa_resids, rad_a)
    
    # Divide by mean and add rad_c
    px = (ea/jnp.mean(ea)) + rad_c
    
    # Normalize
    px_norm = (px / jnp.sum(px))[:,0]

    sorted_indices = _get_colloc_indices(collocs_pool=collocs_pool, batch_size=batch_size, px=px_norm, seed=seed)

    return sorted_indices, updated_pool

### Data & Grid-Search Parameters

In [None]:
# Collocation points and ICs
collocs_pool = _get_pde_collocs(ranges = [(0,1), (-1,1)], sample_size = 400)
ic_collocs = _get_ic_collocs(x_range = (-1, 1), sample_size = 2**6)
ic_data = ((ic_collocs[:,1]**2)*jnp.cos(jnp.pi*ic_collocs[:,1])).reshape(-1,1)

# Reference solution
ref = np.load('data/allen_cahn.npz')
refsol = jnp.array(ref['usol'])

N_t, N_x = ref['usol'].shape
t, x = ref['t'].flatten(), ref['x'].flatten()
T, X = jnp.meshgrid(t, x, indexing='ij')
coords = jnp.hstack([T.flatten()[:, None], X.flatten()[:, None]])

In [None]:
# Training epochs
num_epochs = 100_000

# Scheduler configurations
learning_rate = 1e-3
decay_steps = 2000
decay_rate = 0.9
warmup_steps = 1000

# Define causal training parameters
causal_tol = 1.0
num_chunks = 32
M = jnp.triu(jnp.ones((num_chunks, num_chunks)), k=1).T

# Define Grad Norm parameters
grad_mixing = 0.9
f_grad_norm = 1000

# Define resampling parameters
batch_size = 2**12
f_resample = 2000
rad_a = 1.0
rad_c = 1.0

# Define RBA parameters
RBA_gamma = 0.999
RBA_eta = 0.01

# Define model parameters
n_in = collocs_pool.shape[1]
n_out = 1
D = 5
period_axes = {1: jnp.pi}
rff_std = None #2.0
sine_D = 5 #None
alpha = 0.0
beta = 0.0
init_scheme = {'type': 'glorot', 'gain': None, 'norm_pow': 0, 'distribution': 'uniform', 'sample_size': 10000}

In [None]:
# Architecture settings
widths = [8, 16, 32]
block_depths = [1, 2, 3, 4, 5, 6]

### Grid Search

In [None]:
 # Grid search over depth size
for num_blocks in block_depths:

    # Grid search over width size
    for n_hidden in widths:
        
        print(f"Training RGA-KAN with depth = {int(2*num_blocks)} and width = {n_hidden}.")

        for run in [3, 5, 6, 7, 13]:
            # Initialize RBA weights - full pool
            l_E_pool = jnp.ones((collocs_pool.shape[0], 1))
            # Also get RBAs for ICs
            l_I = jnp.ones((ic_collocs.shape[0], 1))
        
            # Get starting collocation points & RBA weights
            sorted_indices = _get_colloc_indices(collocs_pool=collocs_pool, batch_size=batch_size, px=None, seed=seed)
            
            collocs = collocs_pool[sorted_indices]
            l_E = l_E_pool[sorted_indices]
            
            # Get opt_type
            opt_type = _get_adam(learning_rate=learning_rate, decay_steps=decay_steps, decay_rate=decay_rate, warmup_steps=warmup_steps)

            # Define model
            model = RGAKAN(n_in=n_in, n_out=n_out, n_hidden=n_hidden, num_blocks=num_blocks, D=D,
                           init_scheme=init_scheme, alpha=alpha, beta=beta, ref=ref, period_axes=period_axes, rff_std=rff_std,
                           sine_D=sine_D, seed=seed+run)

            print(f"Initialized model with {count_params(model)} parameters.")
            
            # Define global loss weights
            λ_E = jnp.array(1.0, dtype=float)
            λ_I = jnp.array(1.0, dtype=float)

            # Set optimizer
            optimizer = nnx.Optimizer(model, opt_type)
        
            # Start training
            for epoch in range(num_epochs):
            
                loss, grads_E, grads_I, l_E, l_I = train_step(model, optimizer, collocs, ic_collocs, ic_data, λ_E, λ_I, l_E, l_I)
            
                # Perform grad norm
                if (epoch != 0) and (epoch % f_grad_norm == 0):
            
                    λ_Ε, λ_I = grad_norm(grads_E, grads_I, λ_E, λ_I, grad_mixing)
            
                # Perform RAD
                if (epoch != 0) and (epoch % f_resample == 0):
        
                    # Get new indices after resampling
                    sorted_indices, l_E_pool = get_RAD_indices(model, collocs_pool, sorted_indices, l_E, l_E_pool)
                    # Set new batch of collocs and l_E
                    collocs = collocs_pool[sorted_indices]
                    l_E = l_E_pool[sorted_indices]

                #if epoch % 2500 == 0:
                #    l2error = model_eval(model, coords, refsol)
                #    print(f"\t\tEpoch {epoch}: Rel. L2 Error: {l2error:.2e}")
                    
        
            # final evaluation
            l2error = model_eval(model, coords, refsol)
        
            # Log results
            new_row = f"ac, {n_hidden}, {int(2*num_blocks)}, {run}, {l2error}"
                            
            # Append the row to the file
            with open(results_file, "a") as rfile:
                rfile.write(new_row + "\n")

            print(f"\t{run}. Final Rel. L2 Error: {l2error:.2e}")

## Plots

In [None]:
import warnings
import pandas as pd
warnings.filterwarnings("ignore", category=pd.errors.ParserWarning)

In [None]:
import numpy as np
import matplotlib.pyplot as plt
from matplotlib.ticker import LogLocator, LogFormatterMathtext, NullLocator, NullFormatter
import seaborn as sns

# ===== configurable knobs =====
TITLE_FS   = 16
LABEL_FS   = 16
TICK_FS    = 14
LEGEND_FS  = 16
LINE_W     = 2.0
ALPHA_FILL = 0.25
WIDTHS     = [8, 16, 32] 
MODEL_LABEL = {"cpikan": "cPIKAN", "rgakan": "RGA KAN"}

# color choice (2 distinct tones from a seaborn colormap)
_cmap = sns.color_palette("Spectral", as_cmap=True)
cmap_points = np.linspace(0, 1, 12)

color_indices = [1, -2]
_colors = [_cmap(cmap_points[i]) for i in color_indices]

def load_and_merge(csv_rgakan: str, csv_compar: str) -> pd.DataFrame:
    # RGA KAN
    df_rg = pd.read_csv(csv_rgakan, sep=', ')
    df_rg.columns = [c.strip() for c in df_rg.columns]
    df_rg["model"] = "rgakan"

    # cPIKAN (keep only glorot)
    df_cp = pd.read_csv(csv_compar, sep=', ')
    df_cp.columns = [c.strip() for c in df_cp.columns]  # tidy potential spaces
    df_cp = df_cp[df_cp["init"].str.strip() == "glorot"].copy()
    df_cp = df_cp[df_cp["pde"].str.strip() == "ac"].copy()
    df_cp["model"] = "cpikan"

    # numeric hygiene
    for col in ["width", "depth", "run"]:
        df_rg[col] = pd.to_numeric(df_rg[col], errors="coerce")
        df_cp[col] = pd.to_numeric(df_cp[col], errors="coerce")

    return pd.concat([df_rg, df_cp], ignore_index=True, sort=False)

def _set_log_ticks(ax):
    ax.set_yscale("log")
    ax.yaxis.set_major_locator(LogLocator(base=10.0))
    ax.yaxis.set_major_formatter(LogFormatterMathtext(base=10.0))
    ax.yaxis.set_minor_locator(NullLocator())
    ax.yaxis.set_minor_formatter(NullFormatter())
    ax.tick_params(axis="both", which="both", labelsize=TICK_FS)

def _plot_mean_sem(ax, x, y_mean, y_sem, label, color):
    line, = ax.plot(x, y_mean, label=label, linewidth=LINE_W, color=color, marker="o")
    ax.fill_between(x, y_mean - y_sem, y_mean + y_sem, alpha=ALPHA_FILL, color=color, linewidth=0)
    return line

def _prep_stats(df):
    # expect columns: pde, model, width, depth, run, l2
    g = df.groupby(["model", "width", "depth"])["l2"].agg(["mean", "sem"]).reset_index()
    g = g.sort_values(["width", "model", "depth"]).reset_index(drop=True)
    return g

def plot_pde_grid(df):

    stats = _prep_stats(df)

    # ensure presence/ordering
    widths = [w for w in WIDTHS if w in stats["width"].unique()]
    models = [m for m in ["cpikan", "rgakan"] if m in stats["model"].unique()]  # NEW

    nrows, ncols = 1, len(widths)
    if nrows == 0 or ncols == 0 or not models:
        raise ValueError("No matching widths/models to plot after filtering.")

    fig, axes = plt.subplots(
        nrows, ncols, figsize=(12, 3),
        sharex=True, sharey=False, constrained_layout=True
    )

    fig.set_constrained_layout_pads(w_pad=0.1, h_pad=0.1, hspace=0.06, wspace=0.06)
    
    # normalize axes to 2D array
    if nrows == 1 and ncols == 1:
        axes = np.array([[axes]])
    elif nrows == 1:
        axes = axes[np.newaxis, :]
    elif ncols == 1:
        axes = axes[:, np.newaxis]

    legend_handles = None

    for c, width in enumerate(widths):
        ax = axes[0, c]
        _set_log_ticks(ax)

        # depths available for this (pde, width); keep numeric ascending order
        depths_here = (
            stats[(stats.width == width)]["depth"]
            .drop_duplicates().sort_values().to_numpy()
        )

        # plot each model
        init_lines = []
        for j, model in enumerate(models):  # NEW
            sub = stats[(stats.width == width) & (stats.model == model)]  # NEW
            if sub.empty:
                continue
            sub = sub.set_index("depth").reindex(depths_here)  # align x
            if sub["mean"].isna().all():
                continue
            h = _plot_mean_sem(
                ax, depths_here, sub["mean"].to_numpy(), sub["sem"].fillna(0).to_numpy(),
                label=MODEL_LABEL.get(model, model), color=_colors[j % len(_colors)]
            )
            init_lines.append(h)

        # titles on top row, row labels on leftmost col
        ax.set_title(f"Width = {width}", fontsize=TITLE_FS)
        if c == 0:
            ax.set_ylabel(r"Relative $L^2$ Error", fontsize=LABEL_FS)

        # x label only on bottom row
        ax.set_xlabel("Depth", fontsize=LABEL_FS)

        # fix depth ticks
        ax.set_xticks([2, 4, 6, 8, 10, 12])
        ax.set_xticklabels([2, 4, 6, 8, 10, 12])

        ax.grid(True, which="both", linestyle="--", alpha=0.5)

        # left annotate PDE name per row (nice, unobtrusive)
        if c == ncols - 1:
            ax.annotate(
                "Allen-Cahn",
                xy=(1.08, 0.5), xycoords="axes fraction",
                va="center", ha="left", rotation=90, fontsize=LABEL_FS
            )

        if legend_handles is None and init_lines:
            legend_handles = init_lines

    if legend_handles:
        fig.legend(
            legend_handles, [h.get_label() for h in legend_handles],
            loc="lower center", ncol=len(legend_handles), frameon=False, fontsize=LEGEND_FS,
            bbox_to_anchor=(0.5, -0.18)
        )

    plt.savefig(f"{plots_dir}/rgakan_prelim.pdf", format="pdf", bbox_inches="tight")
    plt.show()
    

In [None]:
df = load_and_merge("results/rgakan_prelim.csv", "results/comparative_pde.csv")
plot_pde_grid(df)