# Tests on Grid generation & extension

What follows are jax samples based on pykan or efficientkan, as far as the grid generation and extension process is concerned, to get a better understanding of what's going on.

In [4]:
import jax.numpy as jnp

In [2]:
# Input parameters relevant to grid
k = 3 # splines order
G = 5 # grid size (number of knot vector elements)
n_in = 2 # number of layer's input nodes
n_out = 3 # number of layer's output nodes

**<span style="color:red">Open Question: Why are we starting with a grid defined only by its range?</span>**

In [3]:
# Grid Initialization - the grid is originally just a knot vector
grid_range = [-1, 1]

**<span style="color:red">Open Question: Why are we augmenting the grid like this instead of appending the first and last entries k times at the start and end?</span>**

In [4]:
# Grid augmentation
h = (grid_range[-1] - grid_range[0]) / (G-1)
grid = (jnp.arange(-k, G + k + 1, dtype=jnp.float32) * h + grid_range[0])

grid

Array([-2.5, -2. , -1.5, -1. , -0.5,  0. ,  0.5,  1. ,  1.5,  2. ,  2.5,
        3. ], dtype=float32)

In [5]:
# Expand for broadcasting - the shape must be (layer_size, G + k), so that the grid
# can be passed in get_spline_basis and return a Bi(x) array, where we need only multiply
# with the coefficients to get the full spline(x), for each layer's spline (n_in * n_out total).
grid = jnp.expand_dims(grid, axis=0)
grid = jnp.tile(grid, (n_in*n_out, 1))

grid

Array([[-2.5, -2. , -1.5, -1. , -0.5,  0. ,  0.5,  1. ,  1.5,  2. ,  2.5,
         3. ],
       [-2.5, -2. , -1.5, -1. , -0.5,  0. ,  0.5,  1. ,  1.5,  2. ,  2.5,
         3. ],
       [-2.5, -2. , -1.5, -1. , -0.5,  0. ,  0.5,  1. ,  1.5,  2. ,  2.5,
         3. ],
       [-2.5, -2. , -1.5, -1. , -0.5,  0. ,  0.5,  1. ,  1.5,  2. ,  2.5,
         3. ],
       [-2.5, -2. , -1.5, -1. , -0.5,  0. ,  0.5,  1. ,  1.5,  2. ,  2.5,
         3. ],
       [-2.5, -2. , -1.5, -1. , -0.5,  0. ,  0.5,  1. ,  1.5,  2. ,  2.5,
         3. ]], dtype=float32)

This is the initialization of the grid per layer (registered at `self.`). Then, we have the `upgrade_grid` routine, which is supposed to perform an update for each layer's grid (**can be trainable or not in the future**).

In pykan, the `update_grid_from_samples` routine simply performs the mix of adaptive and uniform sampling, given the current grid and the batch. It is the `initialize_grid_from_parent` routine that performs the grid extension. On the other hand, in efficientkan both are handled together in the `update_grid` routine. In what follows, we try to follow the approach of efficientkan, however working with grids of shape $(n_{in}\cdot n_{out},\text{batch})$ instead of $(n_{in},\text{batch})$.

**<span style="color:red">IDEA: Keep splines as model param function, same as efficientkan, because the grid is also. Then $k$ can be inherited from self too, and the whole thing can normally be jitted.</span>** 

## efficientkan grid update

TLDR; It appears that they don't really upgrade their grid in terms of expansion.

In [1]:
import torch

from copy import copy

In [5]:
spline_order = 3

in_features = 2
out_features = 5
batch_size = 300
grid_eps = 1.0

grid_size = 10
grid_range = [-1,1]
h = (grid_range[1] - grid_range[0]) / (grid_size-1)

grid = (torch.arange(-spline_order, grid_size + spline_order + 1) * h + grid_range[0]).expand(in_features, -1)


torch.manual_seed(42)
x = torch.randn((batch_size,in_features))
spline_weight = torch.randn((out_features, in_features, grid_size + spline_order))


In [6]:
def b_splines(x: torch.Tensor):
    """
    Compute the B-spline bases for the given input tensor.

    Args:
        x (torch.Tensor): Input tensor of shape (batch_size, in_features).

    Returns:
        torch.Tensor: B-spline bases tensor of shape (batch_size, in_features, grid_size + spline_order).
    """

    x = x.unsqueeze(-1)
    bases = ((x >= grid[:, :-1]) & (x < grid[:, 1:])).to(x.dtype)
    for k in range(1, spline_order + 1):
        bases = (
            (x - grid[:, : -(k + 1)])
            / (grid[:, k:-1] - grid[:, : -(k + 1)])
            * bases[:, :, :-1]
        ) + (
            (grid[:, k + 1 :] - x)
            / (grid[:, k + 1 :] - grid[:, 1:(-k)])
            * bases[:, :, 1:]
        )

    return bases

def curve2coeff(x: torch.Tensor, y: torch.Tensor):
    """
    Compute the coefficients of the curve that interpolates the given points.

    Args:
        x (torch.Tensor): Input tensor of shape (batch_size, in_features).
        y (torch.Tensor): Output tensor of shape (batch_size, in_features, out_features).

    Returns:
        torch.Tensor: Coefficients tensor of shape (out_features, in_features, grid_size + spline_order).
    """

    A = b_splines(x).transpose(
        0, 1
    )  # (in_features, batch_size, grid_size + spline_order)
    B = y.transpose(0, 1)  # (in_features, batch_size, out_features)
    solution = torch.linalg.lstsq(
        A, B
    ).solution  # (in_features, grid_size + spline_order, out_features)
    result = solution.permute(
        2, 0, 1
    )  # (out_features, in_features, grid_size + spline_order)

    return result

In [10]:
margin = 0.01

# Get a copy of the original grid
old_grid = copy(grid)
#print(f"Old grid shape: {old_grid.shape}\n")

# -------------------------------------
# Things done to calculate Sum(ciBi)
# -------------------------------------
batch = x.size(0)

# Gets Sum(ciBi(x)) for old grid
splines = b_splines(x) # (batch, in, coeff)
splines = splines.permute(1, 0, 2)  # (in, batch, coeff)
orig_coeff = spline_weight  # (out, in, coeff)
orig_coeff = orig_coeff.permute(1, 2, 0)  # (in, coeff, out)
unreduced_spline_output = torch.bmm(splines, orig_coeff)  # (in, batch, out)
unreduced_spline_output = unreduced_spline_output.permute(
    1, 0, 2
)  # (batch, in, out)

#print(f"Shape of Sum(ciBi), i.e. y: {unreduced_spline_output.shape}\n")

# ----------------------
# Calculate a new grid
# ----------------------
# sort each channel individually to collect data distribution
x_sorted = torch.sort(x, dim=0)[0]
grid_adaptive = x_sorted[
    torch.linspace(
        0, batch - 1, grid_size + 1, dtype=torch.int64
    )
]

uniform_step = (x_sorted[-1] - x_sorted[0] + 2 * margin) / grid_size
grid_uniform = (
    torch.arange(
        grid_size + 1, dtype=torch.float32
    ).unsqueeze(1)
    * uniform_step
    + x_sorted[0]
    - margin
)

grid = grid_eps * grid_uniform + (1 - grid_eps) * grid_adaptive

#print(f"Grid shape after performing interpolation between adaptive and uniform: {grid.shape}\n")

grid = torch.concatenate(
    [
        grid[:1]
        - uniform_step
        * torch.arange(spline_order, 0, -1).unsqueeze(1),
        grid,
        grid[-1:]
        + uniform_step
        * torch.arange(1, spline_order + 1).unsqueeze(1),
    ],
    dim=0,
)

#print(f"Grid shape after concatenation: {grid.shape}\n")

# --------------------------------
# Use new grid to calculate cj
# --------------------------------

grid = grid.T
#print(f"Grid shape after transposition: {grid.shape}\n")

# curve2coeff includes b_spline(x), so we get the new
# Bj(x) back and then pass it to lstsq
ccoeff = curve2coeff(x, unreduced_spline_output)

#print(f"Shape of new coeffs: {ccoeff.shape}\n")

torch.Size([300, 2])


AttributeError: 'torch.return_types.sort' object has no attribute 'shape'

In [11]:
torch.sort(x,dim=0)

torch.return_types.sort(
values=tensor([[-2.5850e+00, -2.6475e+00],
        [-2.5095e+00, -2.5668e+00],
        [-2.4801e+00, -2.4885e+00],
        [-2.3065e+00, -2.3184e+00],
        [-2.2064e+00, -2.1268e+00],
        [-2.0717e+00, -2.1055e+00],
        [-2.0487e+00, -1.6457e+00],
        [-1.9776e+00, -1.6107e+00],
        [-1.9267e+00, -1.6047e+00],
        [-1.9153e+00, -1.5910e+00],
        [-1.9006e+00, -1.5824e+00],
        [-1.8058e+00, -1.5013e+00],
        [-1.7907e+00, -1.4534e+00],
        [-1.7899e+00, -1.4040e+00],
        [-1.7809e+00, -1.4036e+00],
        [-1.7735e+00, -1.3246e+00],
        [-1.7376e+00, -1.3129e+00],
        [-1.7237e+00, -1.2922e+00],
        [-1.7223e+00, -1.2869e+00],
        [-1.6470e+00, -1.2531e+00],
        [-1.6022e+00, -1.2415e+00],
        [-1.5988e+00, -1.2345e+00],
        [-1.5576e+00, -1.1808e+00],
        [-1.5469e+00, -1.1702e+00],
        [-1.5072e+00, -1.1426e+00],
        [-1.4790e+00, -1.1299e+00],
        [-1.4575e+00, -1.1209e+0

In [13]:
grid

tensor([[-5.8583, -4.7705, -3.6827, -2.5950, -1.5072, -0.4194,  0.6684,  1.7562,
          2.8440,  3.9318,  5.0195,  6.1073],
        [-5.8313, -4.7733, -3.7154, -2.6575, -1.5996, -0.5416,  0.5163,  1.5742,
          2.6321,  3.6901,  4.7480,  5.8059]])

## pykan grid update

TLDR; TODO

## Personal grid update tests

In [1]:
import jax
import jax.numpy as jnp
import math
from flax import linen as nn
from flax.training import train_state
from flax.linen import initializers

class KANLayer(nn.Module):
    
    n_in: int = 2
    n_out: int = 5
    G: int = 5
    grid_range: tuple = (-1, 1)
    k: int = 3

    const_spl: float or bool = False
    const_res: float or bool = False
    residual: nn.Module = nn.swish
    
    noise_std: float = 0.1
    grid_e: float = 0.02

    
    def setup(self):
        # Calculate the step size for the knot vector based on its end values
        h = (self.grid_range[1] - self.grid_range[0]) / (self.G - 1)

        # Create the initial knot vector and perform augmentation
        # Now it is expanded from G+1 points to G+1 + 2k points, because k points are appended at each of its ends
        grid = jnp.arange(-self.k, self.G + self.k + 1, dtype=jnp.float32) * h + self.grid_range[0]
        
        # Expand for broadcasting - the shape becomes (n_in*n_out, G + 2k + 1), so that the grid
        # can be passed in all n_in*n_out spline basis functions simultaneously
        grid = jnp.expand_dims(grid, axis=0)
        grid = jnp.tile(grid, (self.n_in*self.n_out, 1))

        # Store the grid as a non trainable variable
        self.grid = self.variable('state', 'grid', lambda: grid)
        
        # Register & initialize the spline basis functions' coefficients as trainable parameters
        # They are drawn from a normal distribution with zero mean and an std of noise_std
        self.c_basis = self.param('c_basis', initializers.normal(stddev=self.noise_std), (self.n_in * self.n_out, self.G.value + self.k))
        
        # If const_spl is set as a float value, treat it as non trainable and pass it to the c_spl array with shape (n_in*n_out)
        # Otherwise register it as a trainable parameter of the same size and initialize it
        if isinstance(self.const_spl, float):
            self.c_spl = jnp.ones(self.n_in*self.n_out) * self.const_spl
        elif self.const_spl is False:
            self.c_spl = self.param('c_spl', initializers.constant(1.0), (self.n_in * self.n_out,))

        # If const_res is set as a float value, treat it as non trainable and pass it to the c_res array with shape (n_in*n_out)
        # Otherwise register it as a trainable parameter of the same size and initialize it
        if isinstance(self.const_res, float):
            self.c_res = jnp.ones(self.n_in * self.n_out) * self.const_res
        elif self.const_res is False:
            self.c_res = self.param('c_res', initializers.constant(1.0), (self.n_in * self.n_out,))

    def update_grid(self):
        # Example update function
        self.grid.value = self.grid.value*2.25

    def __call__(self, x):
        # Dummy to see correct grid update
        return x*self.grid.value[0][0]


In [3]:
# Instantiate a layer
kan_layer = KANLayer(n_in=2, n_out=3, G=5, grid_range=(-1,1), k=3)

# Initialization
rng = jax.random.PRNGKey(0)
variables = kan_layer.init(rng, jnp.ones((2, 3)))

# Print the trainable parameters and non-trainable state
print("Trainable parameters:")
print(variables['params'])

print("\nNon-trainable state variables:")
print(variables['state'])

Trainable parameters:
{'c_basis': Array([[ 0.02399894,  0.1846351 , -0.00519568, -0.00863923,  0.0585282 ,
        -0.15037541, -0.01022594, -0.02914488],
       [ 0.04935711,  0.03238467, -0.15360688,  0.12719317,  0.04799316,
         0.05530524, -0.02173939,  0.19702375],
       [-0.06749373,  0.0200477 ,  0.23864464, -0.1581935 ,  0.0072664 ,
        -0.10175776,  0.00515934,  0.01749068],
       [-0.00339971, -0.09349852,  0.07724477,  0.04379462,  0.09954414,
         0.04300999,  0.02973044, -0.01073924],
       [ 0.05613207, -0.06796809,  0.07379916, -0.0517869 , -0.0697948 ,
         0.00100837,  0.05920225, -0.07538771],
       [-0.01342505, -0.02973855, -0.00334156, -0.09072405,  0.01592094,
        -0.04852676,  0.02748392, -0.06365392]], dtype=float32), 'c_spl': Array([1., 1., 1., 1., 1., 1.], dtype=float32), 'c_res': Array([1., 1., 1., 1., 1., 1.], dtype=float32)}

Non-trainable state variables:
{'grid': Array([[-2.5, -2. , -1.5, -1. , -0.5,  0. ,  0.5,  1. ,  1.5,  2. , 

In [4]:
# Simple forward pass
x, _ = kan_layer.apply(variables, jnp.ones((2,3)))

print(x)

[-2.5 -2.5 -2.5]


In [5]:
# No forward pass, but applying the update_grid function
x, new_state = kan_layer.apply(variables, method=kan_layer.update_grid, mutable=['state'])

print(new_state)

{'state': {'grid': Array([[-5.625, -4.5  , -3.375, -2.25 , -1.125,  0.   ,  1.125,  2.25 ,
         3.375,  4.5  ,  5.625,  6.75 ],
       [-5.625, -4.5  , -3.375, -2.25 , -1.125,  0.   ,  1.125,  2.25 ,
         3.375,  4.5  ,  5.625,  6.75 ],
       [-5.625, -4.5  , -3.375, -2.25 , -1.125,  0.   ,  1.125,  2.25 ,
         3.375,  4.5  ,  5.625,  6.75 ],
       [-5.625, -4.5  , -3.375, -2.25 , -1.125,  0.   ,  1.125,  2.25 ,
         3.375,  4.5  ,  5.625,  6.75 ],
       [-5.625, -4.5  , -3.375, -2.25 , -1.125,  0.   ,  1.125,  2.25 ,
         3.375,  4.5  ,  5.625,  6.75 ],
       [-5.625, -4.5  , -3.375, -2.25 , -1.125,  0.   ,  1.125,  2.25 ,
         3.375,  4.5  ,  5.625,  6.75 ]], dtype=float32)}}


In [6]:
# Update variables, pass to the model and retry forward pass with new grid
variables['state'] = new_state['state']

x, state = kan_layer.apply(variables, jnp.ones((2, 3)))

print(x)

[-5.625 -5.625 -5.625]


In [7]:
# Apply the update again
x, new_state = kan_layer.apply(variables, method=kan_layer.update_grid, mutable=['state'])

print(new_state)

{'state': {'grid': Array([[-12.65625, -10.125  ,  -7.59375,  -5.0625 ,  -2.53125,   0.     ,
          2.53125,   5.0625 ,   7.59375,  10.125  ,  12.65625,  15.1875 ],
       [-12.65625, -10.125  ,  -7.59375,  -5.0625 ,  -2.53125,   0.     ,
          2.53125,   5.0625 ,   7.59375,  10.125  ,  12.65625,  15.1875 ],
       [-12.65625, -10.125  ,  -7.59375,  -5.0625 ,  -2.53125,   0.     ,
          2.53125,   5.0625 ,   7.59375,  10.125  ,  12.65625,  15.1875 ],
       [-12.65625, -10.125  ,  -7.59375,  -5.0625 ,  -2.53125,   0.     ,
          2.53125,   5.0625 ,   7.59375,  10.125  ,  12.65625,  15.1875 ],
       [-12.65625, -10.125  ,  -7.59375,  -5.0625 ,  -2.53125,   0.     ,
          2.53125,   5.0625 ,   7.59375,  10.125  ,  12.65625,  15.1875 ],
       [-12.65625, -10.125  ,  -7.59375,  -5.0625 ,  -2.53125,   0.     ,
          2.53125,   5.0625 ,   7.59375,  10.125  ,  12.65625,  15.1875 ]],      dtype=float32)}}


In [8]:
# What happens if we do not pass mutable=['state'] ?
x, new_state = kan_layer.apply(variables, method=kan_layer.update_grid)

ModifyScopeVariableError: Cannot update variable "grid" in "/" because collection "state" is immutable. (https://flax.readthedocs.io/en/latest/api_reference/flax.errors.html#flax.errors.ModifyScopeVariableError)

### A very rough example on how this would look during training

```
import optax  # For the optimizer
from flax.training import train_state

class TrainState(train_state.TrainState):
    grid: jnp.ndarray

def create_train_state(rng, model, learning_rate):
    variables = model.init(rng, jnp.ones((1, 2)), mutable=['non_trainable'])
    params = variables['params']
    grid = variables['non_trainable']['grid']
    tx = optax.adam(learning_rate)
    return TrainState.create(apply_fn=model.apply, params=params, tx=tx, grid=grid)

def train_model(state, model, epochs, data_loader):
    for epoch in range(epochs):
        for batch in data_loader:
            state = train_step(state, batch)
        
        if (epoch + 1) % 5 == 0:
            state = update_grid_state(state, model)
            print(f"Grid updated at epoch {epoch + 1}")

    return state

@jax.jit
def train_step(state, batch):
    def loss_fn(params):
        logits = state.apply_fn({'params': params}, batch['x'])
        loss = jnp.mean((logits - batch['y']) ** 2)
        return loss
    grads = jax.grad(loss_fn)(state.params)
    state = state.apply_gradients(grads=grads)
    return state

@jax.jit
def update_grid_state(state, model):
    variables = {'params': state.params, 'non_trainable': {'grid': state.grid}}
    _, updated_state = model.apply(variables, jnp.ones((1, 2)), method=model.update_grid, mutable=['non_trainable'])
    state = state.replace(grid=updated_state['non_trainable']['grid'])
    return state

# Example data loader
data_loader = [{'x': jnp.ones((1, 2)), 'y': jnp.ones((1, 2))} for _ in range(10)]

# Initialize model and state
rng = jax.random.PRNGKey(0)
kan_layer = KANLayer(n_in=2, n_out=3, G=5, grid_range=(-1, 1), k=3)
state = create_train_state(rng, kan_layer, learning_rate=0.001)

# Train the model
state = train_model(state, kan_layer, epochs=20, data_loader=data_loader)

# Check final grid value
print("Final Grid shape:", state.grid.shape)
print(state.grid)
```