# Experiment 5.1: Information Bottleneck Analysis

In [None]:
import jax
import jax.numpy as jnp

jax.config.update("jax_default_matmul_precision", "highest")

from src.equations import ac_res as pde_res
from jaxkan.utils.PIKAN import gradf

from src.utils import _get_adam, _get_pde_collocs, _get_ic_collocs, model_eval, count_params, _get_colloc_indices, grad_norm
from src.kan import KAN
from src.rgakan import RGAKAN

import numpy as np
from jax import device_get

import optax
from flax import nnx

import os

# Create the directory if it doesn't exist
results_dir = "results"
os.makedirs(results_dir, exist_ok=True)

plots_dir = "plots"
os.makedirs(plots_dir, exist_ok=True)
        
seed = 42

## Basic Functions

In [None]:
# Causal Weights with RBA contribution
def get_causal_weights(model, collocs, l_E):
    
    residuals = pde_res(model, collocs)

    # Apply RBA to residuals
    w_resids = l_E * residuals
    
    # Reshape to (num_chunks, points)
    w_resids = w_resids.reshape(num_chunks, -1)
    
    # Get average loss per chunk
    loss = jnp.mean(w_resids**2, axis=1)
    
    # Get causal weights
    weights = jnp.exp(-causal_tol * (M @ loss))
    
    points_per_chunk = collocs.shape[0] // num_chunks
    extended_weights = jnp.repeat(weights, points_per_chunk)
    
    return jax.lax.stop_gradient(extended_weights)


# PDE Loss (E)
def single_pde_loss(model, weight, x):

    # Eq. parameter
    D = jnp.array(1e-4, dtype=jnp.float32)
    c = jnp.array(5.0, dtype=jnp.float32)

    def u(t):
        y = model(t)
        return y

    # Physics Loss Terms
    u_t = gradf(u, 0, 1)
    u_xx = gradf(u, 1, 2)

    # Get all residuals
    pde_res = u_t(x) - D*u_xx(x) - c*(u(x)-(u(x)**3))

    pde_res = pde_res.flatten()

    # Get average loss per chunk
    loss = weight * pde_res[0]**2

    return loss


# IC Loss (I)
def single_ic_loss(model, weight, x, y):

    # Residual
    ic_res = model(x) - y
    
    r = ic_res.flatten()[0]
    
    w = jax.lax.stop_gradient(weight)
    loss = w * (r**2)
    
    return loss


@nnx.jit
def snr_train_step(model, optimizer, collocs, ic_collocs, ic_data, λ_E, λ_I, l_E, l_I):

    # ----------------------- PDE --------------------------------------------- #

    # RBA updates
    res = pde_res(model, collocs)
    abs_res = jnp.abs(res)
    l_E_new = (RBA_gamma * l_E) + (RBA_eta * (abs_res/jnp.max(abs_res)))

    # causal weights computed on RBA-weighted residuals
    causal_w = get_causal_weights(model, collocs, l_E_new)
    total_w_E = causal_w * l_E_new.flatten()

    # Define a batched PDE loss
    batched_E = nnx.vmap(nnx.value_and_grad(single_pde_loss), in_axes=(None, 0, 0))
    # Get all (weighted) losses and all gradients
    all_loss_E, all_grads_E = batched_E(model, total_w_E, collocs[:, None, :])
    # Get average PDE loss
    loss_E = jnp.mean(all_loss_E)
    # Get weighted gradients
    all_grads_E = jax.tree_util.tree_map(lambda g: λ_E * g, all_grads_E)
    # Get average gradients for Loss Annealing
    grads_E = jax.tree_util.tree_map(lambda g: jnp.mean(g, axis=0), all_grads_E)


    # ----------------------- ICs --------------------------------------------- #

    # RBA updates
    ic_res_full = model(ic_collocs) - ic_data
    abs_ic = jnp.abs(ic_res_full)
    l_I_new = (RBA_gamma * l_I) + (RBA_eta * (abs_ic/jnp.max(abs_ic)))

    # Define a batched ICs loss
    batched_I = nnx.vmap(nnx.value_and_grad(single_ic_loss), in_axes=(None, 0, 0, 0))
    # Get all losses and all gradients
    all_loss_I, all_grads_I = batched_I(model, l_I_new.flatten(), ic_collocs[:, None, :], ic_data[:, None, :])
    # Get average IC loss
    loss_I = jnp.mean(all_loss_I)
    # Get weighted gradients
    all_grads_I = jax.tree_util.tree_map(lambda g: λ_I * g, all_grads_I)
    # Get average gradients for Loss Annealing
    grads_I = jax.tree_util.tree_map(lambda g: jnp.mean(g, axis=0), all_grads_I)

    # --------------------- TOTAL -----------------------#

    # Compute total loss
    loss = λ_E*loss_E + λ_I*loss_I
    # Compute total gradients
    grads = jax.tree_util.tree_map(lambda g1, g2: g1 + g2, grads_E, grads_I)
    # Update optimizer
    optimizer.update(grads)

    # Get mean and std for SNR
    all_grads = jax.tree_util.tree_map(lambda g1, g2: jnp.concatenate([g1, g2], axis=0), all_grads_E, all_grads_I)

    mean_grads = jax.tree_util.tree_map(lambda g: jnp.mean(g, axis=0), all_grads)
    std_grads = jax.tree_util.tree_map(lambda g: jnp.std(g, axis=0),  all_grads)
    
    mu = optax.global_norm(mean_grads)
    sd = optax.global_norm(std_grads)
    snr = mu / (sd + 1e-8)
    
    return loss, grads_E, grads_I, snr, l_E_new, l_I_new


@nnx.jit
def get_RAD_indices(model, collocs_pool, old_indices, l_E, l_E_pool):

    # Apply updates from old indices to pool
    updated_pool = l_E_pool.at[old_indices].set(l_E)

    # Get full residuals
    resids = pde_res(model, collocs_pool)
    
    # Multiply by RBA weights
    w_resids = updated_pool * resids
    
    # Get absolute
    wa_resids = jnp.abs(w_resids)

    # Raise to power rad_a
    ea = jnp.power(wa_resids, rad_a)
    
    # Divide by mean and add rad_c
    px = (ea/jnp.mean(ea)) + rad_c
    
    # Normalize
    px_norm = (px / jnp.sum(px))[:,0]

    sorted_indices = _get_colloc_indices(collocs_pool=collocs_pool, batch_size=batch_size, px=px_norm, seed=seed)

    return sorted_indices, updated_pool

In [None]:
def get_frob(model, x):

    # normalize x to (1, d) for the model
    x = x[None, :] if x.ndim == 1 else x

    def u(t):
        y = model(t).flatten()
        return y[0]
    
    g = jax.grad(u)(x)
    fro_sq = jnp.vdot(g, g)
    
    return fro_sq

batched_frob = nnx.jit(jax.vmap(get_frob, in_axes=(None, 0)))

def get_complexity(model, collocs, ic_collocs):
    combined = jnp.concatenate([collocs, ic_collocs], axis=0)
    frob = jnp.mean(batched_frob(model, combined))
    return frob


### Data & Grid-Search Parameters

In [None]:
# Collocation points and ICs
collocs_pool = _get_pde_collocs(ranges = [(0,1), (-1,1)], sample_size = 400)
ic_collocs = _get_ic_collocs(x_range = (-1, 1), sample_size = 2**6)
ic_data = ((ic_collocs[:,1]**2)*jnp.cos(jnp.pi*ic_collocs[:,1])).reshape(-1,1)

# Reference solution
ref = np.load('data/allen_cahn.npz')
refsol = jnp.array(ref['usol'])

N_t, N_x = ref['usol'].shape
t, x = ref['t'].flatten(), ref['x'].flatten()
T, X = jnp.meshgrid(t, x, indexing='ij')
coords = jnp.hstack([T.flatten()[:, None], X.flatten()[:, None]])

In [None]:
# Training epochs
num_epochs = 100_000

# Scheduler configurations
learning_rate = 1e-3
decay_steps = 2000
decay_rate = 0.9
warmup_steps = 1000

# Define causal training parameters
causal_tol = 1.0
num_chunks = 32
M = jnp.triu(jnp.ones((num_chunks, num_chunks)), k=1).T

# Define Grad Norm parameters
grad_mixing = 0.9
f_grad_norm = 1000

# Define resampling parameters
batch_size = 2**12
f_resample = 2000
rad_a = 1.0
rad_c = 1.0

# Define RBA parameters
RBA_gamma = 0.999
RBA_eta = 0.01

# Define model parameters
n_in = collocs_pool.shape[1]
n_out = 1
D = 5
period_axes = {1: jnp.pi}
rff_std = None #2.0
sine_D = 5 #None
alpha = 0.0
beta = 0.0
init_scheme = {'type': 'glorot', 'gain': None, 'norm_pow': 0, 'distribution': 'uniform', 'sample_size': 10000}

num_blocks = 6

In [None]:
# Architecture settings
widths = [8, 16, 32]
archs = ["RGA", "cPIKAN"]

### Runs

In [None]:
RESULTS = dict()

# Grid search over width size
for n_hidden in widths:
    RESULTS[n_hidden] = dict()

    if n_hidden in [16, 32]:
        run = 7 # Worst performing run for RGAKANs with widths = 16, 32 from previous experiment
    else:
        run = 13 # Worst performing run for RGAKAN with width = 8 from previous experiment
    
    print(f"Training models with depth = {int(2*num_blocks)} and width = {n_hidden}.")

    # Grid search over architecture types
    for arch in archs:
        RESULTS[n_hidden][arch] = dict()

        RESULTS[n_hidden][arch]['L2'] = []
        RESULTS[n_hidden][arch]['SNR'] = []
        RESULTS[n_hidden][arch]['C'] = []
        
        # Initialize RBA weights - full pool
        l_E_pool = jnp.ones((collocs_pool.shape[0], 1))
        # Also get RBAs for ICs
        l_I = jnp.ones((ic_collocs.shape[0], 1))
    
        # Get starting collocation points & RBA weights
        sorted_indices = _get_colloc_indices(collocs_pool=collocs_pool, batch_size=batch_size, px=None, seed=seed)
        
        collocs = collocs_pool[sorted_indices]
        l_E = l_E_pool[sorted_indices]
        
        # Get opt_type
        opt_type = _get_adam(learning_rate=learning_rate, decay_steps=decay_steps, decay_rate=decay_rate, warmup_steps=warmup_steps)

        # Define model
        if arch == "RGA":
            model = RGAKAN(n_in=n_in, n_out=n_out, n_hidden=n_hidden, num_blocks=num_blocks, D=D,
                           init_scheme=init_scheme, alpha=alpha, beta=beta, ref=ref, period_axes=period_axes, rff_std=rff_std,
                           sine_D=sine_D, seed=seed+run)
        elif arch == "cPIKAN":
            model = KAN(n_in=n_in, n_out=n_out, n_hidden=n_hidden, num_layers=int(2*num_blocks), D=D,
                        init_scheme=init_scheme, period_axes=period_axes, rff_std=rff_std, seed=seed+run)

        print(f"Initialized {arch} model with {count_params(model)} parameters.")
        
        # Define global loss weights
        λ_E = jnp.array(1.0, dtype=float)
        λ_I = jnp.array(1.0, dtype=float)

        # Set optimizer
        optimizer = nnx.Optimizer(model, opt_type)
    
        # Start training
        for epoch in range(num_epochs):
        
            loss, grads_E, grads_I, snr, l_E, l_I = snr_train_step(model, optimizer, collocs, ic_collocs, ic_data, λ_E, λ_I, l_E, l_I)

            # Perform grad norm
            if (epoch != 0) and (epoch % f_grad_norm == 0):
        
                λ_Ε, λ_I = grad_norm(grads_E, grads_I, λ_E, λ_I, grad_mixing)
        
            # Perform RAD
            if (epoch != 0) and (epoch % f_resample == 0):
    
                # Get new indices after resampling
                sorted_indices, l_E_pool = get_RAD_indices(model, collocs_pool, sorted_indices, l_E, l_E_pool)
                # Set new batch of collocs and l_E
                collocs = collocs_pool[sorted_indices]
                l_E = l_E_pool[sorted_indices]

            # Calculate L2 Error at the end of EACH epoch
            l2error = model_eval(model, coords, refsol)
            # Get complexity
            compl = get_complexity(model, collocs, ic_collocs)
                
            RESULTS[n_hidden][arch]['L2'].append(l2error.item())
            RESULTS[n_hidden][arch]['SNR'].append(snr.item())
            RESULTS[n_hidden][arch]['C'].append(compl.item())

            if epoch % 5000 == 0:
                    print(f"Epoch: {epoch}. Current snr = {snr:.2e}, Current L2 Error = {l2error:.2e}, Current complexity = {compl:.2e}")

In [None]:
import pickle

snr_file = os.path.join(results_dir, "snr.pkl")

with open(snr_file, "wb") as f:
    pickle.dump(RESULTS, f)

## Plots

In [None]:
import pickle

filepath = os.path.join(results_dir, "snr.pkl")

with open(filepath, "rb") as f:
    RESULTS = pickle.load(f)

In [None]:
import numpy as np
import matplotlib.pyplot as plt
from matplotlib.ticker import LogLocator, LogFormatterMathtext, NullLocator, NullFormatter
import seaborn as sns

def plot_snr_grid(RESULTS, vlines=None, widths_order=(8, 16, 32), savepath=None, filename=None):

    # ---- styling: reuse your paper's knobs if present; else set sane fallbacks ----
    TITLE_FS   = globals().get("TITLE_FS", 16)
    LABEL_FS   = globals().get("LABEL_FS", 16)
    TICK_FS    = globals().get("TICK_FS", 14)
    LEGEND_FS  = globals().get("LEGEND_FS", 16)
    LINE_W     = globals().get("LINE_W", 2.0)

    # colors to stay consistent with your other plots
    _cmap = sns.color_palette("Spectral", as_cmap=True)
    cmap_points = np.linspace(0, 1, 12)
    color_indices = [1, -2]  # two distinct tones
    _colors = [_cmap(cmap_points[i]) for i in color_indices]

    # label mapping to final legend text
    MODEL_LABEL = {"cPIKAN": "cPIKAN", "RGA": "RGA KAN"}

    metrics = ["L2", "SNR", "C"]
    ylabels = [r"Relative $L^2$ Error", "SNR", "Complexity"]

    # figure layout: rows = metrics, cols = widths present in RESULTS
    widths = [w for w in widths_order if w in RESULTS.keys()]
    if not widths:
        raise ValueError("No matching widths found in RESULTS.")

    nrows, ncols = len(metrics), len(widths)
    fig, axes = plt.subplots(
        nrows, ncols, figsize=(4.2*ncols, 2.2*nrows),
        sharex=True, sharey=False, constrained_layout=True
    )
    fig.set_constrained_layout_pads(w_pad=0.08, h_pad=0.08, hspace=0.06, wspace=0.06)

    # normalize axes to 2D
    if nrows == 1 and ncols == 1:
        axes = np.array([[axes]])
    elif nrows == 1:
        axes = axes[np.newaxis, :]
    elif ncols == 1:
        axes = axes[:, np.newaxis]

    # plot
    legend_handles = None
    for c, width in enumerate(widths):
        col_data = RESULTS[width]

        for r, (metric, ylabel) in enumerate(zip(metrics, ylabels)):
            ax = axes[r, c]

            # x = epochs
            # use max length among present models to define x range
            lengths = []
            for arch in ("cPIKAN", "RGA"):
                if arch in col_data and metric in col_data[arch]:
                    lengths.append(len(col_data[arch][metric]))
            if not lengths:
                continue
            xmax = max(lengths)
            x = np.arange(1, xmax + 1)

            # plot in consistent order with consistent colors
            lines_here = []
            for j, arch in enumerate(("cPIKAN", "RGA")):
                if arch not in col_data or metric not in col_data[arch]:
                    continue
                y = np.asarray(col_data[arch][metric], dtype=float)
                xx = np.arange(1, len(y) + 1)
                h, = ax.plot(xx, y, linewidth=LINE_W, color=_colors[j % len(_colors)],
                             label=MODEL_LABEL.get(arch, arch))
                lines_here.append(h)

            # vertical markers (optional)
            if vlines and width in vlines:
                for v in vlines[width]:
                    ax.axvline(v, linestyle="--", color="black", alpha=0.7, linewidth=1.0)

            # scales, ticks, grid
            ax.set_xscale("log")
            # match your log y-axis formatting helper if available
            if " _set_log_ticks" in globals() and callable(globals().get("_set_log_ticks")):
                globals()["_set_log_ticks"](ax)
            else:
                ax.set_yscale("log")
                ax.tick_params(axis="both", which="both", labelsize=TICK_FS)

            ax.grid(True, which="both", linestyle="--", alpha=0.5)

            # labels
            if c == 0:
                ax.set_ylabel(ylabel, fontsize=LABEL_FS)
            if r == nrows - 1:
                ax.set_xlabel("Training Iteration", fontsize=LABEL_FS)
            if r == 0:
                ax.set_title(f"Width = {width}", fontsize=TITLE_FS)

            if legend_handles is None and lines_here:
                legend_handles = lines_here

    # common legend (models) below all subplots
    if legend_handles:
        fig.legend(
            legend_handles, [h.get_label() for h in legend_handles],
            loc="lower center", ncol=len(legend_handles), frameon=False, fontsize=LEGEND_FS,
            bbox_to_anchor=(0.5, -0.10)
        )

    # save
    if savepath is not None and filename is not None:
        plt.savefig(os.path.join(savepath, f"{filename}.pdf"), format="pdf", bbox_inches="tight")

    plt.show()


In [None]:
vlines = {
    8: (230, 9500),
    16: (200, 7000),
    32: (140, 4500)
}

plot_snr_grid(RESULTS, vlines, widths_order=(8, 16, 32), savepath=plots_dir, filename="snr")