In [1]:
from math import ceil

import numpy as np
import plotly.express as px
import torch
from plotly import graph_objects as go

from mcmc_experiments.distributions import GaussianMixtureModel
from mcmc_experiments.vae import VAEConfig, VAE, elbo_loss_vae

%load_ext autoreload
%autoreload 2

In [55]:
def plot_density_surface(data_dist, x_range=(-1, 2), y_range=(-1, 2), n_points=100):
    """
    Create a surface plot of a 2D density function using Plotly.

    Parameters:
    -----------
    data_dist : object
        Object with a likelihood method that takes a torch.Tensor of shape (N, 2)
        and returns likelihoods for each point
    x_range : tuple
        (min, max) for x axis
    y_range : tuple
        (min, max) for y axis
    n_points : int
        Number of points along each axis for the grid

    Returns:
    --------
    plotly.graph_objects.Figure
    """
    # Create grid of points
    x = np.linspace(x_range[0], x_range[1], n_points)
    y = np.linspace(y_range[0], y_range[1], n_points)
    X, Y = np.meshgrid(x, y)

    # Convert grid points to torch tensor
    points = torch.tensor(np.stack([X.flatten(), Y.flatten()], axis=1))

    # Evaluate likelihood at each point
    with torch.no_grad():
        Z = data_dist.likelihood(points).numpy()
    Z = Z.reshape(n_points, n_points)

    # Create surface plot
    fig = go.Figure(
        data=[go.Surface(x=X, y=Y, z=Z, colorscale="Viridis", showscale=True)]
    )

    # Update layout
    fig.update_layout(
        title="Density Function",
        scene=dict(xaxis_title="X", yaxis_title="Y", zaxis_title="Likelihood"),
        width=800,
        height=800,
        autosize=False,
    )

    return fig


def plot_density_contour(data_dist, x_range=(-1, 2), y_range=(-1, 2), n_points=100):
    """
    Create a contour plot of a 2D density function using Plotly.

    Parameters:
    -----------
    data_dist : object
        Object with a likelihood method that takes a torch.Tensor of shape (N, 2)
        and returns likelihoods for each point
    x_range : tuple
        (min, max) for x axis
    y_range : tuple
        (min, max) for y axis
    n_points : int
        Number of points along each axis for the grid

    Returns:
    --------
    plotly.graph_objects.Figure
    """
    # Create grid of points
    x = np.linspace(x_range[0], x_range[1], n_points)
    y = np.linspace(y_range[0], y_range[1], n_points)
    X, Y = np.meshgrid(x, y)

    # Convert grid points to torch tensor
    points = torch.tensor(np.stack([X.flatten(), Y.flatten()], axis=1))

    # Evaluate likelihood at each point
    with torch.no_grad():
        Z = data_dist.likelihood(points).numpy()
    Z = Z.reshape(n_points, n_points)

    # Create contour plot
    fig = go.Figure(
        data=[
            go.Contour(
                x=x,
                y=y,
                z=Z,
                colorscale="Viridis",
                contours=dict(
                    showlabels=True,  # show labels on contours
                    labelfont=dict(size=12, color="white"),
                ),
                colorbar=dict(title="Likelihood", titleside="right"),
            )
        ]
    )

    # Update layout
    fig.update_layout(
        title="Density Function",
        xaxis_title="X",
        yaxis_title="Y",
        width=800,
        height=800,
        autosize=False,
    )

    return fig


def plot_density_heatmap(density_fn, x_range=(-2, 2), y_range=(-2, 2), n_points=100):
    """
    Create a single-color density heatmap using Plotly.

    Parameters:
    -----------
    data_dist : object
        Object with a likelihood method that takes a torch.Tensor of shape (N, 2)
        and returns likelihoods for each point
    x_range : tuple
        (min, max) for x axis
    y_range : tuple
        (min, max) for y axis
    n_points : int
        Number of points along each axis for the grid

    Returns:
    --------
    plotly.graph_objects.Figure
    """
    # Create grid of points
    x = np.linspace(x_range[0], x_range[1], n_points)
    y = np.linspace(y_range[0], y_range[1], n_points)
    X, Y = np.meshgrid(x, y)

    # Convert grid points to torch tensor
    points = torch.tensor(np.stack([X.flatten(), Y.flatten()], axis=1))

    # Evaluate likelihood at each point
    with torch.no_grad():
        Z = density_fn(points).numpy()
    Z = Z.reshape(n_points, n_points)

    # Create heatmap
    fig = go.Figure(
        data=[
            go.Contour(
                x=x,
                y=y,
                z=Z,
                colorscale=[
                    [0, "rgba(255,255,255,0.8)"],
                    [1.0, "rgba(0,0,255,1.0)"],
                ],
                showscale=False,  # hide colorbar
                contours=dict(
                    coloring="fill",
                    showlines=False,  # no contour lines
                    size=0.025,
                    start=0.0,
                    end=1.0,
                ),
            )
        ]
    )

    # Update layout
    fig.update_layout(
        width=800,
        height=800,
        autosize=False,
        plot_bgcolor="white",
        paper_bgcolor="white",
        showlegend=False,
        xaxis=dict(showgrid=False, zeroline=False, title="X"),
        yaxis=dict(showgrid=False, zeroline=False, title="Y"),
    )

    return fig

In [56]:
# Create Gaussian mixture model for sampling etc.
data_dist = GaussianMixtureModel(
    torch.tensor(
        [
            [0.0, 0.0],
            [1.0, 1.0],
        ]
    ),
    torch.stack([0.1 * torch.eye(2), 0.1 * torch.eye(2)]),
    torch.tensor([0.5, 0.5]),
)

# Draw samples from distribution.
samples = data_dist.sample(100)

# Plot distributions and the samples.
fig = plot_density_heatmap(data_dist.likelihood)
fig.add_trace(go.Scatter(x=samples[:, 0], y=samples[:, 1], mode="markers"))
fig.show()

In [4]:
def train_vae(model, data, val_data, num_epochs, batch_size, learning_rate=3e-4):
    optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)
    n_samples = len(data)

    for epoch in range(num_epochs):
        # Shuffle indices
        indices = np.random.permutation(n_samples)
        data = data[indices]

        total_loss = 0
        num_batches = ceil(n_samples / batch_size)

        for i in range(num_batches):
            # Get batch
            start_idx = i * batch_size
            end_idx = start_idx + batch_size
            x_batch = data[start_idx:end_idx]

            loss = elbo_loss_vae(model, x_batch)

            # Backward pass
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            total_loss += loss.item()

        avg_loss = total_loss / n_samples
        val_loss = evaluate(model, val_data, batch_size)
        print(
            f"Epoch [{epoch+1}/{num_epochs}] Loss: {avg_loss:.4f}, Val: {val_loss:.4f}"
        )


def evaluate(model, val_data, batch_size):
    model.eval()
    total_loss = 0
    n_samples = len(val_data)
    num_batches = ceil(n_samples / batch_size)

    with torch.no_grad():
        for i in range(num_batches):
            start_idx = i * batch_size
            end_idx = start_idx + batch_size
            x_batch = val_data[start_idx:end_idx]

            loss = elbo_loss_vae(model, x_batch)
            total_loss += loss.item()

    return total_loss / n_samples

In [51]:
# Split data.
train_frac = 0.8
num_train = int(train_frac * len(samples))
train_data = samples[:num_train]
val_data = samples[num_train:]

num_epochs = 2500
batch_size = 32

config = VAEConfig(
    input_dim=2, latent_dim=1, encoder_layers=(64, 64), decoder_layers=(64, 64)
)

model = VAE(config, train_data.mean(0), train_data.var(0))

train_vae(model, train_data, val_data, num_epochs, batch_size)

Epoch [1/2500] Loss: 0.1098, Val: 0.1539
Epoch [2/2500] Loss: 0.1099, Val: 0.1549
Epoch [3/2500] Loss: 0.1082, Val: 0.1523
Epoch [4/2500] Loss: 0.1075, Val: 0.1517
Epoch [5/2500] Loss: 0.1074, Val: 0.1503
Epoch [6/2500] Loss: 0.1073, Val: 0.1499
Epoch [7/2500] Loss: 0.1074, Val: 0.1525
Epoch [8/2500] Loss: 0.1077, Val: 0.1513
Epoch [9/2500] Loss: 0.1070, Val: 0.1506
Epoch [10/2500] Loss: 0.1056, Val: 0.1513
Epoch [11/2500] Loss: 0.1070, Val: 0.1507
Epoch [12/2500] Loss: 0.1055, Val: 0.1492
Epoch [13/2500] Loss: 0.1064, Val: 0.1505
Epoch [14/2500] Loss: 0.1060, Val: 0.1511
Epoch [15/2500] Loss: 0.1056, Val: 0.1492
Epoch [16/2500] Loss: 0.1070, Val: 0.1501
Epoch [17/2500] Loss: 0.1073, Val: 0.1505
Epoch [18/2500] Loss: 0.1070, Val: 0.1494
Epoch [19/2500] Loss: 0.1065, Val: 0.1496
Epoch [20/2500] Loss: 0.1056, Val: 0.1509
Epoch [21/2500] Loss: 0.1062, Val: 0.1503
Epoch [22/2500] Loss: 0.1063, Val: 0.1504
Epoch [23/2500] Loss: 0.1071, Val: 0.1493
Epoch [24/2500] Loss: 0.1051, Val: 0.1504
E

In [57]:
# Define density function of VAE.
def create_vae_density(model: VAE, num_latent_samples: int = 1000) -> torch.Tensor:
    z = torch.randn(num_latent_samples, model.latent_dim)
    mean, var = model.decode(z, unnormalize=False)
    weights = (1 / num_latent_samples) * torch.ones(num_latent_samples)
    return GaussianMixtureModel(mean, torch.diag_embed(var), weights)


posterior_dist = create_vae_density(model)
fig = plot_density_heatmap(posterior_dist.likelihood)
samples_normalized = model.output_normalizer.normalize(samples)
fig.add_trace(
    go.Scatter(x=samples_normalized[:, 0], y=samples_normalized[:, 1], mode="markers")
)
# fig.add_trace(
#     go.Scatter(
#         x=train_data[:, 0],
#         y=train_data[:, 1],
#         mode="markers",
#         marker=dict(color="Orange"),
#     )
# )
# fig.add_trace(
#     go.Scatter(
#         x=val_data[:, 0], y=val_data[:, 1], mode="markers", marker=dict(color="Purple")
#     )
# )
fig.show()

In [58]:
posterior_dist.likelihood(samples_normalized)

tensor([0.1311, 0.0878, 0.1201, 0.3166, 0.1485, 0.1149, 0.2198, 0.0749, 0.1386,
        0.1384, 0.1476, 0.1841, 0.2600, 0.0757, 0.2874, 0.1211, 0.0164, 0.0512,
        0.0164, 0.0409, 0.1219, 0.0693, 0.1346, 0.0608, 0.0956, 0.1271, 0.0202,
        0.0050, 0.1162, 0.0942, 0.1827, 0.1885, 0.1322, 0.1025, 0.0322, 0.1054,
        0.1088, 0.1493, 0.0383, 0.1511, 0.1245, 0.0857, 0.0044, 0.0643, 0.2286,
        0.0743, 0.1065, 0.1435, 0.2749, 0.0678, 0.0719, 0.3140, 0.2145, 0.2896,
        0.2011, 0.1797, 0.1123, 0.1032, 0.0589, 0.0892, 0.1499, 0.1787, 0.2951,
        0.1107, 0.1520, 0.1072, 0.1111, 0.0346, 0.2466, 0.0414, 0.2632, 0.2409,
        0.0542, 0.0917, 0.1224, 0.1287, 0.1592, 0.0270, 0.3550, 0.1686, 0.1280,
        0.2216, 0.2075, 0.2811, 0.1172, 0.0195, 0.0302, 0.1772, 0.1784, 0.1205,
        0.0822, 0.1011, 0.2376, 0.1207, 0.2810, 0.2078, 0.1241, 0.0755, 0.1490,
        0.0977], grad_fn=<ExpBackward0>)

In [59]:
elbo_loss_vae(model, train_data) / len(train_data)

tensor(0.0277, grad_fn=<DivBackward0>)

In [60]:
z = torch.randn(100, model.latent_dim)
mean, var = model.decode(z, unnormalize=False)
mean

z_mean, z_var = model.encode(samples_normalized)
x_mean, x_var = model.decode(z_mean, unnormalize=False)
weights = 1 / len(x_mean) * torch.ones(len(x_mean))
emp_dist = GaussianMixtureModel(x_mean, torch.diag_embed(x_var), weights)
fig = plot_density_heatmap(emp_dist.likelihood)
fig.add_trace(
    go.Scatter(x=samples_normalized[:, 0], y=samples_normalized[:, 1], mode="markers")
)
fig.show()