# Logistic regression with JAX

In [None]:
import numpy as np
import jax
import jax.numpy as jnp
import matplotlib.pyplot as plt

# Learning as gradient-based optimization

In a typical model based analysis, you have a model with parameters $\theta$.

A datasets consisting in conditions $X$ and observed behavior $Y$.

The model provided a `1oss` $l(X, Y, \theta)$ which quantifies the mismatch between the model predictions and the data. 

To simplify the notation, I'll drop the dependence on the data, writing $l(\theta) = l(X, Y, \theta)$


Fitting a model can be cast as finding the parameter $\hat{\theta}$ that minimize the loss, i.e. $\hat{\theta} = \arg\min_{\theta} l(X, Y, \theta)$$ 

Gradient based optimization is set of procedure using the gradient of the loss to guide the search of the optimal parameter $\hat{\theta}$.

The simplest scheme construct a sequence of parameters $\theta_1, \theta_2, ..., \theta_T$ by 'following the gradient'
$$ \theta_{t+1} = \theta_{t} - \eta \frac{d l}{d \theta}$$


##  Logistic Regression

logistic regression 
\begin{align}
\phi(x) &= wx + b\\
y &\sim Bernoulli( \sigma( \phi(x) ))\\
\end{align}

where $\sigma(x) = \frac{1}{1+e^{-x}}$ is the sigmoid function

You have 
* data: $\{X, Y\}$ 
* a linear predictor : $\rho(x) = w x + b$ (parameters are weights $w$ and bias $b$)
* leading to predicted classes $p(y=1|x) = \sigma(\rho(x))$
* a loss (log likelihood): $ l(w) = \sum_n y_n \,\log\,p(x) + (1-y_n)\log(1 - p(x))$

We could compute the gradient manually, but with jax we don't need to



In [None]:
# To be a tiny bit more sophisticated, we'll use a class to hold the parameters

from typing import NamedTuple

class Params(NamedTuple):
    """ 
    Class that acts as a container for the parameters.     
    you can use it like param = Params(weights, bias)
    """
    
    weight: jnp.ndarray
    bias: jnp.ndarray

In [None]:
# We create an initialization class for the weights

def init(rng) -> Params:
    """Returns the initial model params.
    :param rng: random number generator
    :return: an instance of the 'Params' class
    """
    weights_key, bias_key = jax.random.split(rng)
    weight = jax.random.normal(weights_key, ()) * 3
    bias = jax.random.normal(bias_key, ()) * 3
    return Params(weight, bias)


# we need the sigmoid for the pointwise activations
def sigmoid(x: jnp.ndarray):
    """
    The sigmoid function x->1/(x + e^{-x})
    """
    # your code here!
    # <SOLUTION
    return 1./(1.+jnp.exp(-x))
    # SOLUTION>
    


# we compute our prediction of the label being one
def predictions(params: Params, x: jnp.ndarray, y: jnp.ndarray) -> jnp.ndarray:
    """
    The prediction p(y=1|x).
    :param params: an instance of the class Params
    :param x: the scalar input
    :param y: the binary observation
    :return: a probabilty p, 0<p<1 
    """
    # your code here!
    # <SOLUTION
    eps = 1.e-3
    return eps + (1.-2 * eps) * sigmoid(params.weight * x + params.bias)
    # SOLUTION>
    
    
# again we define the loss (as defined above)
def loss(params: Params, x: jnp.ndarray, y: jnp.ndarray) -> jnp.ndarray:
    """
    Computes the least squares error of the model's predictions on x against y.
    :param params: an instance of the class Params
    :param x: the scalar input
    :param y: the binary observation
    :return: the scalar loss 
    """
    # your code here!
    # <SOLUTION
    pred = predictions(params, x, y)
    return jnp.mean(y * jnp.log(pred) + (1.-y) * jnp.log(1.-pred))
    # SOLUTION>


LEARNING_RATE = 0.1

# there is a convenient way to propagate the gradient through to the parameters
# even for the more complex structure we now have.


@jax.jit
def update(params: Params, x: jnp.ndarray, y: jnp.ndarray) -> Params:
    """
    Performs one SGD update step on params using the given data.
    
    θ = θ + η grad l(θ)  (with η the learning rate) 
    
    :param params: an instance of the class Params
    :param x: the scalar input
    :param y: the binary observation
    :return: an instance of the Params class containing the updated parameters 
    """
    grad = jax.grad(loss)(params, x, y)
    new_params = jax.tree_map(
      lambda param, g: param + g * LEARNING_RATE, params, grad)

    return new_params


Let's generate some data!

\begin{align}
x &\sim {\cal N}(0,1)\quad \text{this means x is Gaussian with mean 0 and variance 1}\\
\phi(x) &= wx + b\\
y &\sim Bernoulli( \sigma( \phi(x) ))\\
\end{align}

for some chosen values for $w$ and $b$.

In [None]:
# this is to prepare jax to sample random variable
rng = jax.random.PRNGKey(42)
x_rng, noise_rng = jax.random.split(rng)

# Generate true data from p = w*x + b, y ~ Bernoulli(p)

# we choose some ground true parameters
true_w, true_b = 2, -1
true_params = Params(true_w, true_b)


# we generate 500 data points
num_data = 500 
# your code here
# hint: for x,  you can use `jax.random.normal`
# hint: for y,  you can use `jax.random.bernoulli`
# <SOLUTION
xs = jax.random.normal(x_rng, (num_data, 1))
ps = sigmoid(xs * true_w + true_b)
ys = jax.random.bernoulli(noise_rng, p=ps)
# SOLUTION>



Let's plot the generated data

In [None]:
# plot y as a function of x
# plot the predictions for the true parameters p(x) = \sigma( \phi(x) )

# <SOLUTION
plt.plot(xs, ys, '.')
plt.plot(xs, predictions(true_params, xs, ys), '*')
plt.show()
# SOLUTION>


Now let's train the model

In [None]:
# Train the model using the update function for 200 iterations
# and store the intermediate values of the parameters and loss

# initializing some arrays for storage
num_iterations = 200
weights = np.zeros((num_iterations,))
biases = np.zeros((num_iterations,))
losses = np.zeros((num_iterations,))


params = init(rng)
for it in range(200):
    # <SOLUTION
    # run the update
    params = update(params, xs, ys)
    # store the current params and loss
    weights[it] = params.weight
    biases[it] = params.bias
    losses[it] = loss(params, xs, ys)
    # SOLUTION>

print('done!')


Let's plot the learning curve and evolution of the loss

In [None]:
# <SOLUTION
fig, axes = plt.subplots(2, 1, figsize=(6,4), sharex=True)
axes[0].plot(weights)
axes[0].plot(biases)
axes[0].set_xlabel('iteration')
axes[0].set_ylabel('parameter $w$ & $b$')
axes[1].plot(losses)
axes[1].set_xlabel('iteration')
axes[1].set_ylabel('loss $l(w)$')
plt.show()
# SOLUTION>


## What if there is really a lot of data?


If there is a lot of data (think billions), there is no way to evaluate the loss
on the whole dataset as we have done so far.

We can evaluate the loss on a randomly sampled subset of the data, 
which leads to an unbiased estimator of the total loss.

\begin{align}
\hat{l}(w) &=  l(x_{batch}, y_{batch}, \theta)\\
\mathbb{E}[\hat{l}(w)] &= l(w)
\end{align}

If we replace the gradient of the loss in gradient descent by the gradient of 
an estimator of the loss we get stochastica gradient descent (SGD)



This corresponds to iterate gradient descent steps on batch approximation of the loss.
\begin{align}
\theta_{t+1} &= \theta_{t} - \eta \frac{d \hat{l}}{d \theta}
\end{align}

In [None]:
# First we generate a bit more data

rng = jax.random.PRNGKey(0)

num_data = 5000
xs = jax.random.normal(x_rng, (num_data, 1))
ps = sigmoid(xs * true_w + true_b)
ys = jax.random.bernoulli(noise_rng, p=ps)

# we set the size of the batch
batch_size = 20

# and compute the number of batches
num_complete_batches, leftover = divmod(num_data, batch_size)
num_batches = num_complete_batches + bool(leftover)
print('num_batches: ',  num_batches)

# we construct a data stream to easily get access to the batches consecutively
def data_stream():
    while True:
        perm = np.random.permutation(num_data)
        for i in range(num_batches):
            batch_idx = perm[i * batch_size:(i + 1) * batch_size]
            yield xs[batch_idx], ys[batch_idx]
batches = data_stream()


We can now train the model on the batches

In [None]:
 
params = init(rng)
weights = np.zeros((num_batches,))
biases = np.zeros((num_batches,))
losses = np.zeros((num_batches,))

# now code the iterations (for loop) to run the updates on the batches
# you can get the next batch via the call 'next(batches)' 
# store the loss and parameters for later plotting

# <SOLUTION
for j in range(num_batches):
    x_batch, y_batch = next(batches)
    params = update(params, x_batch, y_batch)
    losses[j] = loss(params, x_batch, y_batch)
    biases[j] = params.bias
    weights[j] = params.weight
# SOLUTION>
    
# plot the results (evolution of parameters and loss)
# <SOLUTION
fig, axes = plt.subplots(2, 1, figsize=(6,4), sharex=True)
axes[0].plot(weights)
axes[0].plot(biases)
axes[0].set_xlabel('iteration')
axes[0].set_ylabel('parameter $w$ & $b$')
axes[1].plot(losses)
axes[1].set_xlabel('iteration')
axes[1].set_ylabel('loss $l(w)$')
plt.show()
# SOLUTION>


# BONUS Question
What does the loss look like? (good exercice to use the vectorization function vmap of jax)

In [None]:
# <SOLUTION

# plotting the loss
from jax import vmap

# we have 2 parameters so we need to make a grid 
n_grid = 100
weight_grid = np.linspace(-5,10,n_grid).reshape(1, -1)
bias_grid = np.linspace(-10,10,n_grid).reshape(1, -1)
w, b = np.meshgrid(weight_grid, bias_grid)
w_flat = w.reshape(-1,1)
b_flat = b.reshape(-1,1)


# conveniently, we can use the vectorization function of jax to evaluate the loss on the whole grid at once
loss_tmp = lambda p: loss(p, xs, ys)
loss_flat = vmap(loss_tmp)(Params(w_flat, b_flat))
loss_grid = loss_flat.reshape((n_grid, n_grid))


# let's plot the resulting loss
plt.contourf(
    loss_grid, 
    extent=[weight_grid.min(), weight_grid.max(), bias_grid.min(), bias_grid.max()],
    origin='lower', levels=20
)
plt.colorbar()
plt.plot(true_w, true_b, 'x')

# SOLUTION>


Can you plot the parameter trajectory during optimization on top of the loss?

In [None]:
# <SOLUTION

# Running the update
params = init(rng)
for _ in range(200):
    params = update(params, xs, ys)
    plt.plot(params.weight, params.bias, 'k.')
print('done!')
    
    
# Plotting results
plt.contourf(
    loss_grid, 
    extent=[weight_grid.min(), weight_grid.max(), bias_grid.min(), bias_grid.max()],
    origin='lower', levels=20
)
plt.colorbar()
plt.plot(true_w, true_b, 'x')
plt.show()

# SOLUTION>
