In [59]:
from math import ceil

import numpy as np
import plotly.express as px
import torch
from plotly import graph_objects as go

from mcmc_experiments.distributions import GaussianMixtureModel
from mcmc_experiments.vae import VAEConfig, VAE, elbo_loss_vae

%load_ext autoreload
%autoreload 2

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [60]:
def plot_density_surface(data_dist, x_range=(-1, 2), y_range=(-1, 2), n_points=100):
    """
    Create a surface plot of a 2D density function using Plotly.

    Parameters:
    -----------
    data_dist : object
        Object with a likelihood method that takes a torch.Tensor of shape (N, 2)
        and returns likelihoods for each point
    x_range : tuple
        (min, max) for x axis
    y_range : tuple
        (min, max) for y axis
    n_points : int
        Number of points along each axis for the grid

    Returns:
    --------
    plotly.graph_objects.Figure
    """
    # Create grid of points
    x = np.linspace(x_range[0], x_range[1], n_points)
    y = np.linspace(y_range[0], y_range[1], n_points)
    X, Y = np.meshgrid(x, y)

    # Convert grid points to torch tensor
    points = torch.tensor(np.stack([X.flatten(), Y.flatten()], axis=1))

    # Evaluate likelihood at each point
    with torch.no_grad():
        Z = data_dist.likelihood(points).numpy()
    Z = Z.reshape(n_points, n_points)

    # Create surface plot
    fig = go.Figure(
        data=[go.Surface(x=X, y=Y, z=Z, colorscale="Viridis", showscale=True)]
    )

    # Update layout
    fig.update_layout(
        title="Density Function",
        scene=dict(xaxis_title="X", yaxis_title="Y", zaxis_title="Likelihood"),
        width=800,
        height=800,
        autosize=False,
    )

    return fig


def plot_density_contour(data_dist, x_range=(-1, 2), y_range=(-1, 2), n_points=100):
    """
    Create a contour plot of a 2D density function using Plotly.

    Parameters:
    -----------
    data_dist : object
        Object with a likelihood method that takes a torch.Tensor of shape (N, 2)
        and returns likelihoods for each point
    x_range : tuple
        (min, max) for x axis
    y_range : tuple
        (min, max) for y axis
    n_points : int
        Number of points along each axis for the grid

    Returns:
    --------
    plotly.graph_objects.Figure
    """
    # Create grid of points
    x = np.linspace(x_range[0], x_range[1], n_points)
    y = np.linspace(y_range[0], y_range[1], n_points)
    X, Y = np.meshgrid(x, y)

    # Convert grid points to torch tensor
    points = torch.tensor(np.stack([X.flatten(), Y.flatten()], axis=1))

    # Evaluate likelihood at each point
    with torch.no_grad():
        Z = data_dist.likelihood(points).numpy()
    Z = Z.reshape(n_points, n_points)

    # Create contour plot
    fig = go.Figure(
        data=[
            go.Contour(
                x=x,
                y=y,
                z=Z,
                colorscale="Viridis",
                contours=dict(
                    showlabels=True,  # show labels on contours
                    labelfont=dict(size=12, color="white"),
                ),
                colorbar=dict(title="Likelihood", titleside="right"),
            )
        ]
    )

    # Update layout
    fig.update_layout(
        title="Density Function",
        xaxis_title="X",
        yaxis_title="Y",
        width=800,
        height=800,
        autosize=False,
    )

    return fig


def plot_density_heatmap(density_fn, x_range=(-2, 2), y_range=(-2, 2), n_points=100):
    """
    Create a single-color density heatmap using Plotly.

    Parameters:
    -----------
    data_dist : object
        Object with a likelihood method that takes a torch.Tensor of shape (N, 2)
        and returns likelihoods for each point
    x_range : tuple
        (min, max) for x axis
    y_range : tuple
        (min, max) for y axis
    n_points : int
        Number of points along each axis for the grid

    Returns:
    --------
    plotly.graph_objects.Figure
    """
    # Create grid of points
    x = np.linspace(x_range[0], x_range[1], n_points)
    y = np.linspace(y_range[0], y_range[1], n_points)
    X, Y = np.meshgrid(x, y)

    # Convert grid points to torch tensor
    points = torch.tensor(np.stack([X.flatten(), Y.flatten()], axis=1))

    # Evaluate likelihood at each point
    with torch.no_grad():
        Z = density_fn(points).numpy()
    Z = Z.reshape(n_points, n_points)

    # Create heatmap
    fig = go.Figure(
        data=[
            go.Contour(
                x=x,
                y=y,
                z=Z,
                colorscale=[
                    [0, "rgba(255,255,255,0.8)"],
                    [1.0, "rgba(0,0,255,1.0)"],
                ],
                showscale=False,  # hide colorbar
                contours=dict(
                    coloring="fill",
                    showlines=False,  # no contour lines
                    size=0.025,
                    start=0.0,
                    end=1.0,
                ),
            )
        ]
    )

    # Update layout
    fig.update_layout(
        width=800,
        height=800,
        autosize=False,
        plot_bgcolor="white",
        paper_bgcolor="white",
        showlegend=False,
        xaxis=dict(showgrid=False, zeroline=False, title="X"),
        yaxis=dict(showgrid=False, zeroline=False, title="Y"),
    )

    return fig

In [61]:
# Create Gaussian mixture model for sampling etc.
data_dist = GaussianMixtureModel(
    torch.tensor(
        [
            [0.0, 0.0],
            [1.0, 1.0],
        ]
    ),
    torch.stack([0.1 * torch.eye(2), 0.1 * torch.eye(2)]),
    torch.tensor([0.5, 0.5]),
)

# Draw samples from distribution.
samples = data_dist.sample(1000)

# Plot distributions and the samples.
fig = plot_density_heatmap(data_dist.likelihood)
fig.add_trace(go.Scatter(x=samples[:, 0], y=samples[:, 1], mode="markers"))
fig.show()

In [62]:
def train_vae(model, data, val_data, num_epochs, batch_size, learning_rate=3e-4):
    optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)
    n_samples = len(data)

    for epoch in range(num_epochs):
        # Shuffle indices
        indices = np.random.permutation(n_samples)
        data = data[indices]

        total_loss = 0
        num_batches = ceil(n_samples / batch_size)

        for i in range(num_batches):
            # Get batch
            start_idx = i * batch_size
            end_idx = start_idx + batch_size
            x_batch = data[start_idx:end_idx]

            loss = elbo_loss_vae(model, x_batch)

            # Backward pass
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            total_loss += loss.item()

        avg_loss = total_loss / n_samples
        val_loss = evaluate(model, val_data, batch_size)
        print(
            f"Epoch [{epoch+1}/{num_epochs}] Loss: {avg_loss:.4f}, Val: {val_loss:.4f}"
        )


def evaluate(model, val_data, batch_size):
    model.eval()
    total_loss = 0
    n_samples = len(val_data)
    num_batches = ceil(n_samples / batch_size)

    with torch.no_grad():
        for i in range(num_batches):
            start_idx = i * batch_size
            end_idx = start_idx + batch_size
            x_batch = val_data[start_idx:end_idx]

            loss = elbo_loss_vae(model, x_batch)
            total_loss += loss.item()

    return total_loss / n_samples

In [64]:
# Split data.
train_frac = 0.8
num_train = int(train_frac * len(samples))
train_data = samples[:num_train]
val_data = samples[num_train:]

num_epochs = 2500
batch_size = 32

config = VAEConfig(
    input_dim=2, latent_dim=1, encoder_layers=(32, 32), decoder_layers=(32, 32)
)

model = VAE(config, train_data.mean(0), train_data.var(0))

train_vae(model, train_data, val_data, num_epochs, batch_size)

Epoch [1/2500] Loss: 0.0896, Val: 0.1006
Epoch [2/2500] Loss: 0.0888, Val: 0.1005
Epoch [3/2500] Loss: 0.0886, Val: 0.0998
Epoch [4/2500] Loss: 0.0888, Val: 0.1003
Epoch [5/2500] Loss: 0.0889, Val: 0.0993
Epoch [6/2500] Loss: 0.0884, Val: 0.1002
Epoch [7/2500] Loss: 0.0887, Val: 0.1001
Epoch [8/2500] Loss: 0.0884, Val: 0.0985
Epoch [9/2500] Loss: 0.0878, Val: 0.0976
Epoch [10/2500] Loss: 0.0879, Val: 0.1008
Epoch [11/2500] Loss: 0.0880, Val: 0.0984
Epoch [12/2500] Loss: 0.0875, Val: 0.0988
Epoch [13/2500] Loss: 0.0867, Val: 0.0981
Epoch [14/2500] Loss: 0.0876, Val: 0.0983
Epoch [15/2500] Loss: 0.0850, Val: 0.0950
Epoch [16/2500] Loss: 0.0839, Val: 0.0980
Epoch [17/2500] Loss: 0.0817, Val: 0.0946
Epoch [18/2500] Loss: 0.0795, Val: 0.0900
Epoch [19/2500] Loss: 0.0779, Val: 0.0885
Epoch [20/2500] Loss: 0.0767, Val: 0.0902
Epoch [21/2500] Loss: 0.0759, Val: 0.0904
Epoch [22/2500] Loss: 0.0769, Val: 0.0895
Epoch [23/2500] Loss: 0.0751, Val: 0.0851
Epoch [24/2500] Loss: 0.0766, Val: 0.0851
E

KeyboardInterrupt: 

In [65]:
# Define density function of VAE.
def create_vae_density(model: VAE, num_latent_samples: int = 10000) -> torch.Tensor:
    z = torch.randn(num_latent_samples, model.latent_dim)
    mean, var = model.decode(z, unnormalize=True)
    weights = (1 / num_latent_samples) * torch.ones(num_latent_samples)
    return GaussianMixtureModel(mean, torch.diag_embed(var), weights)


posterior_dist = create_vae_density(model)
fig = plot_density_heatmap(posterior_dist.likelihood)
# samples_normalized = model.output_normalizer.normalize(samples)
# fig.add_trace(
#     go.Scatter(x=samples_normalized[:, 0], y=samples_normalized[:, 1], mode="markers")
# )
fig.add_trace(
    go.Scatter(
        x=train_data[:, 0],
        y=train_data[:, 1],
        mode="markers",
        marker=dict(color="Orange"),
    )
)
fig.add_trace(
    go.Scatter(
        x=val_data[:, 0], y=val_data[:, 1], mode="markers", marker=dict(color="Purple")
    )
)
fig.show()

In [None]:
posterior_dist.likelihood(val_data)

tensor([0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00,
        0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00,
        0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00,
        0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00,
        0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00,
        0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00,
        0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00,
        0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00,
        0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00,
        0.0000e+00, 6.0256e-44, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00,
        0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00,
        0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00,
        0.0000e+00, 0.0000e+00, 0.0000e+

In [None]:
elbo_loss_vae(model, val_data) / len(val_data)

tensor(-0.0448, grad_fn=<DivBackward0>)

In [66]:
z = torch.randn(100, model.latent_dim)
mean, var = model.decode(z, unnormalize=False)
mean

samples_normalized = model.output_normalizer.normalize(val_data)
z_mean, z_var = model.encode(samples_normalized)
x_mean, x_var = model.decode(z_mean, unnormalize=False)
weights = 1 / len(x_mean) * torch.ones(len(x_mean))
emp_dist = GaussianMixtureModel(x_mean, torch.diag_embed(x_var), weights)
fig = plot_density_heatmap(emp_dist.likelihood)
fig.add_trace(
    go.Scatter(x=samples_normalized[:, 0], y=samples_normalized[:, 1], mode="markers")
)
fig.show()