# A simple linear model fitting using optax to get our feet wet
* Consider the following linear model:
    * $y = \beta _0 + \beta _1 x_1 + \beta _2 x_2 ... \beta _n x_n$
* Here $\boldsymbol{\beta} = \{\beta _1, \beta _2 ... \beta _3\}$ are the parameters, abbreviated as "params" here onwards and $\boldsymbol{x} = \{x_1,x_2 ... x_n\}$ are the input features, with $y$ being the output feature.

In [26]:
import jax
import jax.numpy as jnp
import optax
import functools
import matplotlib.pyplot as plt

In [25]:
# check if gpu is detected
test =jnp.arange(5)
print(test.device)

cuda:0


## Example function using functools partial and vmap
* functools presets some of the args to the vmap function, here, the "in_axes()" parameter is passed through the partial decorator
* Note that jax.vmap() has other function arguments as well, but they are not taken into consideration here, only the in_axes() param is "frozen"
* The partial decorator is then applied to the function network which takes in params and x

In [17]:
@functools.partial(jax.vmap,in_axes=(None,0))
def network(params,x):
    return jnp.dot(params,x)

def compute_loss(params,x,y):
    y_pred = network(params,x)
    loss = jnp.mean(optax.l2_loss(y_pred,y))
    return loss

# Data
* We generate data using a known linear model with the target params =0.5

In [38]:
key = jax.random.PRNGKey(42)
# optax tutorial uses this as the target params which is a bit confusing
# target_params=0.5
# instead of just defining a scalar we define a vector
target_params = jnp.array([0.5,0.5])
# Generate the data
xs = jax.random.normal(key,(16,2))
test = xs*target_params
ys = jnp.sum(xs*target_params,axis=-1)
test.shape

(16, 2)

# Define optimizer
* define initial learning rate for the ADAM optimizer for example
* intialize params
* construct optimizer object

In [33]:
lr=1e-1
optimizer = optax.adam(lr)

#intialize the params using randoms values, here we got with zeroes
params = jnp.array([0.0,0.0])
opt_state = optimizer.init(params)

In [34]:
xs.shape

(16, 2)

In [35]:
ys.shape

(16,)

In [36]:
result

Array([0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],      dtype=float32)