In [1]:
from models.samplers import HMC, MALA, ULA, run_chain
from models.normflows import NormFlow
import torch
import numpy as np
from tqdm.auto import tqdm
import matplotlib.pyplot as plt
%matplotlib widget

In [2]:
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

In [None]:
class GMM():
    def __init__(self, locs, covariance_matrices):
        self.locs = torch.tensor([], device=locs[0].device)
        self.covariance_matrices = torch.tensor([], device=covariance_matrices[0].device)
        self.distributions = []
        self.device = locs[0].device
        for i in range(len(locs)):
            self.distributions.append(
                torch.distributions.MultivariateNormal(loc=locs[i], covariance_matrix=covariance_matrices[i]))

    def log_prob(self, z, x=None):
        log_p = torch.tensor([], device=z.device)
        for i in range(len(self.distributions)):
            log_paux = self.distributions[i].log_prob(z).view(-1, 1)
            log_p = torch.cat([log_p, log_paux], dim=-1)
        log_density = torch.logsumexp(log_p, dim=1)
        return log_density

    def sample(self, shape):
        p = np.random.choice(a=len(self.distributions), size=shape[0])
        samples = torch.tensor([], device=self.device)
        for idx in p:
            z = self.distributions[idx].sample((1,))
            samples = torch.cat([samples, z])
        return samples
    
class Gaussian():
    def __init__(self, loc, scale):
        self.distr = torch.distributions.Normal(loc=torch.tensor(loc, device=device), scale=torch.tensor(scale, device=device))
    def log_prob(self, z, x=None):
        return self.distr.log_prob(z).sum(-1)
    def sample(self, shape):
        return self.distr.sample(shape)

In [4]:
locs = [torch.tensor([-3, -3], device=device, dtype=torch.float32),
       torch.tensor([3, 3], device=device, dtype=torch.float32),]

covs = [torch.eye(2, device=device, dtype=torch.float32)*0.1,
       torch.eye(2, device=device, dtype=torch.float32)*0.1]

In [5]:
target = GMM(locs=locs, covariance_matrices=covs)
# target = Gaussian(loc=[10., -10.], scale=[1., 1.])

In [6]:
plt.close()
plt.figure()

target_samples = target.sample((5000, )).cpu()
plt.scatter(target_samples[:, 0], target_samples[:, 1])

plt.show();

Canvas(toolbar=Toolbar(toolitems=[('Home', 'Reset original view', 'home', 'home'), ('Back', 'Back to previous …

In [7]:
dim = 2
flow = NormFlow(flow_type='RNVP', num_flows=2, hidden_dim=dim, need_permute=True).to(device)
optimizer = torch.optim.Adam(flow.parameters(), lr=1e-3)

prior = torch.distributions.Normal(loc=torch.zeros(dim, device=device, dtype=torch.float32),
                                  scale=torch.ones(dim, device=device, dtype=torch.float32))
for ep in tqdm(range(10000)):
#         import pdb
#         pdb.set_trace()
    eps = prior.sample((100, ))
    z, log_jac = flow(eps)
    loss = (prior.log_prob(eps).sum(-1) - log_jac - target.log_prob(z)).mean()
    loss.backward()
    optimizer.step()
    optimizer.zero_grad()
    if ep % 100 == 0:
        print(loss.item())

HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=10000.0), HTML(value='')))

62.95368957519531
40.01558303833008
23.565397262573242
14.11245346069336
11.473953247070312
5.775651931762695
0.8184287548065186
0.16920296847820282
0.1546507030725479
0.13421529531478882
0.04289143532514572
0.05631246417760849
0.06898728013038635
-0.002356430282816291
0.025244303047657013
0.008367604576051235
0.02510571852326393
0.04167329519987106
0.008731777779757977
0.033164724707603455
-0.0005605006008408964
0.030274996533989906
0.010937687940895557
-0.0047379168681800365
-0.002835329622030258
-0.0021473716478794813
0.009011618793010712
0.016713811084628105
0.00662901159375906
0.007869038730859756



KeyboardInterrupt: 

In [8]:
with torch.no_grad():
    eps = prior.sample((5000,))
    z, _ = flow(eps)
    z = z.cpu().numpy()
    
plt.close()
plt.figure()

target_samples = target.sample((5000, )).cpu()
plt.scatter(target_samples[:, 0], target_samples[:, 1], color='blue')
plt.scatter(z[:, 0], z[:, 1], color='red')
# plt.scatter(ula_samples[:, 0], ula_samples[:, 1], color='green')
# plt.scatter(hmc_samples[:, 0], hmc_samples[:, 1], color='orange')

plt.show();

Canvas(toolbar=Toolbar(toolitems=[('Home', 'Reset original view', 'home', 'home'), ('Back', 'Back to previous …

In [7]:
# MALA

mala = MALA(step_size=0.01, use_barker=False, learnable=False)
z_init = torch.zeros(2, device=device, dtype=torch.float32)[None]

mala_samples = run_chain(mala, z_init, target=target.log_prob, n_steps=1000, return_trace=True,).cpu()

In [8]:
# ULA

ula = ULA(step_size=0.01, learnable=False)
z_init = torch.zeros(2, device=device, dtype=torch.float32)[None]

ula_samples = run_chain(ula, z_init, target=target.log_prob, n_steps=1000, return_trace=True,).cpu()

In [9]:
# HMC

hmc = HMC(step_size=0.05, n_leapfrogs=10, learnable=False)
z_init = torch.zeros(2, device=device, dtype=torch.float32)[None]

hmc_samples = run_chain(hmc, z_init, target=target.log_prob, n_steps=1000, return_trace=True,).cpu()

In [8]:
plt.close()
plt.figure()

target_samples = target.sample((5000, )).cpu()
plt.scatter(target_samples[:, 0], target_samples[:, 1], color='blue')
plt.scatter(mala_samples[:, 0], mala_samples[:, 1], color='red')
# plt.scatter(ula_samples[:, 0], ula_samples[:, 1], color='green')
# plt.scatter(hmc_samples[:, 0], hmc_samples[:, 1], color='orange')

plt.show();

Canvas(toolbar=Toolbar(toolitems=[('Home', 'Reset original view', 'home', 'home'), ('Back', 'Back to previous …