# Tests on Grid generation & extension

What follows are jax samples based on pykan or efficientkan, as far as the grid generation and extension process is concerned, to get a better understanding of what's going on.

In [4]:
import jax.numpy as jnp

In [2]:
# Input parameters relevant to grid
k = 3 # splines order
G = 5 # grid size (number of knot vector elements)
n_in = 2 # number of layer's input nodes
n_out = 3 # number of layer's output nodes

**<span style="color:red">Open Question: Why are we starting with a grid defined only by its range?</span>**

In [3]:
# Grid Initialization - the grid is originally just a knot vector
grid_range = [-1, 1]

**<span style="color:red">Open Question: Why are we augmenting the grid like this instead of appending the first and last entries k times at the start and end?</span>**

In [4]:
# Grid augmentation
h = (grid_range[-1] - grid_range[0]) / (G-1)
grid = (jnp.arange(-k, G + k + 1, dtype=jnp.float32) * h + grid_range[0])

grid

Array([-2.5, -2. , -1.5, -1. , -0.5,  0. ,  0.5,  1. ,  1.5,  2. ,  2.5,
        3. ], dtype=float32)

In [5]:
# Expand for broadcasting - the shape must be (layer_size, G + k), so that the grid
# can be passed in get_spline_basis and return a Bi(x) array, where we need only multiply
# with the coefficients to get the full spline(x), for each layer's spline (n_in * n_out total).
grid = jnp.expand_dims(grid, axis=0)
grid = jnp.tile(grid, (n_in*n_out, 1))

grid

Array([[-2.5, -2. , -1.5, -1. , -0.5,  0. ,  0.5,  1. ,  1.5,  2. ,  2.5,
         3. ],
       [-2.5, -2. , -1.5, -1. , -0.5,  0. ,  0.5,  1. ,  1.5,  2. ,  2.5,
         3. ],
       [-2.5, -2. , -1.5, -1. , -0.5,  0. ,  0.5,  1. ,  1.5,  2. ,  2.5,
         3. ],
       [-2.5, -2. , -1.5, -1. , -0.5,  0. ,  0.5,  1. ,  1.5,  2. ,  2.5,
         3. ],
       [-2.5, -2. , -1.5, -1. , -0.5,  0. ,  0.5,  1. ,  1.5,  2. ,  2.5,
         3. ],
       [-2.5, -2. , -1.5, -1. , -0.5,  0. ,  0.5,  1. ,  1.5,  2. ,  2.5,
         3. ]], dtype=float32)

This is the initialization of the grid per layer (registered at `self.`). Then, we have the `upgrade_grid` routine, which is supposed to perform an update for each layer's grid (**can be trainable or not in the future**).

In pykan, the `update_grid_from_samples` routine simply performs the mix of adaptive and uniform sampling, given the current grid and the batch. It is the `initialize_grid_from_parent` routine that performs the grid extension. On the other hand, in efficientkan both are handled together in the `update_grid` routine. In what follows, we try to follow the approach of efficientkan, however working with grids of shape $(n_{in}\cdot n_{out},\text{batch})$ instead of $(n_{in},\text{batch})$.

**<span style="color:red">IDEA: Keep splines as model param function, same as efficientkan, because the grid is also. Then $k$ can be inherited from self too, and the whole thing can normally be jitted.</span>** 

## efficientkan grid update

TLDR; It appears that they don't really upgrade their grid in terms of expansion.

In [1]:
import torch

from copy import copy

In [23]:
spline_order = 3

in_features = 2
out_features = 5
batch_size = 300
grid_eps = 1.0

grid_size = 10
grid_range = [-1,1]
h = (grid_range[1] - grid_range[0]) / (grid_size-1)

grid = (torch.arange(-spline_order, grid_size + spline_order + 1) * h + grid_range[0]).expand(in_features, -1)


torch.manual_seed(42)
x = torch.randn((batch_size,in_features))
spline_weight = torch.randn((out_features, in_features, grid_size + spline_order))


In [24]:
margin = 0.01

# Get a copy of the original grid
old_grid = copy(grid)
#print(f"Old grid shape: {old_grid.shape}\n")

# -------------------------------------
# Things done to calculate Sum(ciBi)
# -------------------------------------
batch = x.size(0)

#print(f"Shape of Sum(ciBi), i.e. y: {unreduced_spline_output.shape}\n")

# ----------------------
# Calculate a new grid
# ----------------------
# sort each channel individually to collect data distribution
x_sorted = torch.sort(x, dim=0)[0]
grid_adaptive = x_sorted[
    torch.linspace(
        0, batch - 1, grid_size + 1, dtype=torch.int64
    )
]

uniform_step = (x_sorted[-1] - x_sorted[0] + 2 * margin) / grid_size
grid_uniform = (
    torch.arange(
        grid_size + 1, dtype=torch.float32
    ).unsqueeze(1)
    * uniform_step
    + x_sorted[0]
    - margin
)

grid = grid_eps * grid_uniform + (1 - grid_eps) * grid_adaptive

#print(f"Grid shape after performing interpolation between adaptive and uniform: {grid.shape}\n")

grid = torch.concatenate(
    [
        grid[:1]
        - uniform_step
        * torch.arange(spline_order, 0, -1).unsqueeze(1),
        grid,
        grid[-1:]
        + uniform_step
        * torch.arange(1, spline_order + 1).unsqueeze(1),
    ],
    dim=0,
)

#print(f"Grid shape after concatenation: {grid.shape}\n")

# --------------------------------
# Use new grid to calculate cj
# --------------------------------

grid = grid.T
#print(f"Grid shape after transposition: {grid.shape}\n")

#print(f"Shape of new coeffs: {ccoeff.shape}\n")

In [29]:
grid

tensor([[-4.2266, -3.6827, -3.1388, -2.5950, -2.0511, -1.5072, -0.9633, -0.4194,
          0.1245,  0.6684,  1.2123,  1.7562,  2.3001,  2.8440,  3.3879,  3.9318,
          4.4756],
        [-4.2444, -3.7154, -3.1865, -2.6575, -2.1285, -1.5996, -1.0706, -0.5416,
         -0.0127,  0.5163,  1.0453,  1.5742,  2.1032,  2.6321,  3.1611,  3.6901,
          4.2190]])

## pykan grid update

TLDR; TODO

In [35]:
import torch
import numpy as np

from copy import copy

In [36]:
spline_order = 3

in_features = 3
out_features = 5
batch_size = 300
grid_eps = 1.0

size = in_features*out_features
grid_size = 10
grid_range = [-1,1]
h = (grid_range[1] - grid_range[0]) / (grid_size-1)

grid = (torch.arange(-spline_order, grid_size + spline_order + 1) * h + grid_range[0]).expand(in_features, -1)


torch.manual_seed(42)
x = torch.randn((batch_size,in_features))
spline_weight = torch.randn((out_features, in_features, grid_size + spline_order))

In [32]:
batch = x.shape[0]

x = torch.einsum('ij,k->ikj', x, torch.ones(out_features, )).reshape(batch, size).permute(1, 0)
print(x.shape)
# Sort on the batch dimension
x_pos = torch.sort(x, dim=1)[0]

###y_eval = coef2curve(x_pos, self.grid, self.coef, self.k, device=self.device)

# G - 1
num_interval = grid.shape[1] - 1
ids = [int(batch / num_interval * i) for i in range(num_interval)] + [-1]
grid_adaptive = x_pos[:, ids]
margin = 0.01
grid_eps  = 0.05
grid_uniform = torch.cat([grid_adaptive[:, [0]] - margin + (grid_adaptive[:, [-1]] - grid_adaptive[:, [0]] + 2 * margin) * a for a in np.linspace(0, 1, num=grid.shape[1])], dim=1)
new_grid = grid_eps * grid_uniform + (1 - grid_eps) * grid_adaptive

# We must also augment it


torch.Size([15, 300])


In [37]:
grid

tensor([[-1.6667, -1.4444, -1.2222, -1.0000, -0.7778, -0.5556, -0.3333, -0.1111,
          0.1111,  0.3333,  0.5556,  0.7778,  1.0000,  1.2222,  1.4444,  1.6667,
          1.8889],
        [-1.6667, -1.4444, -1.2222, -1.0000, -0.7778, -0.5556, -0.3333, -0.1111,
          0.1111,  0.3333,  0.5556,  0.7778,  1.0000,  1.2222,  1.4444,  1.6667,
          1.8889],
        [-1.6667, -1.4444, -1.2222, -1.0000, -0.7778, -0.5556, -0.3333, -0.1111,
          0.1111,  0.3333,  0.5556,  0.7778,  1.0000,  1.2222,  1.4444,  1.6667,
          1.8889]])

In [40]:
x_pos[:, ids]

tensor([[-2.5095, -1.4861, -1.0020, -0.8212, -0.6902, -0.4949, -0.3389, -0.1722,
          0.0439,  0.1852,  0.3444,  0.4967,  0.7448,  0.9595,  1.2791,  1.7733,
          2.8544],
        [-3.1016, -1.6047, -1.2524, -0.9216, -0.6766, -0.5161, -0.3858, -0.1418,
         -0.0305,  0.0950,  0.2762,  0.4880,  0.6378,  0.8539,  1.1799,  1.4703,
          2.7312],
        [-2.7936, -1.4534, -1.1702, -0.9061, -0.6185, -0.4292, -0.2076, -0.0271,
          0.0803,  0.2320,  0.3892,  0.5636,  0.7118,  0.8436,  1.1914,  1.6192,
          3.0250],
        [-2.5095, -1.4861, -1.0020, -0.8212, -0.6902, -0.4949, -0.3389, -0.1722,
          0.0439,  0.1852,  0.3444,  0.4967,  0.7448,  0.9595,  1.2791,  1.7733,
          2.8544],
        [-3.1016, -1.6047, -1.2524, -0.9216, -0.6766, -0.5161, -0.3858, -0.1418,
         -0.0305,  0.0950,  0.2762,  0.4880,  0.6378,  0.8539,  1.1799,  1.4703,
          2.7312],
        [-2.7936, -1.4534, -1.1702, -0.9061, -0.6185, -0.4292, -0.2076, -0.0271,
          0.08

## Personal grid update tests

In [1]:
import jax
import jax.numpy as jnp
import math
from flax import linen as nn
from flax.training import train_state
from flax.linen import initializers

class KANLayer(nn.Module):
    
    n_in: int = 2
    n_out: int = 5
    G: int = 5
    grid_range: tuple = (-1, 1)
    k: int = 3

    const_spl: float or bool = False
    const_res: float or bool = False
    residual: nn.Module = nn.swish
    
    noise_std: float = 0.1
    grid_e: float = 0.02

    
    def setup(self):
        # Calculate the step size for the knot vector based on its end values
        h = (self.grid_range[1] - self.grid_range[0]) / (self.G - 1)

        # Create the initial knot vector and perform augmentation
        # Now it is expanded from G+1 points to G+1 + 2k points, because k points are appended at each of its ends
        grid = jnp.arange(-self.k, self.G + self.k + 1, dtype=jnp.float32) * h + self.grid_range[0]
        
        # Expand for broadcasting - the shape becomes (n_in*n_out, G + 2k + 1), so that the grid
        # can be passed in all n_in*n_out spline basis functions simultaneously
        grid = jnp.expand_dims(grid, axis=0)
        grid = jnp.tile(grid, (self.n_in*self.n_out, 1))

        # Store the grid as a non trainable variable
        self.grid = self.variable('state', 'grid', lambda: grid)
        
        # Register & initialize the spline basis functions' coefficients as trainable parameters
        # They are drawn from a normal distribution with zero mean and an std of noise_std
        self.c_basis = self.param('c_basis', initializers.normal(stddev=self.noise_std), (self.n_in * self.n_out, self.G.value + self.k))
        
        # If const_spl is set as a float value, treat it as non trainable and pass it to the c_spl array with shape (n_in*n_out)
        # Otherwise register it as a trainable parameter of the same size and initialize it
        if isinstance(self.const_spl, float):
            self.c_spl = jnp.ones(self.n_in*self.n_out) * self.const_spl
        elif self.const_spl is False:
            self.c_spl = self.param('c_spl', initializers.constant(1.0), (self.n_in * self.n_out,))

        # If const_res is set as a float value, treat it as non trainable and pass it to the c_res array with shape (n_in*n_out)
        # Otherwise register it as a trainable parameter of the same size and initialize it
        if isinstance(self.const_res, float):
            self.c_res = jnp.ones(self.n_in * self.n_out) * self.const_res
        elif self.const_res is False:
            self.c_res = self.param('c_res', initializers.constant(1.0), (self.n_in * self.n_out,))

    def update_grid(self):
        # Example update function
        self.grid.value = self.grid.value*2.25

    def __call__(self, x):
        # Dummy to see correct grid update
        return x*self.grid.value[0][0]


In [3]:
# Instantiate a layer
kan_layer = KANLayer(n_in=2, n_out=3, G=5, grid_range=(-1,1), k=3)

# Initialization
rng = jax.random.PRNGKey(0)
variables = kan_layer.init(rng, jnp.ones((2, 3)))

# Print the trainable parameters and non-trainable state
print("Trainable parameters:")
print(variables['params'])

print("\nNon-trainable state variables:")
print(variables['state'])

Trainable parameters:
{'c_basis': Array([[ 0.02399894,  0.1846351 , -0.00519568, -0.00863923,  0.0585282 ,
        -0.15037541, -0.01022594, -0.02914488],
       [ 0.04935711,  0.03238467, -0.15360688,  0.12719317,  0.04799316,
         0.05530524, -0.02173939,  0.19702375],
       [-0.06749373,  0.0200477 ,  0.23864464, -0.1581935 ,  0.0072664 ,
        -0.10175776,  0.00515934,  0.01749068],
       [-0.00339971, -0.09349852,  0.07724477,  0.04379462,  0.09954414,
         0.04300999,  0.02973044, -0.01073924],
       [ 0.05613207, -0.06796809,  0.07379916, -0.0517869 , -0.0697948 ,
         0.00100837,  0.05920225, -0.07538771],
       [-0.01342505, -0.02973855, -0.00334156, -0.09072405,  0.01592094,
        -0.04852676,  0.02748392, -0.06365392]], dtype=float32), 'c_spl': Array([1., 1., 1., 1., 1., 1.], dtype=float32), 'c_res': Array([1., 1., 1., 1., 1., 1.], dtype=float32)}

Non-trainable state variables:
{'grid': Array([[-2.5, -2. , -1.5, -1. , -0.5,  0. ,  0.5,  1. ,  1.5,  2. , 

In [4]:
# Simple forward pass
x, _ = kan_layer.apply(variables, jnp.ones((2,3)))

print(x)

[-2.5 -2.5 -2.5]


In [5]:
# No forward pass, but applying the update_grid function
x, new_state = kan_layer.apply(variables, method=kan_layer.update_grid, mutable=['state'])

print(new_state)

{'state': {'grid': Array([[-5.625, -4.5  , -3.375, -2.25 , -1.125,  0.   ,  1.125,  2.25 ,
         3.375,  4.5  ,  5.625,  6.75 ],
       [-5.625, -4.5  , -3.375, -2.25 , -1.125,  0.   ,  1.125,  2.25 ,
         3.375,  4.5  ,  5.625,  6.75 ],
       [-5.625, -4.5  , -3.375, -2.25 , -1.125,  0.   ,  1.125,  2.25 ,
         3.375,  4.5  ,  5.625,  6.75 ],
       [-5.625, -4.5  , -3.375, -2.25 , -1.125,  0.   ,  1.125,  2.25 ,
         3.375,  4.5  ,  5.625,  6.75 ],
       [-5.625, -4.5  , -3.375, -2.25 , -1.125,  0.   ,  1.125,  2.25 ,
         3.375,  4.5  ,  5.625,  6.75 ],
       [-5.625, -4.5  , -3.375, -2.25 , -1.125,  0.   ,  1.125,  2.25 ,
         3.375,  4.5  ,  5.625,  6.75 ]], dtype=float32)}}


In [6]:
# Update variables, pass to the model and retry forward pass with new grid
variables['state'] = new_state['state']

x, state = kan_layer.apply(variables, jnp.ones((2, 3)))

print(x)

[-5.625 -5.625 -5.625]


In [7]:
# Apply the update again
x, new_state = kan_layer.apply(variables, method=kan_layer.update_grid, mutable=['state'])

print(new_state)

{'state': {'grid': Array([[-12.65625, -10.125  ,  -7.59375,  -5.0625 ,  -2.53125,   0.     ,
          2.53125,   5.0625 ,   7.59375,  10.125  ,  12.65625,  15.1875 ],
       [-12.65625, -10.125  ,  -7.59375,  -5.0625 ,  -2.53125,   0.     ,
          2.53125,   5.0625 ,   7.59375,  10.125  ,  12.65625,  15.1875 ],
       [-12.65625, -10.125  ,  -7.59375,  -5.0625 ,  -2.53125,   0.     ,
          2.53125,   5.0625 ,   7.59375,  10.125  ,  12.65625,  15.1875 ],
       [-12.65625, -10.125  ,  -7.59375,  -5.0625 ,  -2.53125,   0.     ,
          2.53125,   5.0625 ,   7.59375,  10.125  ,  12.65625,  15.1875 ],
       [-12.65625, -10.125  ,  -7.59375,  -5.0625 ,  -2.53125,   0.     ,
          2.53125,   5.0625 ,   7.59375,  10.125  ,  12.65625,  15.1875 ],
       [-12.65625, -10.125  ,  -7.59375,  -5.0625 ,  -2.53125,   0.     ,
          2.53125,   5.0625 ,   7.59375,  10.125  ,  12.65625,  15.1875 ]],      dtype=float32)}}


In [8]:
# What happens if we do not pass mutable=['state'] ?
x, new_state = kan_layer.apply(variables, method=kan_layer.update_grid)

ModifyScopeVariableError: Cannot update variable "grid" in "/" because collection "state" is immutable. (https://flax.readthedocs.io/en/latest/api_reference/flax.errors.html#flax.errors.ModifyScopeVariableError)

### A very rough example on how this would look during training

```
import optax  # For the optimizer
from flax.training import train_state

class TrainState(train_state.TrainState):
    grid: jnp.ndarray

def create_train_state(rng, model, learning_rate):
    variables = model.init(rng, jnp.ones((1, 2)), mutable=['non_trainable'])
    params = variables['params']
    grid = variables['non_trainable']['grid']
    tx = optax.adam(learning_rate)
    return TrainState.create(apply_fn=model.apply, params=params, tx=tx, grid=grid)

def train_model(state, model, epochs, data_loader):
    for epoch in range(epochs):
        for batch in data_loader:
            state = train_step(state, batch)
        
        if (epoch + 1) % 5 == 0:
            state = update_grid_state(state, model)
            print(f"Grid updated at epoch {epoch + 1}")

    return state

@jax.jit
def train_step(state, batch):
    def loss_fn(params):
        logits = state.apply_fn({'params': params}, batch['x'])
        loss = jnp.mean((logits - batch['y']) ** 2)
        return loss
    grads = jax.grad(loss_fn)(state.params)
    state = state.apply_gradients(grads=grads)
    return state

@jax.jit
def update_grid_state(state, model):
    variables = {'params': state.params, 'non_trainable': {'grid': state.grid}}
    _, updated_state = model.apply(variables, jnp.ones((1, 2)), method=model.update_grid, mutable=['non_trainable'])
    state = state.replace(grid=updated_state['non_trainable']['grid'])
    return state

# Example data loader
data_loader = [{'x': jnp.ones((1, 2)), 'y': jnp.ones((1, 2))} for _ in range(10)]

# Initialize model and state
rng = jax.random.PRNGKey(0)
kan_layer = KANLayer(n_in=2, n_out=3, G=5, grid_range=(-1, 1), k=3)
state = create_train_state(rng, kan_layer, learning_rate=0.001)

# Train the model
state = train_model(state, kan_layer, epochs=20, data_loader=data_loader)

# Check final grid value
print("Final Grid shape:", state.grid.shape)
print(state.grid)
```