# **Variational Inference in VAE**

## Prerequisite
Before starting this lesson, students should be familiar with the following concepts
- Probability Distribution & Bayes Theorem
- Intro of VAE


## **Variational Inference**

Variational inference is a technique to deduce parameters of probability distributions in graphical models.

*Please forget everything you know about deep learning and neural networks for now. Thinking about the following concepts in isolation from neural networks will clarify things. At the very end, we’ll bring back neural nets.*

Graphical models are used in probability and statistics where graphs visually represent conditional dependence between random variables.

Suppose a graphical VAE model with data $x$ and latent variable $h$.

<center>
<figure>
<p><img src="https://i.postimg.cc/R0GtwfVh/graph-model-for-variational-inference.jpg"></p>
<figcaption align="center">Fig: Graphical Model used for approximating our posterior probability, where solid lines denote the generative model and dotted ones denote the approximate posterior</figcaption>
</figure>
</center>

We assume that the data is generated by some random process involving a continuous random variable $h$. Each latent variable is drawn from a known distribution, prior $P(h)$ (i.e., $h \sim P(h)$), and the data points are then generated from a likelihood $P(x|h)$ (i.e., $x \sim P(x|h)$).
We can write the joint probability of the model as $P(x, h) = P(x|h) \cdot P(h)$.
The prior $P(h)$ and likelihood $P(x|h)$ are selected to be simple (e.g., Gaussian), from which $h$, $x$ are generated. Gaussian is chosen because input data representation is feasible that way. It reduces computational complexity of approximation and is reasonable for latent attributes. (For a face image, latent features can be a pose, amount of smile, etc.)

> **Analogy**: Imagine generating faces by randomly picking features like smile intensity, angle of head, eye size—those are your $h$'s. You pick these from a “bag” (the prior), and a drawing mechanism (the likelihood) turns those features into actual face images.

Coming from a machine learning background, we know the same neural network (same function) is used by all data points. But in our graphical VAE model, the latent variable is **local**. This means each data point has a unique latent $h$ and doesn’t share it with another variable.
This is very important to note when tackling the problem using a graphical model.

> **Analogy:** Suppose you see a face image $x$ and want to figure out which facial features (smile, head tilt, eye size — the latent $h$) created it. Each face has unique features, but directly identifying them all is complicated.


The goal of the model is to infer a good value of the latent variable given an observation of data, represented by the posterior $P(h|x)$, denoted by the dotted line.
The idea is that inferring $P(h)$ using $P(h|x)$ makes the latent variables more likely under our data. Using Bayes' theorem:

$$
P(h|x) = \frac{P(x|h) \cdot P(h)}{P(x)} \tag{7}
$$

But the evidence $P(x)$ in the above equation is **intractable** when the latent space is high-dimensional:

$$
P(x) = \underbrace{\int P(x|h) \cdot P(h) \, dh}_{\text{intractable}} \tag{8}
$$

This integral becomes computationally expensive because it requires summing over **all possible values** of the latent variables. As a result, the true posterior is also intractable.

> **Analogy**: Imagine trying to figure out how likely a face image is by **checking every possible combination** of facial features—every smile intensity, head angle, eye size. Since the combinations are endless, this quickly becomes impossible.

This is where **variational inference** comes in.

We approximate the true posterior $P(h|x)$ with a simpler, tractable distribution $q_\lambda(h|x)$, where $\lambda$ are parameters (e.g., mean and variance in a Gaussian).

> **Analogy**: Instead of checking every combination of facial features, you **pick one combination that seems good enough** based on prior info—like guessing the most likely smile and pose that produced the face you see.

Our goal is to **minimize the KL divergence** between $q_\lambda(h|x)$ and $P(h|x)$:

$$\min \mathbb{KL}(q_\lambda(h|x) || P(h|x)) $$ Given as

$$
\mathbb{KL}(q_\lambda(h|x) || P(h|x)) = -\sum q_\lambda(h|x)*log\frac{P(h|x)}{q_\lambda(h|x)} \tag{9}
$$

Now, moving on with the derivation. Substituting the true posterior value from Eqn (7), Eqn (9) converts to


$$
\begin{aligned}
\text{or,} \quad \mathbb{KL}(q_\lambda(h|x) \| P(h|x))
&= -\sum q_\lambda(h|x) \cdot \log \frac{\frac{P(x,h)}{P(x)}}{q_\lambda(h|x)} \\
&= -\sum q_\lambda(h|x) \cdot \left[\log \frac{P(x,h)}{q_\lambda(h|x)} - \log P(x)\right] \quad \text{(log property)}
\end{aligned}
$$


$$
or, \mathbb{KL}(q_\lambda(h|x) || P(h|x)) =  -\sum_h q_\lambda(h|x)*log\frac{P(x,h)}{q_\lambda(h|x)} + \sum_h q_\lambda(h|x)*logP(x)
$$

Since, $\sum_h q_\lambda(h|x) = 1$ (summation of probability) and $logP(x)$'s parameter is defined by variable $h$

$$
or, \mathbb{KL}(q_\lambda(h|x) || P(h|x)) = -\sum q_\lambda(h|x)*log\frac{P(x,h)}{q_\lambda(h|x)} + logP(x) \tag{10}
$$

Eqn (10) can be re-written in term of evidence as

$$
logP(x) =\mathbb{KL}(q_\lambda(h|x) || P(h|x)) + \underbrace{\sum q_\lambda(h|x)*log\frac{P(x,h)}{q_\lambda(h|x)}}_{\mathcal{L(\lambda)}} \tag{11}
$$

In the above eqn, the latter term given by $\mathcal{L(\lambda)}$ is known as **Variational LowerBound** or **Evidence Lowerbound(ELBO)** (computationally tractable as we will discuss). The KL-divergence term still contains the pesky term $P(h|x)$ which we discussed was intractable which is always greater or equal to zero.

Since, $P(x)$ has definite value for particular input data $x$ we can assume $logP(x)$ to be a constant for given $x$. So, Instead of minimizing the KL-Divergence we can maximize $\mathcal{L(\lambda)}$ as the operations are equivalent. As we know $\mathbb{KL} >= 0$, so, variational lowerbound, $\mathcal{L} <= logP(x)$.

<center>
<figure>


<p><img src="https://i.postimg.cc/L6nMSMZb/kl-divergence-minimization.png"  ></p>
<figcaption align="center">Fig: Graphical representation of the KL-divergence minimization, which can be achieved by maximizing variational lowerbound or evidence lowerbound (ELBO) to reach evidence (log P(x))

We know that KL>=0. Our target is to reduce KL divergence to zero.</br>
i.e $\mathbb{KL}=0$ or $log P(x)=ELBO$</figcaption>
</figure>
</center>



Also the variational lowerbound term's gradient is computable through which we can later optimize, which can also be written as:

$$
\begin{aligned}
\mathcal{L}(\lambda) &= \sum q_\lambda(h|x) * \log \frac{P(x,h)}{q_\lambda(h|x)} \\
&= \sum q_\lambda(h|x) * \log \frac{P(x|h) * P(h)}{q_\lambda(h|x)}
\end{aligned}
$$

$$
or, \mathcal{L(\lambda)} = \sum q_\lambda(h|x)[logP(x|h)+log\frac{P(h)}{q_\lambda(h|x)}]
$$

$$
or, \mathcal{L(\lambda)} = \sum q_\lambda(h|x)*logP(x|h)+ \sum q_\lambda(h|x)*log\frac{P(h)}{q_\lambda(h|x)} \tag{12}
$$

The 1st term in Eqn (12) is familiar in nature, it's the expectation of log-likelihood, $logP(x|h)$.

$$
\mathbb{E}_{q_\lambda(h|x)} [logP(x|h)] = \sum q_\lambda(h|x)*logP(x|h)
$$

The 2nd term in eqn(12) is the -ve of KL-Divergence of $q_\lambda(h|x)$ w.r.to prior $P(h)$

$$
-\mathbb{KL}(q_\lambda(h|x)  || P(h)) = \sum q_\lambda(h|x)*log\frac{P(h)}{q_\lambda(h|x)}
$$


$$
\mathcal{L(\lambda)} = \mathbb{E}_{q_\lambda(h|x)} [logP(x|h)] - \mathbb{KL}(q_\lambda(h|x)  || P(h)) \tag{13}
$$

For a single data point $x^{(i)}$, the variational lowerbound is represented as

$$
\mathcal{L_i(\lambda)} = \mathbb{E}_{q_\lambda(h|x^{(i)})} [logP(x^{(i)}|h)] - \mathbb{KL}(q_\lambda(h|x^{(i)})  || P(h))
$$

> **Analogy**: You're trying to find the best "explanation" (latent variable $h$) for a particular observed event (data point $x$), while also making sure that your explanation is not too far-fetched (i.e., it stays close to what you'd expect a prior).

We have to **maximize** the variational lowerbound to make assumed posterior and true posterior similar to each other. This can be achieved by maximizing the likelihood of obtaining points $logP(x^{(i)}|h)$ given $h$ and making our assumed posterior similar to prior $P(h)$ deduced from Eqn (11).

The prior is selected to be Gaussian distribution(or Bernoulli distribution) making our approximate posterior $q_\lambda(h|x)$ to be similar to Gaussian, which has a neat closed-form solution. This enables us to save the mean and variance of the approximate posterior as vectors/matrix in our neural network.

Using variational inference, we’ve now solved the problem of not being able to calculate the intractable true posterior. This gives us a mathematically sound and computationally feasible way to proceed—and now we’re ready to bring back **neural networks** to parameterize $q_\lambda(h|x)$ and $P(x|h)$, which will be discussed in the next notebook.




<!-- $$p\left( {z|x} \right) = \frac{{p\left( {x|z} \right)p\left( z \right)}}{{p\left( x \right)}}$$


$$p\left( x \right) = \int {p\left( {x|z} \right)p\left( z \right)dz}$$



$$\min KL\left( {q\left( {z|x} \right)||p\left( {z|x} \right)} \right)$$


$$\text{log }p(x) =\text{KL}(q(z|x) || p(z|x)) + {\sum q(z|x)*\text{log}\frac{p(x,z)}{q(z|x)}} $$


$${E_{q\left( {z|x} \right)}}\log p\left( {x|z} \right) - \text{KL}\left( {q\left( {z|x} \right)||p\left( z \right)} \right)$$  -->


### Key Takeaways

1. Variational inference approximates the posterior distribution of the data with the family of distribution.
2. We minimize the KL-Divergence between two distribution. It is equivalent to maximizing the evidence lowerbound (ELBO).



### Additional Resources

* Papers
   * Kingma, D. P., & Welling, M. (2013). Auto-encoding variational bayes. arXiv preprint arXiv:1312.6114. https://arxiv.org/pdf/1312.6114.pdf
       * Appendix B analytically solves the KL divergence term.
