In [208]:
from math import ceil

import numpy as np
import plotly.express as px
import torch
from plotly import graph_objects as go

from mcmc_experiments.distributions import GaussianMixtureModel
from mcmc_experiments.vae import VAEConfig, VAE

%load_ext autoreload
%autoreload 2

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [209]:
def plot_density_surface(data_dist, x_range=(-1, 2), y_range=(-1, 2), n_points=100):
    """
    Create a surface plot of a 2D density function using Plotly.

    Parameters:
    -----------
    data_dist : object
        Object with a likelihood method that takes a torch.Tensor of shape (N, 2)
        and returns likelihoods for each point
    x_range : tuple
        (min, max) for x axis
    y_range : tuple
        (min, max) for y axis
    n_points : int
        Number of points along each axis for the grid

    Returns:
    --------
    plotly.graph_objects.Figure
    """
    # Create grid of points
    x = np.linspace(x_range[0], x_range[1], n_points)
    y = np.linspace(y_range[0], y_range[1], n_points)
    X, Y = np.meshgrid(x, y)

    # Convert grid points to torch tensor
    points = torch.tensor(np.stack([X.flatten(), Y.flatten()], axis=1))

    # Evaluate likelihood at each point
    with torch.no_grad():
        Z = data_dist.likelihood(points).numpy()
    Z = Z.reshape(n_points, n_points)

    # Create surface plot
    fig = go.Figure(
        data=[go.Surface(x=X, y=Y, z=Z, colorscale="Viridis", showscale=True)]
    )

    # Update layout
    fig.update_layout(
        title="Density Function",
        scene=dict(xaxis_title="X", yaxis_title="Y", zaxis_title="Likelihood"),
        width=800,
        height=800,
        autosize=False,
    )

    return fig


def plot_density_contour(data_dist, x_range=(-1, 2), y_range=(-1, 2), n_points=100):
    """
    Create a contour plot of a 2D density function using Plotly.

    Parameters:
    -----------
    data_dist : object
        Object with a likelihood method that takes a torch.Tensor of shape (N, 2)
        and returns likelihoods for each point
    x_range : tuple
        (min, max) for x axis
    y_range : tuple
        (min, max) for y axis
    n_points : int
        Number of points along each axis for the grid

    Returns:
    --------
    plotly.graph_objects.Figure
    """
    # Create grid of points
    x = np.linspace(x_range[0], x_range[1], n_points)
    y = np.linspace(y_range[0], y_range[1], n_points)
    X, Y = np.meshgrid(x, y)

    # Convert grid points to torch tensor
    points = torch.tensor(np.stack([X.flatten(), Y.flatten()], axis=1))

    # Evaluate likelihood at each point
    with torch.no_grad():
        Z = data_dist.likelihood(points).numpy()
    Z = Z.reshape(n_points, n_points)

    # Create contour plot
    fig = go.Figure(
        data=[
            go.Contour(
                x=x,
                y=y,
                z=Z,
                colorscale="Viridis",
                contours=dict(
                    showlabels=True,  # show labels on contours
                    labelfont=dict(size=12, color="white"),
                ),
                colorbar=dict(title="Likelihood", titleside="right"),
            )
        ]
    )

    # Update layout
    fig.update_layout(
        title="Density Function",
        xaxis_title="X",
        yaxis_title="Y",
        width=800,
        height=800,
        autosize=False,
    )

    return fig


def plot_density_heatmap(density_fn, x_range=(-2, 2), y_range=(-2, 2), n_points=100):
    """
    Create a single-color density heatmap using Plotly.

    Parameters:
    -----------
    data_dist : object
        Object with a likelihood method that takes a torch.Tensor of shape (N, 2)
        and returns likelihoods for each point
    x_range : tuple
        (min, max) for x axis
    y_range : tuple
        (min, max) for y axis
    n_points : int
        Number of points along each axis for the grid

    Returns:
    --------
    plotly.graph_objects.Figure
    """
    # Create grid of points
    x = np.linspace(x_range[0], x_range[1], n_points)
    y = np.linspace(y_range[0], y_range[1], n_points)
    X, Y = np.meshgrid(x, y)

    # Convert grid points to torch tensor
    points = torch.tensor(np.stack([X.flatten(), Y.flatten()], axis=1))

    # Evaluate likelihood at each point
    with torch.no_grad():
        Z = density_fn(points).numpy()
    Z = Z.reshape(n_points, n_points)

    # Create heatmap
    fig = go.Figure(
        data=[
            go.Contour(
                x=x,
                y=y,
                z=Z,
                colorscale=[
                    [0, "rgba(255,255,255,0.8)"],
                    [1.0, "rgba(0,0,255,1.0)"],
                ],
                showscale=False,  # hide colorbar
                contours=dict(
                    coloring="fill",
                    showlines=False,  # no contour lines
                    size=0.025,
                    start=0.0,
                    end=1.0,
                ),
            )
        ]
    )

    # Update layout
    fig.update_layout(
        width=800,
        height=800,
        autosize=False,
        plot_bgcolor="white",
        paper_bgcolor="white",
        showlegend=False,
        xaxis=dict(showgrid=False, zeroline=False, title="X"),
        yaxis=dict(showgrid=False, zeroline=False, title="Y"),
    )

    return fig

In [243]:
# Create Gaussian mixture model for sampling etc.
data_dist = GaussianMixtureModel(
    torch.tensor(
        [
            [0.0, 0.0],
            [1.0, 1.0],
        ]
    ),
    torch.stack([0.1 * torch.eye(2), 0.1 * torch.eye(2)]),
    torch.tensor([0.5, 0.5]),
)

# Draw samples from distribution.
samples = data_dist.sample(200)

# Plot distributions and the samples.
fig = plot_density_heatmap(data_dist.likelihood)
fig.add_trace(go.Scatter(x=samples[:, 0], y=samples[:, 1], mode="markers"))
fig.show()

In [244]:
def train(model, data, val_data, num_epochs, batch_size, learning_rate=1e-3):
    optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)
    n_samples = len(data)

    for epoch in range(num_epochs):
        # Shuffle indices
        indices = np.random.permutation(n_samples)
        data = data[indices]

        total_loss = 0
        num_batches = ceil(n_samples / batch_size)

        for i in range(num_batches):
            # Get batch
            start_idx = i * batch_size
            end_idx = start_idx + batch_size
            x_batch = data[start_idx:end_idx]

            loss = model.loss(x_batch)

            # Backward pass
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            total_loss += loss.item() * len(x_batch)

        avg_loss = total_loss / n_samples
        val_loss = evaluate(model, val_data, batch_size)
        print(
            f"Epoch [{epoch+1}/{num_epochs}] Loss: {avg_loss:.4f}, Val: {val_loss:.4f}"
        )


def evaluate(model, val_data, batch_size):
    model.eval()
    total_loss = 0
    n_samples = len(val_data)
    num_batches = ceil(n_samples / batch_size)

    with torch.no_grad():
        for i in range(num_batches):
            start_idx = i * batch_size
            end_idx = start_idx + batch_size
            x_batch = val_data[start_idx:end_idx]

            loss = model.loss(x_batch)
            total_loss += loss.item() * len(x_batch)

    return total_loss / n_samples

In [245]:
# Split data.
train_frac = 0.8
num_train = int(train_frac * len(samples))
train_data = samples[:num_train]
val_data = samples[num_train:]

num_epochs = 100
batch_size = 32

config = VAEConfig(
    input_dim=2, latent_dim=1, encoder_layers=(32, 32), decoder_layers=(32, 32)
)

model = VAE(config, train_data.mean(0), train_data.var(0))

train(model, train_data, val_data, num_epochs, batch_size)

Epoch [1/100] Loss: 2.8730, Val: 2.9717
Epoch [2/100] Loss: 2.8576, Val: 2.9330
Epoch [3/100] Loss: 2.8441, Val: 2.9149
Epoch [4/100] Loss: 2.8327, Val: 2.8965
Epoch [5/100] Loss: 2.8296, Val: 2.9129
Epoch [6/100] Loss: 2.8291, Val: 2.9078
Epoch [7/100] Loss: 2.8426, Val: 2.8756
Epoch [8/100] Loss: 2.8343, Val: 2.8924
Epoch [9/100] Loss: 2.8301, Val: 2.8995
Epoch [10/100] Loss: 2.8357, Val: 2.8857
Epoch [11/100] Loss: 2.8334, Val: 2.8966
Epoch [12/100] Loss: 2.8343, Val: 2.8869
Epoch [13/100] Loss: 2.8289, Val: 2.8852
Epoch [14/100] Loss: 2.8223, Val: 2.9042
Epoch [15/100] Loss: 2.8341, Val: 2.8853
Epoch [16/100] Loss: 2.8325, Val: 2.8637
Epoch [17/100] Loss: 2.8340, Val: 2.9021
Epoch [18/100] Loss: 2.8318, Val: 2.8903
Epoch [19/100] Loss: 2.8350, Val: 2.8655
Epoch [20/100] Loss: 2.8321, Val: 2.9032
Epoch [21/100] Loss: 2.8345, Val: 2.9039
Epoch [22/100] Loss: 2.8328, Val: 2.8832
Epoch [23/100] Loss: 2.8346, Val: 2.9063
Epoch [24/100] Loss: 2.8384, Val: 2.9023
Epoch [25/100] Loss: 2.83

In [246]:
# Define density function of VAE.
def create_vae_density(model: VAE, num_latent_samples: int = 10000) -> torch.Tensor:
    z = torch.randn(num_latent_samples, model.latent_dim)
    mean, var = model.decode(z, unnormalize=True)
    weights = (1 / num_latent_samples) * torch.ones(num_latent_samples)
    return GaussianMixtureModel(mean, torch.diag_embed(var), weights)


posterior_dist = create_vae_density(model)
fig = plot_density_heatmap(posterior_dist.likelihood)
# samples_normalized = model.output_normalizer.normalize(samples)
# fig.add_trace(
#     go.Scatter(x=samples_normalized[:, 0], y=samples_normalized[:, 1], mode="markers")
# )
fig.add_trace(
    go.Scatter(
        x=train_data[:, 0],
        y=train_data[:, 1],
        mode="markers",
        marker=dict(color="Orange"),
    )
)
fig.add_trace(
    go.Scatter(
        x=val_data[:, 0], y=val_data[:, 1], mode="markers", marker=dict(color="Purple")
    )
)
fig.show()

In [None]:
posterior_dist.likelihood(val_data)

tensor([0.2097, 0.4269, 0.1926, 0.2246, 0.0760, 0.1167, 0.0217, 0.3644, 0.1318,
        0.1708, 0.0904, 0.0929, 0.0834, 0.0612, 0.3517, 0.0429, 0.1105, 0.2997,
        0.0388, 0.4355], grad_fn=<ExpBackward0>)

In [247]:
z = torch.randn(100, model.latent_dim)
mean, var = model.decode(z, unnormalize=False)
mean

samples_normalized = model.output_normalizer.normalize(val_data)
z_mean, z_var = model.encode(samples_normalized)
x_mean, x_var = model.decode(z_mean, unnormalize=False)
weights = 1 / len(x_mean) * torch.ones(len(x_mean))
emp_dist = GaussianMixtureModel(x_mean, torch.diag_embed(x_var), weights)
fig = plot_density_heatmap(emp_dist.likelihood)
fig.add_trace(
    go.Scatter(x=samples_normalized[:, 0], y=samples_normalized[:, 1], mode="markers")
)
fig.show()

### DDPM

In [250]:
from mcmc_experiments.diffusion.ddpm import DDPMConfig, DDPM

config = DDPMConfig(2, num_denoising_steps=1000, hidden_dims=(32, 32, 32, 32, 32))
model = DDPM(config, train_data[0].shape, train_data.mean(0), train_data.var(0))

num_epochs = 10000
batch_size = 32
train(model, train_data, val_data, num_epochs, 32, 2e-4)

Epoch [1/10000] Loss: 1.0572, Val: 0.8510
Epoch [2/10000] Loss: 0.9143, Val: 0.8527
Epoch [3/10000] Loss: 0.9625, Val: 1.1098
Epoch [4/10000] Loss: 1.0690, Val: 1.1719
Epoch [5/10000] Loss: 1.0285, Val: 0.9678
Epoch [6/10000] Loss: 1.0479, Val: 1.0440
Epoch [7/10000] Loss: 0.8828, Val: 0.9441
Epoch [8/10000] Loss: 1.1165, Val: 1.0772
Epoch [9/10000] Loss: 0.8188, Val: 0.9278
Epoch [10/10000] Loss: 0.9428, Val: 0.9002
Epoch [11/10000] Loss: 0.9518, Val: 1.1125
Epoch [12/10000] Loss: 0.8890, Val: 1.0765
Epoch [13/10000] Loss: 0.9557, Val: 1.0683
Epoch [14/10000] Loss: 0.9540, Val: 0.8935
Epoch [15/10000] Loss: 1.0193, Val: 1.0480
Epoch [16/10000] Loss: 0.9870, Val: 0.7140
Epoch [17/10000] Loss: 0.9728, Val: 0.8959
Epoch [18/10000] Loss: 0.8928, Val: 0.8132
Epoch [19/10000] Loss: 0.9956, Val: 0.8694
Epoch [20/10000] Loss: 0.7708, Val: 0.6916
Epoch [21/10000] Loss: 0.9932, Val: 0.9557
Epoch [22/10000] Loss: 0.9293, Val: 1.0798
Epoch [23/10000] Loss: 0.9175, Val: 1.1741
Epoch [24/10000] Los

In [253]:
new_samples = model.sample(1000).detach().numpy()
fig = go.Figure()
fig.add_trace(go.Scatter(x=new_samples[:, 0], y=new_samples[:, 1], mode="markers"))
fig.add_trace(go.Scatter(x=train_data[:, 0], y=train_data[:, 1], mode="markers"))

fig.update_xaxes(range=(-1, 3))
fig.update_yaxes(range=(-1, 3))
fig.update_layout(width=800, height=800)
fig.show()

torch.Size([1000, 2]) torch.Size([1000, 2])
torch.Size([1000, 2]) torch.Size([1000, 2])
torch.Size([1000, 2]) torch.Size([1000, 2])
torch.Size([1000, 2]) torch.Size([1000, 2])
torch.Size([1000, 2]) torch.Size([1000, 2])
torch.Size([1000, 2]) torch.Size([1000, 2])
torch.Size([1000, 2]) torch.Size([1000, 2])
torch.Size([1000, 2]) torch.Size([1000, 2])
torch.Size([1000, 2]) torch.Size([1000, 2])
torch.Size([1000, 2]) torch.Size([1000, 2])
torch.Size([1000, 2]) torch.Size([1000, 2])
torch.Size([1000, 2]) torch.Size([1000, 2])
torch.Size([1000, 2]) torch.Size([1000, 2])
torch.Size([1000, 2]) torch.Size([1000, 2])
torch.Size([1000, 2]) torch.Size([1000, 2])
torch.Size([1000, 2]) torch.Size([1000, 2])
torch.Size([1000, 2]) torch.Size([1000, 2])
torch.Size([1000, 2]) torch.Size([1000, 2])
torch.Size([1000, 2]) torch.Size([1000, 2])
torch.Size([1000, 2]) torch.Size([1000, 2])
torch.Size([1000, 2]) torch.Size([1000, 2])
torch.Size([1000, 2]) torch.Size([1000, 2])
torch.Size([1000, 2]) torch.Size

In [None]:
fig = go.Figure()
fig.add_trace(
    go.Scatter(x=torch.arange(100), y=model.noise_scheduler.alphas_cumprod[:])
)

In [None]:
model.noise_scheduler.alphas

tensor([0.9999, 0.9999, 0.9999, 0.9998, 0.9998, 0.9998, 0.9998, 0.9998, 0.9997,
        0.9997, 0.9997, 0.9997, 0.9997, 0.9996, 0.9996, 0.9996, 0.9996, 0.9996,
        0.9995, 0.9995, 0.9995, 0.9995, 0.9995, 0.9994, 0.9994, 0.9994, 0.9994,
        0.9994, 0.9993, 0.9993, 0.9993, 0.9993, 0.9993, 0.9992, 0.9992, 0.9992,
        0.9992, 0.9992, 0.9991, 0.9991, 0.9991, 0.9991, 0.9991, 0.9990, 0.9990,
        0.9990, 0.9990, 0.9990, 0.9989, 0.9989, 0.9989, 0.9989, 0.9989, 0.9988,
        0.9988, 0.9988, 0.9988, 0.9988, 0.9987, 0.9987, 0.9987, 0.9987, 0.9987,
        0.9986, 0.9986, 0.9986, 0.9986, 0.9986, 0.9985, 0.9985, 0.9985, 0.9985,
        0.9985, 0.9984, 0.9984, 0.9984, 0.9984, 0.9984, 0.9983, 0.9983, 0.9983,
        0.9983, 0.9983, 0.9982, 0.9982, 0.9982, 0.9982, 0.9982, 0.9981, 0.9981,
        0.9981, 0.9981, 0.9981, 0.9980, 0.9980, 0.9980, 0.9980, 0.9980, 0.9979,
        0.9979, 0.9979, 0.9979, 0.9979, 0.9978, 0.9978, 0.9978, 0.9978, 0.9978,
        0.9977, 0.9977, 0.9977, 0.9977, 

In [None]:
from torch.nn import functional as F


class TinyNoiseScheduler:
    def __init__(
        self,
        num_timesteps=1000,
        beta_start=0.0001,
        beta_end=0.02,
        beta_schedule="linear",
    ):
        self.num_timesteps = num_timesteps
        if beta_schedule == "linear":
            self.betas = torch.linspace(
                beta_start, beta_end, num_timesteps, dtype=torch.float32
            )
        elif beta_schedule == "quadratic":
            self.betas = (
                torch.linspace(
                    beta_start**0.5, beta_end**0.5, num_timesteps, dtype=torch.float32
                )
                ** 2
            )

        self.alphas = 1.0 - self.betas
        self.alphas_cumprod = torch.cumprod(self.alphas, axis=0)
        self.alphas_cumprod_prev = F.pad(self.alphas_cumprod[:-1], (1, 0), value=1.0)

        # required for self.add_noise
        self.sqrt_alphas_cumprod = self.alphas_cumprod**0.5
        self.sqrt_one_minus_alphas_cumprod = (1 - self.alphas_cumprod) ** 0.5

        # required for reconstruct_x0
        self.sqrt_inv_alphas_cumprod = torch.sqrt(1 / self.alphas_cumprod)
        self.sqrt_inv_alphas_cumprod_minus_one = torch.sqrt(1 / self.alphas_cumprod - 1)

        # required for q_posterior
        self.posterior_mean_coef1 = (
            self.betas
            * torch.sqrt(self.alphas_cumprod_prev)
            / (1.0 - self.alphas_cumprod)
        )
        self.posterior_mean_coef2 = (
            (1.0 - self.alphas_cumprod_prev)
            * torch.sqrt(self.alphas)
            / (1.0 - self.alphas_cumprod)
        )


tns = TinyNoiseScheduler()
(tns.alphas_cumprod - model.noise_scheduler.alphas_cumprod).abs().sum()

tensor(0.)