# Main Plots

Given that all previous notebooks (1-6) have run successfully, this notebook generates all the plots shown in the main text using the previous notebooks' outputs.

In [None]:
import pickle
import os

import numpy as np

import matplotlib.pyplot as plt
import matplotlib.lines as mlines

import seaborn as sns

In [None]:
# Create the directory if it doesn't exist
plots_dir = "main_plots"
os.makedirs(plots_dir, exist_ok=True)

## Training Curves: Function Fitting

In [None]:
results_dir = 'ff_results/'

with open(os.path.join(results_dir, "losses.pkl"), "rb") as f:
    results = pickle.load(f)

In [None]:
# Plotting
num_epochs = 2000
cmap = sns.color_palette("crest", as_cmap=True)
spectral_points = np.linspace(0, 1, 20)
color_indices = [0, 10, -1]

TITLE_FS = 22
LABEL_FS = 20
TICK_FS  = 18

init_types = ['baseline', 'glorot', 'power']
architectures = ['small', 'big']
func_names = list(results.keys())
func_plot_names = [r'$f_1(x,y)$', r'$f_2(x,y)$', r'$f_3(x,y)$', r'$f_4(x,y)$', r'$f_5(x,y)$']

colors = [cmap(spectral_points[i]) for i in color_indices]
custom_colors = dict(zip(init_types, colors))

fig, axes = plt.subplots(2, 5, figsize=(25, 10))

for col, func_name in enumerate(func_names):
    for row, arch in enumerate(architectures):
        ax = axes[row, col]
        
        for init in init_types:
            # Collect all runs for this configuration
            runs = []
            for run in results[func_name][arch]:
                arr = np.array(results[func_name][arch][run][init])
                runs.append(arr)
            runs = np.stack(runs)

            # Compute mean and standard error
            mean = runs.mean(axis=0)
            stderr = runs.std(axis=0) / np.sqrt(runs.shape[0])

            # Plot mean with stderr shaded area
            ax.plot(mean, label=init, color=custom_colors[init])
            ax.fill_between(np.arange(num_epochs), mean - stderr, mean + stderr, alpha=0.3, color=custom_colors[init])
            
            ax.tick_params(axis='both', labelsize=TICK_FS)

        # Labeling
        if row == 0:
            ax.set_title(func_plot_names[col], fontsize=TITLE_FS)
        if col == 0:
            ax.set_ylabel("Training Loss", fontsize=LABEL_FS, labelpad=10)
        if row == 1:
            ax.set_xlabel("Training Iteration", fontsize=LABEL_FS, labelpad=10)
        if col == len(func_names) - 1:
            ax.text(1.10, 0.5, r'$G = 5$, depth = 2, width = 8' if row == 0 else r'$G = 20$, depth = 3, width = 32', transform=ax.transAxes,
                    fontsize=TICK_FS, rotation=270, va='center', ha='left')

        ax.set_yscale('log')
        ax.grid(True, which='both', linestyle='--', linewidth=0.25, alpha=0.55)

# Construct legend manually
handles = [
    mlines.Line2D([], [], color=custom_colors['baseline'], label='Baseline', linewidth=3),
    mlines.Line2D([], [], color=custom_colors['glorot'], label='Glorot', linewidth=3),
    mlines.Line2D([], [], color=custom_colors['power'], label='Power-Law', linewidth=3),
]

# Add global legend
fig.legend(handles=handles, loc="lower center", ncol=4, fontsize=LABEL_FS, frameon=False, bbox_to_anchor=(0.5, -0.05))

plt.subplots_adjust(hspace=0.35, wspace=0.3, bottom=0.1)

fig.savefig(os.path.join(plots_dir, "ff_losses.pdf"), bbox_inches='tight')

plt.show()

## Training Curves: PDEs

In [None]:
results_dir = 'pde_results/'

with open(os.path.join(results_dir, "losses.pkl"), "rb") as f:
    results = pickle.load(f)

In [None]:
# Plotting
num_epochs = 5000
cmap = sns.color_palette("crest", as_cmap=True)
spectral_points = np.linspace(0, 1, 20)
color_indices = [0, 10, -1]

TITLE_FS = 18
LABEL_FS = 16
TICK_FS  = 14

init_types = ['baseline', 'glorot', 'power']
architectures = ['small', 'big']
func_names = list(results.keys())
func_plot_names = ['Allen-Cahn', 'Burgers', 'Helmholtz']

colors = [cmap(spectral_points[i]) for i in color_indices]
custom_colors = dict(zip(init_types, colors))

fig, axes = plt.subplots(2, 3, figsize=(20, 7))

for col, func_name in enumerate(func_names):
    for row, arch in enumerate(architectures):
        ax = axes[row, col]
        
        for init in init_types:
            # Collect all runs for this configuration
            runs = []
            for run in results[func_name][arch]:
                arr = np.array(results[func_name][arch][run][init])
                runs.append(arr)
            runs = np.stack(runs)

            # Compute mean and standard error
            mean = runs.mean(axis=0)
            stderr = runs.std(axis=0) / np.sqrt(runs.shape[0])

            # Plot mean with stderr shaded area
            ax.plot(mean, label=init, color=custom_colors[init])
            ax.fill_between(np.arange(num_epochs), mean - stderr, mean + stderr, alpha=0.3, color=custom_colors[init])
            
            ax.tick_params(axis='both', labelsize=TICK_FS)

        # Labeling
        if row == 0:
            ax.set_title(func_plot_names[col], fontsize=TITLE_FS)
        if col == 0:
            ax.set_ylabel("Training Loss", fontsize=LABEL_FS, labelpad=10)
        if row == 1:
            ax.set_xlabel("Training Iteration", fontsize=LABEL_FS, labelpad=10)
        if col == len(func_names) - 1:
            ax.text(1.10, 0.5, '            G = 5\ndepth = 2, width = 8' if row == 0 else '            G = 20\ndepth = 3, width = 32', transform=ax.transAxes,
                    fontsize=LABEL_FS, rotation=270, va='center', ha='left')
        
        ax.set_yscale('log')
        ax.grid(True, which='both', linestyle='--', linewidth=0.25, alpha=0.55)

# Construct legend manually
handles = [
    mlines.Line2D([], [], color=custom_colors['baseline'], label='Baseline', linewidth=3),
    mlines.Line2D([], [], color=custom_colors['glorot'], label='Glorot', linewidth=3),
    mlines.Line2D([], [], color=custom_colors['power'], label='Power-Law', linewidth=3),
]

# Add global legend
fig.legend(handles=handles, loc="lower center", ncol=4, fontsize=LABEL_FS, frameon=False, bbox_to_anchor=(0.5, -0.08))

plt.subplots_adjust(hspace=0.25, wspace=0.2, bottom=0.1)

fig.savefig(os.path.join(plots_dir, "pde_losses.pdf"), bbox_inches='tight')

plt.show()

## NTK Plots: Function Fitting

In [None]:
results_dir = 'ff_results/'

with open(os.path.join(results_dir, "ntk.pkl"), "rb") as f:
    results = pickle.load(f)

In [None]:
def plot_ntk_row(results, func_name):

    TITLE_FS = 14
    LABEL_FS = 12
    TICK_FS  = 10

    palette = sns.color_palette("crest", 20)
    c_init  = palette[-1]
    c_mid   = palette[10]
    c_final = palette[0]

    fig, axes = plt.subplots(1, 3, figsize=(16, 3), sharex=True, sharey=False)

    for col, init in enumerate(["Baseline", "Glorot", "Power"]):
        ax = axes[col]
        rec = results[func_name]["big"].get(init)
        if rec is None:
            ax.set_visible(False)
            continue

        specs = [np.asarray(e) for e in rec["spec_list"]]
        taus  = rec["tau_list"]
        idx   = np.arange(1, specs[0].size + 1)

        # Initialization
        ax.plot(idx, specs[0], color=c_init, lw=2, label="Initialization")
        # Intermediates
        if len(specs) > 2:
            for lam in specs[1:-1]:
                ax.plot(idx, lam, "--", color=c_mid, alpha=0.7, label="Intermediate Iterations" if col==0 else None)
        # Final
        ax.plot(idx, specs[-1], "--", color=c_final, lw=2, label="Final Iteration")

        ax.set_xscale("log"); ax.set_yscale("log")
        title = f"{init}" if init != "Power" else "Power-Law"
        ax.set_title(title, fontsize=TITLE_FS)
        if col == 0:
            ax.set_ylabel("Eigenvalues", fontsize=LABEL_FS)
        ax.set_xlabel("Indices", fontsize=LABEL_FS)
        ax.tick_params(axis='both', labelsize=TICK_FS)

        ax.grid(True, which='both', linestyle='--', linewidth=0.25, alpha=0.55)

        # Right-side row annotation (like your style)
        #if col == len(INIT_ORDER) - 1:
        #    ax.text(1.05, 0.5, "G = 20, depth = 3, width = 32", transform=ax.transAxes,
        #            fontsize=14, rotation=270, va='center', ha='left')

    # Global legend outside
    handles = [
        mlines.Line2D([], [], color=c_init,  label="Initialization",        linewidth=2),
        mlines.Line2D([], [], color=c_mid,   label="Intermediate Iterations", linewidth=2, linestyle="--"),
        mlines.Line2D([], [], color=c_final, label="Final Iteration",       linewidth=2, linestyle="--"),
    ]
    
    fig.legend(handles=handles, loc="lower center", ncol=3, fontsize=LABEL_FS,
               frameon=False, bbox_to_anchor=(0.5, -0.15))

    plt.subplots_adjust(hspace=0.2, wspace=0.2, bottom=0.18)

    fig.savefig(os.path.join(plots_dir, f"ntk_{func_name}.pdf"), bbox_inches='tight')
    plt.show()

In [None]:
for func_name in ["f1", "f2", "f3", "f4", "f5"]:
    plot_ntk_row(results, func_name)

In [None]:
plot_ntk_row(results, "f3")

## NTK Plots: PDEs

In [None]:
results_dir = 'pde_results/'

with open(os.path.join(results_dir, "ntk.pkl"), "rb") as f:
    results = pickle.load(f)

In [None]:
def plot_ntk_rows(results, pde_name):

    TITLE_FS = 14
    LABEL_FS = 12
    TICK_FS  = 10

    palette = sns.color_palette("crest", 20)
    c_init  = palette[-1]
    c_mid   = palette[10]
    c_final = palette[0]

    fig, axes = plt.subplots(2, 3, figsize=(16, 6), sharex=False, sharey=False)

    for col, init in enumerate(["Baseline", "Glorot", "Power"]):
        rec = results[pde_name]["big"].get(init)
        if rec is None:
            axes[0, col].set_visible(False)
            axes[1, col].set_visible(False)
            continue

        # ---- Row 0: PDE spectra ----
        ax0 = axes[0, col]
        pde_specs = [np.asarray(e) for e in rec["specE_list"]]
        taus      = rec["tau_list"]
        idx       = np.arange(1, pde_specs[0].size + 1)

        ax0.plot(idx, pde_specs[0], color=c_init, lw=2, label="Initialization")
        if len(pde_specs) > 2:
            for lam in pde_specs[1:-1]:
                ax0.plot(idx, lam, "--", color=c_mid, alpha=0.7,
                         label="Intermediate Iterations" if (col==0) else None)
        ax0.plot(idx, pde_specs[-1], "--", color=c_final, lw=3, label="Final Iteration")

        ax0.set_xscale("log"); ax0.set_yscale("log")
        title = f"{init}" if init != "Power" else "Power-Law"
        ax0.set_title(title, fontsize=TITLE_FS)
        if col == 0:
            ax0.set_ylabel("Eigenvalues (PDE)", fontsize=LABEL_FS)
        ax0.tick_params(axis='both', labelsize=TICK_FS)
        ax0.grid(True, which='both', linestyle='--', linewidth=0.25, alpha=0.55)

        # ---- Row 1: BC spectra ----
        ax1 = axes[1, col]
        bc_specs = [np.asarray(e) for e in rec["specB_list"]]
        idx_b    = np.arange(1, bc_specs[0].size + 1)

        ax1.plot(idx_b, bc_specs[0], color=c_init, lw=2)
        if len(bc_specs) > 2:
            for lam in bc_specs[1:-1]:
                ax1.plot(idx_b, lam, "--", color=c_mid, alpha=0.7)
        ax1.plot(idx_b, bc_specs[-1], "--", color=c_final, lw=2)

        ax1.set_xscale("log"); ax1.set_yscale("log")
        if col == 0:
            ax1.set_ylabel("Eigenvalues (BC)", fontsize=LABEL_FS)
        ax1.set_xlabel("Indices", fontsize=LABEL_FS)
        ax1.tick_params(axis='both', labelsize=TICK_FS)
        ax1.grid(True, which='both', linestyle='--', linewidth=0.25, alpha=0.55)

    # Global legend outside (one set for all panels)
    handles = [
        mlines.Line2D([], [], color=c_init,  label="Initialization",          linewidth=2),
        mlines.Line2D([], [], color=c_mid,   label="Intermediate Iterations", linewidth=2, linestyle="--"),
        mlines.Line2D([], [], color=c_final, label="Final Iteration",         linewidth=2, linestyle="--"),
    ]
    
    fig.legend(handles=handles, loc="lower center", ncol=3, fontsize=LABEL_FS,
               frameon=False, bbox_to_anchor=(0.5, -0.02))

    plt.subplots_adjust(hspace=0.25, wspace=0.2, bottom=0.16)

    fig.savefig(os.path.join(plots_dir, f"ntk_{pde_name}.pdf"), bbox_inches="tight")
    plt.show()

In [None]:
for pde_name in ["allen-cahn", "burgers", "helmholtz"]:
    plot_ntk_rows(results, pde_name)

In [None]:
plot_ntk_rows(results, 'allen-cahn')