# Full Example

This notebook contains a full example of training and evaluating a KAN model for a task. The task at hand is learning a multivariate function, namely

$$ f\left(x_1,x_2\right) = \exp\left(\sin\left(\pi x_1\right) + x_2\right) $$

To do so, we will first generate some training samples, i.e. create an artificial dataset.

## Dataset Creation

We randomly sample $N$ points from this function in the $\left[-2,2\right] \times \left[-2,2\right]$ range. Some of them are retained for training, while the rest are used for evaluation.

In [10]:
import jax
import jax.numpy as jnp

# Define the true function
def true_f(x_1,x_2):
    return jnp.exp(jnp.sin(jnp.pi*x_1)+x_2)

N = 5000

# Randomly selected x_1, x_2 in [-2,2]
x_1 = 4*jax.random.uniform(jax.random.PRNGKey(42), (N,)) - 2
x_2 = 4*jax.random.uniform(jax.random.PRNGKey(43), (N,)) - 2

# Get back ys
y = true_f(x_1,x_2)

We then use sklearn's built-in methods to split the "dataset" for training and testing. Of course this can be handled manually, but we're lazy. To run the following cell, we used `scikit-learn==1.4.2`.

In [13]:
from sklearn.model_selection import train_test_split

# Combine x_1 and x_2 into a single array
X = jnp.stack((x_1, x_2), axis=-1)

# Split the dataset into training and evaluation sets
X_train, X_eval, y_train, y_eval = train_test_split(X, y, test_size=0.2, random_state=42)

## Data Loader

We then define our data loaders to yield batches of data during training and evaluation.

In [15]:
def data_loader(X, y, batch_size, seed):
    dataset_size = len(X)
    indices = jax.random.permutation(jax.random.PRNGKey(seed), dataset_size)
    
    for start_idx in range(0, dataset_size, batch_size):
        batch_indices = indices[start_idx:start_idx + batch_size]
        yield X[batch_indices], y[batch_indices]

In [20]:
batch = 500
train_loader = data_loader(X_train, y_train, batch, 42)
eval_loader = data_loader(X_eval, y_eval, batch, 42)

## Initialization

With the data at hand, we proceed with the initialization of the necessary items. For the optimization part we will be using `optax`.

In [22]:
import optax
from flax import linen as nn

import sys
import os

path_to_src = os.path.abspath(os.path.join(os.getcwd(), '../src'))
if path_to_src not in sys.path:
    sys.path.append(path_to_src)

from KAN import KAN

In [23]:
# Initialize model
key = jax.random.PRNGKey(0)

layer_dims = [2, 5, 2, 1]
model = KAN(layer_dims=layer_dims, k=3, add_bias=True)
variables = model.init(key, jnp.ones([1, 2]))

In [25]:
# Initialize optimizer
learning_rate = 0.001
optimizer = optax.adam(learning_rate)
opt_state = optimizer.init(variables['params'])

## Loss Function

This is the point where we need to define our loss function. For the predictions we simply use MSE Loss, while for the regularization we follow the arXiv preprint's direction and use the layer norms and entropy.

In [None]:
def loss_fn(variables, x, y):
    # Forward pass to acquire predictions and spl_regs
    preds, spl_regs = model.apply(variables, x)

    # Define the prediction loss
    loss_pred = jnp.mean((preds-y)**2)
    
    # Define the regularization loss
    loss_reg = ...

In [None]:
# $$ f\left(x_1,x_2,x_3\right) = \left(x_1^2 + x_2^2\right)^3\cdot\exp\left(x_1 \cdot x_3\right) $$