# Airlines dataset

In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
!export XLA_PYTHON_CLIENT_MEM_FRACTION=.1

In [3]:
# from jax.config import config
# config.update("jax_enable_x64", True)

## Load Airlines data

In [4]:
import pandas as pd
from pathlib import Path

In [5]:
data_path = Path("../data/airline.csv")
df = pd.read_csv(data_path)

In [6]:
df.head()

Unnamed: 0,Month,DayofMonth,DayOfWeek,DepTime,ArrTime,AirTime,Distance,plane_age,ArrDelay
0,1,3,4,1203,1331,116,810,10,-14
1,1,3,4,454,598,314,2283,10,-22
2,1,3,4,652,963,175,1521,10,-17
3,1,3,4,1013,1172,79,577,10,2
4,1,4,5,818,880,48,239,10,10


In [7]:
df['ArrDelay'] = (df['ArrDelay'] - df['ArrDelay'].mean()) / df['ArrDelay'].std()

In [8]:
feature_names = ['Month', 'DayofMonth', 'DayOfWeek', 'DepTime', 'ArrTime', 'AirTime', 'Distance', 'plane_age']
x = df[feature_names].to_numpy()

label_name = ['ArrDelay']
y = df[label_name].to_numpy()

In [9]:
x = (x - x.min(axis=0)) / (x.max(axis=0) - x.min(axis=0))

### Data to JAX array

In [10]:
from jax import numpy as jnp
x = jnp.array(x)
y = jnp.array(y)



In [11]:
from jax import random
key = random.PRNGKey(42)

## Tensor Kernel Machine

### Polynomial

In [32]:
from tkm.model import fit
from jax import jit
from tkm.features import polynomial

fit_compiled = jit(fit)

W,loss = fit(key,x,y,M=4,R=2,feature_map=polynomial,numberSweeps=5)

# %timeit W = fit(key,x,y, feature_map=polynomial)
# %timeit W = fit_compiled(key, x, y, feature_map=polynomial)

In [33]:
loss

[5887081.5,
 5868684.0,
 5852947.5,
 5786518.5,
 5754037.0,
 5749995.0,
 5741692.5,
 5725896.5,
 5718140.0,
 5708719.0,
 5704233.0,
 5619731.0,
 5542762.5,
 5521153.0,
 5634999.0,
 5672474.5,
 5545400.0,
 5532472.0,
 5531686.0,
 5464710.0,
 5451585.0,
 5438640.5,
 5431995.0,
 5403375.0,
 5401704.5,
 5400475.5,
 5399960.0,
 5397021.5,
 5389956.0,
 5387088.5,
 5378288.0,
 5377906.0,
 5375984.5,
 5375589.0,
 5376113.0,
 5365985.5,
 5356519.5,
 5352387.0,
 5346017.0,
 5355945.0,
 5343421.5,
 5343146.0,
 5345951.0,
 5335261.0,
 5326116.0,
 5321030.5,
 5316943.0,
 5322492.0,
 5313683.5,
 5312603.0,
 5313831.0,
 5306247.0,
 5299004.0,
 5291290.0,
 5291153.5,
 5292189.5,
 5286456.0,
 5285015.0,
 5285916.0,
 5283170.0,
 5281379.5,
 5272585.0,
 5275935.5,
 5273724.0,
 5270499.5,
 5268933.5,
 5269916.0,
 5269837.5,
 5270604.5,
 5262048.0,
 5265740.5,
 5264445.0,
 5260804.5,
 5259120.5,
 5260269.5,
 5261099.0,
 5262292.5,
 5254369.5,
 5257309.0,
 5256725.5]

#### Predict

In [34]:
from tkm.model import predict
from tkm.metrics import rmse
predict_compiled = jit(predict)



y_hat = predict(x, W, feature_map=polynomial,M=4,R=2)
err = rmse(y.squeeze(), y_hat)
err

# %timeit predict(x, W, feature_map=polynomial)
# %timeit predict_compiled(x,W, feature_map=polynomial)

DeviceArray(2291.404, dtype=float32)

In [27]:
# from jax import jit
# from tkm.model import predict_vmap
# predict_vmap_compiled = jit(predict_vmap)

# %timeit y_hat = predict_vmap(x, W, feature_map=polynomial)
# %timeit predict_vmap_compiled(x,W)

#### fourier features

In [20]:
100/x.shape[0]

1.6865075851521897e-05

In [23]:
from tkm.model import fit
from tkm.features import fourier
from functools import partial
from jax import jit

fit_fourier = partial(fit, feature_map=fourier, M=40, R=5, lengthscale=x.std(axis=0).mean(), l=100/x.shape[0])
fit_compiled = jit(fit_fourier)

In [24]:
W = fit_compiled(key,x,y)

In [1]:
%timeit fit_compiled(key,x,y).block_until_ready()

NameError: name 'fit_compiled' is not defined

In [None]:
from tkm.model import predict_vmap
predict_vmap_poly = partial(predict_vmap, feature_map=fourier, M=40, R=5, lengthscale=x.std(axis=0).mean())
predict_vmap_compiled = jit(predict_vmap_poly)

In [None]:
y_hat = predict_vmap_compiled(x,W).block_until_ready()
err = rmse(y.squeeze(), y_hat)
print(err)

In [None]:
%timeit predict_vmap_compiled(x,W).block_until_ready()