In [1]:
from models.samplers import HMC, MALA, ULA, run_chain
import torch
import numpy as np
import matplotlib.pyplot as plt
%matplotlib widget

In [2]:
device = torch.device("cuda:1" if torch.cuda.is_available() else "cpu")

In [3]:
class GMM():
    def __init__(self, locs, covariance_matrices):
        self.locs = torch.tensor([], device=locs[0].device)
        self.covariance_matrices = torch.tensor([], device=covariance_matrices[0].device)
        self.distributions = []
        self.device = locs[0].device
        for i in range(len(locs)):
            self.distributions.append(
                torch.distributions.MultivariateNormal(loc=locs[i], covariance_matrix=covariance_matrices[i]))

    def log_prob(self, z, x=None):
        log_p = torch.tensor([], device=z.device)
        for i in range(len(self.distributions)):
            log_paux = self.distributions[i].log_prob(z).view(-1, 1)
            log_p = torch.cat([log_p, log_paux], dim=-1)
        log_density = torch.logsumexp(log_p, dim=1)
        return log_density

    def sample(self, shape):
        p = np.random.choice(a=len(self.distributions), size=shape[0])
        samples = torch.tensor([], device=self.device)
        for idx in p:
            z = self.distributions[idx].sample((1,))
            samples = torch.cat([samples, z])
        return samples

In [4]:
locs = [torch.tensor([-1., -1.], device=device, dtype=torch.float32),
       torch.tensor([1., 1.], device=device, dtype=torch.float32),]

covs = [torch.eye(2, device=device, dtype=torch.float32)*0.1,
       torch.eye(2, device=device, dtype=torch.float32)*0.1]

In [5]:
target = GMM(locs=locs, covariance_matrices=covs)

In [6]:
plt.close()
plt.figure()

target_samples = target.sample((5000, )).cpu()
plt.scatter(target_samples[:, 0], target_samples[:, 1])

plt.show();

Canvas(toolbar=Toolbar(toolitems=[('Home', 'Reset original view', 'home', 'home'), ('Back', 'Back to previous …

In [20]:
# MALA

mala = MALA(step_size=0.3, use_barker=False, learnable=False)
z_init = torch.zeros(2, device=device, dtype=torch.float32)[None]

mala_samples = run_chain(mala, z_init, target=target.log_prob, n_steps=1000, return_trace=True,).cpu()

In [25]:
# ULA

ula = ULA(step_size=0.05, learnable=False)
z_init = torch.zeros(2, device=device, dtype=torch.float32)[None]

ula_samples = run_chain(ula, z_init, target=target.log_prob, n_steps=1000, return_trace=True,).cpu()

In [39]:
# HMC

hmc = HMC(step_size=0.05, n_leapfrogs=10, learnable=False)
z_init = torch.zeros(2, device=device, dtype=torch.float32)[None]

hmc_samples = run_chain(hmc, z_init, target=target.log_prob, n_steps=1000, return_trace=True,).cpu()

In [40]:
plt.close()
plt.figure()

target_samples = target.sample((5000, )).cpu()
plt.scatter(target_samples[:, 0], target_samples[:, 1], color='blue')
plt.scatter(mala_samples[:, 0], mala_samples[:, 1], color='red')
plt.scatter(ula_samples[:, 0], ula_samples[:, 1], color='green')
plt.scatter(hmc_samples[:, 0], hmc_samples[:, 1], color='orange')

plt.show();

Canvas(toolbar=Toolbar(toolitems=[('Home', 'Reset original view', 'home', 'home'), ('Back', 'Back to previous …