# NTK Results (Function Fitting)

## Preliminaries

In [None]:
from jax import config
config.update("jax_enable_x64", True)
config.update("jax_default_matmul_precision", "highest")

In [None]:
import pickle
import os

import jax
import jax.numpy as jnp

from src.utils import generate_func_data, func_fit_step, func_fit_eval
from src.functions import *
from src.ntk import ntk_spectrum
from src.kan import KAN

from flax import nnx
import optax


from sklearn.model_selection import train_test_split

# Create the directory if it doesn't exist
results_dir = "results"
os.makedirs(results_dir, exist_ok=True)

result_file = os.path.join(results_dir, "ntk_ff.pkl")

plots_dir = "plots"
os.makedirs(plots_dir, exist_ok=True)

results = dict()

seed = 42

## Parameters

In [None]:
N = 5000
n_ntk = 256

num_epochs = 2001
checkpoints = [0, 500, 1000, 1500, 2000]

opt_type = optax.adam(learning_rate=0.001)

# Define the two types of initialization
base_init = {'type': 'default'}
glorot_init = {'type': 'glorot', 'gain': None, 'norm_pow': 0, 'distribution': 'uniform', 'sample_size': 10000}

D = 8
period_axes = None
rff_std = None

# Studied functions
func_dict = {"f1": {'func': f1, 'dim': 1},
             "f2": {'func': f2, 'dim': 2},
             "f3": {'func': f3, 'dim': 2},
             "f4": {'func': f4, 'dim': 3},
             "f5": {'func': f5, 'dim': 5}}

### Helper Functions

In [None]:
def get_data(func, dim, N, n_ntk, seed):
    
    # Generate data
    x, y = generate_func_data(func, dim, N, seed)
    
    # Split data (at this point just to ensure continuity)
    X_train, X_test, y_train, y_test = train_test_split(x, y, test_size=0.2, random_state=seed)
    
    # Subsample points used to compute NTK
    key_ntk = jax.random.PRNGKey(seed)
    idx = jax.random.choice(key_ntk, X_train.shape[0], shape=(n_ntk,), replace=False)
    X_ntk = X_train[idx]

    return X_train, y_train, X_ntk

In [None]:
def run_experiment(func_name, dim, func, model, opt, X_train, y_train, X_ntk):
    spec_list = []

    # Ï„ = 0 (before any updates)
    lam0 = ntk_spectrum(model, X_ntk)
    spec_list.append(lam0)

    for epoch in range(num_epochs):
        loss = func_fit_step(model, opt, X_train, y_train)

        if epoch in checkpoints[1:]:
            lam = ntk_spectrum(model, X_ntk)
            spec_list.append(lam)

    if func_name == "f1":
        res = 1000
    elif func_name in ["f2", "f3"]:
        res = 200
    elif func_name in ["f4"]:
        res = 30
    else:
        res = 10
                    
    l2error = func_fit_eval(model, func, dim, res)

    print(f"\tFinal Loss = {loss:.2e}\t L^2 Error = {l2error:.2e}\n")

    return spec_list

### Main Routine

In [None]:
for func_name in func_dict.keys():
    print(f"Running Experiments for {func_name} function.")
    function = func_dict[func_name]['func']
    dim = func_dict[func_name]['dim']

    results[func_name] = dict()

    # Get the data for the function
    X_train, y_train, X_ntk = get_data(function, dim, N, n_ntk, seed)

    # Model input/output
    n_in, n_out = X_train.shape[1], y_train.shape[1]

    # Different architectures
    for arch_name, arch in zip(["small", "big"], [(4, 3), (16, 5)]):

        results[func_name][arch_name] = dict()
        width, depth = arch

        # Discern between baseline initialization and Glorot-like initialization
        for init_scheme in [base_init, glorot_init]:

            type_init = init_scheme['type']
            results[func_name][arch_name][type_init] = dict()
        
            print(f"\tTraining model with depth = {depth} and width = {width} ({type_init} init).")

            model = KAN(n_in = n_in, n_out = n_out, n_hidden = width, num_layers = depth, D = D,
                        init_scheme = init_scheme, period_axes = period_axes, rff_std = rff_std,
                        seed = seed)
        
            optimizer = nnx.Optimizer(model, opt_type)
        
            spec_list = run_experiment(func_name, dim, function, model, optimizer, X_train, y_train, X_ntk)
            
            results[func_name][arch_name][type_init]["spec_list"] = spec_list
            

In [None]:
with open(result_file, "wb") as f:
    pickle.dump(results, f)

## Visualizations

In [None]:
with open(result_file, "rb") as f:
    results = pickle.load(f)

In [None]:
def func_title(name: str) -> str:
    if name == 'f1':
        x = "x"
    elif name in ['f2', 'f3']:
        x = "x_1, x_2"
    elif name == 'f4':
        x = "x_1, x_2, x_3"
    else:
        x = """x_1,\dots,x_5"""
    return rf"$f_{{{name[1:]}}}({x})$"

In [None]:
import matplotlib.pyplot as plt
import matplotlib.lines as mlines
import seaborn as sns
import numpy as np
import os

# Define the order of keys as they appear in your dictionary
# Update these strings if your 'type_init' variable acts differently
FUNCTIONS = ["f1", "f2", "f3", "f4", "f5"]
INIT_KEYS = ["default", "glorot"] # Assuming these are the strings stored in 'type_init'
init_dict = {"default": "Default", "glorot": "Glorot"}

ARCHITECTURES = ["small", "big"]

def plot_ntk_grid(results, arch_name):
    """
    Plots a 2x5 grid for a specific architecture.
    Rows: Initialization schemes (Baseline, Glorot)
    Cols: Functions (f1...f5)
    """
    TITLE_FS = 18
    LABEL_FS = 16
    TICK_FS  = 14
    
    # Setup Colors
    palette = sns.color_palette("Spectral", 20)
    c_init  = palette[0]
    c_mid   = palette[16]
    c_final = palette[-1]

    # Create 2 rows x 5 columns
    fig, axes = plt.subplots(nrows=2, ncols=5, figsize=(20, 7), sharex=True, sharey=False)
    
    # Iterate over rows (Initializations) and columns (Functions)
    for row_idx, init_key in enumerate(INIT_KEYS):
        for col_idx, func_name in enumerate(FUNCTIONS):
            
            ax = axes[row_idx, col_idx]
            
            # Safely retrieve data: results -> func -> arch -> init
            try:
                rec = results[func_name][arch_name].get(init_key)
            except KeyError:
                rec = None
            
            if rec is None or "spec_list" not in rec:
                ax.set_visible(False)
                continue

            # Process Data
            specs = [np.asarray(e) for e in rec["spec_list"]]
            idx   = np.arange(1, specs[0].size + 1)

            # --- Plotting ---
            # 1. Initialization
            ax.plot(idx, specs[0], color=c_init, lw=2, label="Initialization")
            
            # 2. Intermediates (skip if too few iterations)
            if len(specs) > 2:
                for lam in specs[1:-1]:
                    # We create a dummy loop for alpha, or just plot all mid-iters
                    ax.plot(idx, lam, "--", color=c_mid, alpha=0.5)

            # 3. Final
            ax.plot(idx, specs[-1], "--", color=c_final, lw=2, label="Final Iteration")

            # --- Styling ---
            ax.set_xscale("log")
            ax.set_yscale("log")
            ax.grid(True, which='both', linestyle='--', linewidth=0.25, alpha=0.55)
            ax.tick_params(axis='both', labelsize=TICK_FS)

            # Set Column Titles (Function Name) - Only on top row
            if row_idx == 0:
                ax.set_title(f"{func_title(func_name)}", fontsize=TITLE_FS, pad=15)

            # Set Row Labels (Init Type) - Only on left column
            if col_idx == 0:
                ax.set_ylabel(f"Eigenvalues ({init_dict[init_key]})", fontsize=LABEL_FS)
            
            # Set X Labels - Only on bottom row
            if row_idx == 1:
                ax.set_xlabel("Indices", fontsize=LABEL_FS)
                
            # Add Row annotation on the far right (Optional style choice)
            """
            if col_idx == 4:
                ax.text(1.05, 0.5, f"{init_key}", transform=ax.transAxes, 
                        rotation=270, va='center', fontsize=TITLE_FS, fontweight='bold')"""

    fig.align_ylabels(axes[:, 0])

    # Global Legend (Create custom handles to avoid clutter)
    handles = [
        mlines.Line2D([], [], color=c_init,  label="Initialization",        linewidth=2),
        mlines.Line2D([], [], color=c_mid,   label="Intermediate Iterations", linewidth=2, linestyle="--"),
        mlines.Line2D([], [], color=c_final, label="Final Iteration",       linewidth=2, linestyle="--"),
    ]
    
    fig.legend(handles=handles, loc="lower center", ncol=3, fontsize=LABEL_FS, 
               frameon=False, bbox_to_anchor=(0.5, -0.02))

    # Adjust layout to make room for titles and legend
    #plt.suptitle(f"NTK Spectra Evolution - Architecture: {arch_name.upper()}", fontsize=16, y=0.98)
    plt.subplots_adjust(top=0.88, bottom=0.15, wspace=0.25, hspace=0.1)

    # Save
    save_path = os.path.join(plots_dir, f"ntk_ff_{arch_name}.pdf")
    fig.savefig(save_path, bbox_inches='tight')
    print(f"Saved plot to {save_path}")
    plt.show()

In [None]:
# --- Execution ---
# Loop through the two architecture types
for arch in ARCHITECTURES:
    if any(arch in results[f] for f in FUNCTIONS): # Check if arch exists in data
        plot_ntk_grid(results, arch)