In [None]:
!pip install --upgrade coax

In [None]:
pip install jaxlib jax --upgrade

In [1]:
import jax
import jax.numpy as jnp
from sklearn.datasets import make_regression
from sklearn.model_selection import train_test_split


# create our dataset
X, y = make_regression(n_features=3)
X, X_test, y, y_test = train_test_split(X, y)


# model weights
params = {
    'w': jnp.zeros(X.shape[1:]),
    'b': 0.
}


def forward(params, X):
    return jnp.dot(X, params['w']) + params['b']


def loss_fn(params, X, y):
    err = forward(params, X) - y
    return jnp.mean(jnp.square(err))  # mse


grad_fn = jax.grad(loss_fn)


def update(params, grads):
    return jax.tree_map(lambda p, g: p - 0.05 * g, params, grads)


# the main training loop
for _ in range(50):
    loss = loss_fn(params, X_test, y_test)
    print(loss)

    grads = grad_fn(params, X, y)
    params = update(params, grads)


3929.5195


  return jax.tree_map(lambda p, g: p - 0.05 * g, params, grads)


3270.893
2729.2104
2282.706
1913.8132
1608.3345
1354.7747
1143.8125
967.8746
820.79816
697.55756
594.0473
506.90674
433.37927
371.19897
318.49924
273.73923
235.6438
203.15565
175.39609
151.63278
131.25436
113.74911
98.68779
85.709335
74.50959
64.83157
56.457832
49.20384
42.912804
37.45106
32.70467
28.576141
24.981916
21.850405
19.120012
16.737732
14.657859
12.8409395
11.252851
9.864106
8.649089
7.585652
6.654494
5.8388834
5.124228
4.4978404
3.9486606
3.4670675
3.0446293
