In [None]:
import sys
from pathlib import Path

src_path = str((Path('.') / '..' / 'src').resolve())
if src_path not in sys.path:
  sys.path.insert(0, src_path)

In [None]:
import matplotlib.pyplot as plt
import seaborn as sns
import numpy as np
from tqdm.auto import tqdm

sns.set(font_scale=1.5, style='whitegrid')

In [None]:
import torch
import torch.nn as nn
from torch.distributions import MultivariateNormal

class MoG2(nn.Module):
  def __init__(self):
    super().__init__()

    self.p1 = MultivariateNormal(torch.zeros(2) + 2., covariance_matrix=.5 * torch.eye(2))
    self.p2 = MultivariateNormal(torch.zeros(2) - 2., covariance_matrix=torch.eye(2))

  def forward(self, x):
    log_half = torch.tensor(1/2).log()
    v1 = self.p1.log_prob(x) + log_half
    v2 = self.p2.log_prob(x) + log_half

    return torch.stack([v1, v2]).logsumexp(0)

In [None]:
from sgld.optim import SGLD
from sgld.optim.lr_scheduler import CosineLR

class LL(nn.Module):
  def __init__(self):
    super().__init__()

    self.theta = nn.Parameter(2. * torch.randn(1,2))
    self.mog = MoG2()

  def forward(self):
    return self.mog(self.theta)

f = LL()
T = int(1e4)
n_cycles = 4
lr = .5

sgld = SGLD(f.parameters(), lr=lr, momentum=.9)
sgld_scheduler = CosineLR(sgld, n_cycles=n_cycles, n_samples=2000, T_max=T)

samples = []
for t in tqdm(range(T)):
  sgld.zero_grad()

  v = -f()
  v.backward()

  if sgld_scheduler.get_last_beta() <= sgld_scheduler.beta:
    sgld.step(noise=False)
  else:
    sgld.step()

    if sgld_scheduler.should_sample():
      samples.append(f.theta.detach().clone())

  sgld_scheduler.step()

samples = torch.stack(samples).squeeze(1)

In [None]:
grid = torch.from_numpy(np.mgrid[-7:7:.1, -7:7:.1]).T.float()
mog = MoG2()
logpgrid = mog(grid)

fig, ax = plt.subplots(figsize=(7,7))

ax.contourf(grid[..., 0].numpy(), grid[..., 1].numpy(), logpgrid.exp().numpy(), levels=10,
            cmap=sns.color_palette("crest_r", as_cmap=True))

ax.scatter(samples[:, 0].numpy(), samples[:, 1].numpy(), c='red', alpha=.1)

ax.set_axis_off()