# Testing functionality of TKM and JAX

In [24]:
%load_ext autoreload
%autoreload 2

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [25]:
import os
os.environ['XLA_PYTHON_CLIENT_MEM_FRACTION']='.1'

In [26]:
import jax
# cpu = jax.devices("cpu")[0]
# jax.default_device(cpu)

In [27]:
# from jax.config import config
# config.update("jax_enable_x64", True)

## Load Banana Data

In [28]:
import pandas as pd
from pathlib import Path

In [29]:
data_path = Path("../data/banana.csv")
df_banana = pd.read_csv(data_path)

In [30]:
df_banana.head()

Unnamed: 0,V1,V2,Class
0,1.14,-0.114,1
1,-1.52,-1.15,2
2,-1.05,0.72,1
3,-0.916,0.397,2
4,-1.09,0.437,2


In [31]:
# expecting labels 1 and -1

# transform label 2 to -1
df_banana.loc[df_banana.Class == 2, 'Class'] = -1
df_banana.head()

Unnamed: 0,V1,V2,Class
0,1.14,-0.114,1
1,-1.52,-1.15,-1
2,-1.05,0.72,1
3,-0.916,0.397,-1
4,-1.09,0.437,-1


In [32]:
# standardize

# df_banana[['V1','V2']] = (df_banana[['V1','V2']] -  df_banana[['V1','V2']].min() )/ (df_banana[['V1','V2']].max() - df_banana[['V1','V2']].min())
# df_banana.head()

In [33]:
feature_names = ['V1', 'V2']
x = df_banana[feature_names].to_numpy()

label_name = ['Class']
y = df_banana[label_name].to_numpy()

In [34]:
# rescale

x = (x - x.min(axis=0)) / (x.max(axis=0) - x.min(axis=0))

### Data to JAX array

In [35]:
from jax import numpy as jnp
x = jnp.array(x)
y = jnp.array(y)

In [36]:
from jax import random
key = random.PRNGKey(42)

## Tensor Kernel Machine

### Polynomial

In [37]:
from tkm.model import TensorizedKernelMachine as TKM
from jax import jit
from tkm.features import polynomial

model = TKM(M=12,R=6,features=polynomial)

In [38]:
W = model.fit(key,x,y) #,M=12,R=6,feature_map=polynomial)

In [39]:
%timeit model.fit(key,x,y) #,M=12,R=6,feature_map=polynomial)

30.8 ms ± 976 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)


#### Predict

In [40]:
from tkm.metrics import accuracy

y_hat = model.predict(x, W) #, feature_map=polynomial, M=12,R=6)
acc = accuracy(y.squeeze(), y_hat)
acc

Array(0.8228302, dtype=float32)

In [41]:
%timeit model.predict(x, W) #, M=12,R=6, feature_map=polynomial)

985 µs ± 6.85 µs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)


#### fourier features

In [42]:
from tkm.model import TensorizedKernelMachine
from tkm.features import fourier

model_fourier = TensorizedKernelMachine(M=12,R=6,lengthscale=0.5,feature_map=fourier)
W = model_fourier.fit(key,x,y)
y_hat = model_fourier.predict(x, W)
acc = accuracy(y.squeeze(), y_hat)
print(acc)

0.8228302


## batched dotkron

In [43]:
from tkm.model import TensorizedKernelMachine 

model_batch = TensorizedKernelMachine(M=12,R=6,feature_map=polynomial,batch_size=10)
W_batch = model_batch.fit(key,x,y)

In [44]:
y_hat = model_batch.predict(x,W_batch)
acc = accuracy(y.squeeze(), y_hat)
print(acc)

0.8228302
