# Logistic regression with JAX

In [None]:
import numpy as np
import jax
from jax import vmap
import jax.numpy as jnp
import matplotlib.pyplot as plt

# Learning as gradient-based optimization

##  Logistic Regression

In logistic regression, the assumption is that data is generated as follows 
\begin{align}
\phi(x) &= wx + b\\
y &\sim Bernoulli( \sigma( \phi(x) ))\\
\end{align}

where $\sigma(x) = \frac{1}{1+e^{-x}}$ is the sigmoid function.

a sample from $Bernoulli(p)$ is a biased coin flip with probability of Heads (=1) equal to $p$

You have 
* data: $\{X, Y\}$ 
* a linear predictor : $\rho(x) = w x + b$ (parameters are weights $w$ and bias $b$)
* leading to predicted classes $p(y=1|x) = \sigma(\rho(x))$
* a loss (log likelihood): $ l(w) = \sum_n y_n \,\log\,p(x) + (1-y_n)\log(1 - p(x))$

We could compute the gradient manually, but with jax we don't need to



In [None]:
# To be a tiny bit more sophisticated, we'll use a class to hold the parameters

from typing import NamedTuple

class Params(NamedTuple):
    """ 
    Class that acts as a container for the parameters.     
    you can use it like param = Params(weights, bias)
    """
    weight: jnp.ndarray
    bias: jnp.ndarray

In [None]:
# We create an initialization class for the weights

def init(rng) -> Params:
    """Returns the initial model params.
    :param rng: random number generator
    :return: an instance of the 'Params' class
    """
    weights_key, bias_key = jax.random.split(rng)
    random_weight = jax.random.normal(weights_key, ()) * 3
    random_bias = jax.random.normal(bias_key, ()) * 3
    return Params(random_weight, random_bias)


# we need the sigmoid for the pointwise activations
def sigmoid(x: jnp.ndarray):
    """
    The sigmoid function x->1/(x + e^{-x})
    """
    # your code here ...


# we compute our prediction of the label being one
def predictions(params: Params, x: jnp.ndarray) -> jnp.ndarray:
    """
    The prediction p(y=1|x).
    :param params: an instance of the class Params
    :param x: the scalar input
    :return: a probabilty p, 0<p<1 
    """
    # your code here ...
    
    
# again we define the loss (as defined above)
def loss(params: Params, x: jnp.ndarray, y: jnp.ndarray) -> jnp.ndarray:
    """
    Computes the least squares error of the model's predictions on x against y.
    :param params: an instance of the class Params
    :param x: the scalar input
    :param y: the binary observation
    :return: the scalar loss 
    """
    # your code here ...


LEARNING_RATE = 0.1

# there is a convenient way to propagate the gradient through to the parameters
# even for the more complex structure we now have.


@jax.jit
def update(params: Params, x: jnp.ndarray, y: jnp.ndarray) -> Params:
    """
    Performs one SGD update step on params using the given data.
    
    θ = θ - η grad l(θ)  (with η the learning rate) 
    
    :param params: an instance of the class Params
    :param x: the scalar input
    :param y: the binary observation
    :return: an instance of the Params class containing the updated parameters 
    """
    # code up the gradient and the update
    # hint: with the Params class, jax.grad will produce a gradient with the same structure
    # you can use the function jax.tree_map to apply the gradient to the parameters.
    
    # your code here ...



Let's generate some data!

\begin{align}
x &\sim {\cal N}(0,1)\quad \text{this means x is Gaussian with mean 0 and variance 1}\\
\phi(x) &= wx + b\\
y &\sim Bernoulli( \sigma( \phi(x) ))\\
\end{align}

for some chosen values for $w$ and $b$.

In [None]:
# this is to prepare jax to sample random variable
# see https://colab.research.google.com/github/google/jax/blob/main/docs/jax-101/05-random-numbers.ipynb
# for more details
rng = jax.random.PRNGKey(42) 
x_rng, noise_rng = jax.random.split(rng)

# Generate true data from p = w*x + b, y ~ Bernoulli(p)

# we choose some ground true parameters w=2, b=-1 
# create an instance of the Params class for these ground true parameters
# your code here ...



# we generate 500 data points
num_data = 500 
# your code here ...
# hint: for x,  you can use `jax.random.normal`
# hint: for y,  you can use `jax.random.bernoulli`


Let's plot the generated data

In [None]:
# plot y as a function of x
# plot the predictions for the true parameters p(x) = \sigma( \phi(x) )

# your code here ...

Now let's train the model

In [None]:
# Train the model using the update function for 200 iterations
# and store the intermediate values of the parameters and loss

# initializing some arrays for storage
num_iterations = 400
weights = np.zeros((num_iterations,))
biases = np.zeros((num_iterations,))
losses = np.zeros((num_iterations,))


params = init(rng)
for it in range(num_iterations):
    # your code here...


Let's plot the learning curve and evolution of the loss

In [None]:
# your code here ...

## What if there is really a lot of data?


If there is a lot of data (think billions), there is no way to evaluate the loss
on the whole dataset as we have done so far.

We can evaluate the loss on a randomly sampled subset of the data, 
which leads to an unbiased estimator of the total loss.

\begin{align}
\hat{l}(w) &=  l(x_{batch}, y_{batch}, \theta)\\
\mathbb{E}[\hat{l}(w)] &= l(w)
\end{align}

If we replace the gradient of the loss in gradient descent by the gradient of 
an estimator of the loss we get stochastica gradient descent (SGD)



This corresponds to iterate gradient descent steps on batch approximation of the loss.
\begin{align}
\theta_{t+1} &= \theta_{t} - \eta \frac{d \hat{l}}{d \theta}
\end{align}

In [None]:
# First generate a bit more data (n=5000)
# your code here ...

# we set the size of the batch
batch_size = 20

# The batching and streaming of data is done for you below

# we compute the number of batches to iterate over
num_complete_batches, leftover = divmod(num_data, batch_size)
num_batches = num_complete_batches + bool(leftover)
print('num_batches: ',  num_batches)

# we construct a data stream to easily get access to the batches consecutively
def data_stream():
    while True:
        perm = np.random.permutation(num_data)
        for i in range(num_batches):
            batch_idx = perm[i * batch_size:(i + 1) * batch_size]
            yield xs[batch_idx], ys[batch_idx]
batches = data_stream()


We can now train the model on the batches

In [None]:
 
params = init(rng)
weights = np.zeros((num_batches,))
biases = np.zeros((num_batches,))
losses = np.zeros((num_batches,))

# now code the iterations (for loop) to run the updates on the batches
# you can get the next batch via the call 'next(batches)' 
# store the loss and parameters for later plotting

# your code here ... 

# plot the results (evolution of parameters and loss)

# your code here ... 


# BONUS Question
What does the loss look like? (good exercice to use the vectorization function vmap of jax)

In [None]:
# evaluate the loss on a 2d grid of parameters and plot the loss landscape.

# your code here ...


Can you plot the parameter trajectory during optimization on top of the loss?

In [None]:
# your code here ...