In [1]:
from math import ceil

import numpy as np
import plotly.express as px
import torch
from plotly import graph_objects as go

from mcmc_experiments.distributions import GaussianMixtureModel
from mcmc_experiments.autoencoders.vae import VAEConfig, VAE

%load_ext autoreload
%autoreload 2

In [2]:
def plot_density_surface(data_dist, x_range=(-1, 2), y_range=(-1, 2), n_points=100):
    """
    Create a surface plot of a 2D density function using Plotly.

    Parameters:
    -----------
    data_dist : object
        Object with a likelihood method that takes a torch.Tensor of shape (N, 2)
        and returns likelihoods for each point
    x_range : tuple
        (min, max) for x axis
    y_range : tuple
        (min, max) for y axis
    n_points : int
        Number of points along each axis for the grid

    Returns:
    --------
    plotly.graph_objects.Figure
    """
    # Create grid of points
    x = np.linspace(x_range[0], x_range[1], n_points)
    y = np.linspace(y_range[0], y_range[1], n_points)
    X, Y = np.meshgrid(x, y)

    # Convert grid points to torch tensor
    points = torch.tensor(np.stack([X.flatten(), Y.flatten()], axis=1))

    # Evaluate likelihood at each point
    with torch.no_grad():
        Z = data_dist.likelihood(points).numpy()
    Z = Z.reshape(n_points, n_points)

    # Create surface plot
    fig = go.Figure(
        data=[go.Surface(x=X, y=Y, z=Z, colorscale="Viridis", showscale=True)]
    )

    # Update layout
    fig.update_layout(
        title="Density Function",
        scene=dict(xaxis_title="X", yaxis_title="Y", zaxis_title="Likelihood"),
        width=800,
        height=800,
        autosize=False,
    )

    return fig


def plot_density_contour(data_dist, x_range=(-1, 2), y_range=(-1, 2), n_points=100):
    """
    Create a contour plot of a 2D density function using Plotly.

    Parameters:
    -----------
    data_dist : object
        Object with a likelihood method that takes a torch.Tensor of shape (N, 2)
        and returns likelihoods for each point
    x_range : tuple
        (min, max) for x axis
    y_range : tuple
        (min, max) for y axis
    n_points : int
        Number of points along each axis for the grid

    Returns:
    --------
    plotly.graph_objects.Figure
    """
    # Create grid of points
    x = np.linspace(x_range[0], x_range[1], n_points)
    y = np.linspace(y_range[0], y_range[1], n_points)
    X, Y = np.meshgrid(x, y)

    # Convert grid points to torch tensor
    points = torch.tensor(np.stack([X.flatten(), Y.flatten()], axis=1))

    # Evaluate likelihood at each point
    with torch.no_grad():
        Z = data_dist.likelihood(points).numpy()
    Z = Z.reshape(n_points, n_points)

    # Create contour plot
    fig = go.Figure(
        data=[
            go.Contour(
                x=x,
                y=y,
                z=Z,
                colorscale="Viridis",
                contours=dict(
                    showlabels=True,  # show labels on contours
                    labelfont=dict(size=12, color="white"),
                ),
                colorbar=dict(title="Likelihood", titleside="right"),
            )
        ]
    )

    # Update layout
    fig.update_layout(
        title="Density Function",
        xaxis_title="X",
        yaxis_title="Y",
        width=800,
        height=800,
        autosize=False,
    )

    return fig


def plot_density_heatmap(density_fn, x_range=(-2, 2), y_range=(-2, 2), n_points=100):
    """
    Create a single-color density heatmap using Plotly.

    Parameters:
    -----------
    data_dist : object
        Object with a likelihood method that takes a torch.Tensor of shape (N, 2)
        and returns likelihoods for each point
    x_range : tuple
        (min, max) for x axis
    y_range : tuple
        (min, max) for y axis
    n_points : int
        Number of points along each axis for the grid

    Returns:
    --------
    plotly.graph_objects.Figure
    """
    # Create grid of points
    x = np.linspace(x_range[0], x_range[1], n_points)
    y = np.linspace(y_range[0], y_range[1], n_points)
    X, Y = np.meshgrid(x, y)

    # Convert grid points to torch tensor
    points = torch.tensor(np.stack([X.flatten(), Y.flatten()], axis=1))

    # Evaluate likelihood at each point
    with torch.no_grad():
        Z = density_fn(points).numpy()
    Z = Z.reshape(n_points, n_points)

    # Create heatmap
    fig = go.Figure(
        data=[
            go.Contour(
                x=x,
                y=y,
                z=Z,
                colorscale=[
                    [0, "rgba(255,255,255,0.8)"],
                    [1.0, "rgba(0,0,255,1.0)"],
                ],
                showscale=False,  # hide colorbar
                contours=dict(
                    coloring="fill",
                    showlines=False,  # no contour lines
                    size=0.025,
                    start=0.0,
                    end=1.0,
                ),
            )
        ]
    )

    # Update layout
    fig.update_layout(
        width=800,
        height=800,
        autosize=False,
        plot_bgcolor="white",
        paper_bgcolor="white",
        showlegend=False,
        xaxis=dict(showgrid=False, zeroline=False, title="X"),
        yaxis=dict(showgrid=False, zeroline=False, title="Y"),
    )

    return fig

In [64]:
# Create Gaussian mixture model for sampling etc.
n_means = 10
# thetas = torch.linspace(0, 2 * torch.pi, n_means)
means = torch.randn(n_means, 2)
# means = torch.stack([torch.sin(thetas), torch.cos(thetas)], dim=-1)
covs = torch.stack(
    [
        # torch.diag(torch.tensor([0.8, 1.0])),
        # torch.diag(torch.tensor([0.7, 0.5])),
        # torch.diag(torch.tensor([1.8, 1.2])),
        1e-1 * np.random.rand() * torch.eye(2)
        for _ in range(len(means))
    ]
)
print(means.shape, covs.shape)
data_dist = GaussianMixtureModel(
    means,
    covs,
    # torch.stack([0.01 * torch.eye(2)] * len(means)),
    torch.tensor([1 / len(means)] * len(means)),
)


# Draw samples from distribution.
samples = data_dist.sample(500)

# Plot distributions and the samples.
fig = plot_density_heatmap(data_dist.likelihood)
fig.add_trace(go.Scatter(x=samples[:, 0], y=samples[:, 1], mode="markers"))
fig.show()

torch.Size([10, 2]) torch.Size([10, 2, 2])


In [101]:
def train(
    model,
    data,
    val_data,
    num_epochs,
    batch_size,
    learning_rate=2e-3,
):
    optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)
    n_samples = len(data)

    for epoch in range(num_epochs):
        # Shuffle indices
        indices = np.random.permutation(n_samples)
        data = data[indices]

        total_loss = 0
        num_batches = ceil(n_samples / batch_size)

        for i in range(num_batches):
            # Get batch
            start_idx = i * batch_size
            end_idx = start_idx + batch_size
            x_batch = data[start_idx:end_idx]

            loss = model.is_loss(x_batch)

            # Backward pass
            optimizer.zero_grad()
            loss.backward()
            # torch.nn.utils.clip_grad_norm_(model.parameters(), 1e2)
            optimizer.step()

            total_loss += loss.item() * len(x_batch)

        avg_loss = total_loss / n_samples
        val_loss = evaluate(model, val_data, batch_size)
        print(
            f"Epoch [{epoch+1}/{num_epochs}] Loss: {avg_loss:.4f}, Val: {val_loss:.4f}"
        )


def evaluate(model, val_data, batch_size):
    model.eval()
    total_loss = 0
    n_samples = len(val_data)
    num_batches = ceil(n_samples / batch_size)

    with torch.no_grad():
        for i in range(num_batches):
            start_idx = i * batch_size
            end_idx = start_idx + batch_size
            x_batch = val_data[start_idx:end_idx]

            loss = model.is_loss(x_batch)
            total_loss += loss.item() * len(x_batch)

    return total_loss / n_samples

In [116]:
# Split data.
train_frac = 0.8
num_train = int(train_frac * len(samples))
train_data = samples[:num_train]
val_data = samples[num_train:]

num_epochs = 2500
batch_size = 1024

config = VAEConfig(
    input_dim=2,
    latent_dim=2,
    encoder_layers=(
        128,
        128,
    ),
    decoder_layers=(
        128,
        128,
    ),
)
vae = VAE(config, train_data.mean(0), train_data.var(0))

train(vae, train_data, val_data, num_epochs, batch_size, 1e-3)

Epoch [1/2500] Loss: 2.9381, Val: 2.7259
Epoch [2/2500] Loss: 2.9516, Val: 2.7156
Epoch [3/2500] Loss: 2.9448, Val: 2.7394
Epoch [4/2500] Loss: 2.9442, Val: 2.7326
Epoch [5/2500] Loss: 2.9430, Val: 2.7781
Epoch [6/2500] Loss: 2.9585, Val: 2.7760
Epoch [7/2500] Loss: 2.9384, Val: 2.7252
Epoch [8/2500] Loss: 2.9436, Val: 2.7937
Epoch [9/2500] Loss: 3.0047, Val: 2.7528
Epoch [10/2500] Loss: 3.0179, Val: 2.8488
Epoch [11/2500] Loss: 3.0496, Val: 2.8593
Epoch [12/2500] Loss: 3.1055, Val: 2.9367
Epoch [13/2500] Loss: 3.1601, Val: 2.9175
Epoch [14/2500] Loss: 3.2476, Val: 2.9186
Epoch [15/2500] Loss: 3.4290, Val: 3.1566
Epoch [16/2500] Loss: 3.6318, Val: 3.4721
Epoch [17/2500] Loss: 3.8594, Val: 3.6352
Epoch [18/2500] Loss: 4.2094, Val: 3.8921
Epoch [19/2500] Loss: 4.4985, Val: 4.7252
Epoch [20/2500] Loss: 4.9883, Val: 4.9336
Epoch [21/2500] Loss: 5.4875, Val: 5.6084
Epoch [22/2500] Loss: 6.2447, Val: 6.4435
Epoch [23/2500] Loss: 6.9380, Val: 6.7872
Epoch [24/2500] Loss: 8.0433, Val: 8.0041
E

ValueError: Expected parameter loc (Tensor of shape (100, 10, 2)) of distribution Normal(loc: torch.Size([100, 10, 2]), scale: torch.Size([100, 10, 2])) to satisfy the constraint Real(), but found invalid values:
tensor([[[nan, nan],
         [nan, nan],
         [nan, nan],
         ...,
         [nan, nan],
         [nan, nan],
         [nan, nan]],

        [[nan, nan],
         [nan, nan],
         [nan, nan],
         ...,
         [nan, nan],
         [nan, nan],
         [nan, nan]],

        [[nan, nan],
         [nan, nan],
         [nan, nan],
         ...,
         [nan, nan],
         [nan, nan],
         [nan, nan]],

        ...,

        [[nan, nan],
         [nan, nan],
         [nan, nan],
         ...,
         [nan, nan],
         [nan, nan],
         [nan, nan]],

        [[nan, nan],
         [nan, nan],
         [nan, nan],
         ...,
         [nan, nan],
         [nan, nan],
         [nan, nan]],

        [[nan, nan],
         [nan, nan],
         [nan, nan],
         ...,
         [nan, nan],
         [nan, nan],
         [nan, nan]]])

In [114]:
# Define density function of VAE.
def create_vae_density(model: VAE, num_latent_samples: int = 10000) -> torch.Tensor:
    z = torch.randn(num_latent_samples, model.latent_dim)
    mean, var = model.decode(z, unnormalize=True)
    weights = (1 / num_latent_samples) * torch.ones(num_latent_samples)
    return GaussianMixtureModel(mean, torch.diag_embed(var), weights)


posterior_dist = create_vae_density(vae)
fig = plot_density_heatmap(posterior_dist.likelihood)
# fig = plot_density_heatmap(data_dist.likelihood)  #
with torch.no_grad():
    vae_samples = vae.sample(250)
# samples_normalized = model.output_normalizer.normalize(samples)
# fig.add_trace(
#     go.Scatter(x=samples_normalized[:, 0], y=samples_normalized[:, 1], mode="markers")
# )
fig.add_trace(
    go.Scatter(
        x=train_data[:, 0],
        y=train_data[:, 1],
        mode="markers",
        marker=dict(color="Orange"),
    )
)
fig.add_trace(
    go.Scatter(
        x=vae_samples[:, 0],
        y=vae_samples[:, 1],
        mode="markers",
        marker=dict(color="Purple"),
    )
)
fig.show()

In [51]:
vae = model

In [79]:
(
    data_dist.log_likelihood(vae_samples).mean(),
    data_dist.log_likelihood(vae_samples).median(),
)

(tensor(-2.1733), tensor(-1.8406))

In [89]:
z = torch.randn(100000, model.latent_dim)
mean, var = model.decode(z, unnormalize=False)
mean

samples_normalized = model.output_normalizer.normalize(train_data)
z_mean, z_var = model.encode(samples_normalized)
# x_mean, x_var = model.decode(z_mean, unnormalize=False)
# weights = 1 / len(x_mean) * torch.ones(len(x_mean))
weights = 1 / len(z_mean) * torch.ones(len(z_mean))
emp_dist = GaussianMixtureModel(z_mean, torch.diag_embed(z_var), weights)
# emp_dist = GaussianMixtureModel(x_mean, torch.diag_embed(x_var), weights)
fig = plot_density_heatmap(emp_dist.likelihood)
fig.show()

ValueError: The right-most size of value must match event_shape: torch.Size([10000, 1, 2]) vs torch.Size([1]).

### DDPM

In [82]:
from mcmc_experiments.diffusion.ddpm import DDPMConfig, DDPM

config = DDPMConfig(2, num_denoising_steps=1000, hidden_dims=(128, 128, 128, 128))
ddpm = DDPM(config, train_data[0].shape, train_data.mean(0), train_data.var(0))

num_epochs = 10000
batch_size = 1024
train(ddpm, train_data, val_data, num_epochs, batch_size, 2e-4)

Epoch [1/10000] Loss: 0.9140, Val: 1.0277
Epoch [2/10000] Loss: 1.0770, Val: 0.8539
Epoch [3/10000] Loss: 0.9159, Val: 0.8607
Epoch [4/10000] Loss: 0.9847, Val: 1.1195
Epoch [5/10000] Loss: 0.9768, Val: 0.8680
Epoch [6/10000] Loss: 0.9281, Val: 0.9234
Epoch [7/10000] Loss: 0.9836, Val: 0.9828
Epoch [8/10000] Loss: 0.9694, Val: 0.9333
Epoch [9/10000] Loss: 1.0068, Val: 0.9518
Epoch [10/10000] Loss: 1.0167, Val: 0.8310
Epoch [11/10000] Loss: 0.9253, Val: 0.9570
Epoch [12/10000] Loss: 1.0907, Val: 1.0133
Epoch [13/10000] Loss: 0.9329, Val: 0.8774
Epoch [14/10000] Loss: 0.9279, Val: 0.9453
Epoch [15/10000] Loss: 0.8715, Val: 1.0819
Epoch [16/10000] Loss: 0.9718, Val: 0.7864
Epoch [17/10000] Loss: 0.9206, Val: 1.0561
Epoch [18/10000] Loss: 0.8703, Val: 0.9437
Epoch [19/10000] Loss: 0.9653, Val: 0.9150
Epoch [20/10000] Loss: 0.8957, Val: 0.6967
Epoch [21/10000] Loss: 0.8583, Val: 0.8832
Epoch [22/10000] Loss: 0.8120, Val: 0.6643
Epoch [23/10000] Loss: 0.8792, Val: 0.9182
Epoch [24/10000] Los

In [54]:
new_samples, X = model.sample(250, return_trajectories=True)
new_samples = new_samples.detach().numpy()
X = X.detach().numpy()

fig = plot_density_heatmap(data_dist.likelihood)  # , x_range=(-6, 6), y_range=(-6, 6))
fig.add_trace(
    go.Scatter(
        x=train_data[:, 0],
        y=train_data[:, 1],
        mode="markers",
        marker=dict(color="Orange"),
    )
)
fig.add_trace(
    go.Scatter(
        x=new_samples[:, 0],
        y=new_samples[:, 1],
        mode="markers",
        marker=dict(color="Purple"),
    )
)


# for i in range(len(new_samples)):
#     fig.add_trace(
#         go.Scatter(x=X[i, :, 0], y=X[i, :, 1], mode="lines")
#     )  # Fixed indexing

fig.update_xaxes(range=(-2.5, 2.5))
fig.update_yaxes(range=(-2.5, 2.5))
# fig.update_layout(width=800, height=800)
fig.show()

In [55]:
ddpm_samples, X = model.sample(250, return_trajectories=True)
(
    data_dist.log_likelihood(ddpm_samples).mean(),
    data_dist.log_likelihood(ddpm_samples).median(),
)

(tensor(0.1447, grad_fn=<MeanBackward0>),
 tensor(1.2871, grad_fn=<MedianBackward0>))

In [None]:
X[:, 0, :], new_samples

(array([[0.1222567, 0.8462161],
        [0.1222567, 0.8462161],
        [0.1222567, 0.8462161],
        [0.1222567, 0.8462161],
        [0.1222567, 0.8462161],
        [0.1222567, 0.8462161],
        [0.1222567, 0.8462161],
        [0.1222567, 0.8462161],
        [0.1222567, 0.8462161],
        [0.1222567, 0.8462161]], dtype=float32),
 array([[ 1.2257783 ,  0.9363651 ],
        [ 0.10902694, -0.07706431],
        [ 0.96844035,  0.14769363],
        [ 1.1357272 ,  0.00247806],
        [-0.01460606,  0.13346389],
        [ 0.94328547,  0.24063401],
        [ 1.0010737 ,  0.91146445],
        [-0.05305451,  0.01498097],
        [ 1.062444  ,  0.05105248],
        [ 0.1222567 ,  0.8462161 ]], dtype=float32))

In [None]:
import plotly.figure_factory as ff

lb = -1
ub = 2
step = (ub - lb) / 20

x, y = torch.meshgrid(
    torch.arange(lb, ub, step), torch.arange(lb, ub, step), indexing="ij"
)
points = torch.stack([x, y], dim=-1).reshape(-1, 2)

t = 0 * torch.ones(points.shape[0])

noise_pred = -model(points, t).detach().numpy()

fig = ff.create_quiver(points[:, 0], points[:, 1], noise_pred[:, 0], noise_pred[:, 1])
fig.add_trace(go.Scatter(x=train_data[:, 0], y=train_data[:, 1], mode="markers"))
fig.update_xaxes(range=(lb, ub))
fig.update_yaxes(range=(lb, ub))
fig.update_layout(width=800, height=800)
fig.show()

NotImplementedError: Module [DDPM] is missing the required "forward" function

In [None]:
data_normalized = model.data_normalizer.normalize(train_data)
t = 0 * torch.ones(data_normalized.shape[0])
noise_pred = model.ddpm_block(data_normalized, t).detach().numpy()

fig = ff.create_quiver(
    data_normalized[:, 0], data_normalized[:, 1], noise_pred[:, 0], noise_pred[:, 1]
)
fig.add_trace(
    go.Scatter(x=data_normalized[:, 0], y=data_normalized[:, 1], mode="markers")
)
fig.update_xaxes(range=(lb, ub))
fig.update_yaxes(range=(lb, ub))
fig.update_layout(width=800, height=800)
fig.show()

In [None]:
lb = -2
ub = 2
step = (ub - lb) / 20
x, y = torch.meshgrid(
    torch.arange(lb, ub, step), torch.arange(lb, ub, step), indexing="ij"
)
points = torch.stack([x, y], dim=-1).reshape(-1, 2)

t = 20 * torch.ones(points.shape[0])
data_normalized = model.data_normalizer.normalize(train_data)
noise_pred = -model.ddpm_block(points, t).detach().numpy()

fig = ff.create_quiver(points[:, 0], points[:, 1], noise_pred[:, 0], noise_pred[:, 1])

fig.add_trace(
    go.Scatter(x=data_normalized[:, 0], y=data_normalized[:, 1], mode="markers")
)
fig.update_xaxes(range=(lb, ub))
fig.update_yaxes(range=(lb, ub))
fig.update_layout(width=800, height=800)
fig.show()

In [None]:
model.ddpm_block(
    model.data_normalizer.normalize(train_data), 0 * torch.ones(train_data.shape[0])
)

tensor([[-1.3727e-02, -9.2505e-02],
        [ 6.2618e-02,  2.6951e-02],
        [-5.3746e-02,  9.5031e-02],
        [-2.0816e-01,  1.1859e-01],
        [-5.6162e-02,  9.0789e-02],
        [-1.0923e-01,  7.7159e-04],
        [-7.2920e-02,  1.3994e-01],
        [ 4.0569e-02,  1.2358e-01],
        [ 2.5908e-02, -6.2648e-04],
        [ 2.9125e-02, -7.1843e-02],
        [-6.5598e-02, -1.2196e-01],
        [ 4.0274e-02, -5.2204e-02],
        [-1.0559e-01, -2.1631e-01],
        [-5.2892e-02,  1.0017e-01],
        [-1.1566e-01, -9.6052e-02],
        [ 1.6169e-01,  1.6975e-01],
        [ 5.3073e-02,  3.7945e-02],
        [ 7.8000e-02,  1.5930e-01],
        [-9.9258e-02,  2.2979e-01],
        [ 1.7214e-01,  7.8054e-02],
        [-2.8296e-02,  3.6664e-02],
        [ 3.2998e-02,  1.3506e-02],
        [ 1.2355e-01,  6.3572e-02],
        [ 8.1577e-02,  1.5658e-01],
        [ 2.7762e-01,  2.0283e-01],
        [-4.3467e-02, -1.4826e-01],
        [ 5.4561e-02,  8.4657e-02],
        [-8.2769e-02,  6.222

In [None]:
    fig = go.Figure()
    fig.add_trace(
        go.Scatter(x=torch.arange(100), y=model.noise_scheduler.alphas_cumprod[:])
    )

In [None]:
new_samples

array([[  39.24051 ,  -13.038661],
       [   6.459928, -271.2436  ],
       [-136.61372 ,  196.4226  ],
       ...,
       [-130.76932 , -124.14439 ],
       [ -94.151405, -132.7102  ],
       [  24.305527,   23.280384]], dtype=float32)

In [None]:
from torch.nn import functional as F


class TinyNoiseScheduler:
    def __init__(
        self,
        num_timesteps=1000,
        beta_start=0.0001,
        beta_end=0.02,
        beta_schedule="linear",
    ):
        self.num_timesteps = num_timesteps
        if beta_schedule == "linear":
            self.betas = torch.linspace(
                beta_start, beta_end, num_timesteps, dtype=torch.float32
            )
        elif beta_schedule == "quadratic":
            self.betas = (
                torch.linspace(
                    beta_start**0.5, beta_end**0.5, num_timesteps, dtype=torch.float32
                )
                ** 2
            )

        self.alphas = 1.0 - self.betas
        self.alphas_cumprod = torch.cumprod(self.alphas, axis=0)
        self.alphas_cumprod_prev = F.pad(self.alphas_cumprod[:-1], (1, 0), value=1.0)

        # required for self.add_noise
        self.sqrt_alphas_cumprod = self.alphas_cumprod**0.5
        self.sqrt_one_minus_alphas_cumprod = (1 - self.alphas_cumprod) ** 0.5

        # required for reconstruct_x0
        self.sqrt_inv_alphas_cumprod = torch.sqrt(1 / self.alphas_cumprod)
        self.sqrt_inv_alphas_cumprod_minus_one = torch.sqrt(1 / self.alphas_cumprod - 1)

        # required for q_posterior
        self.posterior_mean_coef1 = (
            self.betas
            * torch.sqrt(self.alphas_cumprod_prev)
            / (1.0 - self.alphas_cumprod)
        )
        self.posterior_mean_coef2 = (
            (1.0 - self.alphas_cumprod_prev)
            * torch.sqrt(self.alphas)
            / (1.0 - self.alphas_cumprod)
        )


tns = TinyNoiseScheduler()
(tns.alphas_cumprod - model.noise_scheduler.alphas_cumprod).abs().sum()

tensor(0.)