# Testing functionality of TKM and JAX

In [11]:
%load_ext autoreload
%autoreload 2

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


## Load Banana Data

In [12]:
import pandas as pd
from pathlib import Path

In [13]:
data_path = Path("../data/banana.csv")
df_banana = pd.read_csv(data_path)

In [14]:
df_banana.head()

Unnamed: 0,V1,V2,Class
0,1.14,-0.114,1
1,-1.52,-1.15,2
2,-1.05,0.72,1
3,-0.916,0.397,2
4,-1.09,0.437,2


In [15]:
# expecting labels 1 and -1

# transform label 2 to -1
df_banana.loc[df_banana.Class == 2, 'Class'] = -1
df_banana.head()

Unnamed: 0,V1,V2,Class
0,1.14,-0.114,1
1,-1.52,-1.15,-1
2,-1.05,0.72,1
3,-0.916,0.397,-1
4,-1.09,0.437,-1


In [16]:
feature_names = ['V1', 'V2']
x = df_banana[feature_names].to_numpy()

label_name = ['Class']
y = df_banana[label_name].to_numpy()

### Data to JAX array

In [17]:
from jax import numpy as jnp
x = jnp.array(x)
y = jnp.array(y)

In [18]:
from jax import random
key = random.PRNGKey(42)

## Tensor Kernel Machine

In [19]:
from tkm.model import init

# W, reg, Matd = init(key, x)

In [23]:
from tkm.model import fit
from jax import jit

W = fit(key,x,y)

fit_compiled = jit(fit)

%timeit W = fit(key,x,y)
%timeit W = fit_compiled(key, x, y)

91.8 ms ± 1.91 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)
9.75 µs ± 57.3 ns per loop (mean ± std. dev. of 7 runs, 100,000 loops each)


### Predict

In [25]:
from tkm.model import predict
predict_compiled = jit(predict)

%timeit predict(x, W)
%timeit predict_compiled(x,W)

8.31 ms ± 236 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)


In [None]:
# polynomial = compile_feature_map(M=M)
# %timeit scores = predict(x, W) 
# 29.9 ms ± 258 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)

In [35]:
from tkm.model import predict_vmap
predict_vmap_compiled = jit(predict_vmap)

%timeit predict_vmap(x, W)
%timeit predict_vmap_compiled(x,W)

31.2 ms ± 74.1 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)
1.1 ms ± 2.41 µs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)
