<a href="https://colab.research.google.com/github/USCbiostats/PM520/blob/main/Lab_11_Variational_Inference_PtIV.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Get Lucky, or: Non-Conjugate Variational Inference Pt II

Recall, that  CAVI derivations assume that our surrogate models are the result of conditional conjugacy between the expected log likelihood and the prior.

Last week, we _assumed_ that $Q_j(\theta_j)$ is in the _same_ exponential family as its corresponding prior $\Pr(\theta_j)$. We then derived the analytic expectations required for the ELBO and performed gradient ascent.

> _What if there are no closed form/analytic solutions for the expectations (i.e. ELBO)_?

One solution to this problem, is to leverage _stochastic_ gradient descent by performing Monte Carlo sampling of the necessary gradients.

## Why is this a problem? Lemme differentiate under the integral!
$$\begin{align*}
\text{ELBO}(\theta) &:= \mathbb{E}_Q\left[\log \Pr(\mathbf{X} | \mathbf{z})\right] - \mathsf{D}_{KL}(Q(\mathbf{z})||P(\mathbf{z}))\\
\nabla_{\theta}\text{ELBO}(\theta) &=
  \nabla_{\theta}\mathbb{E}_Q\left[\log \Pr(\mathbf{X} | \mathbf{z})\right]
  -\underbrace{\nabla_{\theta}\mathsf{D}_{KL}(Q(\mathbf{z})||P(\mathbf{z}))}_{\text{typically analytically tractable!}}\\
\nabla_{\theta}\mathbb{E}_Q\left[\log \Pr(\mathbf{X} | \mathbf{z})\right] &= \nabla_{\theta}\int Q_{\theta}(\mathbf{z})\log \Pr(\mathbf{X} | \mathbf{z}) d\mathbf{z}\\
  &= \int \nabla_{\theta}Q_{\theta}(\mathbf{z})\log \Pr(\mathbf{X}|\mathbf{z})d \mathbf{z} \\
  &\neq \mathbb{E}_{Q}\left[\nabla_{\theta} \log \Pr(\mathbf{X} | \mathbf{z})\right]
\end{align*}$$


Our expectation depends on the parameters $\theta$, which complicates our expression, and typically doesn't result in a known closed form solution. Additionally, by the straightforward derivation above, we see that we can't sample gradients directly as it isn't the same expression!

> _What gives?_

## Reparameterization Trick
If we can define $\mathbf{z}$ as a _deterministic_ function $g_{\theta}(ɛ) \mapsto \mathbf{z}$, then we may be able to circumvent this issue. For _location-scale_ families $f$, this is trivial! Namely, $\mathbf{z} = \mu + \sigma \circ ɛ$, where $ɛ_j \sim f(0, 1)$, $\theta = \{\mu, \sigma\}$. Now we have gone from $\mathbb{E}_{Q}\left[ \log \Pr(\mathbf{X} | \mathbf{z}) \right]$ to $\mathbb{E}_{ɛ \sim f(0, 1)}\left[ \log \Pr(\mathbf{X} | g_{\theta}(ɛ))\right]$.

> _How does this help us?_

We can use Monte-Carlo estimates of the gradient under this reparameterization to approximate the exact gradient. This can be
shown by,
$$\begin{align*}
\nabla_{\theta}\mathbb{E}_Q\left[\log \Pr(\mathbf{X} | \mathbf{z})\right] &=
  \nabla_{\theta}\mathbb{E}_{ɛ \sim f(0, 1)}\left[ \log \Pr(\mathbf{X} | g_{\theta}(ɛ))\right] \\
&= \nabla_{\theta} \int f(ɛ) \log \Pr(\mathbf{X} | g_{\theta}(ɛ)) dɛ\\
&= \int f(ɛ) \nabla_{\theta}\log \Pr(\mathbf{X} | g_{\theta}(ɛ)) dɛ\\
&= \mathbb{E}_{ɛ \sim f(0, 1)}\left[\nabla_{\theta}\log \Pr(\mathbf{X} | g_{\theta}(ɛ)) \right] \\
&\approx \dfrac{1}{T} \sum_{t=1}^T \nabla_{\theta}\log \Pr(\mathbf{X} | g_{\theta}(ɛ^t)),
\end{align*}$$
where $ɛ^t \sim f(0, 1)$.


## Lab: Logistic Regression with Normal priors on effects
We would like to perform variational inference under the following model:
$$\begin{align*}
\mathbf{y}_i | \beta &\sim \text{Bernoulli}(\text{sigmoid}(\mathbf{x}_i^T \beta))\\
\beta_j &\sim N(0, \sigma^2_b).
\end{align*}$$


Let's assume $Q(\beta) = \prod_{j=1}^p Q_j(\beta_j) = N(\beta_j | \mu_j, \sigma^2_j)$. This model doesn't exhibit (any known) closed-form expectations, so we must rely on stochastic optimization. Let's code the reparameterization trick to optimize the ELBO for this model.

In [None]:
import operator as op

from typing import NamedTuple

import jax
import jax.random as rdm
import jax.nn as nn
import jax.numpy as jnp
import jax.scipy as jsp

from jax import Array
from jax.tree_util import tree_map, tree_reduce
from jax.typing import ArrayLike

from jax import config
config.update("jax_enable_x64", True)


class PriorParams(NamedTuple):
  mean_b: Array
  var_b: Array


class PosteriorParams(NamedTuple):
  mean_b: Array
  log_var_b: Array


def _kl_divergence(post: PosteriorParams, prior: PriorParams) -> Array:
  var_b = jnp.exp(post.log_var_b)

  term1 = ((post.mean_b - prior.mean_b) ** 2) / prior.var_b
  term2 = var_b / prior.var_b
  term3 = post.log_var_b - jnp.log(prior.var_b)
  return 0.5 * (term1 + term2 - term3 - 1)


def kl_divergence(post: PosteriorParams, prior: PriorParams) -> float:
  """KL divergence for scalar normals
  """
  return jnp.sum(_kl_divergence(post, prior))


def complete_log_likelihood(y: ArrayLike, X: ArrayLike, beta: Array) -> float:
  """ log-likelihood of binary outcomes y, given X, beta
  """
  eps = jnp.finfo(beta.dtype).eps
  probs = jnp.clip(nn.sigmoid(X @ beta),
                   eps,
                   1. - eps,
  )
  return jnp.sum(jsp.stats.bernoulli.logpmf(y, probs))


def sample_log_likelihood(post: PosteriorParams, y: ArrayLike, X: ArrayLike, key: ArrayLike) -> float:
  n, p = X.shape
  # reparameterization trick
  std_dev_b = jnp.exp(0.5 * post.log_var_b)
  beta = post.mean_b + std_dev_b * rdm.normal(key, shape=(p,))
  return complete_log_likelihood(y, X, beta)


def elbo(post: PosteriorParams, prior: PriorParams, y: ArrayLike, X: ArrayLike, key: ArrayLike) -> float:
  e_ll = sample_log_likelihood(post, y, X, key)
  kl = kl_divergence(post, prior)
  return e_ll - kl


def fit(y: ArrayLike, X: ArrayLike, prior_var_b: float = 1e-3, num_samples: int = 5, step_size = 1e-3, seed = 0, max_iter=100) -> PosteriorParams:
  #initialize our random keye
  n, p = X.shape
  key = rdm.PRNGKey(seed)

  # split to initalize our variational parameters
  key, init_mean_key = rdm.split(key, 2)
  post = PosteriorParams(
      mean_b = jnp.sqrt(prior_var_b) * rdm.normal(init_mean_key, (p,)),
      log_var_b = jnp.log(prior_var_b * jnp.ones((p,))),
  )
  prior = PriorParams(
      mean_b = jnp.zeros((p,)),
      var_b = prior_var_b * jnp.ones((p,)),
  )

  # sample gradient using reparam trick for expected log like
  def _step(post, key): # value of the elbo and the gradient with respect to all variational parameters
    elboval, elboggrad = jax.value_and_grad(elbo)(post, prior, y, X, key)

    return elboval, elboggrad

  for epoch in range(max_iter):
    key, *skey = rdm.split(key, num_samples + 1)
    skey = jnp.array(skey)
    evals, grads = jax.vmap(_step, (None, 0,))(post, skey)
    elboval = jnp.mean(evals)
    grad = PosteriorParams(mean_b = jnp.mean(grads.mean_b, axis=0),
                           log_var_b = jnp.mean(grads.log_var_b, axis=0))

    print(f"ELBO[{epoch}] ≈ {elboval}")
    post = tree_map(lambda _post, _grad: _post + step_size * _grad, post, grad)
    #print(f"params = {post}")
    # sample to -compute/evaluate- the ELBO

  return post

#New:

In [None]:
import operator as op

from typing import NamedTuple

import jax
import jax.random as rdm
import jax.nn as nn
import jax.numpy as jnp
import jax.scipy as jsp

from jax import Array
from jax.tree_util import tree_map, tree_reduce
from jax.typing import ArrayLike

from jax import config
config.update("jax_enable_x64", True)


class PriorParams(NamedTuple):
  mean_b: Array
  var_b: Array


class PosteriorParams(NamedTuple):
  mean_b: Array
  log_var_b: Array


def _kl_divergence(post: PosteriorParams, prior: PriorParams) -> Array:
  var_b = jnp.exp(post.log_var_b)

  term1 = ((post.mean_b - prior.mean_b) ** 2) / prior.var_b
  term2 = var_b / prior.var_b
  term3 = post.log_var_b - jnp.log(prior.var_b)

  return 0.5 * (term1 + term2 - term3 - 1)


def kl_divergence(post: PosteriorParams, prior: PriorParams) -> float:
  """KL divergence for scalar normals
  """
  return jnp.sum(_kl_divergence(post, prior))


def complete_log_likelihood(y: ArrayLike, X: ArrayLike, beta: Array) -> float:
  """ log-likelihood of binary outcomes y, given X, beta
  """
  eps = jnp.finfo(beta.dtype).eps
  probs = jnp.clip(nn.sigmoid(X @ beta),
                   eps,
                   1. - eps,
  )
  return jnp.sum(jsp.stats.bernoulli.logpmf(y, probs))


def sample_log_likelihood(post: PosteriorParams, y: ArrayLike, X: ArrayLike, key: ArrayLike) -> float:
  n, p = X.shape
  # reparameterization trick
  std_dev_b = jnp.exp(0.5 * post.log_var_b)
  beta = post.mean_b + std_dev_b * rdm.normal(key, shape=(p,))
  return complete_log_likelihood(y, X, beta)


def elbo(post: PosteriorParams, prior: PriorParams, y: ArrayLike, X: ArrayLike, key: ArrayLike) -> float:
  e_ll = sample_log_likelihood(post, y, X, key)
  kl = kl_divergence(post, prior)
  return e_ll - kl


def fit(y: ArrayLike, X: ArrayLike, prior_var_b: float = 1e-3, num_samples: int = 5, step_size = 1e-3, seed = 0, max_iter=100) -> PosteriorParams:
  #initialize our random keye
  n, p = X.shape
  key = rdm.PRNGKey(seed)

  # split to initalize our variational parameters
  key, init_mean_key = rdm.split(key, 2)
  post = PosteriorParams(
      mean_b = jnp.sqrt(prior_var_b) * rdm.normal(init_mean_key, (p,)),
      log_var_b = jnp.log(prior_var_b * jnp.ones((p,))),
  )
  prior = PriorParams(
      mean_b = jnp.zeros((p,)),
      var_b = prior_var_b * jnp.ones((p,)),
  )

  # sample gradient using reparam trick for expected log like
  def _step(post, key):
    elboval, elboggrad = jax.value_and_grad(elbo)(post, prior, y, X, key)

    return elboval, elboggrad

  for epoch in range(max_iter):
    key, *skey = rdm.split(key, num_samples + 1)
    skey = jnp.array(skey)
    evals, grads = jax.vmap(_step, (None, 0,))(post, skey)
    elboval = jnp.mean(evals)
    grad = PosteriorParams(mean_b = jnp.mean(grads.mean_b, axis=0),
                           log_var_b = jnp.mean(grads.log_var_b, axis=0))

    print(f"ELBO[{epoch}] ≈ {elboval}")
    post = tree_map(lambda _post, _grad: _post + step_size * _grad, post, grad)
    #print(f"params = {post}")
    # sample to -compute/evaluate- the ELBO

  return post


In [None]:
# simulate binary outcome
N, P = 250, 25
prior_var_b = 1e-2

seed = 0
key = rdm.PRNGKey(seed)
key, beta_key, x_key, y_key = rdm.split(key, 4)

beta = jnp.sqrt(prior_var_b) * rdm.normal(beta_key, shape=(P,))
X = rdm.normal(x_key, (N, P))

pred = X @ beta
prob = nn.sigmoid(pred)
y = rdm.bernoulli(y_key, prob)

params = fit(y, X, step_size=0.01, prior_var_b=prior_var_b, num_samples=10)

ELBO[0] ≈ -184.54540376083236
ELBO[1] ≈ -175.7925378237437
ELBO[2] ≈ -176.099753779408
ELBO[3] ≈ -173.94517013973316
ELBO[4] ≈ -173.4134425126356
ELBO[5] ≈ -172.5357129989426
ELBO[6] ≈ -171.5784176964027
ELBO[7] ≈ -172.4656387737173
ELBO[8] ≈ -170.90120011248692
ELBO[9] ≈ -175.01941725066098
ELBO[10] ≈ -174.36018326215716
ELBO[11] ≈ -172.98281305260588
ELBO[12] ≈ -173.9027007199709
ELBO[13] ≈ -174.05918948987508
ELBO[14] ≈ -174.48104346964143
ELBO[15] ≈ -174.82107653110103
ELBO[16] ≈ -175.03129071284138
ELBO[17] ≈ -174.19863748010906
ELBO[18] ≈ -173.98924627599308
ELBO[19] ≈ -172.409280886289
ELBO[20] ≈ -173.36528208857646
ELBO[21] ≈ -173.37780664708146
ELBO[22] ≈ -171.61521277299641
ELBO[23] ≈ -171.24451708687448
ELBO[24] ≈ -172.4470084097542
ELBO[25] ≈ -173.5336888594151
ELBO[26] ≈ -174.4124669760982
ELBO[27] ≈ -172.31565097498822
ELBO[28] ≈ -172.27892303974167
ELBO[29] ≈ -173.52573341865605
ELBO[30] ≈ -174.33788238084207
ELBO[31] ≈ -173.60274941791602
ELBO[32] ≈ -171.61597690682893
