# Linear Regression - with and without JAX

### Aims of the notebook
* code a model and perform gradient based learning via gradient descent
* do it both with numpy and jax

In [None]:
import numpy as np
import jax
import jax.numpy as jnp
import matplotlib.pyplot as plt

# Learning as gradient-based optimization

In a typical model based analysis, you have a model with parameters $\theta$.

A datasets consisting in conditions $X$ and observed behavior $Y$.

The model provided a `1oss` $l(X, Y, \theta)$ which quantifies the mismatch between the model predictions and the data. 

To simplify the notation, we drop the dependence on the data, writing $l(\theta) = l(X, Y, \theta)$


Fitting a model can be cast as finding the parameter $\hat{\theta}$ that minimize the loss, i.e. $\hat{\theta} = \arg\min_{\theta} l(X, Y, \theta)$$ 

Gradient based optimization is set of procedure using the gradient of the loss to guide the search of the optimal parameter $\hat{\theta}$.

The simplest scheme construct a sequence of parameters $\theta_1, \theta_2, ..., \theta_T$ by 'following the gradient'
$$ \theta_{t+1} = \theta_{t} - \eta \frac{d l}{d \theta}$$


### Example of Linear Regression

Introductory example: linear regression $y = 2 x + \epsilon$, with $\epsilon \sim {\cal N}(0, 1)$

You have 
* data: $\{X, Y\}$ 
* a parametric model: $\rho(x) = w x$ (here $\theta=w$)
* a loss: $ l(w) = \frac{1}{n} \sum_n (y_n - w x_n)^2$

In this simple example, the gradient of the loss wrt $w$ is available :
$$\frac{d l}{d w} = -\frac{2}{n}  \sum_n x_n (y_n - w x_n)$$

This gradient tells gives you how the loss $l(w)$ changes as you vary $w$.

A typical gradient based optimization scheme 'follows the gradient'
$$ w_{t+1} = w_{t} - \eta \frac{d l}{d w}$$


# Linear Regression from scratch

In [None]:
# Example: Linear regression with manual gradients

# Here is toy dataset (X and Y) defined by Y = 2 X + noise
N = 100
X_np = np.random.rand(N, 1) 

# for linear regression
W_true = 2.
P_np = W_true * X_np 
Y_np = W_true * X_np + np.random.randn(N, 1) * .5

# plot the data set (Y as a function of X)
# your code here ...

# plot the optimal prediction (2 X as a function of X) for a grid of X values
# your code here ...

Let's now build the model and the loss

In [None]:
# Let's build the model

# linear model
def prediction(W, X):
    """
    Prediction rho(X,W) = W X
    :param W: weights, [batch_shape] + (1,)
    :param X: inputs, with shape (N, 1)
    :return: the predictor, with shape [batch_shape] + (N,)
    """
    # code here the linear predictor p =  X * W
    # your code here ...

# loss
def loss(W, X, Y):
    """
    the loss for linear regression l(W) = sum_n ||rho(X_n,W) - Y_n||^2
    :param W: weights, [batch_shape] + (1,)
    :param Y: outputs, with shape (N, 1)
    :return: the loss, with shape [batch_shape] + (,)
    """
    # code here the loss function  l = sum_n (X_n*W - Y_n)^2
    # your code here ...


Let's plot the loss as a function of parameters

In [None]:
# create a grid of parameter values on the interval [-5, 5]
# your code here ...

# compute the loss for this grid of parameter values
# your code here ...

# plot the loss l(w) as a function of w
# your code here ...

# can you eye ball the parameter value minimizing the loss?

__Question__ : Can you eye ball the parameter value minimizing the loss?

Now let's code up the gradient of the loss

In [None]:
# loss gradient (manually coded)
def grad_loss(W, X, Y):
    """
    gradient of the loss for linear regression, with respect to W
    :param W: weights, with shape [batch_shape] + (1,)
    :param X: inputs, with shape (N, 1)
    :param Y: outputs, with shape (N, 1)
    :return: gradient of the loss, with shape [batch_shape] + (N, 1)
    """
    # code here grad = - 2 sum_n X_n(Y_n - W X_n)

# parameter update function
def update(W, X, Y, learning_rate):
    """
    A step of gradient descent
    :param W: weights, with shape [batch_shape] + (1,)
    :param X: inputs, with shape (N, 1)
    :param Y: outputs, with shape (N, 1)
    :param learning_rate: the scalar learning rate
    :return: the updated weights following the gradient descent update rule, with shape [batch_shape] + (1,)
    """
    # your code here ...

# gradient descent

# we choose a learning rate
learning_rate = 1e-1 # the speed at which you follow the gradient
# we choose an initial guess
W0 = -1. # initial guess 


# Code here the script to follow the gradient a 100 time
# (storing the intermediate parameter and loss values)


# plot the parameter trajectory (values as a function of the iteration)
# plot the loss as a function of the iteration


__Remark__: the choice of the optimization procedure is important
    
__Question__: What happens if you change the `learning_rate`? Try `0.1`, or `0.0001`


## Linear Regression using JAX

JAX has been created to be close to numpy.
As a result, the code is going to be very similar.

you can almost always replace `np.some_operation` 
by it's jax equivalent `jnp.some_operation`

The main advantage of JAX for our purpose is the fact it can compute the gradient for you
via the `jax.grad` operation

In [None]:
def loss_jax(W, X, Y):
    """
    the loss for linear regression l(W) = sum_n ||rho(X_n,W) - Y_n||^2
    :param W: weights, scalar
    :param X: inputs, with shape (N, 1)
    :param Y: outputs, with shape (N, 1)
    :return: the loss, scalar
    """
    # code here the loss function  l = sum_n (X_n*W - Y_n)^2

def grad_loss_jax(W, X, Y):
    """
    gradient of the loss for linear regression, with respect to W
    :param W: weights, scalar
    :param X: inputs, with shape (N, 1)
    :param Y: outputs, with shape (N, 1)
    :return: gradient of the loss, scalar
    """
    # no need to do the gradient manually! use jax.grad

@jax.jit
def update_jax(W, X, Y, learning_rate):
    """
    A step of gradient descent
    :param W: weights, scalar
    :param X: inputs, with shape (N, 1)
    :param Y: outputs, with shape (N, 1)
    :param learning_rate: the scalar learning rate
    :return: the updated weights following the gradient descent update rule, scalar
    """    
    
# Note: all these functions work for scalar weights, to evaluate the loss on a grid of weights,
# you may use the 'vmap' function

Now let's right the main training loop 

In [None]:
# code up the main training loop, store the intermediate parameter and loss values to plot their evolution in time


# plot the parameter trajectory (values as a function of the iteration)
# plot the loss as a function of the iteration
