# Stochastic Variational Inference
Notes from the Hoffmann *et al.* (2013) paper 

In probabilistic modelling, we use hidden variables to encode hidden structure in observed data; we articulate the relationship between the hidden and observed variables with a factorized probability distribution (i.e. a graphical model) and we use inference algorithms to estimate the **posterior distribution**, the **conditional distribution** of hidden structure given the observations.

Consider a graphical model of hidden and observed random variables for which we want to compute the posterior. For many models of interest, this posterior is not tractable to compute and we must appeal to approximate methods. The two most prominent strategies in statistics and machine learning are Markov chain Monte Carlo (MCMC) sampling and variational inference

* **MCMC Sampling**:
    We construct a Markov chain over the hidden variables whose stationary distribution is the posterior of interest. We run the chain until it has (hopefully) reached equilibrium and collect samples to approximate the posterior.
* **Variational inference**:
    We define a flexible family of distribution over the hidden variables, indexed by free parameters. We then find the setting of the parameters (i.e. the member of the family) that is closest to the posterior. Thus we solve the inference problem by solving an optimization problem.
    
The aim here is to develop a general variational method that scales.

Form of stochastic variational inference:
1. Subsample one or more data points from the data
2. Analyze the subsample using the current variational parameters
3. Implement a closed-form update of the variational parameters.
4. Repeat.
While traditional algorithms require repeatedly analyzing the whole dataset before updating the variational parameters, this algorithm only requires that we analyze randomly sampled subsets. 

SVI is a sochastic **optimization algorithm** for mean-field variational inference. It approximates the posterior distribution of a probabilistic model with hidden variables, and can handle massive data sets of observations.

![graphical model](figures/classic.png)
A graphical model with observations $x_{1:N}$, local hidden variables $z_{1:N}$ and global hidden variables $\beta$. The distribution of each observation $x_n$ only depends on its corresponding local variable $z_n$ and the global variables $\beta$.

## 1. Define the class of models to which our algorithm applies. 
We define *local* and *global* hidden variables, and requirements on the conditional distributions within the model.

The joint distribution factorizes into a global term and a product of local terms:
$$
p(x, z, \beta |\alpha) = p(\beta | \alpha)\prod_{n=1}^N p(x_n, z_n |\beta)
$$
Our goal is to approximate the posterior distribution of the hidden variables given the observations, $p(\beta, z|x)$.

**Assumption 1**: The $n$th observation $x_n$ and the $n$th local variable $z_n$ are conditionally independent, given global variables $\beta$, of all other observations and local hidden variables,
$$
p(x_n, z_n |x_{-n}, z_{-n},\beta,\alpha) = p(x_n, z_n|\beta,\alpha)
$$

**Assumption 2**: The *complete conditionals* in the model. A complete conditional is the conditional distribution of a hidden variable given the other hidden variables and the observations. We assume that these distributions are in the **exponential family**,

\begin{eqnarray}
p(\beta|x, z, \alpha) & = & h(\beta)\exp\{\eta_g(x, z, \alpha)^Tt(\beta)-a_g(\eta_g(x, z, \alpha))\}\\
p(z_{nj}|x_n, z_{n,-j}, \beta) & = & h(z_{nj})\exp\{\eta_l(x_n, z_{n,-j},\beta)^Tt(z_{nj}) - a_l(\eta_l(x_n, z_{n,-j},\beta))\}
\end{eqnarray}

The scala functions $h(\cdot)$ and $a(\cdot)$ are respectively the *base measure* and *log-normalizer*, the vector functions $\eta(\cdot)$ and $t(\cdot)$ are respectively the *natural parameter* and *sufficient statistics*. **(For details of this consult a basic statistic book on exponential distributions)**. 

These are conditional distributions, so the natural parameter is a function of the variables that are being conditioned on. For the local variables $z_{nj}$, the complete conditional distribution is determined by the global variables $\beta$ and the other local variables in the $n$th context, i.e. the $n$th data point $x_n$ and the local variables $z_{n,-j}$.

These assumptions on the complete conditional imply a **conjugacy relationship** between the global variables $\beta$ and the local contexts $(z_n, x_n)$, and this relationship implies the distribution of the local context given the global variables must be in an exponential family,

\begin{equation}
p(x_n, z_n |\beta) = h(x_n, z_n)\exp\{\beta^Tt(x_n, z_n) - a_l(\beta)\}
\end{equation}

The prior distribution $p(\beta)$ must also be in an exponential family,
$$
p(\beta) = h(\beta)\exp\{\alpha^Tt(\beta)-a_g(\alpha)\}
$$
The sufficient statistics are $t(\beta) = (\beta, -a_l(\beta))$ and thus the hyperparameter $\alpha$ has two components $\alpha = (\alpha_1, \alpha_2)$. The first component $\alpha_1$ is a vector of the same dimension as $\beta$, the second component $\alpha_2$ is a scalar.

The two equations above imply that the complete conditional for the global variable is in the same exponential family as the prior with natural parameter

$$
\eta_g(x, z, \alpha) = (\alpha_1 + \sum_{n=1}^N t(z_n, x_n), \alpha_2+N).
$$

Analysing data with one of the model associated with this family of distributions (e.g. Bayesian mixture models, Latent Dirichlet allocation) amounts to computing the posterior distribution of the hidden variables given the observations,

$$
p(z, \beta |x ) = \frac{p(x, z, \beta)}{\int p(x, z, \beta)dz d\beta}.
$$
We then use this posterior to explore the hidden structure of our data or to make predictions about future data.

## 2. Mean field variational inference
An approximate inference strategy that seeks a tractable distribution over the hidden variables which is close to the posterior distribution. Derive the traditional variational inference algorithm for our class of models, which is a coordinate ascent algorithm. Closeness is measured with the KL divergence. We use the resulting distribution, called the *variational distribution* to approximate the posterior.

### The evidence lower bound
Variational inference minimizes the KL divergence from the variational distribution to the posterior distribution. It maximizes the *evidence lower bound* (ELBO), a lower bound on the logarithm of the marginal probability of the observations $\log p(x)$. The ELBO is equal to the negative KL divergence up to an additive constant.

We derive the ELBO by introducing a distribution over the hidden variables $q(\alpha, \beta)$ and using Jensen's inequality. (This implies $\log\mathbb{E}[f(y)]\ge \mathbb{E}[\log f(y)]$ for any random variable $y$).

This gives the following bound on the log marginal,

\begin{eqnarray}
\log p(x) & = & \log\int p(x, z, \beta)dz d\beta\\
& = & \log\int p(x, z, \beta)\frac{q(z, \beta)}{q(z, \beta)}dzd\beta\\
& = & \log\left(\mathbb{E}_q\left[\frac{p(x,z,\beta)}{q(z,\beta)}\right]\right)\\
&\ge & \mathbb{E}_q[\log p(x, z, \beta)]-\mathbb{E}[\log q(z, \beta)]\\
&\triangleq &\mathcal{L}(q).
\end{eqnarray}

The ELBO contains two terms. The first term is the expected log joint, $\mathbb{E}_q[\log p(x, z, \beta)]$. The second is the entropy of the variational distribution, $-\mathbb{E}_q[\log q(z, \beta)]$. Both of these terms depend on $q(z, \beta)$, the variational distribution of the hidden variables.

We restrict $q(z, \beta)$ to be in a family that is tractable, one for which the expectations in the ELBO can be efficiently computed. We then try to find the member of the family that maximizes the ELBO. Finally, we use the optimized distribution as a proxy for the posterior.

Solving this maximization problem is equivalent to finding the member of the family that is closest in KL divergence to the posterior:

\begin{eqnarray}
KL(q(z,\beta)||p(z, \beta|x)) & = & \mathbb{E}_q[\log q(z, \beta)] - \mathbb{E}_q[\log p(z, \beta|x)]\\
& = & \mathbb{E}_q[\log q(z, \beta)]-\mathbb{E}_q[\log p(x, \, \beta)] + \log p(x)\\
& = & -\mathcal{L}(q) + \mathrm{const}.
\end{eqnarray}
$\log p(x)$ is replaced by a constant because it does not depend on $q$.

### The mean-field variational family.
The simplest variational family of distributions. In this family, each hidden variable is independent and governed by its own parameter,

$$
q(z, \beta) = q(\beta |\lambda)\prod_{n=1}^N \prod_{j=1}^J q(z_{nj}|\phi_{nj})
$$

The global parameters $\lambda$ govern the global variables, the local parameters $\phi_n$ govern the local variables in the $n$th context. The ELBO is a function of these parameters.

We set $q(\beta|\lambda)$ and $q(z_{nj}|\phi_{nj})$ to be in the same exponential family as the complete conditional distributions $p(\beta|x, z)$ and $p(z_{nj}|x_n,z_{n,-j},\beta)$. The variational parameters $\lambda$ and $\phi_{nj}$ are the natural parameters to those families,

\begin{eqnarray}
q(\beta|\lambda) & = & h(\beta)\exp\{\lambda^Tt(\beta)-a_g(\lambda)\}\\
q(z_{nj}|\phi_{nj}) & = & h(z_{nj})\exp\{\phi_{nj}^Tt(z_{nj}) - a_{l}(\phi_{nj})\}
\end{eqnarray}