# Testing functionality of TKM and JAX

In [86]:
%load_ext autoreload
%autoreload 2

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [87]:
import os
os.environ['XLA_PYTHON_CLIENT_MEM_FRACTION']='.1'

In [88]:
import jax
# cpu = jax.devices("cpu")[0]
# jax.default_device(cpu)

In [89]:
# from jax.config import config
# config.update("jax_enable_x64", True)

## Load Banana Data

In [90]:
import pandas as pd
from pathlib import Path

In [91]:
data_path = Path("../data/banana.csv")
df_banana = pd.read_csv(data_path)

In [92]:
df_banana.head()

Unnamed: 0,V1,V2,Class
0,1.14,-0.114,1
1,-1.52,-1.15,2
2,-1.05,0.72,1
3,-0.916,0.397,2
4,-1.09,0.437,2


In [93]:
# expecting labels 1 and -1

# transform label 2 to -1
df_banana.loc[df_banana.Class == 2, 'Class'] = -1
df_banana.head()

Unnamed: 0,V1,V2,Class
0,1.14,-0.114,1
1,-1.52,-1.15,-1
2,-1.05,0.72,1
3,-0.916,0.397,-1
4,-1.09,0.437,-1


In [94]:
# standardize

# df_banana[['V1','V2']] = (df_banana[['V1','V2']] -  df_banana[['V1','V2']].min() )/ (df_banana[['V1','V2']].max() - df_banana[['V1','V2']].min())
# df_banana.head()

In [95]:
feature_names = ['V1', 'V2']
x = df_banana[feature_names].to_numpy()

label_name = ['Class']
y = df_banana[label_name].to_numpy()

In [96]:
# rescale

x = (x - x.min(axis=0)) / (x.max(axis=0) - x.min(axis=0))

### Data to JAX array

In [97]:
from jax import numpy as jnp
# x = jnp.array(x)
# y = jnp.array(y)

In [98]:
from jax import random
key = random.PRNGKey(42)

## Tensor Kernel Machine

### Polynomial

In [99]:
from tkm.model import TensorizedKernelMachine as TKM
from jax import jit
from tkm.features import polynomial

model = TKM(M=12,R=6,features=polynomial, key=key)

In [100]:
model = model.fit(x,y) #,M=12,R=6,feature_map=polynomial)

In [101]:
%timeit model.fit(x,y) #,M=12,R=6,feature_map=polynomial)

39.6 ms ± 1.27 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)


#### Predict

In [102]:
from tkm.metrics import accuracy

y_hat = model.predict(x) #, feature_map=polynomial, M=12,R=6)
acc = accuracy(y.squeeze(), y_hat)
acc

Array(0.8228302, dtype=float32)

In [103]:
%timeit model.predict(x) #, M=12,R=6, feature_map=polynomial)

1.18 ms ± 192 µs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)


#### fourier features

In [104]:
from tkm.model import TensorizedKernelMachine
from tkm.features import fourier

model_fourier = TensorizedKernelMachine(M=12,R=6,lengthscale=0.5,features=fourier, key=key)
model_fourier = model_fourier.fit(x,y)
y_hat = model_fourier.predict(x)
acc = accuracy(y.squeeze(), y_hat)
print(acc)

0.8896226


## batched dotkron

In [105]:
from tkm.model import TensorizedKernelMachine 

model_batch = TensorizedKernelMachine(M=12,R=6,features=polynomial,batch_size=100, key=key)
model_batch = model_batch.fit(x,y)

In [106]:
y_hat = model_batch.predict(x)
acc = accuracy(y.squeeze(), y_hat)
print(acc)

0.8916981


## Check Sklearn model

In [107]:
from sklearn.pipeline import Pipeline
from sklearn.preprocessing import StandardScaler

from tkm.sklearn.model import TKRC
from sklearn.model_selection import GridSearchCV
from sklearn.metrics import accuracy_score

tkrc = TKRC(M=12,R=6,features=polynomial, key=key)
tkrc.fit(x,y)
scaler = StandardScaler()
pipeline = Pipeline([('tkrc', model_batch)])
pipe = pipeline.fit(x,y)

In [108]:
y_hat = jnp.sign(pipe.predict(x))
acc = accuracy_score(y.squeeze(), y_hat)
print(acc)

0.8916981132075472
