## Mean-field variational inference 

$
\newcommand{\R}{\mathbb{R}}
\newcommand{\I}{\mathbb{I}}
\newcommand{\N}{\mathcal{N}}
\newcommand{\E}{\mathbb{E}}
\newcommand{\D}{\mathcal{D}}
\newcommand{\x}{\boldsymbol{x}}
\newcommand{\z}{\boldsymbol{z}}
\newcommand{\X}{\boldsymbol{X}}
\newcommand{\Z}{\boldsymbol{Z}}
\newcommand{\y}{\boldsymbol{y}}
\newcommand{\w}{\boldsymbol{w}}
\newcommand{\bleta}{\boldsymbol{\eta}}
\newcommand{\blchi}{\boldsymbol{\chi}}
\newcommand{\T}{\top}
\newcommand{\mM}{\mathcal{M}}
\newcommand{\trd}{\triangledown}
$

### The general probabilistic model
<hr>


*     Graphical model with local and global latent (hidden) variables $\beta$ and $\Z$ correspondingly.

<br>
<br>
<center>
<img src='prob_model.png' width=500>

* Model parameter $\beta$
$$\large{
\ \beta \sim p(\beta | \alpha)
}$$
* Latent variable $z_n$, where $n = \overline{1..N}$
$$\large{
\ \z_n \sim p(z_n | \beta)
}$$

* Posterior distribution is proportional to the joint distribution of the general probabilistic model
$$\large{
p(\beta,\Z| \X, \alpha) \propto  p(\X,\Z, \beta| \alpha) = p(\beta|\alpha)\prod_{n=1}^N p(x_n,z_n|\beta)  
}$$

* Computing the log-marginal ( _evidence_ ) is normally not fisible.

$$\large{
\ln{p(\X|\alpha)} = \ln \int \overbrace{p( \Z, \beta\ |\ \X, \alpha)}^{\text{no closed form}}\ p( \X\ |\ \alpha )\ d\Z d\beta
}$$



* Let $q(\Z,\beta)$ be an arbitrary $\textit{variational}$ distribution over the local and global hidden variables.


$$\large{
\ln{p(\X|\alpha)} = \ln \int \frac{q(\Z,\beta)\ p(\X, \Z, \beta\ |\ \alpha)}{q(\Z,\beta)}\ d\Z d\beta = \ln \mathbb{E}\bigg[\  \frac{p(\X, \Z, \beta\ |\ \alpha)}{q(\Z,\beta)}\ \bigg]
}$$
<br><br>
**Note:** $\ \ \mathbb{E}[f(x)] = \int p(x) f(x) dx $

### Jensen's inequality

<hr>
<center>
<img src='jensens.png' width=700>
</center>

<br>
<br>
$${
g(\mathbb{E}[X]) \leq \mathbb{E}[g(X)]\quad \text{where}\ q(\cdot)\ \text{is convex}
}$$

$$\text{and}$$

$${\boxed{
g(\mathbb{E}[X]) \geq \mathbb{E}[g(X)]\quad \text{where}\ q(\cdot)\ \text{is concave}
}}$$


* Apply the Jensen's inequality for the concave function $\ln(\cdot)$.

$$\large{
\ln{p(\X|\alpha)} = \ln \mathbb{E}\bigg[\  \frac{p(\X, \Z, \beta\ |\ \alpha)}{q(\Z,\beta)}\ \bigg] \geq  \underbrace{\mathbb{E}\bigg[\  \ln{\frac{p(\X, \Z, \beta\ |\ \alpha)}{q(\Z,\beta)}}\ \bigg]}_{\text{ELBO}} = \mathbb{E}\bigg[\  \ln{\frac{\overbrace{p(\Z, \beta\ |\  \X, \alpha)}^{\text{true posterior}}}{q(\Z,\beta)}}\ \bigg] + \ln{p(\X |\alpha )}
}$$

**Note:** ELBO - _evidence lower bound_

<br>

* Kullback-Leibler divergence definition:

$$\large{
\mathcal{D}_{KL}\Big[q(x)\ ||\ p(x)\Big] = \mathbb{E}\Big[ \ln\frac{q(x)}{p(x)} \Big] = -\mathbb{E}\Big[ \ln\frac{p(x)}{q(x)} \Big] \geq 0
}$$

* Evidence lower bound as KL-diverigence

$$\large{
\ln{p(\X|\alpha)} \geq  \ln{p(\X|\alpha)} - \underbrace{\mathcal{D}_{KL}\Big[ \overbrace{q(\Z,\beta)}^{\text{posterior proxy}}\ ||\ \overbrace{p(\Z, \beta\ |\  \X, \alpha)}^{\text{true posterior}}\ \Big]}_{\geq 0} \triangleq \mathcal{L}\Big[q(\Z,\beta)\Big]
}$$

* Optimal solution when $\mathcal{D}[\ q(\Z,\beta)\ ||\ p(\Z,\beta | X, \alpha)\ ] = 0$, i.e. posterior proxy equal to true posterior.

$$\boxed{\large{
\ q(\Z,\beta)\  = \ p(\Z,\beta | X, \alpha)\ 
}}$$

### Log-marginal approximation
<hr>

$
\DeclareMathOperator*{\argmax}{arg\,max}
\DeclareMathOperator*{\argmin}{arg\,min}
$

* Let $\theta = \{\Z,\beta\}$

* Approximate log-marginal by solving the followig otimization problem for the evidence lower bound $\mathcal{L}[\ q(\Z,\beta)\ ] = \mathcal{L}[\ q(\theta)\ ]$:
<br><br>
$${\large{
\argmin_{q(\theta)}\quad \mathcal{D}_{KL}\Big[\ q(\theta)\ ||\ p(\theta \ |\  \X, \alpha)\ \Big],\quad \text{s.t.}\ \int q(\theta) d\theta = 1
}\ }$$

$$\large{\text{or}}$$

$$\boxed{\ {\large{
\argmax_{q(\theta)}\quad \mathbb{E}\bigg[\  \ln{\frac{p(\X, \theta\ |\ \alpha)}{q(\theta)}}\ \bigg] = \mathbb{E}\bigg[\  \ln{p(\X, \theta \ |\ \alpha)}\ \bigg] - \mathcal{H}\Big[q(\theta)\Big] 
}\ \  \text{s.t.}\ \int q(\theta) d\theta = 1}\ }$$

<br><br>
**Note:** Entropy definition $\ \mathcal{H}[q(x)] = -\mathbb{E}[\ln q(x)]$

### Mean-field assumption
<hr>

* We restrict the family of possible ( _proxy_ ) distibutions $q(\theta)$ to be factorizable as following:

$$\large{
q(\theta) = \prod_{i=1}^M q_i(\theta_i)
}$$



* Then the ELBO can be rewritten as a function of $q(\theta_j)$ using the shortcut $\large{\ q_i = q_i(\theta_i)\ }$:
<br>
<br>
$$\large{
\mathcal{L}[\ q_j \ ] = \int \prod_{i=1}^M q_i \ln{\frac{p(\X,\theta_i|\alpha)}{q_i}}\ d\theta_i = \  \int q_j \bigg\{ \underbrace{\int \ln p(\X,\theta | \alpha ) \prod_{i \neq j} q_i \ d\theta_i}_{\ln\widetilde{p}(X, \theta_j)} \bigg\}\ d\theta_j - \int q_j \ln{q_j}\ d\theta_j + \text{const}
}$$
<br><br>
* Expectation with respect to the $q$ distributions over all variables $\theta_i$ such that $i \neq j$.

$$\large{
\ \ln\widetilde{p}(X, \Z_i) = \int \ln(\X,\theta) \prod_{i \neq j} q_i \ d\theta_i = \mathbb{E}_{i \neq j}[\ \ln p{(\X, \theta | \alpha)} \ ]+ \text{const} 
}$$

* Optimal soltion for the ELBO as a function of $q_j(\theta_j)$

$$\large{
\mathcal{L}[\ q_j \ ] = \int q_j \ln{\widetilde{p}(\X, \theta)} \ d\theta_j - \int q_j \ln{q_j}\ d\theta_j = \int q_j \ln{\frac{\widetilde{p}(\X, \theta)}{q_j}} \ d \theta_j
}$$

$$\large{
\mathcal{L}[\ q_j \ ] = -\int q_j \ln{\frac{q_j}{\widetilde{p}(\X, \theta)} } \ d \theta_j = - \mathcal{D}_{KL}\Big[\  q_j \ || \ \widetilde{p}(\X,\theta) \ \Big]
}$$

* Optimal solution for the KL-divergence $q^{*}_j(\theta_j) = \widetilde{p}(\X,\theta)$

$$\boxed{\large{\ \  \ln q^{*}_j(\theta_j) = \ln\widetilde{p}(\X,\theta) = \mathbb{E}_{i \neq j}[\ \ln p {(\X, \theta | \alpha)} \ ]+ \text{const}  \ \
}}$$

* In practie we will work with the log form and then reinstate the normalization factor (where required) by inspection of the proxy distribution form.

$${\large{ q^{*}_j(\theta_j) = \frac{\exp\Big\{\mathbb{E}_{i \neq j}[\ \ln{(\X, \theta)} \ ]\Big\}}{\int \exp\Big\{\mathbb{E}_{i \neq j}[\ \ln{(\X, \theta)} \ ]\Big\}\ d\theta_j}
}}$$

### Overall ELBO optimization 



* The set of the optimal solutions $q^{*}_j(\theta_j)$ for $j \in \overline{1..M}$ represent a set of conditions for the maximum of the ELBO.



* We wil seek for the overall solution by first initializing all of the factors $ q_i(\theta_i)_{ i \neq j}$ appropriatelly, then finding optimal $ q^*_j(\theta_j)$



* Cycling the factors and replacing them by last estimates from the $q^*_j(\theta_j)$  is garanteed to converge because bound is convex wrt to all of the factors $q_i(\theta_i)$.

## Example Probabilistic Matrix Factorization
<hr>

<center>
    <img  src='VMF.png' width=1000/>
</center>
<br>

* Where 
$\newcommand{\and}{\text{and}}$
$$
A = \Big(A_m \Big)_{m = 1}^{M} \in \mathbb{R}^{(M,H)}, \quad A_m \in \mathbb{R}^{ \ H \ }\quad \and \quad B = \Big(B_l \Big)_{l = 1}^{L} \in \mathbb{R}^{(L,H)}, \quad B_l \in \mathbb{R}^{ \ H \ }
$$
<br><br>
$\quad \quad 
C_A = \textbf{diag}(c_a^2),\quad c_a \in \mathbb{R}^{H} \quad \and \quad C_B = \textbf{diag}(c_b^2),\quad c_b \in \mathbb{R}^{H}
$

* Model log-likelihood with $\theta = \{A,B\}$
<br>
$\newcommand{\const}{\text{const}}$
<br>
$$\large{
\ln p(X_{lm} | \theta, C_A,C_B, \sigma^2) = \ln\mathcal{N}(X_{lm} | \theta, \sigma^2) = -\frac{1}{2}\bigg( \ln{\sigma^2} + {\sigma^{-2}\Big( X_{lm} - B_{l}A_{m}^{\top} \Big)^2 }\bigg) + \const
}$$
<br>

* Data log-likelihood $\large{X = \Big(X_{l,m}\Big)}_{l=1,m=1}^{(L,M)}$
<br>
<br><br>
$$\large{
\ln p(X | \theta, C_A, C_B, \sigma^2) = -\frac1 2\bigg(N\ln\sigma^2  + \sigma^2\Big( ||X - BA^{\top} ||^2_{Fro} \Big) \bigg) + \const
}$$

**Note:** $|| X ||^2_{Fro} = \tr(X^{\top}X)$

###  Log-prior distributions
<hr>
$\newcommand{\tr}{\text{tr}}$

* Matrix A

$$\large{
\ln p(A | C_A) = \ln\mathcal{N}_H(A |0, C_A) = -\frac1 2\bigg(\ln|C_A| + \tr (A C_A^{-1} A^{\top}) \bigg)
}$$

* Matrix B
<br>
$$\large{
\ln p(B | C_B) = \ln\mathcal{N}_H(B |0, C_B) = -\frac1 2\bigg(\ln|C_B| + \tr (B C_B^{-1} B^{\top}) \bigg)
}$$

* Log-proxy for Matrix A

$\large{
\ln q(A)^* = \mathbb{E}_{\neq A}\Big[\ln p(X,A,B|C_A,C_B,\sigma^2)\Big] = }$

$\large{\mathbb{E}_{\neq A}\bigg[ -\frac{1}2\Big(\sigma^{-2}\tr\Big( X^{\top}X - 2X^{\top}BA^{\top} + AB^{\top}BA^{\top} \Big) + \tr\Big( A C_A^{-1}A^{\top}\Big)\Big)\bigg] = \\ \mathbb{E}_{\neq A}\bigg[-\frac{1}2\Big( \tr\Big( - 2\sigma^{-2}X^{\top}BA^{\top} + A\Big(\sigma^{-2}B^{\top}B + C_A^{-1}\Big)A^{\top} \Big)\Big)\bigg]\large}$

* Linearity of the $\mathbb{E}$ operator

$$\mathbb{E}\Big[ a(X + b)\Big] = a\mathbb{E}\Big[X\Big] + b$$

$$\large{
\ln q(A)^* = -\frac{1}2\bigg[ \tr\Big( - 2\sigma^{-2}X^{\top}\mathbb{E}\Big[B\Big]A^{\top} + A\underbrace{\Big(\sigma^{-2}\mathbb{E}\Big[B^{\top}B\Big] + C_A^{-1}\Big)}_{\hat{\Sigma}_A^{-1}}A^{\top} \Big)\bigg]
}$$

* After completing squares for the matrix form A

$$\boxed{\large{
\hat{\Sigma}_A = \Big( \sigma^{-2}\mathbb{E}\Big[B^{\top}B\Big] + C_A^{-1}\Big)^{-1}
}}$$

<br>
$$\boxed{\large{
\hat{A} = \sigma^{-2}X^{\top}\mathbb{E}\Big[B\Big] \hat{\Sigma}_A
}}$$

* Proxy for the Matrix A

$$\large{q(A)^* \sim \mathcal{N}(A | \hat{A}, \hat\Sigma_A)}$$

$$\large{
\mathbb{E}\Big[A\Big] = \hat{A}, \quad \mathbb{E}\Big[A^{\top}A\Big] = \hat{A}^{\top}A + M\Sigma_A 
}$$

* Analogously after completing squares for the matrix form B

$$\boxed{\large{
\hat{\Sigma}_B = \Big( \sigma^{-2}\mathbb{E}\Big[A^{\top}A\Big] + C_B^{-1}\Big)^{-1}
}}$$

<br>
$$\boxed{\large{
\hat{B} = \sigma^{-2}X\mathbb{E}\Big[A\Big] \hat{\Sigma}_B
}}$$

* Proxy for the Matrix A

$$\large{q(B)^* \sim \mathcal{N}(B | \hat{B}, \hat\Sigma_B)}$$

$$\large{
\mathbb{E}\Big[B\Big] = \hat{B}, \quad \mathbb{E}\Big[B^{\top}B\Big] = \hat{B}^{\top}B + L\Sigma_B 
}$$

* Variational free energy is negative ELBO

$$\large{
2F = 2 \mathbb{E}\bigg[ \ln\frac{q(A)q(B)}{p(X | A, B) p(A) p(B)}\bigg]
}$$

$$2F = LM\ln(2\pi\sigma^2) + \frac{|| X - \hat B\hat A^{\top} ||^2_{Fro}}{\sigma^2} + M\frac{\log\text{det}(C_A)}{\log\text{det}(\hat\Sigma_A)}+ L\frac{\log\text{det}(C_B)}{\log\text{det}(\hat\Sigma_B)} - (L+M)H \\ +  
    \text{tr}\Big( C_A^{-1} (A^{\top}A + M\hat{\Sigma}_A) + C_B^{-1} (\hat{B}^{\top}\hat B + L\hat \Sigma_{B})\Big) \\ + \frac{\text{tr}\Big( -A^{\top}A\hat{B}^{\top}\hat B + (A^{\top}A + M\hat{\Sigma}_A) (\hat{B}^{\top}\hat B + L\hat \Sigma_{B}) \Big)}{\sigma^2} 
$$