# Testing functionality of TKM and JAX

In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
!export XLA_PYTHON_CLIENT_MEM_FRACTION=.1

In [3]:
import jax
cpu = jax.devices("cpu")[0]
jax.default_device(cpu)

<contextlib._GeneratorContextManager at 0x7f30c41f73d0>

In [2]:
# from jax.config import config
# config.update("jax_enable_x64", True)

## Load Banana Data

In [3]:
import pandas as pd
from pathlib import Path

In [4]:
data_path = Path("../data/banana.csv")
df_banana = pd.read_csv(data_path)

In [5]:
df_banana.head()

Unnamed: 0,V1,V2,Class
0,1.14,-0.114,1
1,-1.52,-1.15,2
2,-1.05,0.72,1
3,-0.916,0.397,2
4,-1.09,0.437,2


In [6]:
# expecting labels 1 and -1

# transform label 2 to -1
df_banana.loc[df_banana.Class == 2, 'Class'] = -1
df_banana.head()

Unnamed: 0,V1,V2,Class
0,1.14,-0.114,1
1,-1.52,-1.15,-1
2,-1.05,0.72,1
3,-0.916,0.397,-1
4,-1.09,0.437,-1


In [7]:
# standardize

# df_banana[['V1','V2']] = (df_banana[['V1','V2']] -  df_banana[['V1','V2']].min() )/ (df_banana[['V1','V2']].max() - df_banana[['V1','V2']].min())
# df_banana.head()

In [8]:
feature_names = ['V1', 'V2']
x = df_banana[feature_names].to_numpy()

label_name = ['Class']
y = df_banana[label_name].to_numpy()

In [9]:
# rescale

x = (x - x.min(axis=0)) / (x.max(axis=0) - x.min(axis=0))

### Data to JAX array

In [10]:
from jax import numpy as jnp
x = jnp.array(x)
y = jnp.array(y)



In [11]:
from jax import random
key = random.PRNGKey(42)

## Tensor Kernel Machine

### Polynomial

In [24]:
from tkm.model import TensorizedKernelMachine as TKM
from jax import jit
from tkm.features import polynomial

model = TKM(M=12,R=6,feature_map=polynomial)

In [25]:
W = model.fit(key,x,y) #,M=12,R=6,feature_map=polynomial)

In [14]:
%timeit model.fit(key,x,y) #,M=12,R=6,feature_map=polynomial)

19.3 ms ± 574 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)


#### Predict

In [15]:
from tkm.metrics import accuracy

y_hat = model.predict(x, W) #, feature_map=polynomial, M=12,R=6)
acc = accuracy(y.squeeze(), y_hat)
acc

DeviceArray(0.8918868, dtype=float32)

In [19]:
%timeit model.predict(x, W) #, M=12,R=6, feature_map=polynomial)

1.71 ms ± 16 µs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)


In [19]:
predict_polynomial = partial(predict, feature_map=polynomial, M=12, R=6)
predict_compiled = jit(predict_polynomial)

%timeit predict_compiled(x,W).block_until_ready()

54.7 µs ± 3.76 µs per loop (mean ± std. dev. of 7 runs, 10,000 loops each)


In [20]:
# from jax import jit
from tkm.model import predict_vmap

y_hat = predict_vmap(x, W, feature_map=polynomial, M=12, R=6)
acc = accuracy(y.squeeze(), y_hat)
print(acc)

%timeit predict_vmap(x, W, feature_map=polynomial, M=12, R=6)

0.8922641
2.26 ms ± 68.5 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)


In [21]:
predict_vmap_poly = partial(predict_vmap, feature_map=polynomial, M=12, R=6)
predict_vmap_compiled = jit(predict_vmap_poly)

%timeit predict_vmap_compiled(x,W).block_until_ready()

53.4 µs ± 554 ns per loop (mean ± std. dev. of 7 runs, 10,000 loops each)


#### fourier features

In [30]:
from tkm.model import TensorizedKernelMachine
from tkm.features import fourier

model_fourier = TensorizedKernelMachine(M=12,R=6,lengthscale=0.5,feature_map=fourier)
W = model.fit(key,x,y)
y_hat = model.predict(x, W)
acc = accuracy(y.squeeze(), y_hat)
print(acc)

0.8918868


In [23]:
jax.devices()

[GpuDevice(id=0, process_index=0)]

GpuDevice(id=0, process_index=0)

## batched dotkron

In [23]:
from tkm.model import BatchedTKM 

model_batch = BatchedTKM(M=12,R=6,feature_map=polynomial,batch_size=10)
W_batch = model_batch.fit(key,x,y) #,M=12,R=6,feature_map=polynomial)

In [32]:
y_hat = model_batch.predict(x,W_batch)