# Stein-Variational-Gradient-Descent
Qiang Liu and Dilin Wang, [*Stein Variational Gradient Descent (SVGD): A General Purpose Bayesian Inference Algorithm*](https://arxiv.org/pdf/1608.04471.pdf), NIPS, 2016.

In [None]:
# imports
import SVGD

from scipy.stats import multivariate_normal as mvn
import numpy as np
from functools import partial

import matplotlib.pyplot as plt
from matplotlib.animation import FuncAnimation
plt.style.use('dark_background')

import os
from IPython.display import HTML


In [None]:
# Data
# ground truth mean and covariance matrix
cov = np.array([[0.333, 0.357],[0.357, 0.666]])
mu = np.array([0., 0.])

xx, yy = np.mgrid[-3.:3.:.01, -3.:3.:.01]
pos = np.dstack((xx, yy))
rv = mvn(mu, cov)

# heart samples
x = np.linspace(-2, 2, 25)
y1 = np.sqrt(1 - (abs(x) - 1) ** 2)
y2 = -3 * np.sqrt(1 - (abs(x) / 2) ** 0.5)

init_particles = np.column_stack((np.concatenate((x, x)), np.concatenate((y1+1, y2+1))))


In [None]:
plt.figure(figsize=(6, 6))
plt.contourf(xx, yy, rv.pdf(pos))
plt.scatter(init_particles[:, 0], init_particles[:, 1], c='r')
plt.axis('off')
plt.axis('equal')
plt.show()


In [None]:
# creating the evaluate d_log_pdf_mvn
d_log_pdf_mvn_eval = partial(SVGD.d_log_pdf_mvn, mu, cov)

transformed_particles = SVGD.update(init_particles, d_log_pdf_mvn_eval, n_iter=1600, stepsize=0.002)


In [None]:
# testing the results
mu_particles = np.mean(transformed_particles, axis=0)
cov_particles = np.cov(transformed_particles.T)
print("particles' mean =", mu_particles)
print("true mean =", mu)
print("particles' COV=\n", cov_particles)
print("true covariance =\n", cov)


In [None]:
plt.figure(figsize=(6, 6))
plt.contourf(xx, yy, rv.pdf(pos))
plt.scatter(transformed_particles[:, 0], transformed_particles[:, 1], c='r')
plt.axis('off')
plt.axis('equal')
plt.show()


In [None]:
# update and record
trans_parts, parts_evol, grads_rec = SVGD.update_record(init_particles, d_log_pdf_mvn_eval,
                                                        n_iter=1200, stepsize=0.002)
parts_evol = np.array(parts_evol)
grads_rec = np.array(grads_rec)


In [None]:
# simple Animation
num_frames = max(parts_evol.shape) - 1
frame_step = num_frames//400

fig = plt.figure(figsize=(7, 7))
ax = fig.add_axes([0, 0, 1, 1], frameon=False)
im = ax.contourf(xx, yy, rv.pdf(pos))
scat = ax.scatter(parts_evol[0, :, 0],
                  parts_evol[0, :, 1], c='r')

plt.axis('off')
plt.axis('equal')

def update(frame_number):
    xdata = parts_evol[frame_number, :, 0]
    ydata = parts_evol[frame_number, :, 1]
    scat.set_offsets(np.c_[xdata,ydata])

simple_animation = FuncAnimation(fig, update, interval=40, frames=np.arange(0, num_frames, frame_step))
# plt.show()
plt.close(fig)


In [None]:
# uncomment to play the animation in notebook
HTML(simple_animation.to_html5_video())


In [None]:
# uncomment to create a gif copy of the animation
# statics = os.path.join(os.path.abspath(os.getcwd() + "/../../"), 'statics')
# simple_animation.save(filename=str(statics + '/SVGD_2D_MVN_simple.gif'), fps=24, dpi=200)


In [None]:
# force (gradients) animation
num_frames = max(parts_evol.shape) - 1
frame_step = num_frames//400

fig = plt.figure(figsize=(7, 7))
ax = fig.add_axes([0, 0, 1, 1], frameon=False)

cont = plt.contourf(xx, yy, rv.pdf(pos), cmap='BuGn')
guiv_att = ax.quiver(parts_evol[0, :, 0], parts_evol[0, :, 1],
                    grads_rec[0, 0, :, 0], grads_rec[0, 0, :, 1], color='r')
guiv_rep = ax.quiver(parts_evol[0, :, 0], parts_evol[0, :, 1],
                    grads_rec[0, 1, :, 0], grads_rec[0, 1, :, 1], color='b')
scat = ax.scatter(parts_evol[0, :, 0], parts_evol[0, :, 1], c='k')

plt.axis('off')
plt.axis('equal')


def update(f):

    guiv_att.set_offsets(np.c_[parts_evol[f, :, 0], parts_evol[f, :, 1]])

    guiv_rep.set_offsets(np.c_[parts_evol[f, :, 0], parts_evol[f, :, 1]])

    guiv_att.set_UVC(grads_rec[f, 0, :, 0], grads_rec[f, 0, :, 1])
    guiv_rep.set_UVC(grads_rec[f, 1, :, 0], grads_rec[f, 1, :, 1])

    scat.set_offsets(np.c_[parts_evol[f, :, 0], parts_evol[f, :, 1]])

quiver_animation = FuncAnimation(fig, update, interval=40, frames=np.arange(0, num_frames, frame_step))
# plt.show()
plt.close(fig)


In [None]:
# uncomment to play the animation in notebook
HTML(quiver_animation.to_html5_video())


In [None]:
# uncomment to create a gif copy of the animation
# statics = os.path.join(os.path.abspath(os.getcwd() + "/../../"), 'statics')
# quiver_animation.save(filename=str(statics + '/SVGD_2D_MVN_quiver.gif'), fps=24, dpi=200)
