<a href="https://colab.research.google.com/github/yifdai/PM-520-repo/blob/main/HW/PM520_HW3.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Homework 3. Variational Inference

## 1. Evidence Lower Bound
$\newcommand{\bX}{\mathbf{X}}\newcommand{\by}{\mathbf{y}}\newcommand{\bI}{\mathbf{I}}$
Recall from Lab 8, our example of variational inference for a Bayesian linear regression model. Namely,
$$\begin{align*}
\by | \bX, \beta &\sim N(\bX\beta, \bI_n \sigma^2) \\
\beta &\sim N(0, \bI_p \sigma^2_b).
\end{align*}$$

We assumed a mean-field model that $Q$ factorizes as $$Q(\beta) = \prod_{j=1}^P Q_j(\beta_j).$$

### 1.1
Consulting the results in Lab 8 on parameter definitions for each $Q_j$, please derive the *evidence lower bound* or ELBO for this model.


\begin{align*}
ELBO &= -\frac{n}{2}\log(2\pi\sigma^2) - \frac{1}{2\sigma^2}\left[\|\mathbf{y}-\mathbf{X}\boldsymbol{\mu}\|^2 + \sigma^2\sum_{j=1}^p \|\mathbf{X}_j\|^2\right] \\
&\quad -\frac{p}{2}\log(2\pi\sigma_b^2) - \frac{1}{2\sigma_b^2}\left[\|\boldsymbol{\mu}\|^2 + p\,\sigma^2\right] \\
&\quad -\frac{p}{2}\log(2\pi \sigma^2) - \frac{p}{2}.
\end{align*}


### 1.2
Consult lab 8 for the implementation of a CAVI algorithm for the model above, but rather than evaluate the mean squared error (MSE), evaluate the ELBO. The ELBO should *increase* with each iteration, otherwise there is likely a bug.

In [1]:
# Let's code up the CAVI algorithm for bayesian linear regression

import jax
import jax.lax as lax
import jax.numpy as jnp
import jax.random as rdm

jax.config.update("jax_enable_x64", True)

MAX_ITER = 100

N = 500
P = 250
sigma_sq = 0.8
sigma_sq_b = 0.1

seed = 0
key = rdm.PRNGKey(seed)
key, x_key, b_key, y_key = rdm.split(key, 4)

X = rdm.normal(x_key, shape=(N, P))
beta = jnp.sqrt(sigma_sq_b) * rdm.normal(b_key, shape=(P,))
y = X @ beta + jnp.sqrt(sigma_sq) * rdm.normal(y_key, shape=(N,))

post_means = jnp.zeros((P,))
post_vars = jnp.ones((P,)) * sigma_sq_b

def elbo(y, X, post_means, post_vars, sigma_sq):
  Xsq = jnp.sum(X * X, axis=0)
  pred = X @ post_means
  term1 = -0.5 * N * jnp.log(sigma_sq)
  term2 = jnp.sum(y ** 2) - 2 * y.T @ pred + jnp.sum(Xsq * post_vars) + pred @ pred

  # expected log likelihood
  exp_ll = term1 - term2 / (2 * sigma_sq)

  # KL for each Q_j
  kls = 0.5 * ((post_means ** 2 + post_vars) / sigma_sq_b - jnp.log(post_vars / sigma_sq_b) - 1.)

  return exp_ll - jnp.sum(kls)

def _inner(j, carry):
  r, post_means, post_vars = carry

  # update residual
  Xj = X[:,j]
  rj = r + Xj * post_means[j]

  # update variational parameters for jth distribution
  post_var_j = 1. / (jnp.sum(Xj ** 2) / sigma_sq + 1 / sigma_sq_b)
  mu_j = rj @ X[:,j] * post_var_j / sigma_sq
  post_vars = post_vars.at[j].set(post_var_j)
  post_means = post_means.at[j].set(mu_j)

  # remove the updated mean term from global residual
  r = rj - Xj * mu_j
  return r, post_means, post_vars


last = -100000000
for epoch in range(MAX_ITER):
  r = y - X @ post_means
  r, post_means, post_vars = lax.fori_loop(0, P, _inner, (r, post_means, post_vars))
  value = elbo(y, X, post_means, post_vars, sigma_sq)
  print(f"ELBO[{epoch}] = {value}")
  diff = value - last
  if diff < 0:
    print(f"something went wrong {diff}")
    break
  if diff < 1e-3:
    print("all done")
    break
  last = value



ELBO[0] = -1509.9724636196884
ELBO[1] = -868.4248784058789
ELBO[2] = -764.261884492558
ELBO[3] = -732.4241063447057
ELBO[4] = -720.6285640613584
ELBO[5] = -715.676069397937
ELBO[6] = -713.4637318254369
ELBO[7] = -712.4200619999491
ELBO[8] = -711.9027873415175
ELBO[9] = -711.6357059442223
ELBO[10] = -711.4933224560815
ELBO[11] = -711.4155484047511
ELBO[12] = -711.3722767462873
ELBO[13] = -711.3478586499151
ELBO[14] = -711.3339252435566
ELBO[15] = -711.3259023431558
ELBO[16] = -711.3212475981998
ELBO[17] = -711.3185292972977
ELBO[18] = -711.316932642986
ELBO[19] = -711.3159898860399
all done


## 2. Bayesian Linear Regression Pt II
Here we assume a slightly different linear model, which is given by, $$\begin{align*}
\by | \bX, \beta &\sim N(\bX\beta, \bI_n \sigma^2) \\
\beta_j &\sim \text{Laplace}(0, b).
\end{align*}$$

We assumed a mean-field model that $Q$ factorizes as $$Q(\beta) = \prod_{j=1}^P Q_j(\beta_j).$$ Rather than identify optimal $Q_j$ through CAVI, we will first assume $Q_j := \text{Laplace}(\mu_j, b_j)$. Next, to identify updates for each $\mu_j, b_j$, we take the derivative of the ELBO with respect to each; however the gradient of the ELBO requires knowing $\mu_j, b_j$, which causes challenges.

### 2.1
Re-write the ELBO as a deterministic transformation of $\beta_j$ using location-scale rules (i.e. reparameterization trick)

We apply the reparameterization trick here to first introduce two random variables that follows uniform distribution, $U_j, V_j \sim \text{Unif}(0, 1)$ independently. Then we have $\log (\frac{U_j}{V_j}) \sim \text{Laplace} (0, 1)$. Here since the laplace distribution belongs to the location-scale family, then we could represent $\beta_j$ by a deterministic transformation: $\beta_j (\mu_j, b_j, U_j, V_j) = \mu_j + b_j * \log(\frac{U_j}{V_j})$, where $\beta_j \sim \text{Laplace} (\mu_j, b_j)$

\begin{align*}
f(\beta_j | \mu_j, b_j, U_j, V_j) &= \frac{1}{2b_j} \exp(- \frac{|\log(\frac{U_j}{V_j}) - \mu_j|}{b_j})\\
\frac{|\beta_j - \mu_j|}{b_j} &= |\log(\frac{U_j}{V_j})| \\
\end{align*}

And to be more specific we re-define the prior as:
\begin{align*}
\beta \sim \text{Laplace} (0, b_0)
\end{align*}

Then the ELBO is

\begin{align*}
\mathrm{ELBO} = \; \mathbb{E}_{q(U, V)} \Biggl\{\, &-\frac{n}{2}\log(2\pi\sigma^2) - \frac{1}{2\sigma^2}\Bigl\| \mathbf{y} - \mathbf{X}\left[\mathbf{\mu} + \mathbf{b} \circ \left(\log(\mathbf{U}) - \log(\mathbf{V})\right)\right] \Bigl\|^2\\
&- \sum_{j=1}^{P}\left[\log(2b_0) + \frac{\Bigl|\mu_j + b_j\,\log\!\left(\frac{U_j}{V_j}\right)\Bigr|}{b_0}\right]\\[1mm]
&+\sum_{j=1}^{P}\left[\log(2b_j) + \log(U_j) - \log(V_j)\right]
\Biggr\}
\end{align*}

Based on transformation, we have
$$
U, V \sim \text{Unif}(0, 1)\\
-\log(U), -\log(V) \sim \text{Exp}(1)
$$
so we have $\mathbb{E}_{q(U)}(\log(U)) = \mathbb{E}_{q(V)}(\log(V)) = -1$, the last term is then $\sum_{j=1}^{P}\log(2b_j)$.

For the second term, we could transform it back to the laplace distribution and calculate the $\mathbb{E}_q(\beta_j) (|\beta_j|)$

\begin{align*}
\mathbb{E}_q(\beta_j) (|\beta_j|) &= \int^{\inf}_{\mu_j} |\beta_j| \frac{1}{2b_j} \exp(\frac{\beta_j - \mu_j}{b_j})d\beta_j + \int^{0}_{-\inf} |\beta_j| \frac{1}{2b_j} \exp(\frac{\mu_j - \beta_j}{b_j})d\beta_j
\end{align*}

Notice that it might be hard to directly work with it, thus we consider a transformation to center the laplace distribution by define
$$\gamma_j = \beta_j - \mu_j$$\
$$\gamma_j \sim \text{Laplace}(0, b_j)$$
Now that $\gamma_j$ is centered around zero, then we have
\begin{align*}
\mathbb{E}_q(|\beta_j|)
 &= \int_{-\mu_j}^{\infty} (\gamma_j+\mu_j) f(\gamma_j) d\gamma_j
    + \int_{-\infty}^{-\mu_j} -(\gamma_j+\mu_j) f(\gamma_j) d\gamma_j \\[1mm]
 &= \frac{1}{2b_j} \left[ \int_{-\mu_j}^{\infty} (\gamma_j+\mu_j) \exp\left(-\frac{\gamma_j}{b_j}\right) d\gamma_j
    + \int_{-\infty}^{-\mu_j} -(\gamma_j+\mu_j) \exp\left(\frac{\gamma_j}{b_j}\right) d\gamma_j \right] \\[1mm]
 &= |\mu_j| + b_j \exp\left(-\frac{|\mu_j|}{b_j}\right).
\end{align*}


Let $\boldsymbol{\epsilon} = \log(\mathbf{U}) - \log(\mathbf{V})$,  with $\mathbb{E}(\boldsymbol{\epsilon}) = \mathbf{0}$ and $\mathrm{Var}(\epsilon_j) = 2$. Then, we have $\mathbf{z} = \boldsymbol{\mu} + \mathbf{b} \circ \boldsymbol{\epsilon}$, and $\mathbf{X}\mathbf{z} = \mathbf{X}\boldsymbol{\mu} + \mathbf{X}(\mathbf{b} \circ \boldsymbol{\epsilon})$.

Thus we would have
\begin{align*}
\Bigl\|\mathbf{y} - \mathbf{X}\mathbf{z}\Bigr\|^2
&= \Bigl\|\mathbf{y} - \mathbf{X}\boldsymbol{\mu} - \mathbf{X}(\mathbf{b}\circ\boldsymbol{\epsilon})\Bigr\|^2 \\[1mm]
&= \Bigl\|\mathbf{y} - \mathbf{X}\boldsymbol{\mu}\Bigr\|^2 - 2\left(\mathbf{y}-\mathbf{X}\boldsymbol{\mu}\right)^T\mathbf{X}(\mathbf{b}\circ\boldsymbol{\epsilon}) \\
&\quad + \Bigl\|\mathbf{X}(\mathbf{b}\circ\boldsymbol{\epsilon})\Bigr\|^2.
\end{align*}
Taking expectation over $\boldsymbol{\epsilon}$ gives
\begin{align*}
\mathbb{E}\Bigl[\Bigl\|\mathbf{y} - \mathbf{X}\mathbf{z}\Bigr\|^2\Bigr]
&= \Bigl\|\mathbf{y} - \mathbf{X}\boldsymbol{\mu}\Bigr\|^2 - 2\left(\mathbf{y}-\mathbf{X}\boldsymbol{\mu}\right)^T\mathbf{X}\, \mathbb{E}\left[\mathbf{b}\circ\boldsymbol{\epsilon}\right] \\
&\quad + \mathbb{E}\Bigl[\Bigl\|\mathbf{X}(\mathbf{b}\circ\boldsymbol{\epsilon})\Bigr\|^2\Bigr].
\end{align*}
Since $\mathbb{E}[\boldsymbol{\epsilon}] = \mathbf{0}$, the middle term vanishes. Moreover,
\begin{align*}
\mathbb{E}\Bigl[\Bigl\|\mathbf{X}(\mathbf{b}\circ\boldsymbol{\epsilon})\Bigr\|^2\Bigr]
&= \mathbb{E}\left[\sum_{j=1}^{P} \left(b_j\,\epsilon_j\right)^2 \Bigl\|\mathbf{X}_{\cdot j}\Bigr\|^2\right] \\
&= \sum_{j=1}^{P} b_j^2\,\mathbb{E}(\epsilon_j^2)\,\Bigl\|\mathbf{X}_{\cdot j}\Bigr\|^2 \\
&= 2\sum_{j=1}^{P} b_j^2\,\Bigl\|\mathbf{X}_{\cdot j}\Bigr\|^2.
\end{align*}
Thus, the expected squared error is
\begin{align*}
\mathbb{E}\Bigl[\Bigl\|\mathbf{y} - \mathbf{X}\mathbf{z}\Bigr\|^2\Bigr]
&= \Bigl\|\mathbf{y} - \mathbf{X}\boldsymbol{\mu}\Bigr\|^2 + 2\sum_{j=1}^{P} b_j^2\,\Bigl\|\mathbf{X}_{\cdot j}\Bigr\|^2.
\end{align*}

Thus, replacing the corresponding expectations, the full ELBO becomes
\begin{align*}
\mathrm{ELBO} =\; & -\frac{n}{2}\log(2\pi\sigma^2)
-\frac{1}{2\sigma^2}\left\{ \Bigl\| \mathbf{y} - \mathbf{X}\boldsymbol{\mu} \Bigr\|^2 + 2\sum_{j=1}^{P} b_j^2 \Bigl\| \mathbf{X}_{\cdot j} \Bigr\|^2 \right\}\\[1mm]
& -\sum_{j=1}^{P}\left[\log(2b_0) + \frac{|\mu_j| + b_j\,\exp\!\left(-\frac{|\mu_j|}{b_j}\right)}{b_0}\right]
+ \sum_{j=1}^{P}\log(2b_j).
\end{align*}

### 2.2
Implement the above by performing stochastic VI to optimize the ELBO by sampling.



In [24]:
## Problematic, the result for the scale parameter b is fine, but the location parameter mu is not so good ##

import jax
jax.config.update("jax_enable_x64", True)
import jax.numpy as jnp
import jax.random as random
from jax.scipy.special import softmax

def compute_elbo(mu, b, b0, sigma_sq, y_batch, X_batch, n_total):
    batch_size = y_batch.shape[0]
    n, p = X_batch.shape

    # Compute squared error term
    squared_error = jnp.sum((y_batch - X_batch @ mu) ** 2)

    # Compute variance term
    X_sq_norms = jnp.sum(X_batch ** 2, axis=0)  # ||X_j||^2 for each j
    variance_term = 2 * jnp.sum(b ** 2 * X_sq_norms)

    # Compute KL divergence terms
    kl_terms = jnp.sum(jnp.log(2 * b0) + (jnp.abs(mu) + b * jnp.exp(-jnp.abs(mu) / b)) / b0)

    # Compute entropy terms
    entropy_terms = jnp.sum(jnp.log(2 * b))

    # Combine all terms
    elbo = (
        -0.5 * n * jnp.log(2 * jnp.pi * sigma_sq)
        - 0.5 / sigma_sq * (squared_error + variance_term)
        - kl_terms
        + entropy_terms
    )

    return elbo

def compute_gradients(mu, b, b0, sigma_sq, y_batch, X_batch, n_total):
    batch_size = y_batch.shape[0]
    n, p = X_batch.shape

    # Gradient with respect to mu
    grad_mu = (
        X_batch.T @ (y_batch - X_batch @ mu) / sigma_sq
        - jnp.sign(mu) / b0
        + jnp.sign(mu) * jnp.exp(-jnp.abs(mu) / b) / b0
    )

    # Gradient with respect to b
    X_sq_norms = jnp.sum(X_batch ** 2, axis=0)
    grad_b = (
        4 * b * X_sq_norms / sigma_sq
        - jnp.exp(-jnp.abs(mu) / b) / b0
        + 1 / b
    )

    return grad_mu, grad_b

def stochastic_vi(y, X, b0=1.0, sigma_sq=1.0, step_size=1e-4, max_iter=2000, tol=1e-3, batch_size=20, seed=0):
    key = random.PRNGKey(seed)
    n, p = X.shape

    # Initialize variational parameters
    mu = jnp.zeros(p)
    b = jnp.ones(p) * b0

    elbo_val = compute_elbo(mu, b, b0, sigma_sq, y, X, n)
    print("Initial ELBO (full data) =", elbo_val)

    for epoch in range(max_iter):
        key, subkey = random.split(key)
        indices = random.choice(subkey, n, shape=(batch_size,), replace=False)
        X_batch = X[indices]
        y_batch = y[indices]

        # Compute gradients
        grad_mu, grad_b = compute_gradients(mu, b, b0, sigma_sq, y_batch, X_batch, n)

        # Scale gradients by batch size
        grad_mu = grad_mu * (n / batch_size)
        grad_b = grad_b * (n / batch_size)

        # weirdly if we softmax the gradient here it works fine than not
        grad_mu = softmax(grad_mu)
        grad_b = softmax(grad_b)

        mu = mu + step_size * grad_mu
        b = b + step_size * grad_b

        # Compute new ELBO on full dataset
        new_elbo = compute_elbo(mu, b, b0, sigma_sq, y, X, n)
        delta = new_elbo - elbo_val  # Note: we want this to be positive

        if epoch % 10 == 0:
            print(f"Iteration {epoch}, ELBO (full data) = {new_elbo}, Δ = {delta}")

        # Check for convergence
        if delta < -tol and epoch > 100:  # Only stop if ELBO decreases
            print("Warning: ELBO is decreasing, stopping early")
            break

        elbo_val = new_elbo

    return mu, b, elbo_val

def simulate_data(n=200, p=10, sigma_sq=1.0, mu0=1.0, b0=1.0, seed=42):
    key = random.PRNGKey(seed)
    key, xkey, Ukey, Vkey, ykey = random.split(key, num=5)

    X = random.normal(xkey, shape=(n, p))
    U = random.uniform(Ukey, shape=(p,))
    V = random.uniform(Vkey, shape=(p,))
    true_coeff = mu0 + b0 * jnp.log(U / V)  # Laplace(mu, b)
    y = X @ true_coeff + random.normal(ykey, shape=(n,)) * jnp.sqrt(sigma_sq)

    return y, X, mu0, b0

# Test the implementation
n = 200
p = 10
b0 = jnp.ones(p) * 0.5  # scale parameter for Laplace prior
mu0 = jnp.ones(p) * 2 # location parameter for Laplace prior
sigma_sq = 1.0

y, X, mu0, b0 = simulate_data(n=n, p=p, sigma_sq=sigma_sq, mu0=mu0, b0=b0, seed=0)
print("True location:", mu0)
print("True scale:", b0)

mu_opt, b_opt, final_elbo = stochastic_vi(
    y, X,
    b0=b0,
    sigma_sq=sigma_sq,
    step_size=1e-4,
    max_iter=5000,
    tol=1e-3,
    batch_size=100,
    seed=0
)

print("\nOptimized variational parameters:")
print("mu =", mu_opt)
print("b =", b_opt)
print("Final ELBO =", final_elbo)

True location: [2. 2. 2. 2. 2. 2. 2. 2. 2. 2.]
True scale: [0.5 0.5 0.5 0.5 0.5 0.5 0.5 0.5 0.5 0.5]
Initial ELBO (full data) = -5535.9131849967225
Iteration 0, ELBO (full data) = -5535.8828414906, Δ = 0.030343506122335384
Iteration 10, ELBO (full data) = -5535.595914243887, Δ = 0.027715311660358566
Iteration 20, ELBO (full data) = -5535.313337224928, Δ = 0.02561376678659144
Iteration 30, ELBO (full data) = -5535.04553804438, Δ = 0.029957043574540876
Iteration 40, ELBO (full data) = -5534.765423597446, Δ = 0.031276040449483844
Iteration 50, ELBO (full data) = -5534.476337469256, Δ = 0.02552988774732512
Iteration 60, ELBO (full data) = -5534.195522878821, Δ = 0.02987026383743796
Iteration 70, ELBO (full data) = -5533.900979010077, Δ = 0.0307843489881634
Iteration 80, ELBO (full data) = -5533.621401548104, Δ = 0.02798906568204984
Iteration 90, ELBO (full data) = -5533.36506339509, Δ = 0.015272520075086504
Iteration 100, ELBO (full data) = -5533.1108995154045, Δ = 0.02682415862545895
Iter