# NTK Results

## Preliminaries

In [None]:
from jax import config
config.update("jax_enable_x64", True)

In [None]:
import pickle
import os

import jax
import jax.numpy as jnp

from src.utils import *
from src.functions import *
from src.pdes import *
from src.ntk import *

from flax import nnx
import optax

from jaxkan.KAN import KAN

from sklearn.model_selection import train_test_split

## Function Fitting

We first perform experiments relevant to the NTK for the Function Fitting case, because PDEs have their own NTK formulation.

### Parameters

In [None]:
N = 5000
n_ntk = 256

seed = 42

num_epochs = 2001
checkpoints = [0, 500, 1000, 1500, 2000]

opt_type = optax.adam(learning_rate=0.001)

pow_basis = 1.75
pow_res = 0.25

# Model input/output
n_in, n_out = 2, 1

# Studied functions
funcs = [("f1", f1), ("f2", f2), ("f3", f3), ("f4", f4), ("f5", f5)]

In [None]:
# --------------------------
# Small architecture details
# --------------------------
G_small = 5
hidden_small = [8, 8]

params_small_baseline = {'k': 3, 'G': G_small, 'grid_range': (-1.0, 1.0), 'grid_e': 1.0, 'residual': nnx.silu, 'external_weights': True, 'add_bias': True,
                         'init_scheme': {'type': 'default'}}

params_small_glorot = {'k': 3, 'G': G_small, 'grid_range': (-1.0, 1.0), 'grid_e': 1.0, 'residual': nnx.silu, 'external_weights': True, 'add_bias': True,
                            'init_scheme': {'type': 'glorot', 'gain': None, 'distribution': 'uniform'}}

params_small_power = {'k': 3, 'G': G_small, 'grid_range': (-1.0, 1.0), 'grid_e': 1.0, 'residual': nnx.silu, 'external_weights': True, 'add_bias': True,
                      'init_scheme': {'type': 'power', "const_b": 1.0, "const_r": 1.0, "pow_b1": pow_basis, "pow_b2": pow_basis, "pow_r1": pow_res, "pow_r2": pow_res}}

# ------------------------
# Big architecture details
# ------------------------
G_big = 20
hidden_big = [32, 32, 32]

params_big_baseline = {'k': 3, 'G': G_big, 'grid_range': (-1.0, 1.0), 'grid_e': 1.0, 'residual': nnx.silu, 'external_weights': True, 'add_bias': True,
                         'init_scheme': {'type': 'default'}}

params_big_glorot = {'k': 3, 'G': G_big, 'grid_range': (-1.0, 1.0), 'grid_e': 1.0, 'residual': nnx.silu, 'external_weights': True, 'add_bias': True,
                            'init_scheme': {'type': 'glorot', 'gain': None, 'distribution': 'uniform'}}

params_big_power = {'k': 3, 'G': G_big, 'grid_range': (-1.0, 1.0), 'grid_e': 1.0, 'residual': nnx.silu, 'external_weights': True, 'add_bias': True,
                      'init_scheme': {'type': 'power', "const_b": 1.0, "const_r": 1.0, "pow_b1": pow_basis, "pow_b2": pow_basis, "pow_r1": pow_res, "pow_r2": pow_res}}

### Helper Functions

In [None]:
def get_data(func, N, n_ntk, seed):
    
    # Generate data
    x, y = generate_func_data(func, 2, N, seed)
    
    # Split data (at this point just to ensure continuity)
    X_train, X_test, y_train, y_test = train_test_split(x, y, test_size=0.2, random_state=seed)
    
    # Subsample points used to compute NTK
    key_ntk = jax.random.PRNGKey(seed)
    idx = jax.random.choice(key_ntk, X_train.shape[0], shape=(n_ntk,), replace=False)
    X_ntk = X_train[idx]

    return X_train, y_train, X_ntk

In [None]:
def run_experiment(func, model, opt, X_train, y_train, X_ntk, title):
    spec_list, tau_list = [], []
    conds, ranks = [], []

    # τ = 0 (before any updates)
    K0 = stabilize_kernel(ntk_matrix(model, X_ntk))
    lam0 = jnp.sort(jnp.linalg.eigvalsh(K0))[::-1]
    spec_list.append(lam0)
    tau_list.append(0)
    conds.append(cond_from_eigs(lam0))
    
    eff_rank0 = (lam0.sum() ** 2) / (jnp.sum(lam0 ** 2) + 1e-12)
    ranks.append(float(eff_rank0))

    for epoch in range(num_epochs):
        loss = func_fit_step(model, opt, X_train, y_train)

        if epoch in checkpoints[1:]:
            Kt = stabilize_kernel(ntk_matrix(model, X_ntk))
            lam = jnp.sort(jnp.linalg.eigvalsh(Kt))[::-1]
            spec_list.append(lam)
            tau_list.append(epoch)
            conds.append(cond_from_eigs(lam))
            
            eff_rank_t = (lam.sum() ** 2) / (jnp.sum(lam ** 2) + 1e-12)
            ranks.append(float(eff_rank_t))

    l2error = func_fit_eval(model, func, 2, 200)

    print(f"\t{title} Model Metrics:")
    print(f"\tCond Number: τ=0 → {conds[0]:.2e}, τ={tau_list[-1]} → {conds[-1]:.2e}")
    print(f"\tEffective Rank: τ=0 → {ranks[0]:.2f}, τ={tau_list[-1]} → {ranks[-1]:.2f}")
    print(f"\tFinal Loss = {loss:.2e}\t L^2 Error = {l2error:.2e}\n")

    return spec_list, tau_list, conds, ranks

### Main Routine

In [None]:
results = dict()

for func_name, func in funcs:

    results[func_name] = dict()

    # Get the data for the function
    X_train, y_train, X_ntk = get_data(func, N, n_ntk, seed)

    # Define the small architecture
    layer_dims = [n_in, *hidden_small, n_out]

    results[func_name]["small"] = dict()
    
    print(f"Training model with dimensions {layer_dims} for function {func_name}.")
    
    # Baseline
    results[func_name]["small"]["Baseline"] = dict()
    
    base_model = KAN(layer_dims = layer_dims, layer_type = 'spline', required_parameters = params_small_baseline, seed = seed+1)
    base_opt = nnx.Optimizer(base_model, opt_type)

    spec_list, tau_list, conds, ranks = run_experiment(func, base_model, base_opt, X_train, y_train, X_ntk, "Baseline")
    results[func_name]["small"]["Baseline"]["spec_list"] = spec_list
    results[func_name]["small"]["Baseline"]["tau_list"] = tau_list
    results[func_name]["small"]["Baseline"]["conds"] = conds
    results[func_name]["small"]["Baseline"]["ranks"] = ranks

    # Glorot
    results[func_name]["small"]["Glorot"] = dict()
    
    glorot_model = KAN(layer_dims = layer_dims, layer_type = 'spline', required_parameters = params_small_glorot, seed = seed+1)
    glorot_opt = nnx.Optimizer(glorot_model, opt_type)

    spec_list, tau_list, conds, ranks = run_experiment(func, glorot_model, glorot_opt, X_train, y_train, X_ntk, "Glorot")
    results[func_name]["small"]["Glorot"]["spec_list"] = spec_list
    results[func_name]["small"]["Glorot"]["tau_list"] = tau_list
    results[func_name]["small"]["Glorot"]["conds"] = conds
    results[func_name]["small"]["Glorot"]["ranks"] = ranks

    # Power Law
    results[func_name]["small"]["Power"] = dict()
    
    power_model = KAN(layer_dims = layer_dims, layer_type = 'spline', required_parameters = params_small_power, seed = seed+1)
    power_opt = nnx.Optimizer(power_model, opt_type)

    spec_list, tau_list, conds, ranks = run_experiment(func, power_model, power_opt, X_train, y_train, X_ntk, "Power")
    results[func_name]["small"]["Power"]["spec_list"] = spec_list
    results[func_name]["small"]["Power"]["tau_list"] = tau_list
    results[func_name]["small"]["Power"]["conds"] = conds
    results[func_name]["small"]["Power"]["ranks"] = ranks


    # Define the big architecture
    layer_dims = [n_in, *hidden_big, n_out]

    results[func_name]["big"] = dict()

    print(f"Training model with dimensions {layer_dims} for function {func_name}.")

    # Baseline
    results[func_name]["big"]["Baseline"] = dict()
    
    base_model = KAN(layer_dims = layer_dims, layer_type = 'spline', required_parameters = params_big_baseline, seed = seed+1)
    base_opt = nnx.Optimizer(base_model, opt_type)

    spec_list, tau_list, conds, ranks = run_experiment(func, base_model, base_opt, X_train, y_train, X_ntk, "Baseline")
    results[func_name]["big"]["Baseline"]["spec_list"] = spec_list
    results[func_name]["big"]["Baseline"]["tau_list"] = tau_list
    results[func_name]["big"]["Baseline"]["conds"] = conds
    results[func_name]["big"]["Baseline"]["ranks"] = ranks

    # Glorot
    results[func_name]["big"]["Glorot"] = dict()
    
    glorot_model = KAN(layer_dims = layer_dims, layer_type = 'spline', required_parameters = params_big_glorot, seed = seed+1)
    glorot_opt = nnx.Optimizer(glorot_model, opt_type)

    spec_list, tau_list, conds, ranks = run_experiment(func, glorot_model, glorot_opt, X_train, y_train, X_ntk, "Glorot")
    results[func_name]["big"]["Glorot"]["spec_list"] = spec_list
    results[func_name]["big"]["Glorot"]["tau_list"] = tau_list
    results[func_name]["big"]["Glorot"]["conds"] = conds
    results[func_name]["big"]["Glorot"]["ranks"] = ranks

    # Power Law
    results[func_name]["big"]["Power"] = dict()
    
    power_model = KAN(layer_dims = layer_dims, layer_type = 'spline', required_parameters = params_big_power, seed = seed+1)
    power_opt = nnx.Optimizer(power_model, opt_type)

    spec_list, tau_list, conds, ranks = run_experiment(func, power_model, power_opt, X_train, y_train, X_ntk, "Power")
    results[func_name]["big"]["Power"]["spec_list"] = spec_list
    results[func_name]["big"]["Power"]["tau_list"] = tau_list
    results[func_name]["big"]["Power"]["conds"] = conds
    results[func_name]["big"]["Power"]["ranks"] = ranks

In [None]:
# Save results for further processing
results_dir = 'ff_results/'

with open(os.path.join(results_dir, "ntk.pkl"), "wb") as f:
    pickle.dump(results, f)

## PDE Solving

We then expand these ideas to PDEs solved using the PIKAN framework with RBA.

### Parameters

In [None]:
# Setup
pdes = [("allen-cahn", ac_res), ("burgers", burgers_res), ("helmholtz", helmholtz_res)]

N = 2**6
n_ntk_pde = 256
n_ntk_bc = 32

RBA_gamma = 0.999
RBA_eta = 0.01

seed = 42

num_epochs = 5001
checkpoints = [0, 1000, 2000, 3000, 4000, 5000]

n_in, n_out = 2, 1

opt_type = optax.adam(learning_rate=0.001)

pow_basis = 1.75
pow_res = 0.25


# --------------------------
# Small architecture details
# --------------------------
G_small = 5
hidden_small = [8, 8]

params_small_baseline = {'k': 3, 'G': G_small, 'grid_range': (-1.0, 1.0), 'grid_e': 1.0, 'residual': nnx.silu, 'external_weights': True, 'add_bias': True,
                         'init_scheme': {'type': 'default'}}

params_small_glorot = {'k': 3, 'G': G_small, 'grid_range': (-1.0, 1.0), 'grid_e': 1.0, 'residual': nnx.silu, 'external_weights': True, 'add_bias': True,
                            'init_scheme': {'type': 'glorot', 'gain': None, 'distribution': 'uniform'}}

params_small_power = {'k': 3, 'G': G_small, 'grid_range': (-1.0, 1.0), 'grid_e': 1.0, 'residual': nnx.silu, 'external_weights': True, 'add_bias': True,
                      'init_scheme': {'type': 'power', "const_b": 1.0, "const_r": 1.0, "pow_b1": pow_basis, "pow_b2": pow_basis, "pow_r1": pow_res, "pow_r2": pow_res}}

# ------------------------
# Big architecture details
# ------------------------
G_big = 20
hidden_big = [32, 32, 32]

params_big_baseline = {'k': 3, 'G': G_big, 'grid_range': (-1.0, 1.0), 'grid_e': 1.0, 'residual': nnx.silu, 'external_weights': True, 'add_bias': True,
                         'init_scheme': {'type': 'default'}}

params_big_glorot = {'k': 3, 'G': G_big, 'grid_range': (-1.0, 1.0), 'grid_e': 1.0, 'residual': nnx.silu, 'external_weights': True, 'add_bias': True,
                            'init_scheme': {'type': 'glorot', 'gain': None, 'distribution': 'uniform'}}

params_big_power = {'k': 3, 'G': G_big, 'grid_range': (-1.0, 1.0), 'grid_e': 1.0, 'residual': nnx.silu, 'external_weights': True, 'add_bias': True,
                      'init_scheme': {'type': 'power', "const_b": 1.0, "const_r": 1.0, "pow_b1": pow_basis, "pow_b2": pow_basis, "pow_r1": pow_res, "pow_r2": pow_res}}

### Helper Functions

In [None]:
def get_data(pde_name, N, n_ntk_pde=256, n_ntk_bc=32, seed=42):
    
    # Get the reference solution
    refsol, coords = get_ref(pde_name)

    pde_collocs, bc_collocs, bc_data = get_collocs(pde_name, N)
    
    # consistent NTK subsets per experiment
    key = jax.random.PRNGKey(seed)
    
    idx_pde = jax.random.choice(key, pde_collocs.shape[0], shape=(min(n_ntk_pde, pde_collocs.shape[0]),), replace=False)
    
    idx_bc  = jax.random.choice(key,  bc_collocs.shape[0], shape=(min(n_ntk_bc,  bc_collocs.shape[0]),),  replace=False)

    return refsol, coords, pde_collocs, bc_collocs, bc_data, idx_pde, idx_bc

In [None]:
def run_experiment_pde(pde_res_fn, model, opt, refsol, coords, pde_collocs, bc_collocs, bc_data, idx_pde, idx_bc, title):

    # NTK X, Y
    X_pde_ntk = pde_collocs[idx_pde]
    
    X_bc_ntk  = bc_collocs[idx_bc]
    Y_bc_ntk  = bc_data[idx_bc]

    # init RBA weights for training loop
    l_E = jnp.ones((pde_collocs.shape[0], 1))
    l_B = jnp.ones((bc_collocs.shape[0], 1))

    specE_list, specB_list, tau_list = [], [], []
    condsE, condsB = [], []

    # τ = 0
    wE0 = l_E[idx_pde].ravel()
    wB0 = l_B[idx_bc].ravel()
    lamE0, lamB0 = pinntk_diag_spectra_weighted(model, pde_res, X_pde_ntk, X_bc_ntk, Y_bc_ntk, wE0, wB0)
    
    specE_list.append(lamE0)
    specB_list.append(lamB0)
    tau_list.append(0)
    
    condsE.append(cond_from_eigs(lamE0))
    condsB.append(cond_from_eigs(lamB0))

    for epoch in range(num_epochs):
        loss, l_E, l_B = train_step(model, opt, l_E, l_B, pde_collocs, bc_collocs, bc_data)

        if epoch in checkpoints[1:]:
            wEt = l_E[idx_pde].ravel()
            wBt = l_B[idx_bc].ravel()
            lamE, lamB = pinntk_diag_spectra_weighted(model, pde_res, X_pde_ntk, X_bc_ntk, Y_bc_ntk, wEt, wBt)
            
            specE_list.append(lamE)
            specB_list.append(lamB)
            tau_list.append(epoch)
            
            condsE.append(cond_from_eigs(lamE))
            condsB.append(cond_from_eigs(lamB))

    output = model(coords).reshape(refsol.shape)
    l2error = jnp.linalg.norm(output-refsol)/jnp.linalg.norm(refsol)

    print(f"\t{title} Metrics:")
    print(f"\tPDE Cond#: τ=0 → {condsE[0]:.2e}, τ={tau_list[-1]} → {condsE[-1]:.2e}")
    print(f"\tBC  Cond#: τ=0 → {condsB[0]:.2e}, τ={tau_list[-1]} → {condsB[-1]:.2e}")
    print(f"\tFinal Loss = {loss:.2e}\t L^2 Error = {l2error:.2e}\n")

    return specE_list, specB_list, tau_list, condsE, condsB


### Main Routine

In [None]:
results = dict()

for pde_name, pde_res in pdes:

    results[pde_name] = dict()

    # Define the loss function for this PDE
    def loss_fn(model, l_E, l_B, pde_collocs, bc_collocs, bc_data):

        # ------------- PDE ---------------------------- #
        pde_residuals = pde_res(model, pde_collocs)
    
        # Get new RBA weights
        abs_pde_res = jnp.abs(pde_residuals)
        l_E_new = (RBA_gamma*l_E) + (RBA_eta*abs_pde_res/jnp.max(abs_pde_res))
    
        # Multiply by RBA weights
        w_resids_pde = l_E_new * pde_residuals
    
        # Get loss
        pde_loss = jnp.mean(w_resids_pde**2)
    
    
        # ------------- BC ----------------------------- #
        bc_residuals = model(bc_collocs) - bc_data
    
        # Get new RBA weights
        abs_bc_res = jnp.abs(bc_residuals)
        l_B_new = (RBA_gamma*l_B) + (RBA_eta*abs_bc_res/jnp.max(abs_bc_res))
    
        # Multiply by RBA weights
        w_resids_bc = l_B_new * bc_residuals
    
        # Loss
        bc_loss = jnp.mean(w_resids_bc**2)
    
        
        # ------------- Total --------------------------- #
        total_loss = pde_loss + bc_loss
    
        return total_loss, (l_E_new, l_B_new)
        
    # Define the train step
    @nnx.jit
    def train_step(model, optimizer, l_E, l_B, pde_collocs, bc_collocs, bc_data):
    
        (loss, (l_E_new, l_B_new)), grads = nnx.value_and_grad(loss_fn, has_aux = True)(model, l_E, l_B, pde_collocs, bc_collocs, bc_data)
    
        optimizer.update(grads)
    
        return loss, l_E_new, l_B_new

    # Get the data
    refsol, coords, pde_collocs, bc_collocs, bc_data, idx_pde, idx_bc = get_data(pde_name, N, n_ntk_pde, n_ntk_bc, seed)

    """
    # Define the small architecture
    layer_dims = [n_in, *hidden_small, n_out]

    results[pde_name]["small"] = dict()
    
    print(f"Training model with dimensions {layer_dims} for PDE {pde_name}.")
    
    # Baseline
    results[pde_name]["small"]["Baseline"] = dict()
    
    base_model = KAN(layer_dims = layer_dims, layer_type = 'spline', required_parameters = params_small_baseline, seed = seed+1)
    base_opt = nnx.Optimizer(base_model, opt_type)

    specE_list, specB_list, tau_list, condsE, condsB = run_experiment_pde(pde_res, base_model, base_opt, refsol, coords, pde_collocs, 
                                                                          bc_collocs, bc_data, idx_pde, idx_bc, "Baseline")
    
    results[pde_name]["small"]["Baseline"]["specE_list"] = specE_list
    results[pde_name]["small"]["Baseline"]["specB_list"] = specB_list
    results[pde_name]["small"]["Baseline"]["tau_list"] = tau_list
    results[pde_name]["small"]["Baseline"]["condsE"] = condsE
    results[pde_name]["small"]["Baseline"]["condsB"] = condsB

    # Glorot
    results[pde_name]["small"]["Glorot"] = dict()
    
    glorot_model = KAN(layer_dims = layer_dims, layer_type = 'spline', required_parameters = params_small_glorot, seed = seed+1)
    glorot_opt = nnx.Optimizer(glorot_model, opt_type)

    specE_list, specB_list, tau_list, condsE, condsB = run_experiment_pde(pde_res, glorot_model, glorot_opt, refsol, coords, pde_collocs, 
                                                                          bc_collocs, bc_data, idx_pde, idx_bc, "Glorot")
    
    results[pde_name]["small"]["Glorot"]["specE_list"] = specE_list
    results[pde_name]["small"]["Glorot"]["specB_list"] = specB_list
    results[pde_name]["small"]["Glorot"]["tau_list"] = tau_list
    results[pde_name]["small"]["Glorot"]["condsE"] = condsE
    results[pde_name]["small"]["Glorot"]["condsB"] = condsB

    # Power Law
    results[pde_name]["small"]["Power"] = dict()
    
    power_model = KAN(layer_dims = layer_dims, layer_type = 'spline', required_parameters = params_small_power, seed = seed+1)
    power_opt = nnx.Optimizer(power_model, opt_type)

    specE_list, specB_list, tau_list, condsE, condsB = run_experiment_pde(pde_res, power_model, power_opt, refsol, coords, pde_collocs, 
                                                                          bc_collocs, bc_data, idx_pde, idx_bc, "Power")
    
    results[pde_name]["small"]["Power"]["specE_list"] = specE_list
    results[pde_name]["small"]["Power"]["specB_list"] = specB_list
    results[pde_name]["small"]["Power"]["tau_list"] = tau_list
    results[pde_name]["small"]["Power"]["condsE"] = condsE
    results[pde_name]["small"]["Power"]["condsB"] = condsB
    """


    # Define the big architecture
    layer_dims = [n_in, *hidden_big, n_out]

    results[pde_name]["big"] = dict()

    print(f"Training model with dimensions {layer_dims} for PDE {pde_name}.")

    # Baseline
    results[pde_name]["big"]["Baseline"] = dict()
    
    base_model = KAN(layer_dims = layer_dims, layer_type = 'spline', required_parameters = params_big_baseline, seed = seed+1)
    base_opt = nnx.Optimizer(base_model, opt_type)

    specE_list, specB_list, tau_list, condsE, condsB = run_experiment_pde(pde_res, base_model, base_opt, refsol, coords, pde_collocs, 
                                                                          bc_collocs, bc_data, idx_pde, idx_bc, "Baseline")
    
    results[pde_name]["big"]["Baseline"]["specE_list"] = specE_list
    results[pde_name]["big"]["Baseline"]["specB_list"] = specB_list
    results[pde_name]["big"]["Baseline"]["tau_list"] = tau_list
    results[pde_name]["big"]["Baseline"]["condsE"] = condsE
    results[pde_name]["big"]["Baseline"]["condsB"] = condsB

    # Glorot
    results[pde_name]["big"]["Glorot"] = dict()
    
    glorot_model = KAN(layer_dims = layer_dims, layer_type = 'spline', required_parameters = params_big_glorot, seed = seed+1)
    glorot_opt = nnx.Optimizer(glorot_model, opt_type)

    specE_list, specB_list, tau_list, condsE, condsB = run_experiment_pde(pde_res, glorot_model, glorot_opt, refsol, coords, pde_collocs, 
                                                                          bc_collocs, bc_data, idx_pde, idx_bc, "Glorot")
    
    results[pde_name]["big"]["Glorot"]["specE_list"] = specE_list
    results[pde_name]["big"]["Glorot"]["specB_list"] = specB_list
    results[pde_name]["big"]["Glorot"]["tau_list"] = tau_list
    results[pde_name]["big"]["Glorot"]["condsE"] = condsE
    results[pde_name]["big"]["Glorot"]["condsB"] = condsB

    # Power Law
    results[pde_name]["big"]["Power"] = dict()
    
    power_model = KAN(layer_dims = layer_dims, layer_type = 'spline', required_parameters = params_big_power, seed = seed+1)
    power_opt = nnx.Optimizer(power_model, opt_type)

    specE_list, specB_list, tau_list, condsE, condsB = run_experiment_pde(pde_res, power_model, power_opt, refsol, coords, pde_collocs, 
                                                                          bc_collocs, bc_data, idx_pde, idx_bc, "Power")
    
    results[pde_name]["big"]["Power"]["specE_list"] = specE_list
    results[pde_name]["big"]["Power"]["specB_list"] = specB_list
    results[pde_name]["big"]["Power"]["tau_list"] = tau_list
    results[pde_name]["big"]["Power"]["condsE"] = condsE
    results[pde_name]["big"]["Power"]["condsB"] = condsB

In [None]:
# Save results for further processing
results_dir = 'pde_results/'

with open(os.path.join(results_dir, "ntk.pkl"), "wb") as f:
    pickle.dump(results, f)