# Tests on Grid generation & extension

What follows are jax samples based on pykan or efficientkan, as far as the grid generation and extension process is concerned, to get a better understanding of what's going on.

In [1]:
import jax.numpy as jnp

In [2]:
# Input parameters relevant to grid
k = 3 # splines order
G = 5 # grid size (number of knot vector elements)
n_in = 2 # number of layer's input nodes
n_out = 3 # number of layer's output nodes

**<span style="color:red">Open Question: Why are we starting with a grid defined only by its range?</span>**

In [3]:
# Grid Initialization - the grid is originally just a knot vector
grid_range = [-1, 1]

**<span style="color:red">Open Question: Why are we augmenting the grid like this instead of appending the first and last entries k times at the start and end?</span>**

In [4]:
# Grid augmentation
h = (grid_range[-1] - grid_range[0]) / (G-1)
grid = (jnp.arange(-k, G + k + 1, dtype=jnp.float32) * h + grid_range[0])

grid

Array([-2.5, -2. , -1.5, -1. , -0.5,  0. ,  0.5,  1. ,  1.5,  2. ,  2.5,
        3. ], dtype=float32)

In [5]:
# Expand for broadcasting - the shape must be (layer_size, G + k), so that the grid
# can be passed in get_spline_basis and return a Bi(x) array, where we need only multiply
# with the coefficients to get the full spline(x), for each layer's spline (n_in * n_out total).
grid = jnp.expand_dims(grid, axis=0)
grid = jnp.tile(grid, (n_in*n_out, 1))

grid

Array([[-2.5, -2. , -1.5, -1. , -0.5,  0. ,  0.5,  1. ,  1.5,  2. ,  2.5,
         3. ],
       [-2.5, -2. , -1.5, -1. , -0.5,  0. ,  0.5,  1. ,  1.5,  2. ,  2.5,
         3. ],
       [-2.5, -2. , -1.5, -1. , -0.5,  0. ,  0.5,  1. ,  1.5,  2. ,  2.5,
         3. ],
       [-2.5, -2. , -1.5, -1. , -0.5,  0. ,  0.5,  1. ,  1.5,  2. ,  2.5,
         3. ],
       [-2.5, -2. , -1.5, -1. , -0.5,  0. ,  0.5,  1. ,  1.5,  2. ,  2.5,
         3. ],
       [-2.5, -2. , -1.5, -1. , -0.5,  0. ,  0.5,  1. ,  1.5,  2. ,  2.5,
         3. ]], dtype=float32)

This is the initialization of the grid per layer (registered at `self.`). Then, we have the `upgrade_grid` routine, which is supposed to perform an update for each layer's grid (**can be trainable or not in the future**).

In pykan, the `update_grid_from_samples` routine simply performs the mix of adaptive and uniform sampling, given the current grid and the batch. It is the `initialize_grid_from_parent` routine that performs the grid extension. On the other hand, in efficientkan both are handled together in the `update_grid` routine. In what follows, we try to follow the approach of efficientkan, however working with grids of shape $(n_{in}\cdot n_{out},\text{batch})$ instead of $(n_{in},\text{batch})$.

**<span style="color:red">IDEA: Keep splines as model param function, same as efficientkan, because the grid is also. Then $k$ can be inherited from self too, and the whole thing can normally be jitted.</span>** 

## Updating the Grid

Steps:

* 

In [7]:
import torch

In [24]:
batch = 300


In [None]:
# efficient kan
def update_grid(x, margin=0.01):
    batch = x.size(0)

    splines = self.b_splines(x)  # (batch, in, coeff)
    splines = splines.permute(1, 0, 2)  # (in, batch, coeff)
    orig_coeff = self.scaled_spline_weight  # (out, in, coeff)
    orig_coeff = orig_coeff.permute(1, 2, 0)  # (in, coeff, out)
    unreduced_spline_output = torch.bmm(splines, orig_coeff)  # (in, batch, out)
    unreduced_spline_output = unreduced_spline_output.permute(
        1, 0, 2
    )  # (batch, in, out)

    # sort each channel individually to collect data distribution
    x_sorted = torch.sort(x, dim=0)[0]
    grid_adaptive = x_sorted[
        torch.linspace(
            0, batch - 1, self.grid_size + 1, dtype=torch.int64, device=x.device
        )
    ]

    uniform_step = (x_sorted[-1] - x_sorted[0] + 2 * margin) / self.grid_size
    grid_uniform = (
        torch.arange(
            self.grid_size + 1, dtype=torch.float32, device=x.device
        ).unsqueeze(1)
        * uniform_step
        + x_sorted[0]
        - margin
    )

    grid = self.grid_eps * grid_uniform + (1 - self.grid_eps) * grid_adaptive
    grid = torch.concatenate(
        [
            grid[:1]
            - uniform_step
            * torch.arange(self.spline_order, 0, -1, device=x.device).unsqueeze(1),
            grid,
            grid[-1:]
            + uniform_step
            * torch.arange(1, self.spline_order + 1, device=x.device).unsqueeze(1),
        ],
        dim=0,
    )

    self.grid.copy_(grid.T)
    self.spline_weight.data.copy_(self.curve2coeff(x, unreduced_spline_output))