# Spline base functions tests

Note to self: we should be performing grid augmentation outside of the function call, unlike what they do in pykan.

In [1]:
in_dim=3
out_dim=2
G=10
k=3
grid_range=[-1, 1]

In [2]:
import torch

# pykan implementation
def B_batch(x, grid, k=3):

    grid = grid.unsqueeze(dim=2)
    x = x.unsqueeze(dim=1)

    if k == 0:
        value = (x >= grid[:, :-1]) * (x < grid[:, 1:])
    else:
        B_km1 = B_batch(x[:, 0], grid=grid[:, :, 0], k=k - 1)
        value = (x - grid[:, :-(k + 1)]) / (grid[:, k:-1] - grid[:, :-(k + 1)]) * B_km1[:, :-1] + (grid[:, k + 1:] - x) / (grid[:, k + 1:] - grid[:, 1:(-k)]) * B_km1[:, 1:]
    
    return value

# efficientkan implementation
def b_splines(x, grid, K=3):
    x = x.unsqueeze(-1)
    bases = ((x >= grid[:, :-1]) & (x < grid[:, 1:])).float()
    for k in range(1, K + 1):
        bases = (
            (x - grid[:, : -(k + 1)])
            / (grid[:, k:-1] - grid[:, : -(k + 1)])
            * bases[:, :, :-1]
        ) + (
            (grid[:, k + 1 :] - x)
            / (grid[:, k + 1 :] - grid[:, 1:(-k)])
            * bases[:, :, 1:]
        )

    return bases


In [3]:
# Sample points
x = torch.normal(0,1,size=(in_dim*out_dim, 100))
# Sample grid
grid = torch.einsum('i,j->ij', torch.ones(in_dim*out_dim), torch.linspace(grid_range[0], grid_range[1], steps=G + 1))
print(grid.shape)
k = 3
# Grid augmentation
h = (grid[:, [-1]] - grid[:, [0]]) / (grid.shape[1] - 1)
for i in range(k):
    grid = torch.cat([grid[:, [0]] - h, grid], dim=1)
    grid = torch.cat([grid, grid[:, [-1]] + h], dim=1)
print(grid.shape)

torch.Size([6, 11])
torch.Size([6, 17])


In [4]:
method_1 = B_batch(x,grid,k)
method_2 = b_splines(x.T,grid,k)

print("Shapes before permutation:")
print(method_1.shape)
print(method_2.shape)

# Permute to get same shapes
method_2 = method_2.permute(1, 2, 0)
print("Shape of second tensor after permutation:")
print(method_2.shape)

Shapes before permutation:
torch.Size([6, 13, 100])
torch.Size([100, 6, 13])
Shape of second tensor after permutation:
torch.Size([6, 13, 100])


In [5]:
print(f"The two object have {(method_1 == method_2).sum()} out of {method_2.flatten().shape[0]} values equal.")

The two object have 7800 out of 7800 values equal.


Timing to show why the second case is better, as it does not involve recursive function calls.

In [6]:
import timeit

# Make bigger arrays
x_big = torch.normal(0,1,size=(111, 1000))
grid_big = torch.einsum('i,j->ij', torch.ones(111), torch.linspace(grid_range[0], grid_range[1], steps=15))
kappa = 10
h = (grid_big[:, [-1]] - grid_big[:, [0]]) / (grid_big.shape[1] - 1)
for i in range(kappa):
    grid_big = torch.cat([grid_big[:, [0]] - h, grid_big], dim=1)
    grid_big = torch.cat([grid_big, grid_big[:, [-1]] + h], dim=1)

# Wrappers for timing
def timed_function1():
    return B_batch(x_big, grid_big, kappa)

def timed_function2():
    return b_splines(x_big.T, grid_big, kappa)

elapsed_time_1 = timeit.timeit(timed_function1, number=100)
print(f"pykan implementation: {elapsed_time_1}")
elapsed_time_2 = timeit.timeit(timed_function2, number=100)
print(f"efficientkan implementation: {elapsed_time_2}")


pykan implementation: 7.8689660001546144
efficientkan implementation: 7.646361299790442


For some reason it is not faster. In any case, let's try to write these in JAX.

In [7]:
import jax.numpy as jnp

In [8]:
def jB_batch(x, grid, k=3):
    grid = jnp.expand_dims(grid, axis=2)
    x = jnp.expand_dims(x, axis=1)

    if k == 0:
        value = (x >= grid[:, :-1]) & (x < grid[:, 1:])
    else:
        B_km1 = jB_batch(x[:, 0], grid=grid[:, :, 0], k=k - 1)
        value = ((x - grid[:, :-(k + 1)]) / (grid[:, k:-1] - grid[:, :-(k + 1)])) * B_km1[:, :-1] + ((grid[:, k + 1:] - x) / (grid[:, k + 1:] - grid[:, 1:(-k)])) * B_km1[:, 1:]
    
    return value.astype(float)

In [9]:
def jb_splines(x, grid, K=3):
    x = jnp.expand_dims(x, axis=-1)
    bases = ((x >= grid[:, :-1]) & (x < grid[:, 1:])).astype(float)
    
    for k in range(1, K+1):
        left_term = (x - grid[:, :-(k + 1)]) / (grid[:, k:-1] - grid[:, :-(k + 1)])
        right_term = (grid[:, k + 1:] - x) / (grid[:, k + 1:] - grid[:, 1:(-k)])
        
        bases = left_term * bases[:, :, :-1] + right_term * bases[:, :, 1:]

    return bases

In [10]:
# Convert to jnp arrays
jx_big = jnp.array(x_big.numpy())
jgrid_big = jnp.array(grid_big.numpy())

def jbatch_time():
    return jB_batch(jx_big, jgrid_big, kappa)

jelapsed_time_1 = timeit.timeit(jbatch_time, number=100)
print(f"pykan - jax: {jelapsed_time_1}")

pykan - jax: 17.589296099729836


In [11]:
def jsplines_time():
    return jb_splines(jx_big.T, jgrid_big, kappa)

jelapsed_time_2 = timeit.timeit(jsplines_time, number=100)
print(f"efficientkan - jax: {jelapsed_time_2}")

efficientkan - jax: 16.92286629974842


Well, now we're getting the expected result that the loop is faster than the recursive call, however these times are prohibitive (3 times slower than pytorch on CPU). Let's try compiling the second version for both PyTorch and JAX.

In [12]:
@torch.jit.script
def jit_b_splines(x, grid, K: int = 3):
    x = x.unsqueeze(-1)
    bases = ((x >= grid[:, :-1]) & (x < grid[:, 1:])).float()
    for k in range(1, K + 1):
        left_term = (x - grid[:, :-(k + 1)]) / (grid[:, k:-1] - grid[:, :-(k + 1)])
        right_term = (grid[:, k + 1:] - x) / (grid[:, k + 1:] - grid[:, 1:(-k)])
        bases = left_term * bases[:, :, :-1] + right_term * bases[:, :, 1:]

    return bases

In [13]:
from functools import partial
from jax import jit

# Partial applies the jit decorator with static arguments, i.e. which should be kept constant for compilation
# but would require a re-compilation if its value changes
# We don't expect k to change throughout a single run
@partial(jit, static_argnums=(2,))
def jit_jb_splines(x, grid, K=3):
    x = jnp.expand_dims(x, axis=-1)
    bases = ((x >= grid[:, :-1]) & (x < grid[:, 1:])).astype(float)
    
    for k in range(1, K+1):
        left_term = (x - grid[:, :-(k + 1)]) / (grid[:, k:-1] - grid[:, :-(k + 1)])
        right_term = (grid[:, k + 1:] - x) / (grid[:, k + 1:] - grid[:, 1:(-k)])
        
        bases = left_term * bases[:, :, :-1] + right_term * bases[:, :, 1:]

    return bases

In [14]:
def torch_jit():
    return jit_b_splines(x_big.T, grid_big, kappa)

def jax_jit():
    return jit_jb_splines(jx_big.T, jgrid_big, kappa)

time_torch = timeit.timeit(torch_jit, number=100)
time_jax = timeit.timeit(jax_jit, number=100)
print(f"Torch time: {time_torch}")
print(f"JAX time: {time_jax}")

Torch time: 7.70453180000186
JAX time: 0.6737136002629995


Oh well, this is why JAX is cool.

## Other random tests

Let's do some similar tests for more basic functions.

In [15]:
def batch_matmul_pytorch(A, B):
    return torch.bmm(A, B)

def batch_matmul_jax(A, B):
    return jnp.matmul(A, B)

batch_matmul_pytorch_jit = torch.jit.script(batch_matmul_pytorch)
batch_matmul_jax_jit = jit(batch_matmul_jax)

# Create large batch matrices
batch_size = 500
dim = 1024

# PyTorch tensors
A_torch = torch.randn(batch_size, dim, dim, dtype=torch.float32)
B_torch = torch.randn(batch_size, dim, dim, dtype=torch.float32)

# JAX arrays
A_jax = jnp.array(A_torch.numpy())
B_jax = jnp.array(B_torch.numpy())

In [16]:
# Setup number of runs
number = 10

# Time non-JIT functions
torch_time = timeit.timeit('batch_matmul_pytorch(A_torch, B_torch)', globals=globals(), number=number)
jax_time = timeit.timeit('batch_matmul_jax(A_jax, B_jax)', globals=globals(), number=number)

# Time JIT functions
torch_jit_time = timeit.timeit('batch_matmul_pytorch_jit(A_torch, B_torch)', globals=globals(), number=number)
jax_jit_time = timeit.timeit('batch_matmul_jax_jit(A_jax, B_jax)', globals=globals(), number=number)

# Output results
print(f"PyTorch time without JIT: {torch_time / number:.5f} seconds per loop")
print(f"JAX time without JIT: {jax_time / number:.5f} seconds per loop")
print(f"PyTorch time with JIT: {torch_jit_time / number:.5f} seconds per loop")
print(f"JAX time with JIT: {jax_jit_time / number:.5f} seconds per loop")

PyTorch time without JIT: 1.77455 seconds per loop
JAX time without JIT: 1.85137 seconds per loop
PyTorch time with JIT: 1.90362 seconds per loop
JAX time with JIT: 1.88739 seconds per loop


Welp, here the performance is comparable. Let's try with for loops, perhaps that's where the big difference is.

In [17]:
@torch.jit.script
def iterative_matmul_pytorch(A, B, num_iters: int = 10):  # Explicitly type the parameter
    result = torch.zeros_like(A)
    for _ in range(num_iters):
        result += torch.bmm(A, B)
        A = A * 0.99  # simulate some decay or transformation
    return result

@jit
def iterative_matmul_jax(A, B, num_iters=10):
    result = jnp.zeros_like(A)
    for _ in range(num_iters):
        result += jnp.matmul(A, B)
        A = A * 0.99  # simulate some decay or transformation
    return result

# Using a smaller size for easier handling on CPUs
batch_size = 100
dim = 512

# PyTorch tensors
A_torch = torch.randn(batch_size, dim, dim, dtype=torch.float32)
B_torch = torch.randn(batch_size, dim, dim, dtype=torch.float32)

# JAX arrays
A_jax = jnp.array(A_torch.numpy())
B_jax = jnp.array(B_torch.numpy())

# Setup number of runs
number = 10

torch_time = timeit.timeit('iterative_matmul_pytorch(A_torch, B_torch)', globals=globals(), number=number)
jax_time = timeit.timeit('iterative_matmul_jax(A_jax, B_jax)', globals=globals(), number=number)

# Output results
print(f"PyTorch time: {torch_time / number:.5f} seconds per loop")
print(f"JAX time: {jax_time / number:.5f} seconds per loop")

PyTorch time: 0.64673 seconds per loop
JAX time: 0.58546 seconds per loop


Yeah, this is probably it.