<a href="https://colab.research.google.com/github/USCbiostats/PM520/blob/main/Lab_4_Optimization_PtII.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Ain't no mountain high enough, or: Optimization Pt II

In [None]:
!pip install lineax

Collecting lineax
  Downloading lineax-0.0.7-py3-none-any.whl.metadata (17 kB)
Collecting equinox>=0.11.5 (from lineax)
  Downloading equinox-0.11.11-py3-none-any.whl.metadata (18 kB)
Collecting jaxtyping>=0.2.20 (from lineax)
  Downloading jaxtyping-0.2.37-py3-none-any.whl.metadata (6.6 kB)
Collecting jax>=0.4.26 (from lineax)
  Downloading jax-0.5.0-py3-none-any.whl.metadata (22 kB)
Collecting jaxlib<=0.5.0,>=0.5.0 (from jax>=0.4.26->lineax)
  Downloading jaxlib-0.5.0-cp311-cp311-manylinux2014_x86_64.whl.metadata (978 bytes)
Collecting wadler-lindig>=0.1.3 (from jaxtyping>=0.2.20->lineax)
  Downloading wadler_lindig-0.1.3-py3-none-any.whl.metadata (17 kB)
Downloading lineax-0.0.7-py3-none-any.whl (67 kB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m67.3/67.3 kB[0m [31m2.7 MB/s[0m eta [36m0:00:00[0m
[?25hDownloading equinox-0.11.11-py3-none-any.whl (179 kB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m179.2/179.2 kB[0m [31m5.8 MB/s[0m eta [3

## Gradient Descent Redux
Recall under [gradient descent](https://en.wikipedia.org/wiki/Gradient_descent) we can iteratively optimize a function $f(\beta)$ by taking steps in the steepest direction,
$$ \hat{\beta} = \beta_t - \rho_t \nabla f(\beta_t).$$

A helpful way to recast gradient descent is that we seek to perform a series of _local_ optimizations,

$$\hat{\beta} = \min_\beta \nabla f(\beta_t)^T \beta + \frac{1}{2\rho_t}\|\beta - \beta_t\|_2^2.$$

To see how these are equivalent let's solve the local problem. but using inner product notation,
$$m(\beta) = \nabla f(\beta_t)^T \beta + \frac{1}{2\rho_t} (\beta - \beta_t)^T(\beta - \beta_t).$$
Now, using calculus again,
$$\begin{align*}
\nabla m(\beta) &= \nabla [ \nabla f(\beta_t)^T \beta + \frac{1}{2\rho_t} (\beta - \beta_t)^T(\beta - \beta_t)] \\
&= \nabla [\nabla f(\beta_t)^T \beta] + \frac{1}{2\rho_t} \nabla [(\beta - \beta_t)^T(\beta - \beta_t)] \\
&= \nabla f(\beta_t) + \frac{1}{\rho_t}(\beta - \beta_t) \Rightarrow \\
\hat{\beta} &= \beta_t - \rho_t \nabla f(\beta_t).
\end{align*}
$$

Neat! However, notice that the original local objective can be thought of as minimizing the directional derivative, but with a distance penalty, where that distance is defined by the geometry of the parameter space.

$$\hat{\beta} = \min_\beta \nabla f(\beta_t)^T \beta + \frac{1}{2\rho_t}\text{dist}(\beta, \beta_t).$$

When the natural geometry is $\mathbb{R}^p$ then $\text{dist}(\cdot) = \| \cdot \|_2^2$, however there are many  geometries that can describe the natural parameter space (for future class 😉)

## Newton's Method for Optimization
Can we do better, by considering higher-order information (ie geometry) of
the function $f$?

Let's consider a 2nd-order [Taylor-series approximation](https://en.wikipedia.org/wiki/Taylor_series) to $f$ around $\beta_t$ as,

$$f(\beta) \approx f(\beta_t) + \nabla f(\beta_t)^T (\beta - \beta_t) + \frac{1}{2} (\beta - \beta_t)^T H(\beta_t)(\beta - \beta_t),$$ where $H(\beta_t) = \nabla^2 f(\beta_t)$ (i.e. the [Hessian](https://en.wikipedia.org/wiki/Hessian_matrix) of $f$ at $\beta_t$). If we minimize this _local_ approximation, we see

$\nabla_\beta f(\beta) \approx \nabla f(\beta_t) + H(\beta_t)(\beta - \beta_t) = \nabla f(\beta_t) + H(\beta_t)\beta - H(\beta_t)\beta_t ⇒$
$$ H(\beta_t)\beta = H(\beta_t)\beta_t - \nabla f(\beta_t).$$

We can recognize that this is a [system of linear equations](https://en.wikipedia.org/wiki/System_of_linear_equations) $A x = b$ where $A = H(\beta_t)$, $x = \beta$, and $b = H(\beta_t)\beta_t - \nabla f(\beta_t)$. The solution is given by, $\hat{x} = A^{-1}b$, which in this case implies,
$$ \hat{\beta} = H(\beta_t)^{-1}\left(H(\beta_t)\beta_t - \nabla f(\beta_t)\right) = \beta_t - H(\beta_t)^{-1}\nabla f(\beta_t).$$



[Newton's method](https://en.wikipedia.org/wiki/Newton%27s_method_in_optimization) is only guaranteed to converge _locally_, and can diverge even for _strongly_ [convex functions](https://en.wikipedia.org/wiki/Convex_function) (e.g., $f(\beta) = \sqrt{\beta^2 + 1}$). To address this limitation, we can add a dampening parameter, $\rho_t$, which gives us our final update form,
$$ \hat{\beta} = H(\beta_t)^{-1}(H(\beta_t)\beta_t - \nabla f(\beta_t)) = \beta_t - \rho_t H(\beta_t)^{-1}\nabla f(\beta_t).$$

## Quasi-Newton Methods for Optimization
What if computing $H(\beta_t)$ is prohibitive or too costly? Do we need _exact_ second order information to improve on gradient descent's convergence? Given an approximation of $H$, called $B$, i.e. $B(\beta_t) \approx H(\beta_t)$, [_quasi_-Newton methods](https://en.wikipedia.org/wiki/Quasi-Newton_method) optimize for the form
$$f(\beta) \approx f(\beta_t) + \nabla f(\beta_t)^T (\beta - \beta_t) + \frac{1}{2} (\beta - \beta_t)^T B(\beta_t)(\beta - \beta_t),$$ where $B(\beta_t) \approx H(\beta_t)$. Optimizing this statement gives us our update rule,
$$ \hat{\beta} = \beta_t - \rho_t B(\beta_t)^{-1}\nabla f(\beta_t).$$

## Poisson Regression

$$y_i | x_i \sim \text{Poi}(\lambda_i)$$ where $\lambda_i := \exp(x_i^T \beta)$, and $\text{Poi}(k | \lambda) := \frac{\lambda^k \exp(-\lambda)}{k!}$ is the [PMF](https://en.wikipedia.org/wiki/Probability_mass_function) of the [Poisson distribution](https://en.wikipedia.org/wiki/Poisson_distribution). Given $\{(y_i, x_i)\}_{i=1}^n$, we would like to identify the [maximum likelihood parameter estimate](https://en.wikipedia.org/wiki/Maximum_likelihood_estimation) for $\beta$. In other words, we would to find a value for $\beta$ such that we maximize the log-likelihood given by,
$$\begin{align*}
\log \ell(\beta) &= \sum_i \log \text{Poi}(y_i | \exp(x_i^T \beta)) \\
&= \sum_i \log \left[ \frac{\exp(y_i \cdot x_i^T \beta) \exp(-\exp(x_i^T \beta))}{y_i!} \right] \\
&= \sum_i \log \left[ \frac{\exp(y_i \cdot x_i^T \beta - \exp(x_i^T \beta))}{y_i!} \right] \\
&= \sum_i \log \left[\exp(y_i \cdot x_i^T \beta - \exp(x_i^T \beta))\right] - \log(y_i!) \\
&= \sum_i \left[y_i \cdot x_i^T \beta - \exp(x_i^T \beta) - \log(y_i!)\right] \\
&= y^T X\beta - \exp(X\beta)^T 1_n - O(1) \\
&= y^T X\beta - \lambda^T 1_n - O(1),
\end{align*}$$
where $\lambda = \{\lambda_1, \dotsc, \lambda_n\}.$


$$ \begin{align*}
\nabla_\beta \ell &= \nabla_\beta \left[ y^T X\beta - \lambda^T 1_n \right] \\
&= \nabla_\beta [ y^T X\beta ] - \nabla_\beta [\lambda^T 1_n] \\
&= \nabla_\beta [ y^T X\beta ] - \nabla_\beta [\exp(X\beta)^T 1_n] \\
&= X^T y - X^T \exp(X\beta)  \\
&= X^T y - X^T \lambda  \\
&= X^T(y - \lambda) \\
\nabla^2_{\beta \beta} \ell &= \nabla_{\beta} X^T(y - \lambda) \\
&= \nabla_{\beta} \left[X^T y - X^T \lambda \right] \\
&= - X^T \nabla_{\beta}  \lambda \\
&= -X^T \nabla_{\beta}  \exp(X\beta) \\
&= -X^T \Lambda X,
\end{align*}$$
where $\Lambda = \text{diag}(\lambda)$, i.e. $\Lambda_{ii} = \lambda_i$ and $\Lambda_{ij} = 0$ for $i \neq j$.

To illustrate how $\nabla_{\beta}  \exp(X\beta) = \Lambda X$ (i.e. last step in Hessian calculation), recall that the [Jacobian](https://en.wikipedia.org/wiki/Jacobian_matrix_and_determinant) of a function $f : \mathbb{R}^n → \mathbb{R}^m$ is the $m \times n$ matrix $J$ such that $J_{ij} = \frac{∂f_i}{∂j}$. In this case we are computing the Jacobian for $\exp(X\beta)$, which is $\mathbb{R}^p → \mathbb{R}^n$, so our final Jacobian for $\exp(X\beta)$ should have shape $n \times p$. Notice that $J_{i,j} = \frac{\partial}{\partial \beta_j} \exp(x_i^T \beta) = x_{ij}\exp(x_i^T \beta)$, thus $J_{i, .} = \exp(x_i^T \beta) x_i^T$. Repeating this for each $i$ we have $$∇_\beta \exp(X \beta) = J(\exp(X \beta)) = \begin{bmatrix} J_{1,.} \\ ⋮ \\ J_{n,.} \end{bmatrix} =
\begin{bmatrix} \exp(x_1^T \beta) x_1^T \\ ⋮ \\ \exp(x_n^T \beta) x_n^T \end{bmatrix}  =
\begin{bmatrix} \lambda_1 x_1^T \\ ⋮ \\ \lambda_n x_n^T\end{bmatrix} = \Lambda X.$$

We can fit using Newton's method. =>
$$\begin{align*}
\beta(t+1) &= \beta(t) - H(\beta(t))^{-1}\nabla \ell(\beta_t) \\
&= \beta(t) + (X^T \Lambda(t) X)^{-1} X^T (y - \lambda) ⇒ \\
&= (X^T \Lambda(t) X)^{-1} X^T \Lambda(t) (\Lambda(t)^{-1}y + X\beta(t) - 1)
\end{align*}$$
where $\Lambda(t) := \text{diag}(\lambda_1, \dotsc, \lambda_n)$.

In [None]:
https://course.ece.cmu.edu/~ece739/lectures/18739-2020-spring-lecture-08-second-order.pdf

https://web.stanford.edu/class/archive/stats/stats200/stats200.1172/Lecture27.pdf


In [None]:
import jax
import jax.numpy as jnp
import jax.random as rdm
import jax.scipy.stats as stats

import lineax as lx


@jax.jit
def loglikelihood(beta, y, X):
  """
  Our loglikelihood function for $y_i | x_i ~ \text{Poi}(\exp(eta_i))$.

  beta: beta
  y: poisson-distributed observations
  X: our design matrix

  returns: sum of the logliklihoods of each sample
  """
   #loglikelihood function#
  eta = X.mv(beta)

  #𝑦𝑇𝑋𝛽−𝜆𝑇1𝑛

  return  y @ eta - jnp.sum(jax.numpy.exp(eta))

@jax.jit
def irwls_fit(beta, y, X, step_size):
  """
  Perform MLE estimation for $\beta$ under the model
     $y_i | x_i ~ \text{Poi}(\exp(x_i^T \beta))$.

  beta: beta
  y: poisson-distributed observations
  X: our design matrix

  returns: updated estimate of $\beta$
  """


  # compute lambda_i := exp(x_i @ beta)
  eta = X.mv(beta)
  d_i = jnp.exp(eta)
  d_sqrt = jnp.sqrt(d_i)

  # compute z_i := Lambda^{1/2}(Lambda^-1 y + X @beta - 1)
  z = (y / d_i + eta - 1) * d_sqrt

  # X* := Lambda^{1/2} X
  # we use linear operators to postpone any computation
  X_star = lx.DiagonalLinearOperator(d_sqrt) @ X

  # lineax can solve normal equations iteratively as (t(X*) @ (X* @ guess)) - z
  solution = lx.linear_solve(X_star, z, solver=lx.NormalCG(atol=1e-4, rtol=1e-3))
  beta = solution.value

  return beta


def poiss_reg(y, X, fit_func, step_size = 1.0, max_iter=100, tol=1e-3):
  """
  Perform MLE estimation for $\beta$ under the model
     $y_i | x_i ~ \text{Poi}(\exp(x_i^T \beta))$.

  y: poisson-distributed observations
  X: our design matrix
  max_iter: the maximum number of iterations to perform optimization
  tol:

  returns: updated estimate of $\beta$
  """
  # intialize eta := X @ beta
  n, p = X.shape

  # fake bookkeeping
  loglike = -100000
  delta = 10000

  # convert to a linear operator for lineax
  X = lx.MatrixLinearOperator(X)

  # initialize using OLS estimate
  sol = lx.linear_solve(X, (y - jnp.mean(y))/2, solver=lx.NormalCG(atol=1e-4, rtol=1e-3))
  beta = sol.value
  beta = beta / jnp.linalg.norm(beta)
  for epoch in range(max_iter):

    # fit using our function
    beta = fit_func(beta, y, X, step_size)

    # evaluate log likelihood
    newll = loglikelihood(beta, y, X)

    # take delta and check if we can stop
    delta = jnp.fabs(newll - loglike)
    print(f"Epoch[{epoch}] = {newll}")
    if delta < tol:
      break

    # replace old value
    loglike = newll

  return beta

In [None]:
import jax
import jax.numpy as jnp
import jax.numpy.linalg as jnpla
import jax.random as rdm
import jax.scipy.linalg as jspla
import lineax as lx
# Let's simulate a poisson regression model with N samples and P variables
# X: (N, P) , beta: (P,), and y: (N,)
N = 1000
P = 5

# initialize PRNG env
seed = 0
key = rdm.PRNGKey(seed)

# TODO: split key for each random call

key, x_key = rdm.split(key)
X = rdm.normal(x_key, shape=(N, P))
key, b_key = rdm.split(key)
beta = rdm.normal(b_key, shape=(P,))

# TODO: compute lambda_i
lamba = jax.numpy.exp(X @ beta)

# TODO: sample y from Poi(lambda_i)
k=3
#y = jax.scipy.stats.poisson.pmf(k, lamba)
y = rdm.poisson(b_key, lamba,(N,))

# estimate beta using our irwls function
step_size =0.001
X.op = lx.MatrixLinearOperator(X)
beta = irwls_fit(beta, y, X.op, step_size)

# fit_func has signature (eta, y, X, step_size)
beta_hat = poiss_reg(y, X, irwls_fit,step_size)
print(f"beta = {beta}")
print(f"hat(beta) = {beta_hat}")

Epoch[0] = -1414727040.0
Epoch[1] = -522715776.0
Epoch[2] = -193111680.0
Epoch[3] = -71456312.0
Epoch[4] = -26475514.0
Epoch[5] = -9813868.0
Epoch[6] = -3628856.25
Epoch[7] = -1333594.125
Epoch[8] = -472925.78125
Epoch[9] = -151409.265625
Epoch[10] = -31631.16015625
Epoch[11] = 11638.0078125
Epoch[12] = 25922.361328125
Epoch[13] = 29637.5078125
Epoch[14] = 30175.57421875
Epoch[15] = 30195.9375
Epoch[16] = 30195.994140625
Epoch[17] = 30196.01953125
Epoch[18] = 30196.0
Epoch[19] = 30196.0078125
Epoch[20] = 30195.984375
Epoch[21] = 30196.0078125
Epoch[22] = 30196.00390625
Epoch[23] = 30196.001953125
Epoch[24] = 30195.98828125
Epoch[25] = 30196.01171875
Epoch[26] = 30196.00390625
Epoch[27] = 30196.0
Epoch[28] = 30196.001953125
Epoch[29] = 30196.0234375
Epoch[30] = 30195.9921875
Epoch[31] = 30196.001953125
Epoch[32] = 30196.0
Epoch[33] = 30196.01171875
Epoch[34] = 30196.013671875
Epoch[35] = 30195.998046875
Epoch[36] = 30196.03125
Epoch[37] = 30195.998046875
Epoch[38] = 30195.990234375
Epoc

In [1]:
# let's implement poisson regression using _only_ gradient information to perform inference
# and measure how quickly it converges compared with the Newton method
def grad_fit(beta, y, X, step_size):
  eta = X.mv(beta)

  grad = X.mv(y - X @ beta)
  pass

# NB: we can transpose a lx.MatrixLinearOperator (say X) as X.transpose()
# NB: we compute matrix-vector produces using a lx.MatrixLinearOperator as X.mv(v)
step_size = 1e-7
beta_hat = poiss_reg(y, X, grad_fit, step_size, max_iter=1000)
print(f"beta = {beta}")
print(f"hat(beta) = {beta_hat}")

NameError: name 'poiss_reg' is not defined

## Automatic differentiation
Chain rules, okay! Notes TBD

In [3]:
# let's not worry and use autodiff
#autodiff reverse mode differention


auto_grad_ll = jax.grad(loglikelihood)

def jax_grad_step(beta, y, X, step_size):
  pass

# NB: we can transpose a lx.MatrixLinearOperator (say X) as X.transpose()
# NB: we compute matrix-vector produces using a lx.MatrixLinearOperator as X.mv(v)
step_size = 1e-7
beta_hat = poiss_reg(y, X, jax_grad_step, step_size, max_iter=1000)
print(f"beta = {beta}")
print(f"hat(beta) = {beta_hat}")

NameError: name 'jax' is not defined

In [2]:
import jax.scipy.linalg as spla

# Great! But can we use 2nd order information?
auto_hess_ll = jax.hessian(loglikelihood)

def jax_newton_step(beta, y, X, step_size):
  grad = auto_grad_ll(beta,y,X)
  return beta + step_size + grad
# NB: we can transpose a lx.MatrixLinearOperator (say X) as X.transpose()
# NB: we compute matrix-vector produces using a lx.MatrixLinearOperator as X.mv(v)
step_size = 1.
beta_hat = poiss_reg(y, X, jax_newton_step, step_size, max_iter=1000)
print(f"beta = {beta}")
print(f"hat(beta) = {beta_hat}")

NameError: name 'jax' is not defined



 **other**  
we would to find a value for $\beta$ such that we maximize the log-likelihood given by,
$$\begin{align*}
\log \ell(\beta) &= \sum_i \log \text{Poi}(y_i | \exp(x_i^T \beta)) \\
&= \sum_i \log \left[ \frac{\exp(y_i \cdot x_i^T \beta) \exp(-\exp(x_i^T \beta))}{y_i!} \right] \\
&= \sum_i \log \left[ \frac{\exp(y_i \cdot x_i^T \beta - \exp(x_i^T \beta))}{y_i!} \right] \\
&= \sum_i \log \left[\exp(y_i \cdot x_i^T \beta - \exp(x_i^T \beta))\right] - \log(y_i!) \\
&= \sum_i \left[y_i \cdot x_i^T \beta - \exp(x_i^T \beta) - \log(y_i!)\right] \\
&= y^T X\beta - \exp(X\beta)^T 1_n - O(1) \\
&= y^T X\beta - \lambda^T 1_n - O(1),
\end{align*}$$
where $\lambda = \{\lambda_1, \dotsc, \lambda_n\}.$

##### Derive the $ \hat{\beta} $ :

First-order gradient:
$$ \begin{align*}
\nabla_\beta \ell &= \nabla_\beta \left[ y^T X\beta - \lambda^T 1_n \right] \\
&= \nabla_\beta [ y^T X\beta ] - \nabla_\beta [\lambda^T 1_n] \\
&= \nabla_\beta [ y^T X\beta ] - \nabla_\beta [\exp(X\beta)^T 1_n] \\
&= X^T y - \exp(X\beta)^T X  \\
&= X^T y - \lambda^T X  \\
&= X^T(y - \lambda) \\
\end{align*}$$
Second -order gradients:  
$$ \begin{align*}
\nabla^2_{\beta \beta} \ell &= \nabla_{\beta} X^T(y - \lambda) \\
&= \nabla_{\beta} \left[X^T y - X^T \lambda \right] \\
&= - X^T \nabla_{\beta}  \lambda \\
&= -X^T \nabla_{\beta}  \exp(X\beta) \\
&= -X^T \Lambda X,
\end{align*}$$

where $\Lambda = \text{diag}(\lambda)$, i.e. $\Lambda_{ii} = \lambda_i$ and $\Lambda_{ij} = 0$ for $i \neq j$.

Newton's method:
$$\begin{align*}
\beta(t+1) &= \beta(t) - H(\beta(t))^{-1}\nabla \ell(\beta_t) \\
&= \beta(t) + (X^T \Lambda(t) X)^{-1} X^T (y - \lambda) ⇒ \\
&= (X^T \Lambda(t) X)^{-1} X^T \Lambda(t) (\Lambda(t)^{-1}y + X\beta(t) - 1)
\end{align*}$$
where $\Lambda(t) := \text{diag}(\lambda_1, \dotsc, \lambda_n)$.



