# Tests Notebook

This notebook is reserved for tests during development.

In [2]:
# Successful installation of JAX test
from jax.nn import silu
import jax.numpy as jnp

x = jnp.arange(5.0)
print(silu(x))

[0.        0.7310586 1.761594  2.8577223 3.928055 ]


## Testing Splines

Let's test the generation of spline basis functions.

In [1]:
import sys
import os

path_to_src = os.path.abspath(os.path.join(os.getcwd(), '../src'))
if path_to_src not in sys.path:
    sys.path.append(path_to_src)

from bases.splines import get_spline_basis

In [1]:
import sys
import os

path_to_src = os.path.abspath(os.path.join(os.getcwd(), '../src'))
if path_to_src not in sys.path:
    sys.path.append(path_to_src)

from KANLayer import KANLayer as kan
from KAN import KAN
import jax
from jax import numpy as jnp

In [2]:
def print_dict_hierarchy(d, indent=0):
    for key, value in d.items():
        if isinstance(value, dict):
            print(' ' * indent + str(key))
            print_dict_hierarchy(value, indent + 4)
        else:
            print(' ' * indent + f"{key}: shape {value.shape}")

In [3]:
key = jax.random.PRNGKey(0)
layer_dims = [4, 5, 2, 1]
model = KAN(layer_dims=layer_dims, k=3, add_bias=True)

x = jax.random.normal(key, (50, 4))

variables = model.init(key, x)

In [None]:
# Print the trainable parameters and non-trainable state
print("Trainable parameters:")
print(variables['params'])

print("\nNon-trainable state variables:")
print(variables['state'])

In [4]:
print_dict_hierarchy(variables)

params
    bias_0: shape (5,)
    bias_1: shape (2,)
    bias_2: shape (1,)
    layers_0
        c_basis: shape (20, 6)
        c_spl: shape (20,)
        c_res: shape (20,)
    layers_1
        c_basis: shape (10, 6)
        c_spl: shape (10,)
        c_res: shape (10,)
    layers_2
        c_basis: shape (2, 6)
        c_spl: shape (2,)
        c_res: shape (2,)
state
    layers_0
        grid: shape (20, 10)
    layers_1
        grid: shape (10, 10)
    layers_2
        grid: shape (2, 10)


In [5]:
variables['params']['layers_1']['c_basis']

Array([[-0.21964498, -0.03216682, -0.02803302, -0.04404335,  0.10697184,
         0.00917037],
       [-0.07214645, -0.03873394, -0.09372874,  0.03438126,  0.07394429,
        -0.19368024],
       [ 0.144917  ,  0.07064237, -0.09206957,  0.09779473, -0.07791757,
        -0.11313559],
       [-0.08380805, -0.00183191,  0.01103233, -0.13060333, -0.03279752,
         0.00528648],
       [-0.03939278, -0.13313794, -0.02944968, -0.02043018,  0.04978757,
         0.06520142],
       [ 0.05189328, -0.07971673, -0.10149919,  0.08816522,  0.02823404,
         0.14567198],
       [-0.04320967,  0.00129707,  0.06899355,  0.01118089,  0.18311459,
        -0.07600819],
       [-0.07420332,  0.09264632,  0.02440042,  0.03857328,  0.01935199,
        -0.12175804],
       [ 0.01394208,  0.12633352,  0.11282189,  0.07304362, -0.09317457,
        -0.00676571],
       [-0.11424935,  0.07237625,  0.03601482,  0.12565634,  0.04958436,
         0.19709854]], dtype=float32)

In [6]:
# Perform grid updates
new_grid_size = 5
updated_variables = model.apply(variables, x, new_grid_size, method=model.update_grids)

In [7]:
print_dict_hierarchy(updated_variables)

params
    bias_0: shape (5,)
    bias_1: shape (2,)
    bias_2: shape (1,)
    layers_0
        c_basis: shape (20, 8)
        c_spl: shape (20,)
        c_res: shape (20,)
    layers_1
        c_basis: shape (10, 8)
        c_spl: shape (10,)
        c_res: shape (10,)
    layers_2
        c_basis: shape (2, 8)
        c_spl: shape (2,)
        c_res: shape (2,)
state
    layers_0
        grid: shape (20, 12)
    layers_1
        grid: shape (10, 12)
    layers_2
        grid: shape (2, 12)


In [8]:
updated_variables['params']['layers_1']['c_basis']

Array([[-0.02999173, -0.03370977, -0.03347092, -0.02784294, -0.01343204,
         0.02792049,  0.07179645,  0.06490941],
       [-0.07054527, -0.0504249 , -0.02268857,  0.0041829 ,  0.02827478,
         0.05501393,  0.04043134, -0.07338444],
       [-0.04324958, -0.02404464,  0.00513026,  0.03018621,  0.04226551,
         0.01231476, -0.05553922, -0.08770344],
       [-0.00731817, -0.03688883, -0.06182383, -0.08282282, -0.09563003,
        -0.08224212, -0.03881234, -0.017495  ],
       [-0.04516727, -0.03153471, -0.02297354, -0.01604484, -0.00534422,
         0.01625229,  0.04225967,  0.05475741],
       [-0.07709993, -0.03993033, -0.00081709,  0.03031873,  0.05473555,
         0.05968446,  0.0495641 ,  0.07461052],
       [ 0.05258897,  0.04627253,  0.04096211,  0.04042412,  0.05307961,
         0.09784878,  0.1403644 ,  0.03437333],
       [ 0.03864328,  0.03307077,  0.03224307,  0.03306453,  0.0332961 ,
         0.02891284,  0.00801342, -0.04001885],
       [ 0.10859791,  0.10073279

In [9]:
variables['params']['layers_1']['c_basis']

Array([[-0.21964498, -0.03216682, -0.02803302, -0.04404335,  0.10697184,
         0.00917037],
       [-0.07214645, -0.03873394, -0.09372874,  0.03438126,  0.07394429,
        -0.19368024],
       [ 0.144917  ,  0.07064237, -0.09206957,  0.09779473, -0.07791757,
        -0.11313559],
       [-0.08380805, -0.00183191,  0.01103233, -0.13060333, -0.03279752,
         0.00528648],
       [-0.03939278, -0.13313794, -0.02944968, -0.02043018,  0.04978757,
         0.06520142],
       [ 0.05189328, -0.07971673, -0.10149919,  0.08816522,  0.02823404,
         0.14567198],
       [-0.04320967,  0.00129707,  0.06899355,  0.01118089,  0.18311459,
        -0.07600819],
       [-0.07420332,  0.09264632,  0.02440042,  0.03857328,  0.01935199,
        -0.12175804],
       [ 0.01394208,  0.12633352,  0.11282189,  0.07304362, -0.09317457,
        -0.00676571],
       [-0.11424935,  0.07237625,  0.03601482,  0.12565634,  0.04958436,
         0.19709854]], dtype=float32)

In [10]:
y, spl_regs = model.apply(variables, x)

In [21]:
spl_regs

[Array([[0.00683656, 0.00996755, 0.01660573, 0.00698722],
        [0.01234719, 0.00844338, 0.01165444, 0.00359057],
        [0.00864134, 0.00688513, 0.00845095, 0.01252255],
        [0.00697956, 0.00756933, 0.00858223, 0.010777  ],
        [0.01196445, 0.00689442, 0.00360187, 0.01110468]], dtype=float32),
 Array([[0.00437403, 0.00439544, 0.00362973, 0.01215728, 0.00293692],
        [0.00585498, 0.00946443, 0.00511737, 0.00995341, 0.01484072]],      dtype=float32),
 Array([[0.00335365, 0.01805771]], dtype=float32)]

In [2]:
import torch
import numpy as np

def create_dataset(f, 
                   n_var=2, 
                   ranges = [-1,1],
                   train_num=1000, 
                   test_num=1000,
                   normalize_input=False,
                   normalize_label=False,
                   device='cpu',
                   seed=0):
    '''
    create dataset
    
    Args:
    -----
        f : function
            the symbolic formula used to create the synthetic dataset
        ranges : list or np.array; shape (2,) or (n_var, 2)
            the range of input variables. Default: [-1,1].
        train_num : int
            the number of training samples. Default: 1000.
        test_num : int
            the number of test samples. Default: 1000.
        normalize_input : bool
            If True, apply normalization to inputs. Default: False.
        normalize_label : bool
            If True, apply normalization to labels. Default: False.
        device : str
            device. Default: 'cpu'.
        seed : int
            random seed. Default: 0.
        
    Returns:
    --------
        dataset : dic
            Train/test inputs/labels are dataset['train_input'], dataset['train_label'],
                        dataset['test_input'], dataset['test_label']
         
    Example
    -------
    >>> f = lambda x: torch.exp(torch.sin(torch.pi*x[:,[0]]) + x[:,[1]]**2)
    >>> dataset = create_dataset(f, n_var=2, train_num=100)
    >>> dataset['train_input'].shape
    torch.Size([100, 2])
    '''

    np.random.seed(seed)
    torch.manual_seed(seed)

    if len(np.array(ranges).shape) == 1:
        ranges = np.array(ranges * n_var).reshape(n_var,2)
    else:
        ranges = np.array(ranges)
        
    train_input = torch.zeros(train_num, n_var)
    test_input = torch.zeros(test_num, n_var)
    for i in range(n_var):
        train_input[:,i] = torch.rand(train_num,)*(ranges[i,1]-ranges[i,0])+ranges[i,0]
        test_input[:,i] = torch.rand(test_num,)*(ranges[i,1]-ranges[i,0])+ranges[i,0]
        
        
    train_label = f(train_input)
    test_label = f(test_input)
        
        
    def normalize(data, mean, std):
            return (data-mean)/std
            
    if normalize_input == True:
        mean_input = torch.mean(train_input, dim=0, keepdim=True)
        std_input = torch.std(train_input, dim=0, keepdim=True)
        train_input = normalize(train_input, mean_input, std_input)
        test_input = normalize(test_input, mean_input, std_input)
        
    if normalize_label == True:
        mean_label = torch.mean(train_label, dim=0, keepdim=True)
        std_label = torch.std(train_label, dim=0, keepdim=True)
        train_label = normalize(train_label, mean_label, std_label)
        test_label = normalize(test_label, mean_label, std_label)

    dataset = {}
    dataset['train_input'] = train_input.to(device)
    dataset['test_input'] = test_input.to(device)

    dataset['train_label'] = train_label.to(device)
    dataset['test_label'] = test_label.to(device)

    return dataset

In [8]:
f = lambda x: torch.exp(torch.sin(torch.pi*x[:,[0]]) + x[:,[1]]**2)
dataset = create_dataset(f, n_var=4)

In [9]:
dataset['train_input']

tensor([[-0.0075,  0.5547,  0.4975,  0.9488],
        [ 0.5364,  0.1791, -0.2652, -0.4505],
        [-0.8230,  0.1526,  0.7550, -0.6354],
        ...,
        [-0.3216, -0.4567, -0.4248,  0.5234],
        [ 0.0036, -0.3966,  0.0332,  0.2307],
        [-0.1923, -0.8376, -0.7736,  0.6485]])

In [7]:
dataset['train_label']

tensor([[1.3287],
        [2.7886],
        [0.6038],
        [0.6248],
        [0.5784],
        [2.4956],
        [0.9510],
        [2.2221],
        [0.9604],
        [5.0185],
        [0.7090],
        [0.6934],
        [1.4567],
        [0.4631],
        [0.6223],
        [1.1396],
        [2.6025],
        [3.8784],
        [0.7173],
        [0.5919],
        [5.1320],
        [2.4276],
        [0.5478],
        [2.5004],
        [0.6470],
        [1.3875],
        [2.6785],
        [0.9634],
        [0.4302],
        [0.7583],
        [0.4161],
        [3.7847],
        [0.4791],
        [0.3708],
        [0.4698],
        [1.0199],
        [0.4410],
        [2.0350],
        [4.1370],
        [2.7261],
        [1.2566],
        [0.3683],
        [2.6944],
        [2.1286],
        [0.5894],
        [0.3798],
        [3.8286],
        [3.1215],
        [0.4911],
        [0.9707],
        [3.7596],
        [1.2254],
        [2.6039],
        [1.6970],
        [2.3876],
        [0

In [18]:
from jax import numpy as jnp

from flax import linen as nn

from KANLayer import KANLayer


layer_dims = [3, 5, 1]

k = 3
const_spl = False
const_res = False
residual = nn.swish
noise_std = 0.1
grid_e = 0.15


layers = [KANLayer(n_in=layer_dims[i],
                            n_out=layer_dims[i + 1],
                            k=k,
                            const_spl=const_spl,
                            const_res=const_res,
                            residual=residual,
                            noise_std=noise_std,
                            grid_e=grid_e) for i in range(len(layer_dims) - 1)]



In [19]:
updated_params = self.scope.variables()['params']
updated_state = self.scope.variables()['state']

for i, layer in enumerate(self.layers):
    # Extract the variables for the current layer
    layer_variables = {
        'params': updated_params[f'layers_{i}'],
        'state': updated_state[f'layers_{i}']
    }
    
    # Call the update_grid method on the current layer
    coeffs, updated_layer_state = layer.apply(layer_variables, x, new_grid_size, method=layer.update_grid, mutable=['state'])
    
    # Update the state and parameters for the current layer
    updated_state[f'layers_{i}'] = updated_layer_state['state']
    updated_params[f'layers_{i}']['c_basis'] = coeffs

return {'params': updated_params, 'state': updated_state}

[KANLayer(
     # attributes
     n_in = 3
     n_out = 5
     k = 3
     const_spl = False
     const_res = False
     residual = silu
     noise_std = 0.1
     grid_e = 0.15
 ),
 KANLayer(
     # attributes
     n_in = 5
     n_out = 1
     k = 3
     const_spl = False
     const_res = False
     residual = silu
     noise_std = 0.1
     grid_e = 0.15
 )]

In [None]:
def update_grids(self, x, G_new):
    """
    Performs the grid update for each layer of the KAN architecture.

    Args:
    -----
        x (jnp.array): inputs for the first layer
            shape (batch, self.layers[0])
        G_new (int): Size of the new grid (in terms of intervals)

    """
    # Here we must perform a loop over all layers and perform the update for each layer, while also tweaking the variables dict
    # Note that between consecutive layers we must perform a forward pass to get the new value of x


def __call__(self, x):
    for i, layer in enumerate(self.layers):
        x, spl_reshaped = layer(x)
        if self.add_bias:
            x += self.biases[i]
    return x, spl_reshaped