# Stein variational gradient descent

In [5]:
import autograd.numpy as np
import matplotlib.pyplot as plt

from IPython.display import HTML, set_matplotlib_formats
set_matplotlib_formats('pdf', 'svg')
css_style = open('../../../_static/custom_style.css', 'r').read()
HTML(f'<style>{css_style}</style>')

<div class="theorem">
    
**Theorem (Gradient of KL is the KSD)** Let $x \sim q(x)$, and $T(x) = x + \epsilon \phi(x)$, where $\phi$ is a smooth function. Then
    
$$\begin{align}
\nabla_{\epsilon}\text{KL}(q_{[T]} || p) \big|_{\epsilon = 0} = - \mathbb{E}_{x \sim q}\left[\mathcal{A}_p \phi(x)\right],
\end{align}$$
    
where $q_{[T]}$ is the density of $T(x)$ and
    
$$\begin{align}
\mathcal{A}_p \phi(x) = \nabla_x \log p(x)\phi^\top(x) + \nabla_x \phi(x).
\end{align}$$
    
</div>
<br>

<details class="proof">
<summary>Proof: Gradient of KL</summary>
    
Let $p_{\left[T^{-1}\right]}(x)$ denote the density of $z = T^{-1}(x)$ when $x \sim p(x)$. By changing the variable of integration from $z$ to $x = T^{-1}(x)$, we obtain
    
$$\begin{align}
\text{KL}(q_{[T]} || p) &= \int q_{[T]}(z) \log \frac{q_{[T]}(z)}{p(z)} dz \\
                        &= \int q(x) \left[ \log q(x) - \log p_{\left[T^{-1}\right]}(x) + \log  \right] dx.
\end{align}$$
    
This change of variables is convenient because now only one term in the integral depends on $\epsilon$, that is $p_{\left[T^{-1}\right]}(x)$. Now taking the derivative with respect to $\epsilon$ we obtain
    
$$\begin{align}
\nabla_{\epsilon} \text{KL}(q_{[T]} || p) &= - \int q(x) \nabla_{\epsilon} \log p_{\left[T^{-1}\right]}(x) dx, \\
                                          &= - \int q(x) \nabla_{\epsilon} \log p_{\left[T^{-1}\right]}(x) dx,
\end{align}$$
    
and using the fact that

$$\begin{align}
\log p_{\left[T^{-1}\right]}(x) &= \log p(T(x)) + \log |\nabla_x T(x)|,
\end{align}$$
    
we obtain the expression

$$\begin{align}
\nabla_{\epsilon} \log p_{\left[T^{-1}\right]}(x) &= \nabla \log p(T(x))^\top \nabla_\epsilon T(x) + \nabla_\epsilon \log |\nabla_x T(x)|, \\
                                                  &= \nabla \log p(T(x))^\top \nabla_\epsilon T(x) + \text{trace}\left[(\nabla_x T(x))^{-1} \nabla_\epsilon \nabla_x T(x)\right],
\end{align}$$
    
where we used the identity
    
$$\begin{align}
\nabla_{\epsilon} \log |\det A| = \text{trace} A^{-1} \nabla_{\epsilon} A,
\end{align}$$
    
we arrive at the following expression for the derivative
    
$$\begin{align}
\nabla_{\epsilon} \text{KL}(q_{[T]} || p) &= - \mathbb{E}_{x \sim q} \left[\nabla \log p(T(x))^\top \nabla_\epsilon T(x) + \text{trace} (\nabla_x T(x))^{-1} \nabla_\epsilon \nabla_x T(x)\right].
\end{align}$$
    
Setting $T(x) = x + \epsilon \phi(x)$ yields the result
    
$$\begin{align}
\nabla_{\epsilon} \text{KL}(q_{[T]} || p) &= - \mathbb{E}_{x \sim q} \left[\nabla \log p(x)^\top \phi(x) + \text{trace}\left[\nabla_x \phi(x) \right]\right], \\
                                          &= - \mathbb{E}_{x \sim q} \left[\text{trace} \mathcal{A}_p \phi(x) \right].
\end{align}$$
    
</details>
<br>



<div class="theorem">
    
**Theorem (Direction of steepest descent of the KL)** The direction $\phi^* \in \mathcal{H}_D$ of steepest descent of the KL-divergence is given by the expression
    
$$\begin{align}
\phi^*(\cdot) = \mathbb{E}_{x \sim q}\left[ k(x, \cdot) \nabla_x \log p(x) + \nabla_x k(x, \cdot)\right].
\end{align}$$
    
</div>
<br>


<details class="proof">
<summary>Proof: Direction of steepest descent of the KL</summary>
    
For $f \in \mathcal{H}_D$ we have the following equality
    
$$\begin{align}
\langle f, \phi^* \rangle_{\mathcal{H}_D} &= \sum_{d = 1}^D \langle f_d(\cdot), \phi^* \rangle_{\mathcal{H}} \\
                                          &= \sum_{d = 1}^D \left \langle f_d(\cdot), \mathbb{E}_{x \sim q}\left[k(x, \cdot) \nabla_{x_d} \log p(x) + \nabla_{x_d} k(x, \cdot)\right] \right\rangle_{\mathcal{H}} \\
                                          &= \sum_{d = 1}^D \mathbb{E}_{x \sim q}\left[\nabla_{x_d} \log p(x) \langle f_d(\cdot), k(x, \cdot) \rangle + \langle f_d(\cdot), \nabla_{x_d} k(x, \cdot) \rangle \right] \rangle_{\mathcal{H}} \\
                                          &= \sum_{d = 1}^D \mathbb{E}_{x \sim q}\left[\nabla_{x_d} \log p(x) f_d(x) + \nabla_{x_d} f_d(x) \right] \\
                                          &= \mathbb{E}_{x \sim p}\left[\mathcal{A}_q f(x)\right].
\end{align}$$
    
Therefore, the $f \in \mathcal{H}_D$ which maximises $\mathbb{E}_{x \sim p}\left[\mathcal{A}_q f(x)\right]$ is the one which maximises the inner product $\langle f, \phi^* \rangle_{\mathcal{H}_D}$, which occurs when $f$ is parallel to $\phi^*$.
    
</details>
<br>



<div class="definition">
    
**Algorithm (Stein variational gradient descent)** Given a distribution $p(x)$, a postive definite kernel $k(x, x')$ and a set of particles with initial positions $\{x_n\}_{n=1}^N$, Stein variational gradient descent evolves the particles according to
    
$$\begin{align}
\frac{d x_m}{dt} = \sum_{n = 1}^N \left[ k(x_n, x_m) \nabla_x \log p(x)|_{x_n} + \nabla_x k(x_m, x) |_{x_n}\right].
\end{align}$$
    
</div>
<br>

In [6]:
def eq(x, x_, lengthscale):
    
    diff = x[:, None, :] - x_[None, :, :]
    quad = np.sum((diff / lengthscale) ** 2, axis=2)
    exp = np.exp(-0.5 * quad)
    
    return exp

In [7]:
def mog_logprob(locs, scales, weights):
    
    def logprob(x):
        
        diff = x[:, None, :] - locs[None, :, :]
        quad = np.sum((diff / scales) ** 2, axis=2)
        
        coeffs = 
        exp = np.exp(-0.5 * quad)