# A simple linear model fitting using optax to get our feet wet
* Consider the following linear model:
    * $y = \beta _0 + \beta _1 x_1 + \beta _2 x_2 ... \beta _n x_n$
* Here $\boldsymbol{\beta} = \{\beta _1, \beta _2 ... \beta _3\}$ are the parameters, abbreviated as "params" here onwards and $\boldsymbol{x} = \{x_1,x_2 ... x_n\}$ are the input features, with $y$ being the output feature.

In [2]:
import jax
import jax.numpy as jnp
import optax
import functools
import matplotlib.pyplot as plt

In [3]:
# check if gpu is detected
test =jnp.arange(5)
print(test.device)

cuda:0


## Example function using functools partial and vmap
* functools presets some of the args to the vmap function, here, the "in_axes()" parameter is passed through the partial decorator
* Note that jax.vmap() has other function arguments as well, but they are not taken into consideration here, only the in_axes() param is "frozen"
* The partial decorator is then applied to the function network which takes in params and x

In [4]:
@functools.partial(jax.vmap,in_axes=(None,0))
def network(params,x):
    return jnp.dot(params,x)

def compute_loss(params,x,y):
    y_pred = network(params,x)
    loss = jnp.mean(optax.l2_loss(y_pred,y))
    return loss

# Data
* We generate data using a known linear model with the target params =0.5

In [10]:
key = jax.random.PRNGKey(42)
# optax tutorial uses this as the target params which is a bit confusing
# target_params=0.5
# instead of just defining a scalar we define a vector
target_params = jnp.array([0.5,0.5])
# Generate the data
xs = jax.random.normal(key,(16,2))
test = xs*target_params
ys = jnp.sum(xs*target_params,axis=-1)
test

Array([[-1.0100551e+00,  4.3174960e-03],
       [-1.0414395e+00,  4.7344890e-01],
       [-3.7348837e-02,  1.0950059e-01],
       [ 8.7003446e-01,  7.2180462e-01],
       [ 8.4831733e-01, -4.9240713e-04],
       [ 9.9366140e-01, -6.8150443e-01],
       [-1.5684669e-01, -2.3161867e-01],
       [ 8.7166107e-01, -4.8429218e-01],
       [ 2.6437920e-01, -5.8230702e-03],
       [ 1.6898832e-01,  4.8616579e-01],
       [ 3.2079124e-01,  3.6636621e-01],
       [ 7.6128572e-01,  4.2890865e-01],
       [ 1.3108808e-01,  1.4217469e+00],
       [ 2.8882191e-01,  2.5156605e-01],
       [ 1.9488384e-01, -6.8622433e-02],
       [-8.5813260e-01, -4.6977270e-01]], dtype=float32)

In [11]:
xs

Array([[-2.0201101e+00,  8.6349919e-03],
       [-2.0828791e+00,  9.4689780e-01],
       [-7.4697673e-02,  2.1900117e-01],
       [ 1.7400689e+00,  1.4436092e+00],
       [ 1.6966347e+00, -9.8481425e-04],
       [ 1.9873228e+00, -1.3630089e+00],
       [-3.1369337e-01, -4.6323735e-01],
       [ 1.7433221e+00, -9.6858436e-01],
       [ 5.2875841e-01, -1.1646140e-02],
       [ 3.3797663e-01,  9.7233158e-01],
       [ 6.4158249e-01,  7.3273242e-01],
       [ 1.5225714e+00,  8.5781729e-01],
       [ 2.6217616e-01,  2.8434937e+00],
       [ 5.7764381e-01,  5.0313210e-01],
       [ 3.8976768e-01, -1.3724487e-01],
       [-1.7162652e+00, -9.3954539e-01]], dtype=float32)

# Define optimizer
* define initial learning rate for the ADAM optimizer for example
* intialize params
* construct optimizer object

In [6]:
lr=1e-1
optimizer = optax.adam(lr)

#intialize the params using randoms values, here we got with zeroes
params = jnp.array([0.0,0.0])
opt_state = optimizer.init(params)