# NTK Results (PDEs)

## Preliminaries

In [None]:
from jax import config
config.update("jax_enable_x64", True)
config.update("jax_default_matmul_precision", "highest")

In [None]:
import pickle
import os

import jax
import jax.numpy as jnp

from src.equations import burgers_res, helmholtz_14_res
from src.ntk import pinntk_diag_spectra
from src.kan import KAN

from flax import nnx
import optax

import numpy as np

from sklearn.model_selection import train_test_split

# Create the directory if it doesn't exist
results_dir = "results"
os.makedirs(results_dir, exist_ok=True)

result_file = os.path.join(results_dir, "ntk_pde.pkl")

plots_dir = "plots"
os.makedirs(plots_dir, exist_ok=True)

results = dict()

seed = 42

## Parameters

In [None]:
# Define the studied PDEs
pde_dict = {"burgers": burgers_res, "helmholtz_1-4": helmholtz_14_res}

D = 8
period_axes = None
rff_std = None

N_points = 2**6
n_ntk_pde = 256
n_ntk_bc = 32

num_epochs = 5001
checkpoints = [0, 1000, 2000, 3000, 4000, 5000]

opt_type = optax.adam(learning_rate=0.001)

# Define the two types of initialization
base_init = {'type': 'default'}
glorot_init = {'type': 'glorot', 'gain': None, 'norm_pow': 0, 'distribution': 'uniform', 'sample_size': 10000}


### Helper Functions

In [None]:
def get_collocs(pde_name, N):
        
    if pde_name == "burgers":

        # PDE collocation points
        t_pde = jnp.linspace(0, 1, N)
        x_pde = jnp.linspace(-1, 1, N)
        T_pde, X_pde = jnp.meshgrid(t_pde, x_pde, indexing='ij')
        pde_collocs = jnp.stack([T_pde.flatten(), X_pde.flatten()], axis=1) # (N^2, 2)
    
        # Initial condition: - sin(πx)
        t_ic = jnp.array([0.0])
        T_ic, X_ic = jnp.meshgrid(t_ic, x_pde, indexing='ij')
        ic_collocs = jnp.stack([T_ic.flatten(), X_ic.flatten()], axis=1) # (N, 2)
        ic_data = -jnp.sin(jnp.pi*ic_collocs[:,1]).reshape(-1,1) # (N, 1)

        # Boundary conditions: u(t,-1) = u(t,1) = 0
        x_bc_1 = jnp.array([-1.0])
        T_bc_1, X_bc_1 = jnp.meshgrid(t_pde, x_bc_1, indexing='ij')
        bc_1 = jnp.stack([T_bc_1.flatten(), X_bc_1.flatten()], axis=1) # (N, 2)
        bc_1_data = jnp.zeros(bc_1.shape[0]).reshape(-1,1) # (N, 1)

        x_bc_2 = jnp.array([1.0])
        T_bc_2, X_bc_2 = jnp.meshgrid(t_pde, x_bc_2, indexing='ij')
        bc_2 = jnp.stack([T_bc_2.flatten(), X_bc_2.flatten()], axis=1) # (N, 2)
        bc_2_data = jnp.zeros(bc_2.shape[0]).reshape(-1,1) # (N, 1)

        bc_collocs = jnp.concatenate([ic_collocs, bc_1, bc_2], axis=0)
        bc_data = jnp.concatenate([ic_data, bc_1_data, bc_2_data], axis=0)

    elif pde_name == "helmholtz_1-4":
        # PDE collocation points
        x_pde = jnp.linspace(-1, 1, N)
        y_pde = jnp.linspace(-1, 1, N)
        X_pde, Y_pde = jnp.meshgrid(x_pde, y_pde, indexing='ij')
        pde_collocs = jnp.stack([X_pde.flatten(), Y_pde.flatten()], axis=1) # (N^2, 2)
    
        # Boundary conditions: u(-1,y) = u(1,y) = u(x,-1) = u(x,1) = 0
        x_bc_1 = jnp.array([-1.0])
        X_bc_1, Y_bc_1 = jnp.meshgrid(x_bc_1, y_pde, indexing='ij')
        bc_1 = jnp.stack([X_bc_1.flatten(), Y_bc_1.flatten()], axis=1) # (N, 2)
        bc_1_data = jnp.zeros(bc_1.shape[0]).reshape(-1,1) # (N, 1)

        x_bc_2 = jnp.array([1.0])
        X_bc_2, Y_bc_2 = jnp.meshgrid(x_bc_2, y_pde, indexing='ij')
        bc_2 = jnp.stack([X_bc_2.flatten(), Y_bc_2.flatten()], axis=1) # (N, 2)
        bc_2_data = jnp.zeros(bc_2.shape[0]).reshape(-1,1) # (N, 1)

        y_bc_3 = jnp.array([-1.0])
        X_bc_3, Y_bc_3 = jnp.meshgrid(x_pde, y_bc_3, indexing='ij')
        bc_3 = jnp.stack([X_bc_3.flatten(), Y_bc_3.flatten()], axis=1) # (N, 2)
        bc_3_data = jnp.zeros(bc_3.shape[0]).reshape(-1,1) # (N, 1)

        y_bc_4 = jnp.array([1.0])
        X_bc_4, Y_bc_4 = jnp.meshgrid(x_pde, y_bc_4, indexing='ij')
        bc_4 = jnp.stack([X_bc_4.flatten(), Y_bc_4.flatten()], axis=1) # (N, 2)
        bc_4_data = jnp.zeros(bc_4.shape[0]).reshape(-1,1) # (N, 1)

        bc_collocs = jnp.concatenate([bc_1, bc_2, bc_3, bc_4], axis=0)
        bc_data = jnp.concatenate([bc_1_data, bc_2_data, bc_3_data, bc_4_data], axis=0)

    return pde_collocs, bc_collocs, bc_data


def get_ref(pde_name):

    ref = np.load(f'data/{pde_name}.npz')
    refsol = jnp.array(ref['usol'])

    N_t, N_x = ref['usol'].shape

    if pde_name != "helmholtz_1-4":
    
        t = ref['t'].flatten()
        x = ref['x'].flatten()
        T, X = jnp.meshgrid(t, x, indexing='ij')
        coords = jnp.stack([T.flatten(), X.flatten()], axis=1)

    else:

        x = ref['x'].flatten()
        y = ref['y'].flatten()
        X, Y = jnp.meshgrid(x, y, indexing='ij')
        coords = jnp.stack([X.flatten(), Y.flatten()], axis=1)

    return refsol, coords

In [None]:
def get_data(pde_name, N, n_ntk_pde=256, n_ntk_bc=32, seed=42):
    
    # Get the reference solution
    refsol, coords = get_ref(pde_name)

    pde_collocs, bc_collocs, bc_data = get_collocs(pde_name, N)
    
    # consistent NTK subsets per experiment
    key = jax.random.PRNGKey(seed)
    
    idx_pde = jax.random.choice(key, pde_collocs.shape[0], shape=(min(n_ntk_pde, pde_collocs.shape[0]),), replace=False)
    
    idx_bc  = jax.random.choice(key,  bc_collocs.shape[0], shape=(min(n_ntk_bc,  bc_collocs.shape[0]),),  replace=False)

    return refsol, coords, pde_collocs, bc_collocs, bc_data, idx_pde, idx_bc

In [None]:
# Define the loss function for this PDE
def loss_fn(model, pde_collocs, bc_collocs, bc_data):

    # ------------- PDE ---------------------------- #
    pde_residuals = pde_res(model, pde_collocs)

    pde_loss = jnp.mean(pde_residuals**2)


    # ------------- BC ----------------------------- #
    bc_residuals = model(bc_collocs) - bc_data

    bc_loss = jnp.mean(bc_residuals**2)

    
    # ------------- Total --------------------------- #
    total_loss = pde_loss + bc_loss

    return total_loss

# Define the train step
@nnx.jit
def train_step(model, optimizer, pde_collocs, bc_collocs, bc_data):
    loss, grads = nnx.value_and_grad(loss_fn, has_aux = False)(model, pde_collocs, bc_collocs, bc_data)

    optimizer.update(grads)

    return loss

In [None]:
def run_experiment_pde(pde_res_fn, model, opt, refsol, coords, pde_collocs, bc_collocs, bc_data, idx_pde, idx_bc):

    # NTK X, Y
    X_pde_ntk = pde_collocs[idx_pde]
    
    X_bc_ntk  = bc_collocs[idx_bc]
    Y_bc_ntk  = bc_data[idx_bc]

    specE_list, specB_list = [], []

    # τ = 0
    lamE0, lamB0 = pinntk_diag_spectra(model, pde_res, X_pde_ntk, X_bc_ntk, Y_bc_ntk)
    
    specE_list.append(lamE0)
    specB_list.append(lamB0)

    for epoch in range(num_epochs):
        loss = train_step(model, opt, pde_collocs, bc_collocs, bc_data)

        if epoch in checkpoints[1:]:
            lamE, lamB = pinntk_diag_spectra(model, pde_res, X_pde_ntk, X_bc_ntk, Y_bc_ntk)
            
            specE_list.append(lamE)
            specB_list.append(lamB)

    output = model(coords).reshape(refsol.shape)
    l2error = jnp.linalg.norm(output-refsol)/jnp.linalg.norm(refsol)

    print(f"\tFinal Loss = {loss:.2e}\t L^2 Error = {l2error:.2e}\n")

    return specE_list, specB_list

### Main Routine

In [None]:
for pde_name in pde_dict.keys():
    print(f"Running Experiments for {pde_name} PDE.")
    pde_res = pde_dict[pde_name]

    results[pde_name] = dict()

    # Get the data for the pde
    refsol, coords, pde_collocs, bc_collocs, bc_data, idx_pde, idx_bc = get_data(pde_name, N_points, n_ntk_pde, n_ntk_bc, seed)

    # Model input/output
    n_in, n_out = pde_collocs.shape[1], bc_data.shape[1]

    # Different architectures
    for arch_name, arch in zip(["small", "big"], [(4, 3), (16, 5)]):

        results[pde_name][arch_name] = dict()
        width, depth = arch

        # Discern between baseline initialization and Glorot-like initialization
        for init_scheme in [base_init, glorot_init]:

            type_init = init_scheme['type']
            results[pde_name][arch_name][type_init] = dict()
        
            print(f"\tTraining model with depth = {depth} and width = {width} ({type_init} init).")

            model = KAN(n_in = n_in, n_out = n_out, n_hidden = width, num_layers = depth, D = D,
                        init_scheme = init_scheme, period_axes = period_axes, rff_std = rff_std,
                        seed = seed)
        
            optimizer = nnx.Optimizer(model, opt_type)
        
            specE_list, specB_list = run_experiment_pde(pde_res, model, optimizer, refsol, coords, pde_collocs, 
                                                        bc_collocs, bc_data, idx_pde, idx_bc)
            
            results[pde_name][arch_name][type_init]["specE_list"] = specE_list
            results[pde_name][arch_name][type_init]["specB_list"] = specB_list
            

In [None]:
with open(result_file, "wb") as f:
    pickle.dump(results, f)

## Visualizations

In [None]:
with open(result_file, "rb") as f:
    results = pickle.load(f)

In [None]:
def pde_title(name: str) -> str:
    if name == 'helmholtz_1-4':
        return "Helmholtz"
    else:
        return "Burgers"

In [None]:
import matplotlib.pyplot as plt
import matplotlib.lines as mlines
import seaborn as sns
import numpy as np
import os

# ---------------------------------------------------------
# CONFIGURATION
# ---------------------------------------------------------
# Select exactly TWO PDEs to plot in this 2x5 grid format
# If you have 5 PDEs, you might want to run this with different pairs
PDES_TO_PLOT = ["burgers", "helmholtz_1-4"]  # Replace with actual keys from your pde_dict

INIT_KEYS = ["default", "glorot"]
init_dict = {"default": "Default", "glorot": "Glorot"}
ARCHITECTURES = ["small", "big"]

def plot_pde_row(ax, spec_list, color_palette, tick_fs=12):
    """
    Helper function to handle the actual plotting logic for a single subplot.
    """
    if not spec_list:
        return

    c_init, c_mid, c_final = color_palette
    
    specs = [np.asarray(e) for e in spec_list]
    idx   = np.arange(1, specs[0].size + 1)

    # 1. Initialization
    ax.plot(idx, specs[0], color=c_init, lw=2, label="Initialization")
    
    # 2. Intermediates
    if len(specs) > 2:
        for lam in specs[1:-1]:
            ax.plot(idx, lam, "--", color=c_mid, alpha=0.5)

    # 3. Final
    ax.plot(idx, specs[-1], "--", color=c_final, lw=2, label="Final Iteration")
    
    # Styling
    ax.set_xscale("log")
    ax.set_yscale("log")
    ax.grid(True, which='both', linestyle='--', linewidth=0.25, alpha=0.55)
    ax.tick_params(axis='both', labelsize=tick_fs)


def plot_pde_grid(results, arch_name, pde_pair):
    """
    Plots a 2x5 grid for a specific architecture and TWO PDEs.
    Cols 0,1: PDE 1 (Eq, BC)
    Col  2:   Blank
    Cols 3,4: PDE 2 (Eq, BC)
    """
    TITLE_FS = 18
    LABEL_FS = 16
    TICK_FS  = 14
    
    # Setup Colors
    palette = sns.color_palette("Spectral", 20)
    colors  = (palette[0], palette[16], palette[-1]) # Init, Mid, Final

    # Create 2 rows x 5 columns
    fig, axes = plt.subplots(nrows=2, ncols=5, figsize=(20, 8), 
                             sharex='col', sharey=False,
                             gridspec_kw={'width_ratios': [1, 1, 0.15, 1, 1]})
    
    pde1_name, pde2_name = pde_pair

    for row_idx, init_key in enumerate(INIT_KEYS):
        
        # --- LEFT BLOCK: PDE 1 ---
        # Get data
        try:
            rec1 = results[pde1_name][arch_name].get(init_key)
        except KeyError:
            rec1 = None

        if rec1:
            # Column 0: Equation Spectra
            ax_eq1 = axes[row_idx, 0]
            plot_pde_row(ax_eq1, rec1.get("specE_list"), colors, TICK_FS)
            
            # Column 1: Boundary Spectra
            ax_bc1 = axes[row_idx, 1]
            plot_pde_row(ax_bc1, rec1.get("specB_list"), colors, TICK_FS)
        
        # --- MIDDLE BLOCK: SPACER ---
        axes[row_idx, 2].set_visible(False)

        # --- RIGHT BLOCK: PDE 2 ---
        # Get data
        try:
            rec2 = results[pde2_name][arch_name].get(init_key)
        except KeyError:
            rec2 = None

        if rec2:
            # Column 3: Equation Spectra
            ax_eq2 = axes[row_idx, 3]
            plot_pde_row(ax_eq2, rec2.get("specE_list"), colors, TICK_FS)
            
            # Column 4: Boundary Spectra
            ax_bc2 = axes[row_idx, 4]
            plot_pde_row(ax_bc2, rec2.get("specB_list"), colors, TICK_FS)

        # -----------------------------------------------------
        # LABELS & TITLES
        # -----------------------------------------------------
        
        # Row Labels (Init Type) - On Col 0 AND Col 3 (start of new block)
        if row_idx == 0:
            # Titles only on top row
            axes[row_idx, 0].set_title(f"{pde_title(pde1_name)} - PDE", fontsize=TITLE_FS, pad=15)
            axes[row_idx, 1].set_title(f"{pde_title(pde1_name)} - BC", fontsize=TITLE_FS, pad=15)
            
            axes[row_idx, 3].set_title(f"{pde_title(pde2_name)} - PDE", fontsize=TITLE_FS, pad=15)
            axes[row_idx, 4].set_title(f"{pde_title(pde2_name)} - BC", fontsize=TITLE_FS, pad=15)

        if row_idx == 1:
            # X Labels only on bottom row
            axes[row_idx, 0].set_xlabel("Indices", fontsize=LABEL_FS)
            axes[row_idx, 1].set_xlabel("Indices", fontsize=LABEL_FS)
            axes[row_idx, 3].set_xlabel("Indices", fontsize=LABEL_FS)
            axes[row_idx, 4].set_xlabel("Indices", fontsize=LABEL_FS)

        # Y Labels - On Col 0 and Col 3 (because Col 3 starts a visually distinct section)
        axes[row_idx, 0].set_ylabel(f"Eigenvalues ({init_dict[init_key]})", fontsize=LABEL_FS)
        axes[row_idx, 3].set_ylabel(f"Eigenvalues ({init_dict[init_key]})", fontsize=LABEL_FS)

    # ---------------------------------------------------------
    # ALIGNMENT FIX
    # ---------------------------------------------------------
    # Align labels for Col 0 and Col 3
    fig.align_ylabels(axes[:, [0, 3]])

    # ---------------------------------------------------------
    # LEGEND
    # ---------------------------------------------------------
    c_init, c_mid, c_final = colors
    handles = [
        mlines.Line2D([], [], color=c_init,  label="Initialization",        linewidth=2),
        mlines.Line2D([], [], color=c_mid,   label="Intermediate Iterations", linewidth=2, linestyle="--"),
        mlines.Line2D([], [], color=c_final, label="Final Iteration",       linewidth=2, linestyle="--"),
    ]
    
    fig.legend(handles=handles, loc="lower center", ncol=3, fontsize=LABEL_FS, 
               frameon=False, bbox_to_anchor=(0.5, 0.0))

    # Adjust layout
    plt.subplots_adjust(top=0.85, bottom=0.15, wspace=0.3, hspace=0.15)

    # Save
    save_name = f"ntk_pde_{arch_name}.pdf"
    save_path = os.path.join(plots_dir, save_name)
    fig.savefig(save_path, bbox_inches='tight')
    print(f"Saved plot to {save_path}")
    plt.show()


In [None]:
# --- Execution ---
# Ensure PDES_TO_PLOT has exactly 2 keys present in your results
for arch in ARCHITECTURES:
    plot_pde_grid(results, arch, PDES_TO_PLOT)