In [208]:
from math import ceil

import numpy as np
import plotly.express as px
import torch
from plotly import graph_objects as go

from mcmc_experiments.distributions import GaussianMixtureModel
from mcmc_experiments.vae import VAEConfig, VAE

%load_ext autoreload
%autoreload 2

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [209]:
def plot_density_surface(data_dist, x_range=(-1, 2), y_range=(-1, 2), n_points=100):
    """
    Create a surface plot of a 2D density function using Plotly.

    Parameters:
    -----------
    data_dist : object
        Object with a likelihood method that takes a torch.Tensor of shape (N, 2)
        and returns likelihoods for each point
    x_range : tuple
        (min, max) for x axis
    y_range : tuple
        (min, max) for y axis
    n_points : int
        Number of points along each axis for the grid

    Returns:
    --------
    plotly.graph_objects.Figure
    """
    # Create grid of points
    x = np.linspace(x_range[0], x_range[1], n_points)
    y = np.linspace(y_range[0], y_range[1], n_points)
    X, Y = np.meshgrid(x, y)

    # Convert grid points to torch tensor
    points = torch.tensor(np.stack([X.flatten(), Y.flatten()], axis=1))

    # Evaluate likelihood at each point
    with torch.no_grad():
        Z = data_dist.likelihood(points).numpy()
    Z = Z.reshape(n_points, n_points)

    # Create surface plot
    fig = go.Figure(
        data=[go.Surface(x=X, y=Y, z=Z, colorscale="Viridis", showscale=True)]
    )

    # Update layout
    fig.update_layout(
        title="Density Function",
        scene=dict(xaxis_title="X", yaxis_title="Y", zaxis_title="Likelihood"),
        width=800,
        height=800,
        autosize=False,
    )

    return fig


def plot_density_contour(data_dist, x_range=(-1, 2), y_range=(-1, 2), n_points=100):
    """
    Create a contour plot of a 2D density function using Plotly.

    Parameters:
    -----------
    data_dist : object
        Object with a likelihood method that takes a torch.Tensor of shape (N, 2)
        and returns likelihoods for each point
    x_range : tuple
        (min, max) for x axis
    y_range : tuple
        (min, max) for y axis
    n_points : int
        Number of points along each axis for the grid

    Returns:
    --------
    plotly.graph_objects.Figure
    """
    # Create grid of points
    x = np.linspace(x_range[0], x_range[1], n_points)
    y = np.linspace(y_range[0], y_range[1], n_points)
    X, Y = np.meshgrid(x, y)

    # Convert grid points to torch tensor
    points = torch.tensor(np.stack([X.flatten(), Y.flatten()], axis=1))

    # Evaluate likelihood at each point
    with torch.no_grad():
        Z = data_dist.likelihood(points).numpy()
    Z = Z.reshape(n_points, n_points)

    # Create contour plot
    fig = go.Figure(
        data=[
            go.Contour(
                x=x,
                y=y,
                z=Z,
                colorscale="Viridis",
                contours=dict(
                    showlabels=True,  # show labels on contours
                    labelfont=dict(size=12, color="white"),
                ),
                colorbar=dict(title="Likelihood", titleside="right"),
            )
        ]
    )

    # Update layout
    fig.update_layout(
        title="Density Function",
        xaxis_title="X",
        yaxis_title="Y",
        width=800,
        height=800,
        autosize=False,
    )

    return fig


def plot_density_heatmap(density_fn, x_range=(-2, 2), y_range=(-2, 2), n_points=100):
    """
    Create a single-color density heatmap using Plotly.

    Parameters:
    -----------
    data_dist : object
        Object with a likelihood method that takes a torch.Tensor of shape (N, 2)
        and returns likelihoods for each point
    x_range : tuple
        (min, max) for x axis
    y_range : tuple
        (min, max) for y axis
    n_points : int
        Number of points along each axis for the grid

    Returns:
    --------
    plotly.graph_objects.Figure
    """
    # Create grid of points
    x = np.linspace(x_range[0], x_range[1], n_points)
    y = np.linspace(y_range[0], y_range[1], n_points)
    X, Y = np.meshgrid(x, y)

    # Convert grid points to torch tensor
    points = torch.tensor(np.stack([X.flatten(), Y.flatten()], axis=1))

    # Evaluate likelihood at each point
    with torch.no_grad():
        Z = density_fn(points).numpy()
    Z = Z.reshape(n_points, n_points)

    # Create heatmap
    fig = go.Figure(
        data=[
            go.Contour(
                x=x,
                y=y,
                z=Z,
                colorscale=[
                    [0, "rgba(255,255,255,0.8)"],
                    [1.0, "rgba(0,0,255,1.0)"],
                ],
                showscale=False,  # hide colorbar
                contours=dict(
                    coloring="fill",
                    showlines=False,  # no contour lines
                    size=0.025,
                    start=0.0,
                    end=1.0,
                ),
            )
        ]
    )

    # Update layout
    fig.update_layout(
        width=800,
        height=800,
        autosize=False,
        plot_bgcolor="white",
        paper_bgcolor="white",
        showlegend=False,
        xaxis=dict(showgrid=False, zeroline=False, title="X"),
        yaxis=dict(showgrid=False, zeroline=False, title="Y"),
    )

    return fig

In [285]:
# Create Gaussian mixture model for sampling etc.
data_dist = GaussianMixtureModel(
    torch.tensor(
        [
            [0.0, 0.0],
            [5.0, 5.0],
        ]
    ),
    torch.stack([0.1 * torch.eye(2), 0.1 * torch.eye(2)]),
    torch.tensor([0.5, 0.5]),
)

# Draw samples from distribution.
samples = data_dist.sample(200)

# Plot distributions and the samples.
fig = plot_density_heatmap(data_dist.likelihood)
fig.add_trace(go.Scatter(x=samples[:, 0], y=samples[:, 1], mode="markers"))
fig.show()

In [244]:
def train(model, data, val_data, num_epochs, batch_size, learning_rate=1e-3):
    optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)
    n_samples = len(data)

    for epoch in range(num_epochs):
        # Shuffle indices
        indices = np.random.permutation(n_samples)
        data = data[indices]

        total_loss = 0
        num_batches = ceil(n_samples / batch_size)

        for i in range(num_batches):
            # Get batch
            start_idx = i * batch_size
            end_idx = start_idx + batch_size
            x_batch = data[start_idx:end_idx]

            loss = model.loss(x_batch)

            # Backward pass
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            total_loss += loss.item() * len(x_batch)

        avg_loss = total_loss / n_samples
        val_loss = evaluate(model, val_data, batch_size)
        print(
            f"Epoch [{epoch+1}/{num_epochs}] Loss: {avg_loss:.4f}, Val: {val_loss:.4f}"
        )


def evaluate(model, val_data, batch_size):
    model.eval()
    total_loss = 0
    n_samples = len(val_data)
    num_batches = ceil(n_samples / batch_size)

    with torch.no_grad():
        for i in range(num_batches):
            start_idx = i * batch_size
            end_idx = start_idx + batch_size
            x_batch = val_data[start_idx:end_idx]

            loss = model.loss(x_batch)
            total_loss += loss.item() * len(x_batch)

    return total_loss / n_samples

In [286]:
# Split data.
train_frac = 0.8
num_train = int(train_frac * len(samples))
train_data = samples[:num_train]
val_data = samples[num_train:]

num_epochs = 100
batch_size = 32

config = VAEConfig(
    input_dim=2, latent_dim=1, encoder_layers=(32, 32), decoder_layers=(32, 32)
)

model = VAE(config, train_data.mean(0), train_data.var(0))

train(model, train_data, val_data, num_epochs, batch_size)

Epoch [1/100] Loss: 2.9496, Val: 2.8856
Epoch [2/100] Loss: 2.8772, Val: 2.8586
Epoch [3/100] Loss: 2.8456, Val: 2.8573
Epoch [4/100] Loss: 2.8342, Val: 2.8165
Epoch [5/100] Loss: 2.8366, Val: 2.8397
Epoch [6/100] Loss: 2.8259, Val: 2.8093
Epoch [7/100] Loss: 2.8388, Val: 2.8199
Epoch [8/100] Loss: 2.8364, Val: 2.8275
Epoch [9/100] Loss: 2.8265, Val: 2.8564
Epoch [10/100] Loss: 2.8210, Val: 2.8191
Epoch [11/100] Loss: 2.8292, Val: 2.7873
Epoch [12/100] Loss: 2.8137, Val: 2.8138
Epoch [13/100] Loss: 2.8243, Val: 2.8668
Epoch [14/100] Loss: 2.8257, Val: 2.7753
Epoch [15/100] Loss: 2.8052, Val: 2.7602
Epoch [16/100] Loss: 2.7879, Val: 2.7238
Epoch [17/100] Loss: 2.7928, Val: 2.7683
Epoch [18/100] Loss: 2.7962, Val: 2.8591
Epoch [19/100] Loss: 2.7990, Val: 2.8031
Epoch [20/100] Loss: 2.8117, Val: 2.7937
Epoch [21/100] Loss: 2.7764, Val: 2.7487
Epoch [22/100] Loss: 2.6899, Val: 2.7239
Epoch [23/100] Loss: 2.7608, Val: 2.7879
Epoch [24/100] Loss: 2.6154, Val: 2.3603
Epoch [25/100] Loss: 2.58

In [287]:
# Define density function of VAE.
def create_vae_density(model: VAE, num_latent_samples: int = 10000) -> torch.Tensor:
    z = torch.randn(num_latent_samples, model.latent_dim)
    mean, var = model.decode(z, unnormalize=True)
    weights = (1 / num_latent_samples) * torch.ones(num_latent_samples)
    return GaussianMixtureModel(mean, torch.diag_embed(var), weights)


posterior_dist = create_vae_density(model)
fig = plot_density_heatmap(posterior_dist.likelihood)
# samples_normalized = model.output_normalizer.normalize(samples)
# fig.add_trace(
#     go.Scatter(x=samples_normalized[:, 0], y=samples_normalized[:, 1], mode="markers")
# )
fig.add_trace(
    go.Scatter(
        x=train_data[:, 0],
        y=train_data[:, 1],
        mode="markers",
        marker=dict(color="Orange"),
    )
)
fig.add_trace(
    go.Scatter(
        x=val_data[:, 0], y=val_data[:, 1], mode="markers", marker=dict(color="Purple")
    )
)
fig.show()

In [288]:
posterior_dist.likelihood(val_data)

tensor([0.3284, 0.3190, 0.1917, 0.5119, 0.2643, 0.2934, 0.1485, 0.5106, 0.2776,
        0.1622, 0.4677, 0.2591, 0.2519, 0.4698, 0.1605, 0.0682, 0.0753, 0.1702,
        0.1419, 0.1205, 0.2871, 0.3083, 0.1156, 0.3982, 0.1947, 0.3767, 0.1633,
        0.3194, 0.3326, 0.2117, 0.2418, 0.0831, 0.3608, 0.0058, 0.4869, 0.0913,
        0.2099, 0.3276, 0.4874, 0.3486], grad_fn=<ExpBackward0>)

In [289]:
z = torch.randn(100, model.latent_dim)
mean, var = model.decode(z, unnormalize=False)
mean

samples_normalized = model.output_normalizer.normalize(val_data)
z_mean, z_var = model.encode(samples_normalized)
x_mean, x_var = model.decode(z_mean, unnormalize=False)
weights = 1 / len(x_mean) * torch.ones(len(x_mean))
emp_dist = GaussianMixtureModel(x_mean, torch.diag_embed(x_var), weights)
fig = plot_density_heatmap(emp_dist.likelihood)
fig.add_trace(
    go.Scatter(x=samples_normalized[:, 0], y=samples_normalized[:, 1], mode="markers")
)
fig.show()

### DDPM

In [290]:
from mcmc_experiments.diffusion.ddpm import DDPMConfig, DDPM

config = DDPMConfig(2, num_denoising_steps=100, hidden_dims=(32, 32))
model = DDPM(config, train_data[0].shape, train_data.mean(0), train_data.var(0))

num_epochs = 2500
batch_size = 32
train(model, train_data, val_data, num_epochs, 32, 2e-4)

Epoch [1/2500] Loss: 1.1210, Val: 0.9156
Epoch [2/2500] Loss: 1.0291, Val: 1.0262
Epoch [3/2500] Loss: 1.0633, Val: 1.1967
Epoch [4/2500] Loss: 1.1815, Val: 0.9047
Epoch [5/2500] Loss: 0.9657, Val: 0.9194
Epoch [6/2500] Loss: 0.9384, Val: 0.7672
Epoch [7/2500] Loss: 0.9710, Val: 1.1727
Epoch [8/2500] Loss: 1.0652, Val: 0.8516
Epoch [9/2500] Loss: 0.9374, Val: 0.9857
Epoch [10/2500] Loss: 0.9516, Val: 0.6934
Epoch [11/2500] Loss: 0.8648, Val: 0.9314
Epoch [12/2500] Loss: 1.1169, Val: 1.0037
Epoch [13/2500] Loss: 1.0813, Val: 1.0909
Epoch [14/2500] Loss: 1.1588, Val: 1.2651
Epoch [15/2500] Loss: 0.9533, Val: 0.8029
Epoch [16/2500] Loss: 1.0199, Val: 0.9956
Epoch [17/2500] Loss: 1.0800, Val: 1.2479
Epoch [18/2500] Loss: 0.9056, Val: 1.0992
Epoch [19/2500] Loss: 0.9340, Val: 0.9603
Epoch [20/2500] Loss: 0.9168, Val: 1.0317
Epoch [21/2500] Loss: 1.1253, Val: 0.7653
Epoch [22/2500] Loss: 0.9990, Val: 0.9191
Epoch [23/2500] Loss: 0.8895, Val: 0.8517
Epoch [24/2500] Loss: 0.9293, Val: 0.9434
E

In [292]:
new_samples = model.sample(500).detach().numpy()
fig = go.Figure()
fig.add_trace(go.Scatter(x=new_samples[:, 0], y=new_samples[:, 1], mode="markers"))
fig.add_trace(go.Scatter(x=train_data[:, 0], y=train_data[:, 1], mode="markers"))

# fig.update_xaxes(range=(-1, 3))
# fig.update_yaxes(range=(-1, 3))
fig.update_layout(width=800, height=800)
fig.show()

torch.Size([500, 2]) torch.Size([500, 2])
torch.Size([500, 2]) torch.Size([500, 2])
torch.Size([500, 2]) torch.Size([500, 2])
torch.Size([500, 2]) torch.Size([500, 2])
torch.Size([500, 2]) torch.Size([500, 2])
torch.Size([500, 2]) torch.Size([500, 2])
torch.Size([500, 2]) torch.Size([500, 2])
torch.Size([500, 2]) torch.Size([500, 2])
torch.Size([500, 2]) torch.Size([500, 2])
torch.Size([500, 2]) torch.Size([500, 2])
torch.Size([500, 2]) torch.Size([500, 2])
torch.Size([500, 2]) torch.Size([500, 2])
torch.Size([500, 2]) torch.Size([500, 2])
torch.Size([500, 2]) torch.Size([500, 2])
torch.Size([500, 2]) torch.Size([500, 2])
torch.Size([500, 2]) torch.Size([500, 2])
torch.Size([500, 2]) torch.Size([500, 2])
torch.Size([500, 2]) torch.Size([500, 2])
torch.Size([500, 2]) torch.Size([500, 2])
torch.Size([500, 2]) torch.Size([500, 2])
torch.Size([500, 2]) torch.Size([500, 2])
torch.Size([500, 2]) torch.Size([500, 2])
torch.Size([500, 2]) torch.Size([500, 2])
torch.Size([500, 2]) torch.Size([5

In [None]:
fig = go.Figure()
fig.add_trace(
    go.Scatter(x=torch.arange(100), y=model.noise_scheduler.alphas_cumprod[:])
)

In [258]:
new_samples

array([[  39.24051 ,  -13.038661],
       [   6.459928, -271.2436  ],
       [-136.61372 ,  196.4226  ],
       ...,
       [-130.76932 , -124.14439 ],
       [ -94.151405, -132.7102  ],
       [  24.305527,   23.280384]], dtype=float32)

In [None]:
from torch.nn import functional as F


class TinyNoiseScheduler:
    def __init__(
        self,
        num_timesteps=1000,
        beta_start=0.0001,
        beta_end=0.02,
        beta_schedule="linear",
    ):
        self.num_timesteps = num_timesteps
        if beta_schedule == "linear":
            self.betas = torch.linspace(
                beta_start, beta_end, num_timesteps, dtype=torch.float32
            )
        elif beta_schedule == "quadratic":
            self.betas = (
                torch.linspace(
                    beta_start**0.5, beta_end**0.5, num_timesteps, dtype=torch.float32
                )
                ** 2
            )

        self.alphas = 1.0 - self.betas
        self.alphas_cumprod = torch.cumprod(self.alphas, axis=0)
        self.alphas_cumprod_prev = F.pad(self.alphas_cumprod[:-1], (1, 0), value=1.0)

        # required for self.add_noise
        self.sqrt_alphas_cumprod = self.alphas_cumprod**0.5
        self.sqrt_one_minus_alphas_cumprod = (1 - self.alphas_cumprod) ** 0.5

        # required for reconstruct_x0
        self.sqrt_inv_alphas_cumprod = torch.sqrt(1 / self.alphas_cumprod)
        self.sqrt_inv_alphas_cumprod_minus_one = torch.sqrt(1 / self.alphas_cumprod - 1)

        # required for q_posterior
        self.posterior_mean_coef1 = (
            self.betas
            * torch.sqrt(self.alphas_cumprod_prev)
            / (1.0 - self.alphas_cumprod)
        )
        self.posterior_mean_coef2 = (
            (1.0 - self.alphas_cumprod_prev)
            * torch.sqrt(self.alphas)
            / (1.0 - self.alphas_cumprod)
        )


tns = TinyNoiseScheduler()
(tns.alphas_cumprod - model.noise_scheduler.alphas_cumprod).abs().sum()

tensor(0.)