In [74]:
import torch
from torch import nn
import einops as ein

In [75]:
class ParameterizedProb(nn.Module):
    def __init__(self):
        super(ParameterizedProb, self).__init__()
        self.logits = nn.Parameter(torch.rand(7, 10))

    def f(self, samples):
        """
        A function of individual samples
        """
        return torch.pow(samples, 2).sum(dim=-1)

    def forward(self):
        """
        Samples of x composed with f, where < f(x) > is being minimized, along with the logs of
        probabilities associated with sampling each x
        """

        # A probability distribution that's a function of model params
        cat = torch.distributions.Categorical(logits=self.logits)

        samples = cat.sample((100,))  # (100, 7)
        log_probs = cat.log_prob(samples)  # (100, 7)
        samples = samples.to(dtype=torch.float32)

        # Log probs of sampling each x (sum over log probs of the entries of each x)
        log_probs = ein.reduce(
            log_probs,
            "b s -> b",
            reduction="sum",
        )

        assert torch.all(
            torch.exp(log_probs) >= 0.0
        ), "Probs over whole chains must be geq0"
        assert torch.all(
            torch.exp(log_probs) <= 1.0
        ), "Probs over whole chains must be leq 1"

        f_values = self.f(samples)  # (100,)
        return f_values, log_probs

    def surrogate_loss(self, f_values, log_probs):
        loss = ein.einsum(
            f_values,
            log_probs,
            "b, b -> ",
        )
        return loss

In [76]:
p = ParameterizedProb()
f, log_probs = p.forward()
f.shape

torch.Size([100])

In [77]:
p.surrogate_loss(f, log_probs)

tensor(-288449.0938, grad_fn=<ViewBackward0>)

This does 

In [110]:
# Define optimizer
optimizer = torch.optim.Adam(p.parameters(), lr=1e-3)

num_epochs = 1000
for epoch in range(num_epochs):
    optimizer.zero_grad()  # Clear gradients

    f_values, log_probs = p.forward()
    loss = p.surrogate_loss(f_values, log_probs)
    loss.backward()
    optimizer.step()

    if (epoch + 1) % 100 == 0:
        print(
            f"Epoch {epoch + 1}/{num_epochs}, Loss: {loss.item()} f: {f_values.mean().item()} log_probs: {log_probs.mean().item()}"
        )

Epoch 100/1000, Loss: 0.0 f: 0.0 log_probs: -0.0015707015991210938
Epoch 200/1000, Loss: 0.0 f: 0.0 log_probs: -0.0015802383422851562
Epoch 300/1000, Loss: 0.0 f: 0.0 log_probs: -0.0015869140625
Epoch 400/1000, Loss: 0.0 f: 0.0 log_probs: -0.0015802383422851562
Epoch 500/1000, Loss: 0.0 f: 0.0 log_probs: -0.0015716552734375
Epoch 600/1000, Loss: 0.0 f: 0.0 log_probs: -0.0015687942504882812
Epoch 700/1000, Loss: 0.0 f: 0.0 log_probs: -0.001560211181640625
Epoch 800/1000, Loss: -9.958617210388184 f: 0.009999999776482582 log_probs: -0.10111095756292343
Epoch 900/1000, Loss: 0.0 f: 0.0 log_probs: -0.0015306472778320312
Epoch 1000/1000, Loss: 0.0 f: 0.0 log_probs: -0.0015077590942382812
