# Variational Inference in 5 minutes #

(copied with a few changes from [2 posts](http://davmre.github.io/blog/inference/2015/11/13/elbo-in-5min) by Dave Moore)

Let $p(x,z)$ be a probability model with observed variables $x$ and latent variables $z$. We want to infer the posterior $p(z|x)$, but in general this won’t have any nice form that we can write down. Instead, we’ll pick some approximating family $q(z;\lambda)$, with parameters $\lambda$, and then try to find the distribution within this family that best approximates the posterior. For example, if we model each latent variable independently (a “mean field” approximation) using a scalar Gaussian, the parameters $\lambda$ are just the means and standard deviations of these Gaussians.

A natural approach to fitting the approximation parameters $\lambda$ is to minimize the [KL divergence](https://en.wikipedia.org/wiki/Kullback–Leibler_divergence) between our approximation $q(z;\lambda)$ and the posterior $p(z|x)$.$^2$ Writing this out,

\begin{equation}
KL[q(z ; \lambda) \| p(z | x)]=\int q(z ; \lambda) \log \frac{q(z ; \lambda)}{p(z | x)} d z ,
\end{equation}
 
we see that it depends on the posterior density $p(z|x)$ which we don’t know. However, we do have access to the joint distribution $p(x,z)$, which is proportional to the posterior, so we can just apply simple algebra to unpack the normalizing constant:

\begin{equation}
\begin{aligned}
KL[q(z ; \lambda) \| p(z | x)] &=\int q(z ; \lambda) \log \frac{q(z ; \lambda)}{p(z | x)} d z \\
&=\int q(z ; \lambda)[\log q(z ; \lambda)-\log p(z | x)] d z \\
&=\int q(z ; \lambda)\left[\log q(z ; \lambda)-\log \frac{p(x, z)}{p(x)}\right] d z \\
&=\log p(x)+\int q(z ; \lambda)[\log q(z ; \lambda)-\log p(x, z)] d z \\
&=\log p(x)-\mathcal{F}(\lambda ; x)
\end{aligned}
\end{equation} 

This shows that the KL divergence is equal to the model evidence $\log{p(x)}$, which is an (unknown) normalizing constant, minus a term $\mathcal{F}$ given by

\begin{equation}
\mathcal{F}(\lambda ; x)=\int q(z ; \lambda)[\log p(x, z)-\log q(z ; \lambda)] d z
\end{equation}

This term is alternately referred to as (negative) variational free energy or the evidence lower bound (ELBO). It is a lower bound on $\log{p(x)}$
 because we can write $\log{p(x)}=\mathcal{F} + KL[q(z;λ)‖p(z|x)]$
 and the KL divergence is nonnegative. Since the model evidence is constant, maximizing $\mathcal{F}$
 minimizes the KL divergence.

This is the core of variational inference: pick an approximating family and minimize KL divergence between your approximation and the true posterior. 

The practical difficulty tends to be that $\mathcal{F}$ involves an expectation, so evaluating and optimizing it requires either model-specific math$^2$ or Monte Carlo techniques. 

One approach is to note that $\mathcal{F}$ is really just an expectation with respect to our approximating distribution $q$:

\begin{equation}
\begin{aligned}
\mathcal{F}(\lambda ; x) &=E_{z \sim q}[\log p(x, z)-\log q(z ; \lambda)] \\
&=E_{z \sim q}[\log p(x, z)]+H(q ; \lambda)
\end{aligned}
\end{equation}

where we’ve made the simplifying assumption that the entropy $H(q;\lambda)$ is available in closed form. This is true for Gaussian approximating families, but if we’re using some other weird family we can always move the entropy back into the Monte Carlo approximation. The expectation over $\log{p(x,z)}$ might not have a closed form, but we can approximate it by drawing $n$ samples $z_i ∼ q(z;\lambda)$ and evaluating the empirical expectation

\begin{equation}
\hat{\mathcal{F}}(\lambda ; x)=\frac{1}{n} \sum_{i=1}^{n} \log p\left(x, z_{i}\right)+H(q ; \lambda)
\end{equation}

Our approach will be to do gradient ascent on this Monte Carlo approximation. But wait, you might object, $\lambda$ doesn’t appear anywhere in (the Monte Carlo part of) this expression, so how can we compute a gradient? The answer is that $\lambda$ was a parameter of the distribution that produced $z$, so we just have to differentiate through the sampling algorithm, holding fixed the random seed (this is the “reparameterization trick” $^3$ ). In many cases this is straightforward to do.

For example, if $q$ is Gaussian parameterized by a mean and standard deviation $\lambda=(\mu,\sigma)$, a typical sampling procedure would first sample a standard Gaussian variable $\varepsilon \sim N(0,1)$
 and then compute the transform $z = \sigma \varepsilon + \mu$. Rewriting our Monte Carlo ELBO in terms of these “base variables” $\varepsilon_i$,

\begin{equation}
\hat{\mathcal{F}}(\lambda ; x)=\frac{1}{n} \sum_{i=1}^{n} \log p\left(x, \sigma \varepsilon_{i}+\mu\right)+H(q ; \lambda)
\end{equation}

we can now easily differentiate this expression with respect to $\mu$ and $\sigma$ (by the chain rule, this will involve the model gradient $\nabla_z \log{p(x,z)}$). The result is a stochastic estimate of the gradient of the ELBO, which you can plug into your favorite stochastic optimization algorithm (SGD, Adagrad, etc.).

Note the only assumption we’ve made about the model is that we have access to gradients $\nabla_z \log{p(x,z)}$, which is nearly always the case thanks to automatic differentiation. This is how Stan implements variational inference for arbitrary models (more details in [their paper](https://arxiv.org/abs/1506.03431)), and many other languages now support autodiff as well, such as [autograd](https://github.com/HIPS/autograd) and [JAX](https://github.com/google/jax). 

If model gradients are not available, it’s still possible to estimate the ELBO gradient using a trick from reinforcement learning, described in the paper [Black Box Variational Inference](https://arxiv.org/abs/1401.0118). However, this estimate is higher-variance, so optimization will converge much more slowly than when model gradients are available.

$^1$There are other approaches, including the alternate divergence $KL[p‖q]$ which leads to [expectation propagation](https://tminka.github.io/papers/ep/), or the Laplace approximation which locally matches the curvature at the mode. 

$^2$For certain classes of models, e.g., exponential families with conjugate priors, the math is well understood and essentially automateable. This is the idea behind [variational message](http://www.jmlr.org/papers/volume6/winn05a/winn05a.pdf) passing as implemented in, e.g., [Infer.NET](https://dotnet.github.io/infer/). 

$^3$This trick was introduced by Kingma, Salimans, and Welling in the context of variational autoencoders, though also independently proposed by several others around the same time. Shakir Mohamed has a nice post that goes into more depth on the history and applicability of this trick. 