# Training Loss Curves

In [None]:
import pickle
import os

import jax.numpy as jnp

from src.functions import *
from src.pdes import *
from src.utils import *

from jaxkan.KAN import KAN

from flax import nnx
import optax

from sklearn.model_selection import train_test_split

## Function Fitting

We proceed with the training of the two networks mentioned in the manuscript to show the evolution of the training loss for each function, under the selected initialization techniques.

In [None]:
# Setup
func_dict = {"f1": f1, "f2": f2, "f3": f3, "f4": f4, "f5": f5}

N = 5000
seed = 42

num_epochs = 2000

opt_type = optax.adam(learning_rate=0.001)

pow_basis = 1.75
pow_res = 0.25

# --------------------------
# Small architecture details
# --------------------------
G_small = 5
hidden_small = [8, 8]

params_small_baseline = {'k': 3, 'G': G_small, 'grid_range': (-1.0, 1.0), 'grid_e': 1.0, 'residual': nnx.silu, 'external_weights': True, 'add_bias': True,
                         'init_scheme': {'type': 'default'}}

params_small_glorot = {'k': 3, 'G': G_small, 'grid_range': (-1.0, 1.0), 'grid_e': 1.0, 'residual': nnx.silu, 'external_weights': True, 'add_bias': True,
                            'init_scheme': {'type': 'glorot', 'gain': None, 'distribution': 'uniform'}}

params_small_power = {'k': 3, 'G': G_small, 'grid_range': (-1.0, 1.0), 'grid_e': 1.0, 'residual': nnx.silu, 'external_weights': True, 'add_bias': True,
                      'init_scheme': {'type': 'power', "const_b": 1.0, "const_r": 1.0, "pow_b1": pow_basis, "pow_b2": pow_basis, "pow_r1": pow_res, "pow_r2": pow_res}}

# ------------------------
# Big architecture details
# ------------------------
G_big = 20
hidden_big = [32, 32, 32]

params_big_baseline = {'k': 3, 'G': G_big, 'grid_range': (-1.0, 1.0), 'grid_e': 1.0, 'residual': nnx.silu, 'external_weights': True, 'add_bias': True,
                         'init_scheme': {'type': 'default'}}

params_big_glorot = {'k': 3, 'G': G_big, 'grid_range': (-1.0, 1.0), 'grid_e': 1.0, 'residual': nnx.silu, 'external_weights': True, 'add_bias': True,
                            'init_scheme': {'type': 'glorot', 'gain': None, 'distribution': 'uniform'}}

params_big_power = {'k': 3, 'G': G_big, 'grid_range': (-1.0, 1.0), 'grid_e': 1.0, 'residual': nnx.silu, 'external_weights': True, 'add_bias': True,
                      'init_scheme': {'type': 'power', "const_b": 1.0, "const_r": 1.0, "pow_b1": pow_basis, "pow_b2": pow_basis, "pow_r1": pow_res, "pow_r2": pow_res}}

In [None]:
# Initialize results dict
results = dict()

for func_name in func_dict.keys():
    print(f"Running Experiments for {func_name}.")
    function = func_dict[func_name]
    results[func_name] = dict()

    results[func_name]['small'] = dict()
    results[func_name]['big'] = dict()

    # Generate data
    x, y = generate_func_data(function, 2, N, seed)

    # Split data (in this case we do not care about mse loss, but we're doing it for consistency)
    X_train, X_test, y_train, y_test = train_test_split(x, y, test_size=0.2, random_state=seed)

    # Model input/output
    n_in, n_out = X_train.shape[1], y_train.shape[1]

    # Small architecture
    layer_dims = [n_in, *hidden_small, n_out]

    print(f"\tTraining model with dimensions {layer_dims}.")

    # For confidence
    for run in [1, 2, 3, 4, 5]:

        results[func_name]['small'][run] = dict()

        print(f"\t\tRun No. {run}.")

        # Baseline
        base_model = KAN(layer_dims = layer_dims, layer_type = 'spline', required_parameters = params_small_baseline, seed = seed+run)
        base_opt = nnx.Optimizer(base_model, opt_type)

        train_losses = jnp.zeros((num_epochs,))
        for epoch in range(num_epochs):
            loss = func_fit_step(base_model, base_opt, X_train, y_train)
            train_losses = train_losses.at[epoch].set(loss)

        results[func_name]['small'][run]['baseline'] = train_losses.copy()

        print(f"\t\t\tBaseline model: Final Loss = {loss:.2e}")

        # Glorot
        glorot_model = KAN(layer_dims = layer_dims, layer_type = 'spline', required_parameters = params_small_glorot, seed = seed+run)
        glorot_opt = nnx.Optimizer(glorot_model, opt_type)

        train_losses = jnp.zeros((num_epochs,))
        for epoch in range(num_epochs):
            loss = func_fit_step(glorot_model, glorot_opt, X_train, y_train)
            train_losses = train_losses.at[epoch].set(loss)

        results[func_name]['small'][run]['glorot'] = train_losses.copy()

        print(f"\t\t\tGlorot model: Final Loss = {loss:.2e}")

        # Power Law
        power_model = KAN(layer_dims = layer_dims, layer_type = 'spline', required_parameters = params_small_power, seed = seed+run)
        power_opt = nnx.Optimizer(power_model, opt_type)

        train_losses = jnp.zeros((num_epochs,))
        for epoch in range(num_epochs):
            loss = func_fit_step(power_model, power_opt, X_train, y_train)
            train_losses = train_losses.at[epoch].set(loss)

        results[func_name]['small'][run]['power'] = train_losses.copy()

        print(f"\t\t\tPower-law model: Final Loss = {loss:.2e}")

    # Big architecture
    layer_dims = [n_in, *hidden_big, n_out]

    print(f"\tTraining model with dimensions {layer_dims}.")

    # For confidence
    for run in [1, 2, 3, 4, 5]:

        results[func_name]['big'][run] = dict()

        print(f"\t\tRun No. {run}.")

        # Baseline
        base_model = KAN(layer_dims = layer_dims, layer_type = 'spline', required_parameters = params_big_baseline, seed = seed+run)
        base_opt = nnx.Optimizer(base_model, opt_type)

        train_losses = jnp.zeros((num_epochs,))
        for epoch in range(num_epochs):
            loss = func_fit_step(base_model, base_opt, X_train, y_train)
            train_losses = train_losses.at[epoch].set(loss)

        results[func_name]['big'][run]['baseline'] = train_losses.copy()

        print(f"\t\t\tBaseline model: Final Loss = {loss:.2e}")

        # Glorot
        glorot_model = KAN(layer_dims = layer_dims, layer_type = 'spline', required_parameters = params_big_glorot, seed = seed+run)
        glorot_opt = nnx.Optimizer(glorot_model, opt_type)

        train_losses = jnp.zeros((num_epochs,))
        for epoch in range(num_epochs):
            loss = func_fit_step(glorot_model, glorot_opt, X_train, y_train)
            train_losses = train_losses.at[epoch].set(loss)

        results[func_name]['big'][run]['glorot'] = train_losses.copy()

        print(f"\t\t\tGlorot model: Final Loss = {loss:.2e}")

        # Power Law
        power_model = KAN(layer_dims = layer_dims, layer_type = 'spline', required_parameters = params_big_power, seed = seed+run)
        power_opt = nnx.Optimizer(power_model, opt_type)

        train_losses = jnp.zeros((num_epochs,))
        for epoch in range(num_epochs):
            loss = func_fit_step(power_model, power_opt, X_train, y_train)
            train_losses = train_losses.at[epoch].set(loss)

        results[func_name]['big'][run]['power'] = train_losses.copy()

        print(f"\t\t\tPower-law model: Final Loss = {loss:.2e}")

In [None]:
# Save results for further processing
results_dir = 'ff_results/'

with open(os.path.join(results_dir, "losses.pkl"), "wb") as f:
    pickle.dump(results, f)

## PDE

And likewise for the PDEs.

In [None]:
# Setup
pde_dict = {"allen-cahn": ac_res, "burgers": burgers_res, "helmholtz": helmholtz_res}

N_points = 2**6

RBA_gamma = 0.999
RBA_eta = 0.01

seed = 42

num_epochs = 5000

opt_type = optax.adam(learning_rate=0.001)

pow_basis = 1.75
pow_res = 0.25

# --------------------------
# Small architecture details
# --------------------------
G_small = 5
hidden_small = [8, 8]

params_small_baseline = {'k': 3, 'G': G_small, 'grid_range': (-1.0, 1.0), 'grid_e': 1.0, 'residual': nnx.silu, 'external_weights': True, 'add_bias': True,
                         'init_scheme': {'type': 'default'}}

params_small_glorot = {'k': 3, 'G': G_small, 'grid_range': (-1.0, 1.0), 'grid_e': 1.0, 'residual': nnx.silu, 'external_weights': True, 'add_bias': True,
                            'init_scheme': {'type': 'glorot', 'gain': None, 'distribution': 'uniform'}}

params_small_power = {'k': 3, 'G': G_small, 'grid_range': (-1.0, 1.0), 'grid_e': 1.0, 'residual': nnx.silu, 'external_weights': True, 'add_bias': True,
                      'init_scheme': {'type': 'power', "const_b": 1.0, "const_r": 1.0, "pow_b1": pow_basis, "pow_b2": pow_basis, "pow_r1": pow_res, "pow_r2": pow_res}}

# ------------------------
# Big architecture details
# ------------------------
G_big = 20
hidden_big = [32, 32, 32]

params_big_baseline = {'k': 3, 'G': G_big, 'grid_range': (-1.0, 1.0), 'grid_e': 1.0, 'residual': nnx.silu, 'external_weights': True, 'add_bias': True,
                         'init_scheme': {'type': 'default'}}

params_big_glorot = {'k': 3, 'G': G_big, 'grid_range': (-1.0, 1.0), 'grid_e': 1.0, 'residual': nnx.silu, 'external_weights': True, 'add_bias': True,
                            'init_scheme': {'type': 'glorot', 'gain': None, 'distribution': 'uniform'}}

params_big_power = {'k': 3, 'G': G_big, 'grid_range': (-1.0, 1.0), 'grid_e': 1.0, 'residual': nnx.silu, 'external_weights': True, 'add_bias': True,
                      'init_scheme': {'type': 'power', "const_b": 1.0, "const_r": 1.0, "pow_b1": pow_basis, "pow_b2": pow_basis, "pow_r1": pow_res, "pow_r2": pow_res}}

In [None]:
# Experiment
# Initialize results dict
results = dict()

for pde_name in pde_dict.keys():
    print(f"Running Experiments for {pde_name} equation.")
    pde_res = pde_dict[pde_name]
    results[pde_name] = dict()

    results[pde_name]['small'] = dict()
    results[pde_name]['big'] = dict()

    # Define the loss function for this PDE
    def loss_fn(model, l_E, l_B, pde_collocs, bc_collocs, bc_data):

        # ------------- PDE ---------------------------- #
        pde_residuals = pde_res(model, pde_collocs)
    
        # Get new RBA weights
        abs_pde_res = jnp.abs(pde_residuals)
        l_E_new = (RBA_gamma*l_E) + (RBA_eta*abs_pde_res/jnp.max(abs_pde_res))
    
        # Multiply by RBA weights
        w_resids_pde = l_E_new * pde_residuals
    
        # Get loss
        pde_loss = jnp.mean(w_resids_pde**2)
    
    
        # ------------- BC ----------------------------- #
        bc_residuals = model(bc_collocs) - bc_data
    
        # Get new RBA weights
        abs_bc_res = jnp.abs(bc_residuals)
        l_B_new = (RBA_gamma*l_B) + (RBA_eta*abs_bc_res/jnp.max(abs_bc_res))
    
        # Multiply by RBA weights
        w_resids_bc = l_B_new * bc_residuals
    
        # Loss
        bc_loss = jnp.mean(w_resids_bc**2)
    
        
        # ------------- Total --------------------------- #
        total_loss = pde_loss + bc_loss
    
        return total_loss, (l_E_new, l_B_new)
        
    # Define the train step
    @nnx.jit
    def train_step(model, optimizer, l_E, l_B, pde_collocs, bc_collocs, bc_data):
    
        (loss, (l_E_new, l_B_new)), grads = nnx.value_and_grad(loss_fn, has_aux = True)(model, l_E, l_B, pde_collocs, bc_collocs, bc_data)
    
        optimizer.update(grads)
    
        return loss, l_E_new, l_B_new

    # Get the reference solution
    refsol, coords = get_ref(pde_name)

    # Get collocation points
    pde_collocs, bc_collocs, bc_data = get_collocs(pde_name, N_points)

    # Model input/output
    n_in, n_out = pde_collocs.shape[1], bc_data.shape[1]

    # Small architecture
    layer_dims = [n_in, *hidden_small, n_out]

    print(f"\tTraining model with dimensions {layer_dims}.")

    # For confidence
    for run in [1, 2, 3, 4, 5]:

        results[pde_name]['small'][run] = dict()

        print(f"\t\tRun No. {run}.")

        # Baseline
        l_E = jnp.ones((pde_collocs.shape[0], 1))
        l_B = jnp.ones((bc_collocs.shape[0], 1))
        
        base_model = KAN(layer_dims = layer_dims, layer_type = 'spline', required_parameters = params_small_baseline, seed = seed+run)
        base_opt = nnx.Optimizer(base_model, opt_type)

        train_losses = jnp.zeros((num_epochs,))
        for epoch in range(num_epochs):
            loss, l_E, l_B = train_step(base_model, base_opt, l_E, l_B, pde_collocs, bc_collocs, bc_data)
            train_losses = train_losses.at[epoch].set(loss)

        results[pde_name]['small'][run]['baseline'] = train_losses.copy()

        print(f"\t\t\tBaseline model: Final Loss = {loss:.2e}")

        # Glorot
        l_E = jnp.ones((pde_collocs.shape[0], 1))
        l_B = jnp.ones((bc_collocs.shape[0], 1))
        
        glorot_model = KAN(layer_dims = layer_dims, layer_type = 'spline', required_parameters = params_small_glorot, seed = seed+run)
        glorot_opt = nnx.Optimizer(glorot_model, opt_type)

        train_losses = jnp.zeros((num_epochs,))
        for epoch in range(num_epochs):
            loss, l_E, l_B = train_step(glorot_model, glorot_opt, l_E, l_B, pde_collocs, bc_collocs, bc_data)
            train_losses = train_losses.at[epoch].set(loss)

        results[pde_name]['small'][run]['glorot'] = train_losses.copy()

        print(f"\t\t\tGlorot model: Final Loss = {loss:.2e}")

        # Power Law
        l_E = jnp.ones((pde_collocs.shape[0], 1))
        l_B = jnp.ones((bc_collocs.shape[0], 1))
        
        power_model = KAN(layer_dims = layer_dims, layer_type = 'spline', required_parameters = params_small_power, seed = seed+run)
        power_opt = nnx.Optimizer(power_model, opt_type)

        train_losses = jnp.zeros((num_epochs,))
        for epoch in range(num_epochs):
            loss, l_E, l_B = train_step(power_model, power_opt, l_E, l_B, pde_collocs, bc_collocs, bc_data)
            train_losses = train_losses.at[epoch].set(loss)

        results[pde_name]['small'][run]['power'] = train_losses.copy()

        print(f"\t\t\tPower-law model: Final Loss = {loss:.2e}")

    # Big architecture
    layer_dims = [n_in, *hidden_big, n_out]

    print(f"\tTraining model with dimensions {layer_dims}.")

    # For confidence
    for run in [1, 2, 3, 4, 5]:

        results[pde_name]['big'][run] = dict()

        print(f"\t\tRun No. {run}.")

        # Baseline
        l_E = jnp.ones((pde_collocs.shape[0], 1))
        l_B = jnp.ones((bc_collocs.shape[0], 1))
        
        base_model = KAN(layer_dims = layer_dims, layer_type = 'spline', required_parameters = params_big_baseline, seed = seed+run)
        base_opt = nnx.Optimizer(base_model, opt_type)

        train_losses = jnp.zeros((num_epochs,))
        for epoch in range(num_epochs):
            loss, l_E, l_B = train_step(base_model, base_opt, l_E, l_B, pde_collocs, bc_collocs, bc_data)
            train_losses = train_losses.at[epoch].set(loss)

        results[pde_name]['big'][run]['baseline'] = train_losses.copy()

        print(f"\t\t\tBaseline model: Final Loss = {loss:.2e}")

        # Glorot
        l_E = jnp.ones((pde_collocs.shape[0], 1))
        l_B = jnp.ones((bc_collocs.shape[0], 1))
        
        glorot_model = KAN(layer_dims = layer_dims, layer_type = 'spline', required_parameters = params_big_glorot, seed = seed+run)
        glorot_opt = nnx.Optimizer(glorot_model, opt_type)

        train_losses = jnp.zeros((num_epochs,))
        for epoch in range(num_epochs):
            loss, l_E, l_B = train_step(glorot_model, glorot_opt, l_E, l_B, pde_collocs, bc_collocs, bc_data)
            train_losses = train_losses.at[epoch].set(loss)

        results[pde_name]['big'][run]['glorot'] = train_losses.copy()

        print(f"\t\t\tGlorot model: Final Loss = {loss:.2e}")

        # Power Law
        l_E = jnp.ones((pde_collocs.shape[0], 1))
        l_B = jnp.ones((bc_collocs.shape[0], 1))
        
        power_model = KAN(layer_dims = layer_dims, layer_type = 'spline', required_parameters = params_big_power, seed = seed+run)
        power_opt = nnx.Optimizer(power_model, opt_type)

        train_losses = jnp.zeros((num_epochs,))
        for epoch in range(num_epochs):
            loss, l_E, l_B = train_step(power_model, power_opt, l_E, l_B, pde_collocs, bc_collocs, bc_data)
            train_losses = train_losses.at[epoch].set(loss)

        results[pde_name]['big'][run]['power'] = train_losses.copy()

        print(f"\t\t\tPower-law model: Final Loss = {loss:.2e}")

In [None]:
# Save results for further processing
results_dir = 'pde_results/'

with open(os.path.join(results_dir, "losses.pkl"), "wb") as f:
    pickle.dump(results, f)