# Experiment 1: Function Fitting with Chebyshev Layers

We perform a grid search over varying architectures, to get an understanding of the behavior for small, medium and large KANs.

In [None]:
import jax
import jax.numpy as jnp

jax.config.update("jax_default_matmul_precision", "highest")

from src.functions import *
from src.utils import generate_func_data, func_fit_step, func_fit_eval

from src.kan import KAN

from sklearn.model_selection import train_test_split

import optax
from flax import nnx

import os

# Create the directory if it doesn't exist
results_dir = "results"
os.makedirs(results_dir, exist_ok=True)

plots_dir = "plots"
os.makedirs(plots_dir, exist_ok=True)

# Define the experiment
experiment_name = "func_fit_D8"
results_file = os.path.join(results_dir, f"{experiment_name}.csv")

# Define the file header
header = "function, width, depth, init_type, run, loss, l2"

# Check if the file exists and write the header if it doesn't
if not os.path.exists(results_file):
    with open(results_file, "w") as file:
        file.write(header + "\n")
        
seed = 42

## Grid-Search Parameters

In [None]:
# Define the studied functions
func_dict = {"f1": {'func': f1, 'dim': 1},
             "f2": {'func': f2, 'dim': 2},
             "f3": {'func': f3, 'dim': 2},
             "f4": {'func': f4, 'dim': 3},
             "f5": {'func': f5, 'dim': 5}}


D = 8
period_axes = None
rff_std = None

# Define the two types of initialization
base_init = {'type': 'default'}
glorot_init = {'type': 'glorot', 'gain': None, 'norm_pow': 0, 'distribution': 'uniform', 'sample_size': 10000}

# Number of sampled points
N = 5000

# Number of training iterations
num_epochs = 2000

# Define simple optimizer
opt_type = optax.adam(learning_rate=0.001)

# Architecture settings
widths = [2, 4, 8, 16, 32, 64]
depths = [2, 3, 4, 5]

## Grid Search

In [None]:
# Procedure
for func_name in func_dict.keys():
    print(f"Running Experiments for {func_name} function.")
    function = func_dict[func_name]['func']
    dim = func_dict[func_name]['dim']

    # Generate data
    x, y = generate_func_data(function, dim, N, seed)

    # Split data, 80%-20%
    X_train, X_test, y_train, y_test = train_test_split(x, y, test_size=0.2, random_state=seed)

    # Model input/output
    n_in, n_out = X_train.shape[1], y_train.shape[1]

    # Grid search over depth size
    for depth in depths:

        # Grid search over width size
        for width in widths:

            # Discern between baseline initialization and Glorot-like initialization
            for init_scheme in [base_init, glorot_init]:

                type_init = init_scheme['type']
            
                print(f"\tTraining model with depth = {depth} and width = {width} ({type_init} init).")

                for run in [1, 2, 3, 4, 5]:

                    model = KAN(n_in = n_in, n_out = n_out, n_hidden = width, num_layers = depth, D = D,
                                init_scheme = init_scheme, period_axes = period_axes, rff_std = rff_std,
                                seed = seed+run)
            
                    optimizer = nnx.Optimizer(model, opt_type)
                
                    # Train
                    for epoch in range(num_epochs):
                        train_loss = func_fit_step(model, optimizer, X_train, y_train)
                
                    # Evaluate
                    y_pred = model(X_test)
                    if func_name == "f1":
                        res = 1000
                    elif func_name in ["f2", "f3"]:
                        res = 200
                    elif func_name in ["f4"]:
                        res = 30
                    else:
                        res = 10
                    l2error = func_fit_eval(model, function, dim, resolution=res)
                
                    # Log results
                    new_row = f"{func_name}, {width}, {depth}, {type_init}, {run}, {train_loss}, {l2error}"
                                    
                    # Append the row to the file
                    with open(results_file, "a") as rfile:
                        rfile.write(new_row + "\n")

                    print(f"\t\t\t{run}. Final loss: {train_loss:.2e} \tRel. L2 Error: {l2error:.2e}")

## Analysis

Let's first determine how many times glorot overshadows default.

In [None]:
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
from matplotlib import colors

# -----------------------------
# Config
# -----------------------------
agg_func = "mean"   # or "mean"
comparison_mode = "percent"   # "percent" or "logratio"
percent_clip_pct = 98         # robust clipping percentile for % heatmap
eps = 1e-12                   # guard against division-by-zero

# -----------------------------
# Font sizes
# -----------------------------
FONT = {
    "title": 20,
    "xlabel": 18,
    "ylabel": 18,
    "ticks": 16,
    "cbar_label": 18,
    "cbar_ticks": 16,
}

# -----------------------------
# Load & reduce data
# -----------------------------
df = pd.read_csv(results_file, sep=', ')

df = df[df["depth"] != 1]                      # exclude depth=1

df_red = (
    df.groupby(["function", "width", "depth", "init_type"])["l2"]
      .agg(agg_func)
      .reset_index()
)

df_pivot = df_red.pivot(index=["function","width","depth"],
                        columns="init_type", values="l2").reset_index()

df_pivot["width"] = pd.to_numeric(df_pivot["width"])
df_pivot["depth"] = pd.to_numeric(df_pivot["depth"])
W = sorted(df_pivot["width"].unique().tolist())
D = sorted(df_pivot["depth"].unique().tolist())

funcs = sorted(df_pivot["function"].unique().tolist())
W = sorted(df_pivot["width"].unique().tolist())
D = sorted(df_pivot["depth"].unique().tolist())

def build_matrix_for_col(f, col):
    sub = df_pivot[df_pivot["function"] == f]
    M = np.full((len(D), len(W)), np.nan)
    for i, d in enumerate(D):
        for j, w in enumerate(W):
            row = sub[(sub["width"] == w) & (sub["depth"] == d)]
            if not row.empty and col in row:
                M[i, j] = row.iloc[0][col]
    return M

mats_glorot  = {f: build_matrix_for_col(f, "glorot")  for f in funcs}
mats_default = {f: build_matrix_for_col(f, "default") for f in funcs}

# helper: LaTeX title
def func_title(name: str) -> str:
    if name == 'f1':
        x = "x"
    elif name in ['f2', 'f3']:
        x = "x_1, x_2"
    elif name == 'f4':
        x = "x_1, x_2, x_3"
    else:
        x = """x_1,\dots,x_5"""
    return rf"$f_{{{name[1:]}}}({x})$"

mats_pct = {}
mats_default_wins = {}
for f in funcs:
    G = mats_glorot[f]
    Df = mats_default[f]
    Mpct = np.full_like(G, np.nan, dtype=float)
    Mwins = np.full_like(G, False, dtype=bool)
    for i in range(len(D)):
        for j in range(len(W)):
            g, dft = G[i, j], Df[i, j]
            if np.isfinite(g) and np.isfinite(dft) and dft > 0:
                val = (dft - g) / dft * 100.0
                if val >= 0:
                    Mpct[i, j] = val        # keep Glorot improvements
                Mwins[i, j] = (g > dft)    # True if Default wins
    mats_pct[f] = Mpct
    mats_default_wins[f] = Mwins

# colormap
cmap = sns.color_palette("Spectral", as_cmap=True)


# -------- Figure: centered 3-over-2 mosaic --------
fig = plt.figure(figsize=(16, 8), constrained_layout=True)
gs = fig.add_gridspec(2, 6)

slots = [(0, slice(0, 2)), (0, slice(2, 4)), (0, slice(4, 6)),
         (1, slice(1, 3)), (1, slice(3, 5))]
axes = [fig.add_subplot(gs[r, c]) for (r, c) in slots]

im_last = None
for ax, f in zip(axes, funcs):
    Mpct = mats_pct[f]
    wins_default = mats_default_wins[f]

    # main heatmap (Glorot improvements, clipped 0–100 %)
    im_last = ax.imshow(
        Mpct,
        cmap=cmap,
        vmin=0, vmax=100,
        origin="lower",
        aspect="auto",
    )

    # overlay solid black where Default wins
    red_mask = np.where(wins_default, 1.0, np.nan)
    ax.imshow(
        red_mask,
        cmap=colors.ListedColormap(["black"]),
        origin="lower",
        aspect="auto",
        alpha=0.85,
        vmin=0.0, vmax=1.0,
    )

    ax.set_xticks(range(len(W)))
    ax.set_xticklabels(W, fontsize=FONT["ticks"])
    ax.set_yticks(range(len(D)))
    ax.set_yticklabels(D, fontsize=FONT["ticks"])
    ax.set_xlabel("Hidden Layer Dimension", fontsize=FONT["xlabel"])
    ax.set_ylabel("Hidden Layers", fontsize=FONT["ylabel"])
    ax.set_title(func_title(f), fontsize=FONT["title"])

# single shared colorbar at the right
cbar = fig.colorbar(im_last, ax=axes, shrink=0.85, location="right", pad=0.02)
cbar.set_label("Initialization improvement over Default (%)", fontsize=FONT["cbar_label"])
cbar.ax.tick_params(labelsize=FONT["cbar_ticks"])

plt.savefig(f"{plots_dir}/func_fit_heat.pdf", format="pdf", bbox_inches="tight")
plt.show()

## Loss Plots

Given these results, we rerun some experiments to also derive plots for the losses.

In [None]:
loss_dict = dict()

# Procedure
for func_name in func_dict.keys():
    print(f"Running Experiments for {func_name} function.")
    function = func_dict[func_name]['func']
    dim = func_dict[func_name]['dim']

    loss_dict[func_name] = dict()

    # Generate data
    x, y = generate_func_data(function, dim, N, seed)

    # Split data, 80%-20%
    X_train, X_test, y_train, y_test = train_test_split(x, y, test_size=0.2, random_state=seed)

    # Model input/output
    n_in, n_out = X_train.shape[1], y_train.shape[1]

    # Different architectures
    for arch_name, arch in zip(["small", "big"], [(4, 3), (16, 5)]):

        loss_dict[func_name][arch_name] = dict()

        width, depth = arch

        # Discern between baseline initialization and Glorot-like initialization
        for init_scheme in [base_init, glorot_init]:

            type_init = init_scheme['type']

            loss_dict[func_name][arch_name][type_init] = []
            train_losses = jnp.zeros(num_epochs)
        
            print(f"\tTraining model with depth = {depth} and width = {width} ({type_init} init).")

            for run in [1, 2, 3, 4, 5]:

                model = KAN(n_in = n_in, n_out = n_out, n_hidden = width, num_layers = depth, D = D,
                            init_scheme = init_scheme, period_axes = period_axes, rff_std = rff_std,
                            seed = seed+run)
        
                optimizer = nnx.Optimizer(model, opt_type)
            
                # Train
                for epoch in range(num_epochs):
                    train_loss = func_fit_step(model, optimizer, X_train, y_train)
                    train_losses = train_losses.at[epoch].set(train_loss)
            
                # Evaluate
                y_pred = model(X_test)
                if func_name == "f1":
                    res = 1000
                elif func_name in ["f2", "f3"]:
                    res = 200
                elif func_name in ["f4"]:
                    res = 30
                else:
                    res = 10
                l2error = func_fit_eval(model, function, dim, resolution=res)

                loss_dict[func_name][arch_name][type_init].append(train_losses)

                print(f"\t\t{run}. Final loss: {train_loss:.2e} \tRel. L2 Error: {l2error:.2e}")

In [None]:
import pickle

loss_file = os.path.join(results_dir, "func_losses.pkl")

with open(loss_file, "wb") as f:
    pickle.dump(loss_dict, f)

In [None]:
import pickle

with open(os.path.join(results_dir, "func_losses.pkl"), "rb") as f:
    loss_dict = pickle.load(f)

In [None]:
import numpy as np
import matplotlib.pyplot as plt
from matplotlib.ticker import (
    LogLocator, LogFormatterMathtext,
    FixedLocator, NullLocator, NullFormatter
)

# --- config you can tweak ---
FUNCS = ["f1", "f2", "f3", "f4", "f5"]
ARCHS = ["small", "big"]
TITLE_FS = 22
LABEL_FS = 20
TICK_FS  = 18
LEGEND_FS = 20

cmap = plt.get_cmap("Spectral_r")
cmap_points = np.linspace(0, 1, 12)
color_indices = [-2, 1]
colors = [cmap(cmap_points[i]) for i in color_indices]

def _func_title(name: str) -> str:
    if name == 'f1':
        x = "x"
    elif name in ['f2', 'f3']:
        x = "x_1, x_2"
    elif name == 'f4':
        x = "x_1, x_2, x_3"
    else:
        x = """x_1,\dots,x_5"""
    return rf"$f_{{{name[1:]}}}({x})$"

def _stack_runs(runs):
    """Stack a list of 1D arrays to shape (n_runs, n_epochs), trimming to min length if needed."""
    runs = [np.asarray(r).ravel() for r in runs if r is not None]
    if not runs:
        return None
    m = min(map(len, runs))
    return np.stack([r[:m] for r in runs], axis=0).astype(float)

def _set_log_ticks(ax, tick_fs, fixed_ticks=None):
    ax.set_yscale("log")
    if fixed_ticks is None:
        # powers-of-10 only, no minor ticks
        ax.yaxis.set_major_locator(LogLocator(base=10.0))
        ax.yaxis.set_major_formatter(LogFormatterMathtext(base=10.0))
        ax.yaxis.set_minor_locator(NullLocator())
        ax.yaxis.set_minor_formatter(NullFormatter())
    else:
        # exact ticks you want (e.g., [1e4]) and nothing else
        ax.yaxis.set_major_locator(FixedLocator(fixed_ticks))
        #ax.yaxis.set_major_formatter(LogFormatterMathtext(base=10.0))
        ax.yaxis.set_minor_locator(NullLocator())
        ax.yaxis.set_minor_formatter(NullFormatter())
    ax.tick_params(axis="y", which="both", labelsize=tick_fs)

def _plot_mean_with_se(ax, runs, label, color):
    """Plot mean ± standard error (shaded), return the line handle (for legend)."""
    arr = _stack_runs(runs)
    if arr is None:
        return None
    n = arr.shape[0]
    x = np.arange(arr.shape[1])
    mean = arr.mean(axis=0)
    # ddof=1 for sample std; guard n=1
    se = (arr.std(axis=0, ddof=1) / np.sqrt(n)) if n > 1 else np.zeros_like(mean)

    line, = ax.plot(x, mean, label=label, linewidth=2.0, color=color)
    ax.fill_between(x, mean - se, mean + se, alpha=0.25, color=color, linewidth=0)
    return line

def plot_training_curves(loss_dict):
    fig, axes = plt.subplots(2, 5, figsize=(18, 7), sharex=True, sharey=False, constrained_layout=True)

    legend_handles = []

    for col, func in enumerate(FUNCS):
        for row, arch in enumerate(ARCHS):
            ax = axes[row, col]
            if row == 0 and col == 4:
                _set_log_ticks(ax, TICK_FS, fixed_ticks=None)#[1.5e4])
            elif row == 1 and col == 4:
                _set_log_ticks(ax, TICK_FS, fixed_ticks=None)#[1e4])
            else:
                _set_log_ticks(ax, TICK_FS)
                

            # Pull runs safely; skip if missing
            runs_default = loss_dict.get(func, {}).get(arch, {}).get("default", [])
            runs_glorot  = loss_dict.get(func, {}).get(arch, {}).get("glorot",  [])

            h1 = _plot_mean_with_se(ax, runs_default, "Default Initialization", color=colors[0])
            h2 = _plot_mean_with_se(ax, runs_glorot,  "Proposed Initialization", color=colors[1])

            # Titles only on top row
            if row == 0:
                ax.set_title(_func_title(func), fontsize=TITLE_FS)

            # Axis labels on left and bottom edges
            if col == 0:
                ax.set_ylabel("Training Loss", fontsize=LABEL_FS)
            if row == len(ARCHS) - 1:
                ax.set_xlabel("Training Iteration", fontsize=LABEL_FS)

            ax.tick_params(labelsize=TICK_FS)

            ax.grid(True, which="both", linestyle="--", alpha=0.5)

            # Collect legend handles once (first subplot that has both)
            if not legend_handles and (h1 is not None or h2 is not None):
                legend_handles = [h for h in (h1, h2) if h is not None]

    # annotate rows
    axes[0, -1].annotate(" 3 hidden layers\n(dimension = 4)",
                         xy=(1.05, 0.5), xycoords="axes fraction",
                         ha="left", va="center", rotation=90, fontsize=LABEL_FS)
    
    axes[1, -1].annotate("  5 hidden layers\n(dimension = 16)",
                         xy=(1.05, 0.5), xycoords="axes fraction",
                         ha="left", va="center", rotation=90, fontsize=LABEL_FS)

    scale = 1e4
    for r in [0, 1]:
        ax = axes[r, 4]
    
        # switch to linear ticks, showing factors
        yticks = [1.5e4, 1.6e4, 1.7e4, 1.8e4] if r == 0 else [1.0e4, 1.2e4, 1.4e4, 1.6e4, 1.8e4]
        ax.set_yscale("linear")
        ax.set_yticks(yticks)
        ax.set_yticklabels([f"{y/scale:.1f}" for y in yticks], fontsize=TICK_FS)
    
        # add "(×10^4)" annotation above the y-axis
        ax.set_ylabel("")  # clear the original ylabel
        ax.annotate(r"$(\times 10^{4})$",
                    xy=(0.08, 1.03), xycoords="axes fraction",
                    ha="right", va="bottom", fontsize=TICK_FS)

    if legend_handles:
        fig.legend(legend_handles, [h.get_label() for h in legend_handles],
               loc="lower center", ncol=2, frameon=False, fontsize=LEGEND_FS,
               bbox_to_anchor=(0.5, -0.1))

    return fig, axes

# ---- call it ----
fig, axes = plot_training_curves(loss_dict)
plt.savefig(f"{plots_dir}/func_fit_loss.pdf", format="pdf", bbox_inches="tight")
plt.show()