In [3]:
!python3 -m pip install jax
!pip install git+https://github.com/deepmind/dm-haiku

Collecting git+https://github.com/deepmind/dm-haiku
  Cloning https://github.com/deepmind/dm-haiku to /tmp/pip-req-build-mdkbn96m
  Running command git clone -q https://github.com/deepmind/dm-haiku /tmp/pip-req-build-mdkbn96m
Collecting jmp>=0.0.2
  Downloading jmp-0.0.2-py3-none-any.whl (16 kB)
Building wheels for collected packages: dm-haiku
  Building wheel for dm-haiku (setup.py) ... [?25l[?25hdone
  Created wheel for dm-haiku: filename=dm_haiku-0.0.7.dev0-py3-none-any.whl size=576219 sha256=76f4579b3d7f8b4daea30a80355bcbc842e8bc0d010b760ef30a55dd3f3563cd
  Stored in directory: /tmp/pip-ephem-wheel-cache-emyscybh/wheels/06/28/69/ebaac5b2435641427299f29d88d005fb4e2627f4a108f0bdbc
Successfully built dm-haiku
Installing collected packages: jmp, dm-haiku
Successfully installed dm-haiku-0.0.7.dev0 jmp-0.0.2


In [4]:
import jax
import jax.numpy as jnp
import haiku as hk

Load Data

In [5]:
from sklearn import datasets
from sklearn.model_selection import train_test_split

X, Y = datasets.load_boston(return_X_y=True)

X_train, X_test, Y_train, Y_test = train_test_split(X, Y, train_size=0.8, random_state=123)

X_train, X_test, Y_train, Y_test = jnp.array(X_train, dtype=jnp.float32),\
                                   jnp.array(X_test, dtype=jnp.float32),\
                                   jnp.array(Y_train, dtype=jnp.float32),\
                                   jnp.array(Y_test, dtype=jnp.float32),\

samples, features = X_train.shape

X_train.shape, X_test.shape, Y_train.shape, Y_test.shape


    The Boston housing prices dataset has an ethical problem. You can refer to
    the documentation of this function for further details.

    The scikit-learn maintainers therefore strongly discourage the use of this
    dataset unless the purpose of the code is to study and educate about
    ethical issues in data science and machine learning.

    In this special case, you can fetch the dataset from the original
    source::

        import pandas as pd
        import numpy as np


        data_url = "http://lib.stat.cmu.edu/datasets/boston"
        raw_df = pd.read_csv(data_url, sep="\s+", skiprows=22, header=None)
        data = np.hstack([raw_df.values[::2, :], raw_df.values[1::2, :2]])
        target = raw_df.values[1::2, 2]

    Alternative datasets include the California housing dataset (i.e.
    :func:`~sklearn.datasets.fetch_california_housing`) and the Ames housing
    dataset. You can load the datasets as follows::

        from sklearn.datasets import fetch_california_h

((404, 13), (102, 13), (404,), (102,))

Normalize Data

In [6]:
mean = X_train.mean(axis=0)
std = X_train.std(axis=0)

X_train = (X_train - mean) / std
X_test = (X_test - mean) / std

Define Neural Network

In [7]:
def FeedForward(x):
  mlp = hk.nets.MLP(output_sizes=[5,10,15,1])
  return mlp(x)

model = hk.transform(FeedForward)

rng = jax.random.PRNGKey(42)
params = model.init(rng, X_train[:5])
preds = model.apply(params, rng, X_train)
preds[:5]

DeviceArray([[-0.7874111 ],
             [-0.2776872 ],
             [-0.01174069],
             [-0.01407542],
             [-0.38728935]], dtype=float32)

Define Loss Function

In [8]:
def MeanSquaredErrorLoss(weights, input_data, actual):
  preds = model.apply(weights, rng, input_data)
  preds = preds.squeeze()
  return jnp.power(actual - preds, 2).mean()

Train Neural Network

In [9]:
def UpdateWeights(weights,gradients):
    return weights - learning_rate * gradients

In [10]:
from jax import value_and_grad

rng = jax.random.PRNGKey(42) ## Reproducibility ## Initializes model with same weights each time.

params = model.init(rng, X_train[:5])
epochs = 1000
learning_rate = jnp.array(0.001)

for i in range(1, epochs+1):
    loss, param_grads = value_and_grad(MeanSquaredErrorLoss)(params, X_train, Y_train)
    params = jax.tree_map(UpdateWeights, params, param_grads)

    if i%100 == 0:
        print("MSE : {:.2f}".format(loss))

MSE : 17.04
MSE : 12.32
MSE : 10.59
MSE : 9.66
MSE : 9.09
MSE : 8.76
MSE : 8.55
MSE : 8.38
MSE : 8.23
MSE : 8.13


Make Predictions

In [11]:
train_preds = model.apply(params, rng, X_train)
train_preds[:5]

DeviceArray([[48.474415],
             [11.665481],
             [21.027842],
             [26.18423 ],
             [15.279279]], dtype=float32)

In [12]:
test_preds = model.apply(params, rng, X_test)

test_preds[:5]

DeviceArray([[20.905386],
             [25.025269],
             [44.169964],
             [21.290577],
             [29.036097]], dtype=float32)

In [13]:
print("Test  MSE Score : {:.2f}".format(MeanSquaredErrorLoss(params, X_test, Y_test)))
print("Train MSE Score : {:.2f}".format(MeanSquaredErrorLoss(params, X_train, Y_train)))

Test  MSE Score : 18.38
Train MSE Score : 8.13
