# Variational inference

\begin{align*}
\DeclareMathOperator*{\argmax}{argmax}
\DeclareMathOperator*{\EXPSYM}{\mathbf{E}}
\newcommand{\KL}[1]{\mathbf{KL}\left[#1\right]}
\newcommand{\EXP}[2]{\EXPSYM_{#1}\left[#2\right]}
\newcommand{\Kappa}{\boldsymbol{\kappa}}
\newcommand{\Lambda}{\boldsymbol{\lambda}}
\newcommand{\Theta}{\boldsymbol{\theta}}
\newcommand{\Data}{D}
\newcommand{\FFF}{\mathcal{F}}
\newcommand{\LLL}{\mathcal{L}}
\newcommand{\NNN}{\mathcal{N}}
\newcommand{\vec}[1]{\boldsymbol{#1}}
\newcommand{\appropto}{\stackrel{\propto}{\sim}}
\end{align*}

* [Kenneth Tay. Laplace’s approximation for Bayesian posterior distribution](https://statisticaloddsandends.wordpress.com/2019/07/11/laplaces-approximation-for-bayesian-posterior-distribution/)
* [Charles Blundell et al. Weight Uncertainty in Neural Networks](https://arxiv.org/abs/1505.05424)
* [Yeming Wen et al. Flipout: Efficient Pseudo-Independent Weight Perturbations on Mini-Batches](https://arxiv.org/abs/1803.04386)

<center><h1>Variational inference</h1> </center>
<center><h2>Theoretical overview</h2></center>
<br>
<center><h3>Sven Laur</h3></center>
<center><h3>swen@ut.ee</h3></center>

## Bayesian neural networks

<center>
    <img src='./illustrations/bayesian_neural_network.png' width=20% alt='Necessity of reparametrisation'>
</center>

* Activation functions are kept deterministic.
* Randomness is introduced into the weights of the neural network.
* We can add randomness through pertubations. This is technically better.

\begin{align*}
w_i\sim\NNN(\mu,\sigma)\qquad\Longleftrightarrow\qquad w_i=\mu+\sigma\cdot\epsilon, \quad \epsilon\sim\NNN(0,1)
\end{align*}

* This setup defines output probability even if all activation function are deterministic.


## Priors and regularisation

When we search parameters with the highest posterior probability (MAP) then there are two commonly used estimates.

* Maximum likelihood estimate for non-informative prior

\begin{align*}
\hat{\vec{w}}_{MLE}=\argmax_\vec{w} \log p[\Data|\vec{w}]
\end{align*}

* Regularised cost function for more restrictive priors


\begin{align*}
\hat{\vec{w}}_{MAP}=\argmax_\vec{w}\log(\vec{w}) + \log p[\Data|\vec{w}]
\end{align*}

Sometimes maximum likelihood and maximum aposteriori estimates are brittle. 
* The is a large region where the posterior is roughly the same for parameters.
* We can overcome this problem by averaging predictions over different pararameter values.
* To do that we must be able to sample from the posterior distribution.
* Unfortunately, posterior distribution is unnormalised and thus direct sampling is hard.


## Laplace approximation

Assume that the posterior is quite close to the normal distribution $\NNN(\vec{\mu}, \vec{\Sigma})$:
* There is a single mode and the propbability mass is concentrated around it.
* This occurs naturally when the number of samples is large enough -- the law of large numbers.

* Then we can find out the center $\vec{\mu}$ by computing the maximum aposteriori estimate $\hat{\vec{w}}_{MAP}$.
* We can find the variance estimate by computing the Hessian (second derivative) $H$ from the unnormalised log-posterior 

\begin{align*}
h(\vec{w})=\log p(\vec{w}) + \log p[\Data|\vec{w}]
\end{align*}
* As a result the second order Taylor approximation of unnormalised log-likelihood is

\begin{align*}
 h(\hat{\vec{w}}_{MAP}+\Delta \vec{w}) \approx h(\hat{\vec{w}}_{MAP}) + \frac{1}{2} \Delta \vec{w}^T H(\hat{\vec{w}}_{MAP})\Delta \vec{w} \end{align*}

  

## Laplace approximation

Assume that the posterior is quite close to the normal distribution $\NNN(\vec{\mu}, \vec{\Sigma})$:
* There is a single mode and the propbability mass is concentrated around it.
* This occurs naturally when the number of samples is large enough -- the law of large numbers.

* Then we can find out the center $\vec{\mu}$ by computing the maximum aposteriori estimate $\hat{\vec{w}}_{MAP}$.

* The second order Taylor approximation of unnormalised log-likelihood is

\begin{align*}
 h(\hat{\vec{w}}_{MAP}+\Delta \vec{w}) \approx h(\hat{\vec{w}}_{MAP}) + \frac{1}{2} \Delta \vec{w}^T H(\hat{\vec{w}}_{MAP})\Delta \vec{w} \end{align*}

* Thus the unnormalised posterior can be approximated

\begin{align*}
 p(\vec{w}|\Data) \propto  exp\left( \frac{1}{2} \Delta (\vec{w}-\hat{\vec{w}}_{MAP}) H(\vec{w}-\hat{\vec{w}}_{MAP})\right)\cdot \mathcal{O}(1+ ||\Delta \vec{w}||) 
\end{align*}

* The approximation reveals the parameters of the normal distribution and we can sample the weights of the neural nwetwork.


## Variational inference

<center>
    <img src='./illustrations/variational_inference.png' width=20% alt='Necessity of reparametrisation'>
</center>

* Laplace approximation does not work well for multimodal distributions.
* We need to find a global distance measure for approximating the posterior distribution.
* We still need a family of parametrised distributions $q_\Lambda$ to approximate posterior.
* Kullback-Leibler divergence is standard tool for that but we need to careful to not to get into trouble with unnormalised posterior.

## Variational inference

Let us define the cost function (**variational free energy**) 

\begin{align*}
\FFF(\Lambda)=\KL{q_\Lambda(\Theta) || p[\Theta|\Data] }
\end{align*}

where the network parameters are denoted by  $\Theta$ instead of $\vec{w}$ for unknown reasons.

As simple manipulation yields

\begin{align*}
\KL{q_\Lambda(\Theta) || p[\Theta|\Data] }&=\EXP{\Theta\sim q_\Lambda}{\log \frac{q_\lambda(\Theta)}{p[\Theta|\Data]}}=
\EXP{\Theta\sim q_\Lambda}{\log q_\lambda(\Theta)-\log p[\Theta] - \log p[\Data|\Theta] +\log P[\Data]}\\
&=\log P[\Data] + \EXP{\Theta\sim q_\Lambda}{\log \frac{q_\lambda(\Theta)}{p[\Theta]}}-\EXP{\Theta\sim q_\Lambda}{\log p[\Data|\Theta] +\log P[\Data]}\\
&=\log P[\Data] + \KL{q_\Lambda(\Theta) || p[\Theta]}  -\EXP{\Theta\sim q_\Lambda}{\log p[\Data|\Theta] +\log P[\Data]}
\end{align*}

we can simplify the cost function

\begin{align*}
\FFF(\Lambda)= \KL{q_\Lambda(\Theta)||p(\Theta)} - \EXP{\Theta\sim q_\Lambda}{\log p[\Data|\Theta]}
\end{align*}



## Initial analysis of the final minimisation goal

Our goal is to find the optimal set of parameters $\boldsymbol{\lambda}$ that minimises variational free energy

\begin{align*}
\FFF(\Lambda)= \KL{q_\Lambda(\Theta)||p_\Kappa(\Theta)} + \EXP{\Theta\sim q_\Lambda}{\log p[\Data|\Theta]}
\end{align*}

The first term $\KL{q_\Lambda(\Theta)||p_\Kappa(\Theta)}$ in the cost function is easier to handle:
* We control the parametrisation of the prior $p_\Kappa(\Theta)$. 
* We control the parametrisation of the variational approximation $q_\Lambda(\Theta)$.
* We can choose parametrisations so that $\KL{q_\Lambda(\Theta)||p_\Kappa(\Theta)}$ can be found analytically or approximated as $g(\Kappa,\Lambda)$.

The second term is determined by the structure of neural network:
* It is highly non-linear and we have not control over it.
* We cannot evaluate it analytically and we need to rely on Monte-Carlo integration. 

## Naive hill-climbing algorithm 

To minimise $\FFF(\Lambda)$ we can try several different values $\Lambda_1,\ldots, \Lambda_M$ around current estimate $\Lambda$:

* We can evaluate $\FFF(\Lambda_1), \ldots, \FFF(\Lambda_M)$ and choose the lowest value as the next step. 
* We can approximate $\FFF(\Lambda_i)$ with a Monte-Carlo integration 

\begin{align*}
\FFF(\Lambda_i)\approx \KL{q_{\Lambda_i}(\Theta)||p_\Kappa(\Theta)} 
+ \frac{1}{K}\cdot \sum_{j=1}^K \log p[\Data|\Theta_j],\qquad \Theta_1, \ldots, \Theta_K\sim q_{\Lambda_i}  
\end{align*}

* We can compute a linear or quadratic approximation $\hat{\FFF}(\Lambda)$ for the cost function around $\Lambda$ and find $\Lambda_*$ that minimises $\hat{\FFF}(\Lambda)$.


This is terribly ineffective for several reasons:

* We need to sample $M\times K$ points from distributions $q_{\Lambda_1},\ldots, q_{\Lambda_M}$
* The resulting quadratic approximation is a second order method while sampling usually gives the first order quarantees.
* Monte-Carlo integral is too imprecise in small neighbourhoods.

## Stohhastic gradient decent as a way out

The standard neural network minimisation algorithm can be viewed as follows:

* We need to minimise 
\begin{align*}
\EXP{\vec{(x,y)}\sim\Data}{\LLL_{\vec{w}}(\vec{x},y)}=\frac{1}{N}\cdot\sum_{j=1}^N \LLL_{\vec{w}}(\vec{x}_j, y_j)
\end{align*}

* In the gradient decent algorithm $\vec{w}_{i+1}=\vec{w}_i-\eta\cdot \Delta \vec{w}_i$ we must compute

\begin{align*}
\Delta \vec{w}_i= \frac{\partial}{\partial\vec{w}} \left[
\EXP{\vec{(x,y)}\sim\Data}{\LLL_{\vec{w}}(\vec{x},y)}\right]\!\biggl|_{\,\vec{w}=\vec{w}_i}
\end{align*}

* However, we can do a stohhastic updates $\vec{w}_{i+1}=\vec{w}_i-\eta\cdot \widehat{\Delta \vec{w}}_i(\omega)$ instead provided that 

\begin{align*}
 \EXP{\omega}{\widehat{\Delta\vec{w}}_i(\omega)}=\Delta \vec{w}_i
\end{align*}


## Unbiased estimate for the gradient of variational free energy 

As the cost function is indeed an expected value

\begin{align*}
\FFF(\Lambda)= \EXP{\Theta\sim q_\Lambda}{\log \frac{q_\Lambda(\Theta)}{p_\Kappa(\Theta)}- \log p[\Data|\Theta]}
\end{align*}

we need a way to push partial derivative under the expectation to get 

\begin{align*}
\widehat{\Delta \vec{w}}(\Theta)= \frac{\partial}{\partial\Lambda}\left[ 
\log q_\Lambda(\Theta) -\log p_\Kappa(\Theta)- \log p[\Data|\Theta] \right], \qquad \Theta \sim q_\Lambda 
\end{align*}

Pushing the derivative under the expectation works only if we can reparametrise the distribution

\begin{align*}
\Theta= f(\lambda, \epsilon)\qquad \epsilon\sim q_*
\end{align*}

where the distribution $q_*$ does not depend on the parameters $\Lambda$. 

**Important observation:** Normal distribution for the weights of neural network $\Theta \sim \NNN(\boldsymbol{\mu},\boldsymbol{\Sigma})$ fits the bill.


## Why reparametrisation is necessary

<center>
<img src='./illustrations/reparametrisation.png' width=35% alt='Necessity of reparametrisation'>
</center>

Without reparametrisation it is technically impossible to take a derivative

\begin{align*}
\frac{\partial Loss}{\partial \mu_a}=\frac{\partial Loss}{\partial a}\cdot\color{red}{\frac{\partial a}{\partial \mu_a}} 
\end{align*}

* We can sample $a$ from $\NNN(\mu_a, \mu_b)$ but then its a number and there is no dependence on $\mu_a$.
* With reparametrisation $a=\mu_a+\sigma_a\cdot \epsilon$ the dependence between $\mu_a$ and $a$ remains even if $\epsilon$ is sampled.
* Obviously, the reparametrisation must creates the same distribution for the $a$ values.


## Bayes by Backprop algorithm [LeCun 1985, ..., Blundell 2015]


* The resulting algorithm is still overly complex as $\log p[\Data|\Theta]$ sums over the entrire data set.
* We can avoid this by splitting the data into minibatches $\Data_1,\ldots, \Data_M$

* To get the analoque of stohhastic gradient decent we need to split the cost function 

\begin{align*}
\FFF(\Lambda)=\FFF_1(\lambda, \Data_1)+\cdots+ \FFF(\lambda,\Data_M)
\end{align*}

* We need to split the regularisation term $\KL{q_{\Lambda_i}(\Theta)||p_\Kappa(\Theta)}$ between $M$ subfunctions.


* Geometrically decreasing weights $\pi_j$ in front of the regularisation term are good in practice: 

\begin{align*}
\pi_j=\frac{2^{M-j}}{2^M-1}\propto \frac{1}{2^j}
\end{align*}

* The first few minibatches force the distribution close to the prior while the remaining batches fir parameters according to the data.
  
* There seem to be a little difference whether the regularisation term $\KL{q_{\Lambda_i}(\Theta)||p_\Kappa(\Theta)}$ is analytically expressed or sampled.
   


## Flipout optimisation

**Problem:** 
* A minibatch split -- approximation for $\FFF_1(\lambda, \Data_j)$ -- shares weights between different samples $(\vec{x_j}, y_j)$
* Thus the stohhastic gradient $\widehat{\Delta w}$ for the minibatch sum does not contain independent terms. 
* These correlations increase the variance of $\widehat{\Delta w}$ and thus slow the convergence of Bayes by Backprop.

**Solution:**

* Manipulate weigths so that weight distribution is preserved but individual terms in the sum are decorrelated.
* Assume that perturbation to individual weights are independent and pertubation distribution is symmetric.
* By flipping the signs of individual weight perturbations we preserve the perturbation distribution.
* The correlation between the same weight instances for different datapoints is zero.
* As a result the minibatch sum has much smaller variance and the algorithm converges faster.


# Implementation in TensorFlow Probability

There are special layers with reparametrisation:

* `tfp.layers.DenseReparametrisation`:  by default is uses Gaussian variational posterior
* `tfp.layers.DenseFlipout`: the same as previous but with lower variance for the stohhastic gradient 
* `tfp.layers.ConvolutionXDReparameterization`: by default uses Gaussian variational posterior for convolution kernel
* `tfp.layers.ConvolutionXDFlipout`: the same as previous but with lower variance for the stohhastic gradient 

