# Experiment 5.2: Supplementary Plots for IB Phases

In [None]:
import jax
import jax.numpy as jnp

jax.config.update("jax_default_matmul_precision", "highest")

from src.equations import ac_res as pde_res
from jaxkan.utils.PIKAN import gradf

from src.utils import _get_adam, _get_pde_collocs, _get_ic_collocs, model_eval, count_params, _get_colloc_indices, grad_norm
from src.kan import KAN
from src.rgakan import RGAKAN

import numpy as np
from jax import device_get

import optax
from flax import nnx

import os

# Create the directory if it doesn't exist
results_dir = "results"
os.makedirs(results_dir, exist_ok=True)

plots_dir = "plots"
os.makedirs(plots_dir, exist_ok=True)
        
seed = 42

## Basic Functions

In [None]:
# PDE Loss
def pde_loss(model, l_E, collocs):

    residuals = pde_res(model, collocs) # shape (batch_size, 1)

    # Get new RBA weights
    abs_res = jnp.abs(residuals)
    l_E_new = (RBA_gamma*l_E) + (RBA_eta*abs_res/jnp.max(abs_res)) # shape (batch_size, 1)

    # Multiply by RBA weights
    w_resids = l_E_new * residuals # shape (batch_size, 1)

    # Reshape residuals for causal training
    residuals = w_resids.reshape(num_chunks, -1) # shape (num_chunks, points)

    # Get average loss per chunk
    loss = jnp.mean(residuals**2, axis=1)

    # Get causal weights
    weights = jax.lax.stop_gradient(jnp.exp(-causal_tol * (M @ loss)))

    # Weighted loss
    weighted_loss = jnp.mean(weights * loss)

    return weighted_loss, l_E_new


# IC Loss
def ic_loss(model, l_I, ic_collocs, ic_data):

    # Residual
    ic_res = model(ic_collocs) - ic_data

    # Get new RBA weights
    abs_res = jnp.abs(ic_res)
    l_I_new = (RBA_gamma*l_I) + (RBA_eta*abs_res/jnp.max(abs_res))

    # Multiply by RBA weights
    w_resids = l_I_new * ic_res

    # Loss
    loss = jnp.mean(w_resids**2)

    return loss, l_I_new


@nnx.jit
def train_step(model, optimizer, collocs, ic_collocs, ic_data, λ_E, λ_I, l_E, l_I):

    # PDE loss
    (loss_E, l_E_new), grads_E = nnx.value_and_grad(pde_loss, has_aux=True)(model, l_E, collocs)

    # IC loss
    (loss_I, l_I_new), grads_I = nnx.value_and_grad(ic_loss, has_aux=True)(model, l_I, ic_collocs, ic_data)
    
    # Compute total loss
    loss = λ_E*loss_E + λ_I*loss_I

    # Compute total gradients
    grads = jax.tree_util.tree_map(lambda g1, g2: λ_E * g1 + λ_I * g2, grads_E, grads_I)

    # Optimizer step
    optimizer.update(grads)

    return loss, grads_E, grads_I, l_E_new, l_I_new


@nnx.jit
def get_RAD_indices(model, collocs_pool, old_indices, l_E, l_E_pool):

    # Apply updates from old indices to pool
    updated_pool = l_E_pool.at[old_indices].set(l_E)

    # Get full residuals
    resids = pde_res(model, collocs_pool)
    
    # Multiply by RBA weights
    w_resids = updated_pool * resids
    
    # Get absolute
    wa_resids = jnp.abs(w_resids)

    # Raise to power rad_a
    ea = jnp.power(wa_resids, rad_a)
    
    # Divide by mean and add rad_c
    px = (ea/jnp.mean(ea)) + rad_c
    
    # Normalize
    px_norm = (px / jnp.sum(px))[:,0]

    sorted_indices = _get_colloc_indices(collocs_pool=collocs_pool, batch_size=batch_size, px=px_norm, seed=seed)

    return sorted_indices, updated_pool

### Data & Grid-Search Parameters

In [None]:
# Collocation points and ICs
collocs_pool = _get_pde_collocs(ranges = [(0,1), (-1,1)], sample_size = 400)
ic_collocs = _get_ic_collocs(x_range = (-1, 1), sample_size = 2**6)
ic_data = ((ic_collocs[:,1]**2)*jnp.cos(jnp.pi*ic_collocs[:,1])).reshape(-1,1)

# Reference solution
ref = np.load('data/allen_cahn.npz')
refsol = jnp.array(ref['usol'])

N_t, N_x = ref['usol'].shape
t, x = ref['t'].flatten(), ref['x'].flatten()
T, X = jnp.meshgrid(t, x, indexing='ij')
coords = jnp.hstack([T.flatten()[:, None], X.flatten()[:, None]])

In [None]:
# Training epochs
num_epochs = 100_000

# Scheduler configurations
learning_rate = 1e-3
decay_steps = 2000
decay_rate = 0.9
warmup_steps = 1000

# Define causal training parameters
causal_tol = 1.0
num_chunks = 32
M = jnp.triu(jnp.ones((num_chunks, num_chunks)), k=1).T

# Define Grad Norm parameters
grad_mixing = 0.9
f_grad_norm = 1000

# Define resampling parameters
batch_size = 2**12
f_resample = 2000
rad_a = 1.0
rad_c = 1.0

# Define RBA parameters
RBA_gamma = 0.999
RBA_eta = 0.01

# Define model parameters
n_in = collocs_pool.shape[1]
n_out = 1
D = 5
period_axes = {1: jnp.pi}
rff_std = None #2.0
sine_D = 5 #None
alpha = 0.0
beta = 0.0
init_scheme = {'type': 'glorot', 'gain': None, 'norm_pow': 0, 'distribution': 'uniform', 'sample_size': 10000}

n_hidden = 16
run = 7
num_blocks = 6

In [None]:
# Architecture settings
archs = ["RGA", "cPIKAN"]

### Runs

In [None]:
RESULTS = dict()

print(f"Training models with depth = {int(2*num_blocks)} and width = {n_hidden}.")

# Grid search over architecture types
for arch in archs:
    RESULTS[arch] = dict()

    RESULTS[arch]['outputs'] = dict()

    # Initialize RBA weights - full pool
    l_E_pool = jnp.ones((collocs_pool.shape[0], 1))
    # Also get RBAs for ICs
    l_I = jnp.ones((ic_collocs.shape[0], 1))

    # Get starting collocation points & RBA weights
    sorted_indices = _get_colloc_indices(collocs_pool=collocs_pool, batch_size=batch_size, px=None, seed=seed)
    
    collocs = collocs_pool[sorted_indices]
    l_E = l_E_pool[sorted_indices]
    
    # Get opt_type
    opt_type = _get_adam(learning_rate=learning_rate, decay_steps=decay_steps, decay_rate=decay_rate, warmup_steps=warmup_steps)

    # Define model
    if arch == "RGA":
        model = RGAKAN(n_in=n_in, n_out=n_out, n_hidden=n_hidden, num_blocks=num_blocks, D=D,
                       init_scheme=init_scheme, alpha=alpha, beta=beta, ref=ref, period_axes=period_axes, rff_std=rff_std,
                       sine_D=sine_D, seed=seed+run)
    elif arch == "cPIKAN":
        model = KAN(n_in=n_in, n_out=n_out, n_hidden=n_hidden, num_layers=int(2*num_blocks), D=D,
                    init_scheme=init_scheme, period_axes=period_axes, rff_std=rff_std, seed=seed+run)

    print(f"Initialized {arch} model with {count_params(model)} parameters.")
    
    # Define global loss weights
    λ_E = jnp.array(1.0, dtype=float)
    λ_I = jnp.array(1.0, dtype=float)

    # Set optimizer
    optimizer = nnx.Optimizer(model, opt_type)

    # Start training
    for epoch in range(num_epochs):
    
        loss, grads_E, grads_I, l_E, l_I = train_step(model, optimizer, collocs, ic_collocs, ic_data, λ_E, λ_I, l_E, l_I)
            
        # Perform grad norm
        if (epoch != 0) and (epoch % f_grad_norm == 0):
    
            λ_Ε, λ_I = grad_norm(grads_E, grads_I, λ_E, λ_I, grad_mixing)
    
        # Perform RAD
        if (epoch != 0) and (epoch % f_resample == 0):

            # Get new indices after resampling
            sorted_indices, l_E_pool = get_RAD_indices(model, collocs_pool, sorted_indices, l_E, l_E_pool)
            # Set new batch of collocs and l_E
            collocs = collocs_pool[sorted_indices]
            l_E = l_E_pool[sorted_indices]

        # Get outputs for plots for specific epochs
        if epoch in [20, 50, 100, 150, 180, 200, 220, 250, 300, 500, 1000, 2000, 4000, 
                     6000, 7000, 8000, 9000, 10_000, 20_000, 80_000, 90_000]:
            output = model(coords).reshape(refsol.shape)
            woutput = np.asarray(device_get(output))
            RESULTS[arch]['outputs'][epoch] = woutput

        if epoch % 10_000 == 0:
            print(f"Epoch {epoch} done.")

In [None]:
import pickle

save_file = os.path.join(results_dir, "phase_outputs.pkl")

with open(save_file, "wb") as f:
    pickle.dump(RESULTS, f)

## Plots

In [None]:
import pickle

filepath = os.path.join(results_dir, "phase_outputs.pkl")

with open(filepath, "rb") as f:
    RESULTS = pickle.load(f)

#### Horizontal Version

In [None]:
import numpy as np
import matplotlib.pyplot as plt
from matplotlib.gridspec import GridSpec
import seaborn as sns

from matplotlib.ticker import FormatStrFormatter

def plot_architectures_side_by_side(
    RESULTS,
    refsol,
    epochs=(1000, 20000, 90000),
    archs=("RGA", "cPIKAN"),
    cmap_pred="Spectral",
    cmap_res="Spectral",
    coords=None,  # optional (T,X) grid; if None, imshow index axes are used
    figsize=(12, 3),
    dpi=150,
):
    if len(archs) != 2:
        raise ValueError("Please pass exactly two architectures in `archs`.")
    if len(epochs) != 3:
        raise ValueError("Please pass exactly three epochs in `epochs`.")
        
    T, X = refsol.shape

    # Optional extent if coords provided
    extent = None
    if coords is not None:
        t = coords.get("t", None)
        x = coords.get("x", None)
        if t is not None and x is not None:
            if t.ndim == 1 and x.ndim == 1:
                extent = [t.min(), t.max(), x.min(), x.max()]  # x-axis = t, y-axis = x
            elif t.ndim == 2 and x.ndim == 2:
                extent = [t.min(), t.max(), x.min(), x.max()]
            else:
                raise ValueError("coords['t'] and coords['x'] must both be 1D or both be 2D.")


    # Figure layout: 2 rows × 6 cols (3 for each arch)
    fig = plt.figure(figsize=figsize, dpi=dpi)
    gs = GridSpec(
        2, 13, figure=fig,
        wspace=0.7, hspace=0.2,
        width_ratios=[1,1,1,1,1,1, 0.4, 1,1,1,1,1,1]
    )

    # Left y-axis labels once per row
    row_labels = ["Prediction", "Residuals"]
    for row in range(2):
        ax = fig.add_subplot(gs[row, 0])
        ax.set_axis_off()
        ax.text(
            -0.25, 0.5, row_labels[row],
            va="center", ha="right", rotation=90, transform=ax.transAxes, fontsize=11
        )

    # Helper to plot a 2×3 block for one architecture
    def plot_block(arch_idx, arch_name, col_start):
        # Title centered above the block
        ax_title = fig.add_subplot(gs[0, col_start:col_start+6])
        ax_title.set_axis_off()
        display_name = arch_name if arch_name != "RGA" else "RGA KAN"
        ax_title.text(0.5, 1.1, display_name,
                      ha="center", va="bottom",
                      fontsize=12, transform=ax_title.transAxes)

        bottom_axes = []
        if arch_idx == 0:
            column_labels = ["Fitting", "Diffusion", "Diffusion Equilibrium"]
        else:
            column_labels = ["Fitting", "Diffusion", "-"]
        
        im_handles_pred, im_handles_res = [], []
        for j, e in enumerate(epochs):
            pred = np.asarray(RESULTS[arch_name]['outputs'][e])
            res = refsol - pred

            # Top row: prediction
            axp = fig.add_subplot(gs[0, col_start + 2*j : col_start + 2*j + 2])
            
            vmin_p, vmax_p = float(np.nanmin(pred)), float(np.nanmax(pred))
            vmin_r, vmax_r = float(np.nanmin(res)),  float(np.nanmax(res))
            
            im1 = axp.imshow(pred.T, origin="lower", aspect="auto",
                 extent=extent if extent is not None else None,
                 vmin=vmin_p, vmax=vmax_p, cmap=cmap_pred)

            cbar1 = fig.colorbar(im1, ax=axp, fraction=0.046, pad=0.02)
            cbar1.ax.tick_params(labelsize=6)
            cbar1.ax.yaxis.set_major_formatter(FormatStrFormatter('%.1f'))
            
            axp.set_xticks([])
            axp.set_yticks([])
            
            im_handles_pred.append(im1)

            # Bottom row: residuals
            axr = fig.add_subplot(gs[1, col_start + 2*j : col_start + 2*j + 2])
            
            im2 = axr.imshow(res.T, origin="lower", aspect="auto",
                 extent=extent if extent is not None else None,
                 vmin=vmin_r, vmax=vmax_r, cmap=cmap_res)
            
            cbar2 = fig.colorbar(im2, ax=axr, fraction=0.046, pad=0.02)
            cbar2.ax.tick_params(labelsize=6)
            cbar2.ax.yaxis.set_major_formatter(FormatStrFormatter('%.1f'))

            if arch_idx == 0 and j == 2:  # bottom right of RGA panel
                cbar2.set_ticks([0.0, -0.1])
                cbar2.set_ticklabels([f"{tick:.1f}" for tick in [0.0, -0.1]])

            
            axr.set_xticks([])
            axr.set_yticks([])
            im_handles_res.append(im2)

            bottom_axes.append(axr)

        for ax, lbl in zip(bottom_axes, column_labels):
            bbox = ax.get_position()
            x_center = bbox.x0 + bbox.width / 2
            y_below  = bbox.y0 - 0.04   # tweak spacing as needed (e.g., 0.018–0.035)
            fig.text(x_center, y_below, lbl, ha="center", va="top", fontsize=10)

    # First block (columns 0–2) and second (3–5)
    plot_block(0, archs[0], 0)   # left panel occupies cols 1–6
    plot_block(1, archs[1], 7)   # right panel occupies cols 8–13

    # Column labels under the whole figure (optional)
    plt.savefig(os.path.join(plots_dir, "phase_outputs.pdf"), format="pdf", bbox_inches="tight")
    plt.show()


In [None]:
## Choose the three epochs you want as columns:
epochs = (20, 250, 20000)

cmap_pred = sns.color_palette("Spectral", as_cmap=True)
cmap_res  = sns.color_palette("Spectral", as_cmap=True)

plot_architectures_side_by_side(
    RESULTS=RESULTS,
    refsol=refsol,
    cmap_pred = cmap_pred,
    cmap_res = cmap_res,
    epochs=epochs,
    archs=("RGA", "cPIKAN"),
    coords=None  # or coords=coords
)
