# Deep Set Flow Matching Posterior Estimation

This notebook demonstrates how to perform deep set flow matching posterior estimation (DeepSet FMPE) with LAMPE, focusing on handling sets of parameters and observations in a permutation-invariant manner.

In [1]:
import matplotlib.pyplot as plt
import torch
import torch.nn as nn
import torch.optim as optim
import zuko

from itertools import islice
from lampe.data import JointLoader
from inference_multisets import DeepSetFMPE, DeepSetFMPELoss
from lampe.plots import corner, mark_point, nice_rc
from lampe.utils import GDStep
from tqdm import trange

## Simulator

We define a simple simulator function that generates observations based on sets of parameters. This function is designed to be permutation-invariant with respect to the input sets.

In [2]:
LABELS = [r'$\theta_1$', r'$\theta_2$', r'$\theta_3$']
LOWER = -torch.ones(3)
UPPER = torch.ones(3)

prior = zuko.distributions.BoxUniform(LOWER, UPPER)

def simulator(theta: torch.Tensor) -> torch.Tensor:
    x = torch.stack([
        theta[..., 0] + theta[..., 1] * theta[..., 2],
        theta[..., 0] * theta[..., 1] + theta[..., 2],
    ], dim=-1)

    return x + 0.05 * torch.randn_like(x)

theta = prior.sample()
x = simulator(theta)

print(theta, x, sep='\n')

tensor([-0.7399,  0.9650, -0.2231])
tensor([-0.9622, -0.9301])


## Training

We train our DeepSetFMPE model using a standard neural network training routine. The goal is to learn a regression network that approximates a vector field for permutation-invariant sets of parameters and observations.

In [3]:
loader = JointLoader(prior, simulator, batch_size=256, vectorized=True)

estimator = DeepSetFMPE(3, 2, hidden_features=64, num_hidden_layers=5, activation=nn.ELU)
print(estimator)

loss = DeepSetFMPELoss(estimator)
optimizer = optim.Adam(estimator.parameters(), lr=1e-3)
scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, 128)
step = GDStep(optimizer, clip=1.0)  # gradient descent step with gradient clipping

estimator.train()

for epoch in (bar := trange(128, unit='epoch')):
    losses = []

    for theta, x in islice(loader, 256):  # 256 batches per epoch
        losses.append(step(loss(theta, x)))

    bar.set_postfix(loss=torch.stack(losses).mean().item())

DeepSetFMPE(
  (net): Sequential()
  (activation): ELU(alpha=1.0)
  (phi_theta): Sequential(
    (0): Linear(in_features=3, out_features=64, bias=True)
    (1): ELU(alpha=1.0)
    (2): Linear(in_features=64, out_features=64, bias=True)
    (3): ELU(alpha=1.0)
    (4): Linear(in_features=64, out_features=64, bias=True)
    (5): ELU(alpha=1.0)
    (6): Linear(in_features=64, out_features=64, bias=True)
    (7): ELU(alpha=1.0)
    (8): Linear(in_features=64, out_features=64, bias=True)
    (9): ELU(alpha=1.0)
  )
  (phi_x): Sequential(
    (0): Linear(in_features=2, out_features=64, bias=True)
    (1): ELU(alpha=1.0)
    (2): Linear(in_features=64, out_features=64, bias=True)
    (3): ELU(alpha=1.0)
    (4): Linear(in_features=64, out_features=64, bias=True)
    (5): ELU(alpha=1.0)
    (6): Linear(in_features=64, out_features=64, bias=True)
    (7): ELU(alpha=1.0)
    (8): Linear(in_features=64, out_features=64, bias=True)
    (9): ELU(alpha=1.0)
  )
  (time_embedding): Linear(in_features

100%|██████████| 128/128 [01:33<00:00,  1.36epoch/s, loss=0.351]


## Inference

Now that we have an estimator of the vector field, we can sample from the normalizing flow it induces for sets of observations.

In [4]:
theta_star = prior.sample()
x_star = simulator(theta_star)

estimator.eval()

with torch.no_grad():
    log_p = estimator.flow(x_star).log_prob(theta_star)
    samples = estimator.flow(x_star).sample((2**14,))

RuntimeError: got 2 tensors and 1 gradients

## Visualization

We visualize the posterior samples to assess the quality of our DeepSetFMPE model.

In [None]:
plt.rcParams.update(nice_rc(latex=True))  # nicer plot settings

fig = corner(
    samples,
    smooth=2,
    domain=(LOWER, UPPER),
    labels=LABELS,
    legend=r'$p_\phi(\theta | x^*)$',
    figsize=(4.8, 4.8),
)

mark_point(fig, theta_star)