In [1]:
import torch
import sys
sys.path.extend(['..'])
from mvn import MVNIso as MVN
from distributions import log_prob_banana as log_p
from samplers import stams_mvn_hmc, hmc
import matplotlib.pyplot as plt
import time
%matplotlib notebook

Make a discrete grid of x, y and compute the $p(x)$ density on it

In [2]:
x = torch.linspace(-4, 4, 201)
xx, yy = torch.meshgrid(x, x)
xy = torch.stack([xx.flatten(), yy.flatten()], dim=0)
log_p_grid = log_p(xy)
p_grid = (log_p_grid - torch.logsumexp(log_p_grid.flatten(), dim=0)).exp().reshape(xx.size())

Do classic MCMC on x

In [3]:
tstart=time.time()
hmc_results = hmc(log_p, torch.randn(2), n_samples=1000, mass=torch.tensor([0.1, 0.1]))
telapse = time.time()-tstart
print(f"Got 1000 samples in {telapse} seconds with {hmc_results['accept']*100:.2f}% acceptance")

plt.figure()
plt.plot(hmc_results['samples'])

Got 1000 samples in 17.684990882873535 seconds with 100.00% acceptance


<IPython.core.display.Javascript object>

[<matplotlib.lines.Line2D at 0x12d3a8d30>,
 <matplotlib.lines.Line2D at 0x12d33dfa0>]

Do classic VI on $q(x;\theta)$

In [4]:
q_vi = MVN.new_random(d=2)
q_vi.theta.requires_grad_(True)

kl_vals = torch.zeros(2000)
for i in range(2000):
    kl = -q_vi.entropy() - q_vi.monte_carlo_ev(log_p, n_samples=500)
    dkl_dtheta = torch.autograd.grad(kl, q_vi.theta)[0]
    with torch.no_grad():
        lr = .01/(1+i//50)
        q_vi.theta.copy_(q_vi.theta - lr * dkl_dtheta)
        kl_vals[i] = kl
q_vi.theta.requires_grad_(False)

plt.figure()
plt.plot(kl_vals)
plt.xlabel('VI step')
plt.ylabel('KL(q||p) + log Z')

<IPython.core.display.Javascript object>

Text(0, 0.5, 'KL(q||p) + log Z')

Do our parametric-sampling method

In [8]:
tstart=time.time()
mvn_hmc_results = stams_mvn_hmc(log_p, lam_kl=2.0, q_init=MVN.new_random(d=2), n_samples=1000)
telapse = time.time()-tstart
print(f"Got 1000 samples in {telapse} seconds with {hmc_results['accept']*100:.2f}% acceptance")

plt.figure()
plt.plot(mvn_hmc_results['samples'])

<IPython.core.display.Javascript object>

[<matplotlib.lines.Line2D at 0x12d9a15b0>,
 <matplotlib.lines.Line2D at 0x12d75fb80>,
 <matplotlib.lines.Line2D at 0x12d95d6a0>]

## Figure 1: conceptual introduction to sampling, VI, and our intermediate method

In [7]:
fig, ax = plt.subplots(1, 3, figsize=(6.5, 2.5))

ax[0].contourf(xx, yy, p_grid, origin='lower', vmin=0.)
xl, yl = ax[0].get_xlim(), ax[0].get_ylim()
ax[0].set_title('Sampling')
ax[0].set_xticks([]); ax[0].set_yticks([])
ax[0].set_xlim(xl); ax[0].set_ylim(yl)
ax[0].axis('image')
ax[0].plot(hmc_results['samples'][::10,0], hmc_results['samples'][::10,1], '.r', markersize=2)

ax[1].contourf(xx, yy, p_grid, origin='lower', vmin=0.)
ax[1].set_title('VI')
ax[1].set_xticks([]); ax[1].set_yticks([])
ax[1].set_xlim(xl); ax[1].set_ylim(yl)
ax[1].axis('image')
ax[1].plot(*q_vi.ellipse(nsigma=1.), color='r', linewidth=1)

ax[2].contourf(xx, yy, p_grid, origin='lower', vmin=0.)
ax[2].set_title("Proposed 'best of both'")
ax[2].set_xticks([]); ax[2].set_yticks([])
ax[2].set_xlim(xl); ax[2].set_ylim(yl)
ax[2].axis('image')
for theta in mvn_hmc_results['samples'][::50]:
    q = MVN(theta=theta)
    ax[2].plot(*q.ellipse(nsigma=1.), color='r', linewidth=1)

fig.tight_layout()

<IPython.core.display.Javascript object>