## Install
Install `jax` (numerics), `haiku` (deeps) and `optax` (optimisations)

In [None]:
!pip install -U jax jaxlib --upgrade
!pip install -U dm-haiku optax

## Import

In [None]:
import numpy as np
import pandas as pd
import jax
from jax import numpy as jnp
import matplotlib.pyplot as plt
from sklearn.model_selection import train_test_split
import haiku as hk
import optax

test_size = 0.2
random_seed = 0
target_name = 'loss'

## Explore Data

In [None]:
train_data = pd.read_csv('../input/tabular-playground-series-aug-2021/train.csv')
train_data.head()

In [None]:
hists_fig, hists_ax = plt.subplots(10, 10)
for i in range(10):
    for j in range(10):
        ind = 10*i + j + 1
        hists_ax[i,j].hist(train_data.iloc[:,ind], bins=50, color='blue' if train_data.dtypes[ind] == 'float64' else 'red')
        hists_ax[i,j].xaxis.set_visible(False)
        hists_ax[i,j].yaxis.set_visible(False)
hists_fig.tight_layout()

In [None]:
train_data.describe()

## Split and rescale training data

In [None]:
tr_data, val_data = train_test_split(train_data,
                                     test_size=test_size,
                                     stratify=train_data[target_name],
                                     random_state=random_seed)
print(f'Data split with shapes: tr_data = {tr_data.shape}, val_data = {val_data.shape}')

In [None]:
tr_x = tr_data.iloc[:, 1:-1].to_numpy()
tr_y = tr_data.iloc[:, -1:].to_numpy()
val_x = val_data.iloc[:, 1:-1].to_numpy()
val_y = val_data.iloc[:, -1:].to_numpy()

In [None]:
rescale_data = True
if rescale_data:
    tr_mean = tr_x.mean(0)
    tr_std = tr_x.std(0)
    tr_x = (tr_x - tr_mean) / tr_std
    val_x = (val_x - tr_mean) / tr_std

## Define Net Structure

In [None]:
def get_net(x: jnp.ndarray) -> jnp.ndarray:
    num_inputs = tr_data.shape[-1] - 1
    
    nn = hk.Sequential([
        hk.Linear(num_inputs), jax.nn.relu,
        hk.Linear(int(num_inputs * 0.75)), jax.nn.relu,
        hk.Linear(int(num_inputs * 0.25)), jax.nn.relu,
        hk.Linear(1), jax.nn.softplus, # y data is positive
    ])
    return nn(x)

## Train Net with Adam

In [None]:
batch_size = 1000
training_steps = 10001
key = jax.random.PRNGKey(42)

In [None]:
key, shuffle_key = jax.random.split(key)
shuffled_inds = jax.random.permutation(shuffle_key, jnp.arange(tr_x.shape[0]))

def get_batch_inds(i):
    start_ind = (i * batch_size) % tr_x.shape[0]
    return shuffled_inds[jnp.arange(start_ind, start_ind + batch_size) % tr_x.shape[0]]

In [None]:
net = hk.without_apply_rng(hk.transform(get_net))
opt = optax.adam(1e-3)

In [None]:
@jax.jit
def rmse_loss(params: hk.Params, x: jnp.ndarray, y:jnp.ndarray) -> jnp.ndarray:
    predictions = net.apply(params, x)
    return jnp.sqrt(jnp.square(predictions - y).mean())

@jax.jit
def poisson_loss(params: hk.Params, x: jnp.ndarray, y:jnp.ndarray) -> jnp.ndarray:
    predictions = net.apply(params, x)
    return (predictions - y * jnp.log(predictions)).mean()

loss_fn = poisson_loss

In [None]:
@jax.jit
def update(params: hk.Params, opt_state: optax.OptState, batch_x: jnp.ndarray, batch_y: jnp.ndarray):
    v, g = jax.value_and_grad(loss_fn)(params, batch_x, batch_y)
    updates, opt_state = opt.update(g, opt_state)
    new_params = optax.apply_updates(params, updates)
    return new_params, opt_state, v

In [None]:
params = net.init(key, tr_x[get_batch_inds(0)])
opt_state = opt.init(params)

In [None]:
for step in range(10001):
    batch_inds = get_batch_inds(step)
    
    params, opt_state, train_loss = update(params, opt_state, tr_x[batch_inds], tr_y[batch_inds])
    
    if step % 1000 == 0:
        train_rmse = rmse_loss(params, tr_x[batch_inds], tr_y[batch_inds])
        val_rmse = rmse_loss(params, val_x, val_y)
        print(f"[Step {step}] Train loss / Train RMSE / Validate RMSE: "
              f"{train_loss:.3f} / {train_rmse:.3f} / {val_rmse:.3f}")

## Run on test data and submit

In [None]:
test_data = pd.read_csv('../input/tabular-playground-series-aug-2021/test.csv')
te_x = test_data.iloc[:, 1:].to_numpy()

if rescale_data:
    te_x = (te_x - tr_mean) / tr_std

In [None]:
te_predictions = net.apply(params, te_x)
te_predictions

In [None]:
samp_sub = pd.read_csv('../input/tabular-playground-series-aug-2021/sample_submission.csv')
samp_sub[target_name] = te_predictions[:, 0]
samp_sub.to_csv('simple_nn_haiku_submission.csv', index = False)