<a href="https://colab.research.google.com/github/probml/pyprobml/blob/master/pml1/ch5_opt/opt_jax.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Ch 5. Optimization (JAX version)

In this notebook, we explore various  algorithms
for solving optimization problems of the form
$$
x* = \arg \min_{x \in X} f(x)
$$
We focus on the case where $f: R^D \rightarrow R$ is a differentiable function

## TOC
* [Automatic differentiation](#AD)
* [Second-order full-batch optimization](#second)
* [Stochastic gradient descent](#SGD)


In [1]:
import sklearn
import scipy
import scipy.optimize
import matplotlib.pyplot as plt
import warnings
warnings.filterwarnings('ignore')
import itertools
import time
from functools import partial
import os

import numpy as np
from scipy.special import logsumexp
np.set_printoptions(precision=3)


In [2]:
# https://github.com/google/jax
import jax
import jax.numpy as np
import numpy as onp # original numpy
from jax.scipy.special import logsumexp
from jax import grad, hessian, jacfwd, jacrev, jit, vmap
from jax.experimental import optimizers
print("jax version {}".format(jax.__version__))



jax version 0.2.7


# Fit a binary logistic regression model using sklearn

We will evaluate the gradient of the NLL at the MLE.

In [5]:
# Fit the model to a dataset, so we have an "interesting" parameter vector to use.

import sklearn.datasets
from sklearn.model_selection import train_test_split

iris = sklearn.datasets.load_iris()
X = iris["data"]
y = (iris["target"] == 2).astype(onp.int)  # 1 if Iris-Virginica, else 0'
N, D = X.shape # 150, 4

X_train, X_test, y_train, y_test = train_test_split(
        X, y, test_size=0.33, random_state=42)

from sklearn.linear_model import LogisticRegression

# We set C to a large number to turn off regularization.
# We don't fit the bias term to simplify the comparison below.
log_reg = LogisticRegression(solver="lbfgs", C=1e5, fit_intercept=False)
log_reg.fit(X_train, y_train)
w_mle_sklearn = np.ravel(log_reg.coef_)
w = w_mle_sklearn
print(w)

[-4.414 -9.111  6.539 12.686]


# Manual differentiation <a class="anchor" id="AD"></a>

We compute the gradient of the negative log likelihood for binary logistic regression applied to the Iris dataset. 

In [4]:
## Compute gradient of loss "by hand" using numpy

def BCE_with_logits(logits, targets):
  #BCE = -sum_n log(p1)*yn + log(p0)*y0
  #p1 = 1/(1+exp(-a)
  #log(p1) = log(1) - log(1+exp(-a)) = 0 - logsumexp(0, -a)
  N = logits.shape[0]
  logits = logits.reshape(N,1)
  logits_plus = np.hstack([np.zeros((N,1)), logits]) # e^0=1
  logits_minus = np.hstack([np.zeros((N,1)), -logits])
  logp1 = -logsumexp(logits_minus, axis=1)
  logp0 = -logsumexp(logits_plus, axis=1)
  logprobs = logp1 * targets + logp0 * (1-targets)
  return -np.sum(logprobs)/N

def sigmoid(x): return 0.5 * (np.tanh(x / 2.) + 1)

def predict_logit(weights, inputs):
    return np.dot(inputs, weights) # Already vectorized

def predict_prob(weights, inputs):
    return sigmoid(predict_logit(weights, inputs))

def NLL(weights, batch):
    X, y = batch
    logits = predict_logit(weights, X)
    return BCE_with_logits(logits, y)

def NLL_grad(weights, batch):
    X, y = batch
    N = X.shape[0]
    mu = predict_prob(weights, X)
    g = np.sum(np.dot(np.diag(mu - y), X), axis=0)/N
    return g

y_pred = predict_prob(w, X_test)
loss = NLL(w, (X_test, y_test))
grad_np = NLL_grad(w, (X_test, y_test))
print("params {}".format(w))
#print("pred {}".format(y_pred))
print("loss {}".format(loss))
print("grad {}".format(grad_np))

params [-4.414 -9.111  6.539 12.686]
loss 0.11824002861976624
grad [-0.235 -0.122 -0.198 -0.064]


# Automatic differentiation in JAX  <a class="anchor" id="AD-jax"></a>

Below we use JAX to compute the gradient of the NLL for binary logistic regression.
For some examples of using JAX to compute the gradients, Jacobians and Hessians of simple linear and quadratic functions,
see [this notebook](https://github.com/probml/pyprobml/blob/master/notebooks/linear_algebra.ipynb#AD-jax).
More details on JAX's autodiff can be found in the official [autodiff cookbook](https://github.com/google/jax/blob/master/notebooks/autodiff_cookbook.ipynb).


In [6]:
grad_jax = grad(NLL)(w, (X_test, y_test))
print("grad {}".format(grad_jax))
assert np.allclose(grad_np, grad_jax)

grad [-0.235 -0.122 -0.198 -0.064]


# Second-order, full-batch optimization <a class="anchor" id="second"></a>

The "gold standard" of optimization is second-order methods, that leverage Hessian information. Since the Hessian has O(D^2) parameters, such methods do not scale to high-dimensional problems. However, we can sometimes approximate the Hessian using low-rank or diagonal approximations. Below we illustrate the low-rank BFGS method, and the limited-memory version of BFGS, that uses O(D H) space and O(D^2) time per step, where H is the history length.

In general, second-order methods also require exact (rather than noisy) gradients. In the context of ML, this means they are "full batch" methods, since computing the exact gradient requires evaluating the loss on all the datapoints. However, for small data problems, this is feasible (and advisable).

Below we illustrate how to use LBFGS as in [scipy.optimize](https://docs.scipy.org/doc/scipy/reference/generated/scipy.optimize.minimize.html#scipy.optimize.minimize)
                    

In [10]:
import scipy.optimize

# We manually compute gradients, but could use Jax instead
def NLL_grad(weights, batch):
    X, y = batch
    N = X.shape[0]
    mu = predict_prob(weights, X)
    g = np.sum(np.dot(np.diag(mu - y), X), axis=0)/N
    return g

def training_loss(w):
    return NLL(w, (X_train, y_train))

def training_grad(w):
    return NLL_grad(w, (X_train, y_train))

onp.random.seed(42)
w_init = onp.random.randn(D)

options={'disp': None,   'maxfun': 1000, 'maxiter': 1000}
method = 'BFGS'
w_mle_scipy = scipy.optimize.minimize(
    training_loss, w_init, jac=training_grad,
    method=method, options=options).x   

print("parameters from sklearn {}".format(w_mle_sklearn))
print("parameters from scipy-bfgs {}".format(w_mle_scipy))

parameters from sklearn [-4.414 -9.111  6.539 12.686]
parameters from scipy-bfgs [-4.416 -9.117  6.542 12.695]


In [12]:
# Limited memory version requires that we work with 64bit, since implemented in Fortran.

def training_loss_64bit(w):
    l = NLL(w, (X_train, y_train))
    return onp.float64(l)

def training_grad_64bit(w):
    g = NLL_grad(w, (X_train, y_train))
    return onp.asarray(g, dtype=onp.float64)

onp.random.seed(42)
w_init = onp.random.randn(D)                 

memory = 10
options={'disp': None, 'maxcor': memory,  'maxfun': 1000, 'maxiter': 1000}
# The code also handles bound constraints, hence the name
method = 'L-BFGS-B'
w_mle_scipy = scipy.optimize.minimize(training_loss_64bit, w_init, jac=training_grad_64bit, method=method).x 


print("parameters from sklearn {}".format(w_mle_sklearn))
print("parameters from scipy-lbfgs {}".format(w_mle_scipy))

parameters from sklearn [-4.414 -9.111  6.539 12.686]
parameters from scipy-lbfgs [-4.418 -9.112  6.543 12.693]


# Stochastic gradient descent <a class="anchor" id="SGD"></a>

In this section we  illustrate how to implement SGD. We apply it to a simple convex problem, namely MLE for binary logistic regression on the small iris dataset, so we can compare to the exact batch methods we illustrated above.


## Numpy version
We show a minimal implementation of SGD using vanilla numpy. For convenience, we use TFDS to create a stream of mini-batches. We compute gradients by hand, but can use any AD library.


In [17]:
import tensorflow as tf
import tensorflow_datasets as tfds

def make_batcher(batch_size, X, y):
  def get_batches():
    # Convert numpy arrays to tfds
    ds = tf.data.Dataset.from_tensor_slices({"X": X, "y": y})
    ds = ds.batch(batch_size)
    # convert tfds into an iterable of dict of NumPy arrays
    return tfds.as_numpy(ds)
  return get_batches

batcher = make_batcher(20, X_train, y_train)

for epoch in range(2):
  print('epoch {}'.format(epoch))
  for batch in batcher():
    x, y = batch["X"], batch["y"]
    print(x.shape) # batch size * num features = 4
  

epoch 0
(20, 4)
(20, 4)
(20, 4)
(20, 4)
(20, 4)
epoch 1
(20, 4)
(20, 4)
(20, 4)
(20, 4)
(20, 4)


In [18]:
def sgd(params, loss_fn, grad_loss_fn, get_batches_as_dict, max_epochs, lr):
    print_every = max(1, int(0.1*max_epochs))
    for epoch in range(max_epochs):
        epoch_loss = 0.0
        for batch_dict in get_batches_as_dict():
            x, y = batch_dict["X"], batch_dict["y"]
            batch = (x, y)
            batch_grad = grad_loss_fn(params, batch)
            params = params - lr*batch_grad
            batch_loss = loss_fn(params, batch) # Average loss within this batch
            epoch_loss += batch_loss
        if epoch % print_every == 0:
            print('Epoch {}, Loss {}'.format(epoch, epoch_loss))
    return params,


In [19]:
onp.random.seed(42)
w_init = onp.random.randn(D) 

max_epochs = 5
lr = 0.1
batch_size = 10
batcher = make_batcher(batch_size, X_train, y_train)
w_mle_sgd = sgd(w_init, NLL, NLL_grad, batcher, max_epochs, lr)
print(w_mle_sgd)

Epoch 0, Loss 4.692370414733887
Epoch 1, Loss 3.270962715148926
Epoch 2, Loss 3.1224915981292725
Epoch 3, Loss 3.0002613067626953
Epoch 4, Loss 2.896099805831909
(DeviceArray([-0.538, -0.827,  0.613,  1.661], dtype=float32),)


## Jax version <a class="anchor" id="SGD-jax"></a>

JAX has a small optimization library focused on stochastic first-order optimizers. Every optimizer is modeled as an (`init_fun`, `update_fun`, `get_params`) triple of functions. The `init_fun` is used to initialize the optimizer state, which could include things like momentum variables, and the `update_fun` accepts a gradient and an optimizer state to produce a new optimizer state. The `get_params` function extracts the current iterate (i.e. the current parameters) from the optimizer state. The parameters being optimized can be ndarrays or arbitrarily-nested list/tuple/dict structures, so you can store your parameters however youâ€™d like.

Below we show how to reproduce our numpy code using this library.

In [20]:
# Version that uses JAX optimization library

#@jit
def sgd_jax(params, loss_fn, get_batches, max_epochs, opt_init, opt_update, get_params):
    loss_history = []
    opt_state = opt_init(params)
    
    #@jit
    def update(i, opt_state, batch):
        params = get_params(opt_state)
        g = grad(loss_fn)(params, batch)
        return opt_update(i, g, opt_state) 
    
    print_every = max(1, int(0.1*max_epochs))
    total_steps = 0
    for epoch in range(max_epochs):
        epoch_loss = 0.0
        for batch_dict in get_batches():
            X, y = batch_dict["X"], batch_dict["y"]
            batch = (X, y)
            total_steps += 1
            opt_state = update(total_steps, opt_state, batch)
        params = get_params(opt_state)
        train_loss = onp.float(loss_fn(params, batch))
        loss_history.append(train_loss)
        if epoch % print_every == 0:
            print('Epoch {}, train NLL {}'.format(epoch, train_loss))
    return params, loss_history

In [22]:
b=list(batcher())
X, y = b[0]["X"], b[0]["y"]
X.shape
batch = (X, y)
params= w_init
onp.float(NLL(params, batch))
g = grad(NLL)(params, batch)
print(g)

[4.182 2.434 2.209 0.586]


In [23]:
# JAX with constant LR should match our minimal version of SGD


schedule = optimizers.constant(step_size=lr)
opt_init, opt_update, get_params = optimizers.sgd(step_size=schedule)

w_mle_sgd2, history = sgd_jax(w_init, NLL, batcher, max_epochs, 
                              opt_init, opt_update, get_params)
print(w_mle_sgd2)
print(history)

Epoch 0, train NLL 0.36490148305892944
Epoch 1, train NLL 0.34500643610954285
Epoch 2, train NLL 0.32851701974868774
Epoch 3, train NLL 0.3143332004547119
Epoch 4, train NLL 0.3018316924571991
[-0.538 -0.827  0.613  1.661]
[0.36490148305892944, 0.34500643610954285, 0.32851701974868774, 0.3143332004547119, 0.3018316924571991]
