# Tests Notebook

This notebook is reserved for tests during development.

In [1]:
# Successful installation of JAX test
from jax.nn import silu
import jax.numpy as jnp

x = jnp.arange(5.0)
print(silu(x))

[0.        0.7310586 1.761594  2.8577223 3.928055 ]


## pykan & efficientkan tests

Debugging and testing the codebases to get a better understanding of used modules. We also installed PyTorch for this purpose, however it will not be necessary for the final version.

In [2]:
import torch

import torch.nn as nn
import numpy as np

### Spline base functions tests

Note to self: we should be performing grid augmentation outside of the function call, unlike what they do in pykan.

In [3]:
in_dim=3
out_dim=2
G=10
k=3
grid_range=[-1, 1]

In [4]:
import torch

# pykan implementation
def B_batch(x, grid, k=3):

    grid = grid.unsqueeze(dim=2)
    x = x.unsqueeze(dim=1)

    if k == 0:
        value = (x >= grid[:, :-1]) * (x < grid[:, 1:])
    else:
        B_km1 = B_batch(x[:, 0], grid=grid[:, :, 0], k=k - 1)
        value = (x - grid[:, :-(k + 1)]) / (grid[:, k:-1] - grid[:, :-(k + 1)]) * B_km1[:, :-1] + (grid[:, k + 1:] - x) / (grid[:, k + 1:] - grid[:, 1:(-k)]) * B_km1[:, 1:]
    
    return value

# efficientkan implementation
def b_splines(x, grid, K=3):
    x = x.unsqueeze(-1)
    bases = ((x >= grid[:, :-1]) & (x < grid[:, 1:])).float()
    for k in range(1, K + 1):
        bases = (
            (x - grid[:, : -(k + 1)])
            / (grid[:, k:-1] - grid[:, : -(k + 1)])
            * bases[:, :, :-1]
        ) + (
            (grid[:, k + 1 :] - x)
            / (grid[:, k + 1 :] - grid[:, 1:(-k)])
            * bases[:, :, 1:]
        )

    return bases


In [6]:
# Sample points
x = torch.normal(0,1,size=(in_dim*out_dim, 100))
# Sample grid
grid = torch.einsum('i,j->ij', torch.ones(in_dim*out_dim), torch.linspace(grid_range[0], grid_range[1], steps=G + 1))
print(grid.shape)
k = 3
# Grid augmentation
h = (grid[:, [-1]] - grid[:, [0]]) / (grid.shape[1] - 1)
for i in range(k):
    grid = torch.cat([grid[:, [0]] - h, grid], dim=1)
    grid = torch.cat([grid, grid[:, [-1]] + h], dim=1)
print(grid.shape)

torch.Size([6, 11])
torch.Size([6, 17])


In [7]:
method_1 = B_batch(x,grid,k)
method_2 = b_splines(x.T,grid,k)

print("Shapes before permutation:")
print(method_1.shape)
print(method_2.shape)

# Permute to get same shapes
method_2 = method_2.permute(1, 2, 0)
print("Shape of second tensor after permutation:")
print(method_2.shape)

Shapes before permutation:
torch.Size([6, 13, 100])
torch.Size([100, 6, 13])
Shape of second tensor after permutation:
torch.Size([6, 13, 100])


In [8]:
print(f"The two object have {(method_1 == method_2).sum()} out of {method_2.flatten().shape[0]} values equal.")

The two object have 7800 out of 7800 values equal.


Timing to show why the second case is better, as it does not involve recursive function calls.

In [15]:
import timeit

# Make bigger arrays
x_big = torch.normal(0,1,size=(111, 1000))
grid_big = torch.einsum('i,j->ij', torch.ones(111), torch.linspace(grid_range[0], grid_range[1], steps=15))
kappa = 10
h = (grid_big[:, [-1]] - grid_big[:, [0]]) / (grid_big.shape[1] - 1)
for i in range(kappa):
    grid_big = torch.cat([grid_big[:, [0]] - h, grid_big], dim=1)
    grid_big = torch.cat([grid_big, grid_big[:, [-1]] + h], dim=1)

# Wrappers for timing
def timed_function1():
    return B_batch(x_big, grid_big, kappa)

def timed_function2():
    return b_splines(x_big.T, grid_big, kappa)

elapsed_time_1 = timeit.timeit(timed_function1, number=100)
print(f"pykan implementation: {elapsed_time_1}")
elapsed_time_2 = timeit.timeit(timed_function2, number=100)
print(f"efficientkan implementation: {elapsed_time_2}")


pykan implementation: 7.190586400218308
efficientkan implementation: 7.479282299987972


For some reason it is not faster. In any case, let's try to write these in JAX.

In [48]:
def jB_batch(x, grid, k=3):
    grid = jnp.expand_dims(grid, axis=2)
    x = jnp.expand_dims(x, axis=1)

    if k == 0:
        value = (x >= grid[:, :-1]) & (x < grid[:, 1:])
    else:
        B_km1 = jB_batch(x[:, 0], grid=grid[:, :, 0], k=k - 1)
        value = ((x - grid[:, :-(k + 1)]) / (grid[:, k:-1] - grid[:, :-(k + 1)])) * B_km1[:, :-1] + ((grid[:, k + 1:] - x) / (grid[:, k + 1:] - grid[:, 1:(-k)])) * B_km1[:, 1:]
    
    return value.astype(float)

In [68]:
def jb_splines(x, grid, K=3):
    x = jnp.expand_dims(x, axis=-1)
    bases = ((x >= grid[:, :-1]) & (x < grid[:, 1:])).astype(float)
    
    for k in range(1, K + 1):
        left_term = (x - grid[:, :-(k + 1)]) / (grid[:, k:-1] - grid[:, :-(k + 1)])
        right_term = (grid[:, k + 1:] - x) / (grid[:, k + 1:] - grid[:, 1:(-k)])
        
        bases = left_term * bases[:, :, :-1] + right_term * bases[:, :, 1:]

    return bases

In [53]:
# Convert to jnp arrays
jx_big = jnp.array(x_big.numpy())
jgrid_big = jnp.array(grid_big.numpy())

def jbatch_time():
    return jB_batch(jx_big, jgrid_big, kappa)

jelapsed_time_1 = timeit.timeit(jbatch_time, number=100)
print(f"pykan - jax: {jelapsed_time_1}")

pykan - jax: 22.595613399986178


In [69]:
def jsplines_time():
    return jb_splines(jx_big.T, jgrid_big, kappa)

jelapsed_time_2 = timeit.timeit(jsplines_time, number=100)
print(f"efficientkan - jax: {jelapsed_time_2}")

efficientkan - jax: 20.94289659988135


Well, now we're getting the expected result that the loop is faster than the recursive call, however these times are prohibitive (3 times slower than pytorch on CPU).