# Experiment 2: PDE Solving with Chebyshev Layers

We perform a grid search over varying architectures, to get an understanding of the behavior for small, medium and large KANs.

In [None]:
import jax
import jax.numpy as jnp

jax.config.update("jax_default_matmul_precision", "highest")

from src.equations import burgers_res, helmholtz_14_res

from src.kan import KAN

import numpy as np

import optax
from flax import nnx

import os

def get_collocs(pde_name, N):
        
    if pde_name == "burgers":

        # PDE collocation points
        t_pde = jnp.linspace(0, 1, N)
        x_pde = jnp.linspace(-1, 1, N)
        T_pde, X_pde = jnp.meshgrid(t_pde, x_pde, indexing='ij')
        pde_collocs = jnp.stack([T_pde.flatten(), X_pde.flatten()], axis=1) # (N^2, 2)
    
        # Initial condition: - sin(πx)
        t_ic = jnp.array([0.0])
        T_ic, X_ic = jnp.meshgrid(t_ic, x_pde, indexing='ij')
        ic_collocs = jnp.stack([T_ic.flatten(), X_ic.flatten()], axis=1) # (N, 2)
        ic_data = -jnp.sin(jnp.pi*ic_collocs[:,1]).reshape(-1,1) # (N, 1)

        # Boundary conditions: u(t,-1) = u(t,1) = 0
        x_bc_1 = jnp.array([-1.0])
        T_bc_1, X_bc_1 = jnp.meshgrid(t_pde, x_bc_1, indexing='ij')
        bc_1 = jnp.stack([T_bc_1.flatten(), X_bc_1.flatten()], axis=1) # (N, 2)
        bc_1_data = jnp.zeros(bc_1.shape[0]).reshape(-1,1) # (N, 1)

        x_bc_2 = jnp.array([1.0])
        T_bc_2, X_bc_2 = jnp.meshgrid(t_pde, x_bc_2, indexing='ij')
        bc_2 = jnp.stack([T_bc_2.flatten(), X_bc_2.flatten()], axis=1) # (N, 2)
        bc_2_data = jnp.zeros(bc_2.shape[0]).reshape(-1,1) # (N, 1)

        bc_collocs = jnp.concatenate([ic_collocs, bc_1, bc_2], axis=0)
        bc_data = jnp.concatenate([ic_data, bc_1_data, bc_2_data], axis=0)

    elif pde_name == "helmholtz_1-4":
        # PDE collocation points
        x_pde = jnp.linspace(-1, 1, N)
        y_pde = jnp.linspace(-1, 1, N)
        X_pde, Y_pde = jnp.meshgrid(x_pde, y_pde, indexing='ij')
        pde_collocs = jnp.stack([X_pde.flatten(), Y_pde.flatten()], axis=1) # (N^2, 2)
    
        # Boundary conditions: u(-1,y) = u(1,y) = u(x,-1) = u(x,1) = 0
        x_bc_1 = jnp.array([-1.0])
        X_bc_1, Y_bc_1 = jnp.meshgrid(x_bc_1, y_pde, indexing='ij')
        bc_1 = jnp.stack([X_bc_1.flatten(), Y_bc_1.flatten()], axis=1) # (N, 2)
        bc_1_data = jnp.zeros(bc_1.shape[0]).reshape(-1,1) # (N, 1)

        x_bc_2 = jnp.array([1.0])
        X_bc_2, Y_bc_2 = jnp.meshgrid(x_bc_2, y_pde, indexing='ij')
        bc_2 = jnp.stack([X_bc_2.flatten(), Y_bc_2.flatten()], axis=1) # (N, 2)
        bc_2_data = jnp.zeros(bc_2.shape[0]).reshape(-1,1) # (N, 1)

        y_bc_3 = jnp.array([-1.0])
        X_bc_3, Y_bc_3 = jnp.meshgrid(x_pde, y_bc_3, indexing='ij')
        bc_3 = jnp.stack([X_bc_3.flatten(), Y_bc_3.flatten()], axis=1) # (N, 2)
        bc_3_data = jnp.zeros(bc_3.shape[0]).reshape(-1,1) # (N, 1)

        y_bc_4 = jnp.array([1.0])
        X_bc_4, Y_bc_4 = jnp.meshgrid(x_pde, y_bc_4, indexing='ij')
        bc_4 = jnp.stack([X_bc_4.flatten(), Y_bc_4.flatten()], axis=1) # (N, 2)
        bc_4_data = jnp.zeros(bc_4.shape[0]).reshape(-1,1) # (N, 1)

        bc_collocs = jnp.concatenate([bc_1, bc_2, bc_3, bc_4], axis=0)
        bc_data = jnp.concatenate([bc_1_data, bc_2_data, bc_3_data, bc_4_data], axis=0)

    return pde_collocs, bc_collocs, bc_data


def get_ref(pde_name):

    ref = np.load(f'data/{pde_name}.npz')
    refsol = jnp.array(ref['usol'])

    N_t, N_x = ref['usol'].shape

    if pde_name != "helmholtz_1-4":
    
        t = ref['t'].flatten()
        x = ref['x'].flatten()
        T, X = jnp.meshgrid(t, x, indexing='ij')
        coords = jnp.stack([T.flatten(), X.flatten()], axis=1)

    else:

        x = ref['x'].flatten()
        y = ref['y'].flatten()
        X, Y = jnp.meshgrid(x, y, indexing='ij')
        coords = jnp.stack([X.flatten(), Y.flatten()], axis=1)

    return refsol, coords

# Create the directory if it doesn't exist
results_dir = "results"
os.makedirs(results_dir, exist_ok=True)

plots_dir = "plots"
os.makedirs(plots_dir, exist_ok=True)

# Define the experiment
experiment_name = "small_pde_D8"
results_file = os.path.join(results_dir, f"{experiment_name}.csv")

# Define the file header
header = "PDE, width, depth, init_type, run, loss, l2"

# Check if the file exists and write the header if it doesn't
if not os.path.exists(results_file):
    with open(results_file, "w") as file:
        file.write(header + "\n")
        
seed = 42

## Grid-Search Parameters

In [None]:
# Define the studied PDEs
pde_dict = {"burgers": burgers_res, "helmholtz_1-4": helmholtz_14_res}

D = 8
period_axes = None
rff_std = None

# Define the two types of initialization
base_init = {'type': 'default'}
glorot_init = {'type': 'glorot', 'gain': None, 'norm_pow': 0, 'distribution': 'uniform', 'sample_size': 10000}

# Number of sampled points
N_points = 2**6

# Number of training iterations
num_epochs = 5000

# Define simple optimizer
opt_type = optax.adam(learning_rate=0.001)

# Architecture settings
widths = [2, 4, 8, 16, 32, 64]
depths = [2, 3, 4, 5]

## Grid Search

In [None]:
# Procedure
for pde_name in pde_dict.keys():
    print(f"Running Experiments for {pde_name} PDE.")
    pde_res = pde_dict[pde_name]

    # Define the loss function for this PDE
    def loss_fn(model, pde_collocs, bc_collocs, bc_data):

        # ------------- PDE ---------------------------- #
        pde_residuals = pde_res(model, pde_collocs)
    
        pde_loss = jnp.mean(pde_residuals**2)
    
    
        # ------------- BC ----------------------------- #
        bc_residuals = model(bc_collocs) - bc_data
    
        bc_loss = jnp.mean(bc_residuals**2)
    
        
        # ------------- Total --------------------------- #
        total_loss = pde_loss + bc_loss
    
        return total_loss

    # Define the train step
    @nnx.jit
    def train_step(model, optimizer, pde_collocs, bc_collocs, bc_data):
        loss, grads = nnx.value_and_grad(loss_fn, has_aux = False)(model, pde_collocs, bc_collocs, bc_data)
    
        optimizer.update(grads)
    
        return loss

    # Get the reference solution
    refsol, coords = get_ref(pde_name)

    # Get collocation points
    pde_collocs, bc_collocs, bc_data = get_collocs(pde_name, N_points)

    # Model input/output
    n_in, n_out = pde_collocs.shape[1], bc_data.shape[1]

    # Grid search over depth size
    for depth in depths:

        # Grid search over width size
        for width in widths:

            # Discern between baseline initialization and Glorot-like initialization
            for init_scheme in [base_init, glorot_init]:

                type_init = init_scheme['type']
            
                print(f"\tTraining model with depth = {depth} and width = {width} ({type_init} init).")

                for run in [1, 2, 3, 4, 5]:

                    model = KAN(n_in = n_in, n_out = n_out, n_hidden = width, num_layers = depth, D = D,
                                init_scheme = init_scheme, period_axes = period_axes, rff_std = rff_std,
                                seed = seed+run)
                    
                    optimizer = nnx.Optimizer(model, opt_type)
                
                    # Train
                    for epoch in range(num_epochs):
                        train_loss = train_step(model, optimizer, pde_collocs, bc_collocs, bc_data)
                
                    # Evaluate
                    output = model(coords).reshape(refsol.shape)
                    l2error = jnp.linalg.norm(output-refsol)/jnp.linalg.norm(refsol)
                
                    # Log results
                    new_row = f"{pde_name}, {width}, {depth}, {type_init}, {run}, {train_loss}, {l2error}"
                                    
                    # Append the row to the file
                    with open(results_file, "a") as rfile:
                        rfile.write(new_row + "\n")

                    print(f"\t\t\t{run}. Final loss: {train_loss:.2e} \tRel. L2 Error: {l2error:.2e}")

## Analysis

Let's first determine how many times glorot overshadows default.

In [None]:
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
from matplotlib import colors

# -----------------------------
# Config
# -----------------------------
agg_func = "mean"   # or "mean"

# -----------------------------
# Font sizes
# -----------------------------
FONT = {
    "title": 20,
    "xlabel": 18,
    "ylabel": 18,
    "ticks": 16,
    "cbar_label": 18,
    "cbar_ticks": 16,
}

# -----------------------------
# Load & reduce data
# -----------------------------
df = pd.read_csv(results_file, sep=', ')

df = df[df["depth"] != 1]                      # exclude depth=1

df_red = (
    df.groupby(["PDE", "width", "depth", "init_type"])["l2"]
      .agg(agg_func)
      .reset_index()
)

df_pivot = df_red.pivot(index=["PDE","width","depth"],
                        columns="init_type", values="l2").reset_index()

df_pivot["width"] = pd.to_numeric(df_pivot["width"])
df_pivot["depth"] = pd.to_numeric(df_pivot["depth"])
W = sorted(df_pivot["width"].unique().tolist())
D = sorted(df_pivot["depth"].unique().tolist())

pdes = sorted(df_pivot["PDE"].unique().tolist())
W = sorted(df_pivot["width"].unique().tolist())
D = sorted(df_pivot["depth"].unique().tolist())

def build_matrix_for_col(f, col):
    sub = df_pivot[df_pivot["PDE"] == f]
    M = np.full((len(D), len(W)), np.nan)
    for i, d in enumerate(D):
        for j, w in enumerate(W):
            row = sub[(sub["width"] == w) & (sub["depth"] == d)]
            if not row.empty and col in row:
                M[i, j] = row.iloc[0][col]
    return M

mats_glorot  = {f: build_matrix_for_col(f, "glorot")  for f in pdes}
mats_default = {f: build_matrix_for_col(f, "default") for f in pdes}

# helper: LaTeX title
def pde_title(name: str) -> str:
    if name == 'burgers':
        return "Burgers"
    else:
        return "Helmholtz"

mats_pct = {}
mats_default_wins = {}
for f in pdes:
    G = mats_glorot[f]
    Df = mats_default[f]
    Mpct = np.full_like(G, np.nan, dtype=float)
    Mwins = np.full_like(G, False, dtype=bool)
    for i in range(len(D)):
        for j in range(len(W)):
            g, dft = G[i, j], Df[i, j]
            if np.isfinite(g) and np.isfinite(dft) and dft > 0:
                val = (dft - g) / dft * 100.0
                if val >= 0:
                    Mpct[i, j] = val        # keep Glorot improvements
                Mwins[i, j] = (g > dft)    # True if Default wins
    mats_pct[f] = Mpct
    mats_default_wins[f] = Mwins

# colormap
cmap = sns.color_palette("Spectral", as_cmap=True)


# -------- Figure: centered 3-over-2 mosaic --------
fig, axes = plt.subplots(1, 2, figsize=(10, 4), constrained_layout=True)

im_last = None
for ax, f in zip(axes, pdes):
    Mpct = mats_pct[f]
    wins_default = mats_default_wins[f]

    # main heatmap (Glorot improvements, clipped 0–100 %)
    im_last = ax.imshow(
        Mpct,
        cmap=cmap,
        vmin=0, vmax=100,
        origin="lower",
        aspect="auto",
    )

    # overlay solid red where Default wins
    red_mask = np.where(wins_default, 1.0, np.nan)
    ax.imshow(
        red_mask,
        cmap=colors.ListedColormap(["black"]),
        origin="lower",
        aspect="auto",
        alpha=0.85,
        vmin=0.0, vmax=1.0,
    )

    ax.set_xticks(range(len(W)))
    ax.set_xticklabels(W, fontsize=FONT["ticks"])
    ax.set_yticks(range(len(D)))
    ax.set_yticklabels(D, fontsize=FONT["ticks"])
    ax.set_xlabel("Hidden Layer Dimension", fontsize=FONT["xlabel"])
    ax.set_ylabel("Hidden Layers", fontsize=FONT["ylabel"])
    ax.set_title(pde_title(f), fontsize=FONT["title"])

# single shared colorbar at the right
cbar = fig.colorbar(im_last, ax=axes, shrink=0.85, location="right", pad=0.02)
cbar.set_label("Initialization improvement\nover Default (%)", fontsize=FONT["cbar_label"])
cbar.ax.tick_params(labelsize=FONT["cbar_ticks"])

plt.savefig(f"{plots_dir}/pdes_heat.pdf", format="pdf", bbox_inches="tight")
plt.show()

## Loss Plots

Given these results, we rerun some experiments to also derive plots for the losses.

In [None]:
loss_dict = dict()

# Procedure
for pde_name in pde_dict.keys():
    print(f"Running Experiments for {pde_name} PDE.")
    pde_res = pde_dict[pde_name]

    loss_dict[pde_name] = dict()

    # Define the loss function for this PDE
    def loss_fn(model, pde_collocs, bc_collocs, bc_data):

        # ------------- PDE ---------------------------- #
        pde_residuals = pde_res(model, pde_collocs)
    
        pde_loss = jnp.mean(pde_residuals**2)
    
    
        # ------------- BC ----------------------------- #
        bc_residuals = model(bc_collocs) - bc_data
    
        bc_loss = jnp.mean(bc_residuals**2)
    
        
        # ------------- Total --------------------------- #
        total_loss = pde_loss + bc_loss
    
        return total_loss

    # Define the train step
    @nnx.jit
    def train_step(model, optimizer, pde_collocs, bc_collocs, bc_data):
        loss, grads = nnx.value_and_grad(loss_fn, has_aux = False)(model, pde_collocs, bc_collocs, bc_data)
    
        optimizer.update(grads)
    
        return loss

    # Get the reference solution
    refsol, coords = get_ref(pde_name)

    # Get collocation points
    pde_collocs, bc_collocs, bc_data = get_collocs(pde_name, N_points)

    # Model input/output
    n_in, n_out = pde_collocs.shape[1], bc_data.shape[1]

    # Different architectures
    for arch_name, arch in zip(["small", "big"], [(4, 3), (16, 5)]):

        loss_dict[pde_name][arch_name] = dict()

        width, depth = arch

        # Discern between baseline initialization and Glorot-like initialization
        for init_scheme in [base_init, glorot_init]:

            type_init = init_scheme['type']

            loss_dict[pde_name][arch_name][type_init] = []
            train_losses = jnp.zeros(num_epochs)
        
            print(f"\tTraining model with depth = {depth} and width = {width} ({type_init} init).")

            for run in [1, 2, 3, 4, 5]:

                model = KAN(n_in = n_in, n_out = n_out, n_hidden = width, num_layers = depth, D = D,
                            init_scheme = init_scheme, period_axes = period_axes, rff_std = rff_std,
                            seed = seed+run)
                
                optimizer = nnx.Optimizer(model, opt_type)
            
                # Train
                for epoch in range(num_epochs):
                    train_loss = train_step(model, optimizer, pde_collocs, bc_collocs, bc_data)
                    train_losses = train_losses.at[epoch].set(train_loss)
            
                # Evaluate
                output = model(coords).reshape(refsol.shape)
                l2error = jnp.linalg.norm(output-refsol)/jnp.linalg.norm(refsol)
            
                loss_dict[pde_name][arch_name][type_init].append(train_losses)

                print(f"\t\t{run}. Final loss: {train_loss:.2e} \tRel. L2 Error: {l2error:.2e}")

In [None]:
import pickle

loss_file = os.path.join(results_dir, "pde_losses.pkl")

with open(loss_file, "wb") as f:
    pickle.dump(loss_dict, f)

In [None]:
import pickle

with open(os.path.join(results_dir, "pde_losses.pkl"), "rb") as f:
    loss_dict = pickle.load(f)

In [None]:
import numpy as np
import matplotlib.pyplot as plt
from matplotlib.ticker import (
    LogLocator, LogFormatterMathtext,
    FixedLocator, NullLocator, NullFormatter
)

# --- config you can tweak ---
PDES = pdes
ARCHS = ["small", "big"]
TITLE_FS = 22
LABEL_FS = 20
TICK_FS  = 18
LEGEND_FS = 20

cmap = plt.get_cmap("Spectral_r")
cmap_points = np.linspace(0, 1, 12)
color_indices = [-2, 1]
colors = [cmap(cmap_points[i]) for i in color_indices]

def _stack_runs(runs):
    """Stack a list of 1D arrays to shape (n_runs, n_epochs), trimming to min length if needed."""
    runs = [np.asarray(r).ravel() for r in runs if r is not None]
    if not runs:
        return None
    m = min(map(len, runs))
    return np.stack([r[:m] for r in runs], axis=0).astype(float)

def _set_log_ticks(ax, tick_fs, fixed_ticks=None):
    ax.set_yscale("log")
    if fixed_ticks is None:
        # powers-of-10 only, no minor ticks
        ax.yaxis.set_major_locator(LogLocator(base=10.0))
        ax.yaxis.set_major_formatter(LogFormatterMathtext(base=10.0))
        ax.yaxis.set_minor_locator(NullLocator())
        ax.yaxis.set_minor_formatter(NullFormatter())
    else:
        # exact ticks you want (e.g., [1e4]) and nothing else
        ax.yaxis.set_major_locator(FixedLocator(fixed_ticks))
        #ax.yaxis.set_major_formatter(LogFormatterMathtext(base=10.0))
        ax.yaxis.set_minor_locator(NullLocator())
        ax.yaxis.set_minor_formatter(NullFormatter())
    ax.tick_params(axis="y", which="both", labelsize=tick_fs)

def _plot_mean_with_se(ax, runs, label, color):
    """Plot mean ± standard error (shaded), return the line handle (for legend)."""
    arr = _stack_runs(runs)
    if arr is None:
        return None
    n = arr.shape[0]
    x = np.arange(arr.shape[1])
    mean = arr.mean(axis=0)
    # ddof=1 for sample std; guard n=1
    se = (arr.std(axis=0, ddof=1) / np.sqrt(n)) if n > 1 else np.zeros_like(mean)

    line, = ax.plot(x, mean, label=label, linewidth=2.0, color=color)
    ax.fill_between(x, mean - se, mean + se, alpha=0.25, color=color, linewidth=0)
    return line

def plot_training_curves(loss_dict):
    fig, axes = plt.subplots(2, 2, figsize=(12, 6), sharex=True, sharey=False, constrained_layout=True)

    legend_handles = []

    for col, func in enumerate(pdes):
        for row, arch in enumerate(ARCHS):
            ax = axes[row, col]
            if row == 0 and col == 4:
                _set_log_ticks(ax, TICK_FS, fixed_ticks=None)#[1.5e4])
            elif row == 1 and col == 4:
                _set_log_ticks(ax, TICK_FS, fixed_ticks=None)#[1e4])
            else:
                _set_log_ticks(ax, TICK_FS)
                

            # Pull runs safely; skip if missing
            runs_default = loss_dict.get(func, {}).get(arch, {}).get("default", [])
            runs_glorot  = loss_dict.get(func, {}).get(arch, {}).get("glorot",  [])

            h1 = _plot_mean_with_se(ax, runs_default, "Default Initialization", color=colors[0])
            h2 = _plot_mean_with_se(ax, runs_glorot,  "Proposed Initialization", color=colors[1])

            # Titles only on top row
            if row == 0:
                ax.set_title(pde_title(func), fontsize=TITLE_FS)

            # Axis labels on left and bottom edges
            if col == 0:
                ax.set_ylabel("Training Loss", fontsize=LABEL_FS)
            if row == len(ARCHS) - 1:
                ax.set_xlabel("Training Iteration", fontsize=LABEL_FS)

            ax.tick_params(labelsize=TICK_FS)

            ax.grid(True, which="both", linestyle="--", alpha=0.5)

            # Collect legend handles once (first subplot that has both)
            if not legend_handles and (h1 is not None or h2 is not None):
                legend_handles = [h for h in (h1, h2) if h is not None]

    # annotate rows
    axes[0, -1].annotate(" 3 hidden layers\n(dimension = 4)",
                         xy=(1.05, 0.5), xycoords="axes fraction",
                         ha="left", va="center", rotation=90, fontsize=LABEL_FS)
    
    axes[1, -1].annotate("  5 hidden layers\n(dimension = 16)",
                         xy=(1.05, 0.5), xycoords="axes fraction",
                         ha="left", va="center", rotation=90, fontsize=LABEL_FS)

    if legend_handles:
        fig.legend(legend_handles, [h.get_label() for h in legend_handles],
               loc="lower center", ncol=2, frameon=False, fontsize=LEGEND_FS,
               bbox_to_anchor=(0.5, -0.1))

    return fig, axes

# ---- call it ----
fig, axes = plot_training_curves(loss_dict)
plt.savefig(f"{plots_dir}/pde_loss.pdf", format="pdf", bbox_inches="tight")
plt.show()

In [None]:
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
from matplotlib import colors
from matplotlib.ticker import (
    LogLocator, LogFormatterMathtext,
    FixedLocator, NullLocator, NullFormatter
)

# -----------------------------
# Config
# -----------------------------
agg_func = "mean"   # or "mean"

# -----------------------------
# Font sizes
# -----------------------------
FONT = {
    "title": 20,
    "xlabel": 18,
    "ylabel": 18,
    "ticks": 16,
    "cbar_label": 18,
    "cbar_ticks": 16,
}

TITLE_FS = 22
LABEL_FS = 20
TICK_FS  = 18
LEGEND_FS = 20

ARCHS = ["small", "big"]

# -----------------------------
# Load & reduce data
# -----------------------------
df = pd.read_csv(results_file, sep=', ')

df = df[df["depth"] != 1]                      # exclude depth=1

df_red = (
    df.groupby(["PDE", "width", "depth", "init_type"])["l2"]
      .agg(agg_func)
      .reset_index()
)

df_pivot = df_red.pivot(index=["PDE","width","depth"],
                        columns="init_type", values="l2").reset_index()

df_pivot["width"] = pd.to_numeric(df_pivot["width"])
df_pivot["depth"] = pd.to_numeric(df_pivot["depth"])
W = sorted(df_pivot["width"].unique().tolist())
D = sorted(df_pivot["depth"].unique().tolist())

pdes = sorted(df_pivot["PDE"].unique().tolist())

def build_matrix_for_col(f, col):
    sub = df_pivot[df_pivot["PDE"] == f]
    M = np.full((len(D), len(W)), np.nan)
    for i, d in enumerate(D):
        for j, w in enumerate(W):
            row = sub[(sub["width"] == w) & (sub["depth"] == d)]
            if not row.empty and col in row:
                M[i, j] = row.iloc[0][col]
    return M

mats_glorot  = {f: build_matrix_for_col(f, "glorot")  for f in pdes}
mats_default = {f: build_matrix_for_col(f, "default") for f in pdes}

# helper: LaTeX title
def pde_title(name: str) -> str:
    if name == 'burgers':
        return "Burgers"
    else:
        return "Helmholtz"

mats_pct = {}
mats_default_wins = {}
for f in pdes:
    G = mats_glorot[f]
    Df = mats_default[f]
    Mpct = np.full_like(G, np.nan, dtype=float)
    Mwins = np.full_like(G, False, dtype=bool)
    for i in range(len(D)):
        for j in range(len(W)):
            g, dft = G[i, j], Df[i, j]
            if np.isfinite(g) and np.isfinite(dft) and dft > 0:
                val = (dft - g) / dft * 100.0
                if val >= 0:
                    Mpct[i, j] = val        # keep Glorot improvements
                Mwins[i, j] = (g > dft)    # True if Default wins
    mats_pct[f] = Mpct
    mats_default_wins[f] = Mwins

# colormap for heatmap
cmap = sns.color_palette("Spectral", as_cmap=True)

# colormap for line plots
cmap_points = np.linspace(0, 1, 12)
color_indices = [1, -2]
colors_lp = [cmap(cmap_points[i]) for i in color_indices]

# -----------------------------
# Helpers for training curves
# -----------------------------
def _stack_runs(runs):
    runs = [np.asarray(r).ravel() for r in runs if r is not None]
    if not runs:
        return None
    m = min(map(len, runs))
    return np.stack([r[:m] for r in runs], axis=0).astype(float)

def _set_log_ticks(ax, tick_fs, fixed_ticks=None):
    ax.set_yscale("log")
    if fixed_ticks is None:
        ax.yaxis.set_major_locator(LogLocator(base=10.0))
        ax.yaxis.set_major_formatter(LogFormatterMathtext(base=10.0))
        ax.yaxis.set_minor_locator(NullLocator())
        ax.yaxis.set_minor_formatter(NullFormatter())
    else:
        ax.yaxis.set_major_locator(FixedLocator(fixed_ticks))
        ax.yaxis.set_minor_locator(NullLocator())
        ax.yaxis.set_minor_formatter(NullFormatter())
    ax.tick_params(axis="y", which="both", labelsize=tick_fs)

def _plot_mean_with_se(ax, runs, label, color):
    arr = _stack_runs(runs)
    if arr is None:
        return None
    n = arr.shape[0]
    x = np.arange(arr.shape[1])
    mean = arr.mean(axis=0)
    se = (arr.std(axis=0, ddof=1) / np.sqrt(n)) if n > 1 else np.zeros_like(mean)

    line, = ax.plot(x, mean, label=label, linewidth=2.0, color=color)
    ax.fill_between(x, mean - se, mean + se, alpha=0.25, color=color, linewidth=0)
    return line

def _plot_two_arch_panels_for(func, ax_small, ax_big):
    for arch, ax in zip(ARCHS, (ax_small, ax_big)):
        _set_log_ticks(ax, TICK_FS)
        runs_default = loss_dict.get(func, {}).get(arch, {}).get("default", [])
        runs_glorot  = loss_dict.get(func, {}).get(arch, {}).get("glorot",  [])
        h1 = _plot_mean_with_se(ax, runs_default, "Default Initialization", color=colors_lp[0])
        h2 = _plot_mean_with_se(ax, runs_glorot,  "Proposed Initialization", color=colors_lp[1])
        ax.grid(True, which="both", linestyle="--", alpha=0.5)
    return h1, h2

# -----------------------------
# Combined 2x3 figure
# -----------------------------
mosaic = [
    ["H_top",  ".", "L_top_small", "L_top_big"],
    ["H_bot",  ".", "L_bot_small", "L_bot_big"],
]
fig, axd = plt.subplot_mosaic(
    mosaic,
    figsize=(16, 8),
    constrained_layout=True,
    width_ratios=[1, 0.1, 1, 1]  # tweak 0.12 to widen/narrow the gap
)

# Rewire to your existing axes layout: axes[row, col] = [heatmap, line1, line2]
axes = np.empty((2, 3), dtype=object)
axes[0, 0] = axd["H_top"]
axes[0, 1] = axd["L_top_small"]
axes[0, 2] = axd["L_top_big"]
axes[1, 0] = axd["H_bot"]
axes[1, 1] = axd["L_bot_small"]
axes[1, 2] = axd["L_bot_big"]

im_last = None
legend_handles = []

for r, f in enumerate(pdes):
    # heatmap in first column
    ax = axes[r, 0]
    Mpct = mats_pct[f]
    wins_default = mats_default_wins[f]

    im_last = ax.imshow(Mpct, cmap=cmap, vmin=0, vmax=100, origin="lower", aspect="auto")

    red_mask = np.where(wins_default, 1.0, np.nan)
    ax.imshow(red_mask, cmap=colors.ListedColormap(["black"]),
              origin="lower", aspect="auto", alpha=0.85, vmin=0.0, vmax=1.0)

    ax.set_xticks(range(len(W))); ax.set_xticklabels(W, fontsize=TICK_FS)
    ax.set_yticks(range(len(D))); ax.set_yticklabels(D, fontsize=TICK_FS)
    if r == 1:
        ax.set_xlabel("Hidden Layer Dimension", fontsize=LABEL_FS)
    ax.set_ylabel("Hidden Layers", fontsize=LABEL_FS)

    # line plots in columns 1 and 2
    h1, h2 = _plot_two_arch_panels_for(f, axes[r, 1], axes[r, 2])
    axes[r, 1].tick_params(axis="x", labelsize=TICK_FS)
    axes[r, 2].tick_params(axis="x", labelsize=TICK_FS)
    
    axes[r, 1].set_xticks([0, 2500, 5000])
    axes[r, 2].set_xticks([0, 2500, 5000])

    if r == 0:
        axes[r, 1].set_title("depth = 3, width = 4", fontsize=TITLE_FS)
        axes[r, 2].set_title("depth = 5, width = 16",   fontsize=TITLE_FS)
        axes[r, 1].set_ylabel("Training Loss", fontsize=LABEL_FS)
    
    if r == 1:
        axes[r, 1].set_xlabel("Training Iteration", fontsize=LABEL_FS)
        axes[r, 2].set_xlabel("Training Iteration", fontsize=LABEL_FS)
        axes[r, 1].set_ylabel("Training Loss", fontsize=LABEL_FS, labelpad=15)

    if not legend_handles:
        legend_handles = [h for h in (h1, h2) if h is not None]

# Far-right row labels
axes[0, -1].annotate("Burgers",
                     xy=(1.08, 0.5), xycoords="axes fraction",
                     ha="left", va="center", rotation=90, fontsize=TITLE_FS)
axes[1, -1].annotate("Helmholtz",
                     xy=(1.08, 0.5), xycoords="axes fraction",
                     ha="left", va="center", rotation=90, fontsize=TITLE_FS)

# Horizontal colorbar below the heatmaps
cbar = fig.colorbar(
    im_last,
    ax=[axes[0, 0], axes[1, 0]],
    orientation="horizontal",
    fraction=0.045,   # adjust if needed
    pad=0.10          # spacing from subplots
)
cbar.set_label("Initialization improvement\nover Default (%)", fontsize=LABEL_FS)
cbar.ax.tick_params(labelsize=TICK_FS)

# Legend centered under the two lineplot columns (right 2/3 of figure)
if legend_handles:
    fig.legend(
        legend_handles,
        [h.get_label() for h in legend_handles],
        loc="lower center",
        ncol=2,
        frameon=False,
        fontsize=LEGEND_FS,
        bbox_to_anchor=(0.67, 0.05)  # ~center of columns 1–2; tweak y if clipping
    )

# save + show
plt.savefig(f"{plots_dir}/small_pdes.pdf", format="pdf", bbox_inches="tight")
plt.show()
