In [33]:
import torch
from torch import nn
import einops as ein

In [34]:
logits = nn.Parameter(torch.rand(7, 10))
# Log of the probability mass function evaluated at some value?
cat = torch.distributions.Categorical(logits=logits)
test_samples = cat.sample((100,))
test_samples

tensor([[7, 0, 4, 0, 9, 6, 3],
        [4, 3, 5, 2, 1, 6, 4],
        [3, 6, 8, 3, 5, 1, 9],
        [0, 6, 9, 2, 3, 5, 7],
        [0, 7, 6, 5, 2, 4, 8],
        [5, 4, 4, 7, 0, 1, 3],
        [0, 8, 4, 2, 8, 3, 3],
        [1, 5, 1, 5, 9, 9, 1],
        [2, 1, 7, 4, 7, 7, 8],
        [1, 3, 0, 6, 2, 9, 0],
        [0, 8, 2, 0, 3, 5, 7],
        [3, 4, 0, 3, 5, 3, 7],
        [6, 7, 9, 5, 6, 8, 3],
        [7, 6, 1, 9, 1, 3, 7],
        [8, 7, 0, 0, 2, 8, 1],
        [7, 3, 2, 6, 6, 3, 8],
        [6, 4, 9, 6, 8, 2, 3],
        [0, 2, 9, 5, 3, 7, 8],
        [5, 0, 4, 5, 0, 2, 7],
        [8, 8, 9, 4, 1, 0, 9],
        [8, 4, 9, 0, 4, 7, 3],
        [3, 1, 7, 4, 6, 7, 2],
        [4, 2, 0, 2, 4, 5, 6],
        [3, 2, 5, 1, 3, 9, 6],
        [2, 6, 1, 2, 2, 3, 3],
        [4, 4, 1, 9, 8, 6, 3],
        [0, 6, 9, 1, 3, 2, 4],
        [8, 6, 4, 0, 1, 1, 6],
        [3, 4, 7, 5, 0, 6, 3],
        [8, 2, 2, 9, 7, 7, 5],
        [0, 3, 6, 4, 4, 8, 6],
        [2, 5, 3, 1, 9, 5, 8],
        

In [None]:
class ParameterizedProb(nn.Module):
    def __init__(self):
        super(ParameterizedProb, self).__init__()
        self.logits = nn.Parameter(torch.rand(7, 10))

    def f(self, samples):
        """
        A function of individual samples, which are 7 long.
        """
        return -torch.pow(samples, 2).sum(dim=-1)

    def forward(self):
        """
        Samples of x composed with f, where < f(x) > is being minimized, along with the logs of
        probabilities associated with sampling each x
        """

        # A probability distribution that's a function of model params
        cat = torch.distributions.Categorical(logits=self.logits)

        samples = cat.sample((100,))  # (100, 7)
        log_probs = cat.log_prob(samples)  # (100, 7)
        samples = samples.to(dtype=torch.float32)

        # The class dimension is the last dimension, not the first

        # Log probs of sampling each x (sum over log probs of the entries of each x)
        log_probs = ein.reduce(
            log_probs,
            "b s -> b",
            reduction="sum",
        )

        assert torch.all(torch.exp(log_probs) >= 0.0), (
            "Probs over whole chains must be geq0"
        )
        assert torch.all(torch.exp(log_probs) <= 1.0), (
            "Probs over whole chains must be leq 1"
        )

        f_values = self.f(samples)  # (100,)
        return f_values, log_probs

    def surrogate_loss(self, f_values, log_probs):
        loss = ein.einsum(
            f_values,
            log_probs,
            "b, b -> ",
        )
        return loss

In [36]:
p = ParameterizedProb()
f, log_probs = p.forward()
f.shape

torch.Size([100])

In [37]:
p.surrogate_loss(f, log_probs)

tensor(321283.8125, grad_fn=<ViewBackward0>)

This does seem to minimize loss with similar surrogate loss behavior (becoming less negative per backprop iteration):

In [38]:
# Define optimizer
optimizer = torch.optim.Adam(p.parameters(), lr=1e-3)

num_epochs = 100000
for epoch in range(num_epochs):
    optimizer.zero_grad()  # Clear gradients

    f_values, log_probs = p.forward()
    loss = p.surrogate_loss(f_values, log_probs)
    loss.backward()
    optimizer.step()

    if (epoch + 1) % 100 == 0:
        print(
            f"Epoch {epoch + 1}/{num_epochs}, Loss: {loss.item()} f: {f_values.mean().item()} log_probs: {log_probs.mean().item()}"
        )

Epoch 100/100000, Loss: 313324.25 f: -197.3699951171875 log_probs: -15.908193588256836
Epoch 200/100000, Loss: 357144.40625 f: -225.33999633789062 log_probs: -15.941043853759766
Epoch 300/100000, Loss: 334820.90625 f: -210.9499969482422 log_probs: -15.913761138916016
Epoch 400/100000, Loss: 364514.25 f: -231.44000244140625 log_probs: -15.875053405761719
Epoch 500/100000, Loss: 357564.0625 f: -229.61000061035156 log_probs: -15.693944931030273
Epoch 600/100000, Loss: 363309.75 f: -232.2100067138672 log_probs: -15.908476829528809
Epoch 700/100000, Loss: 382260.625 f: -247.75999450683594 log_probs: -15.595168113708496
Epoch 800/100000, Loss: 396765.75 f: -258.239990234375 log_probs: -15.527678489685059
Epoch 900/100000, Loss: 415146.8125 f: -270.5799865722656 log_probs: -15.559765815734863
Epoch 1000/100000, Loss: 390566.0 f: -255.5399932861328 log_probs: -15.574394226074219
Epoch 1100/100000, Loss: 424035.5 f: -284.489990234375 log_probs: -15.249031066894531
Epoch 1200/100000, Loss: 42063

KeyboardInterrupt: 