In [24]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.distributions import Normal

def set_seed(seed: int = 42):
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

class SimpleModel(nn.Module):
    def __init__(self, mean=0.0, std=1.0):
        super(SimpleModel, self).__init__()
        self.mean = nn.Parameter(torch.tensor(mean))
        self.std = nn.Parameter(torch.tensor(std))

    def sample(self, shape_TMS):
        distribution = Normal(self.mean, torch.abs(self.std) + 1e-6)  # ensure std is positive
        return distribution.rsample(shape_TMS)

    def build_from_single_tensor(self, param_P):
        mean, std = param_P[0], torch.abs(param_P[1]) + 1e-6
        return Normal(mean, std)

    def params_to_single_tensor(self):
        return torch.stack([self.mean, self.std])

def train_epoch(model_class, optimizer, train_T, sample_shape_MS, seed=42, compute_jacobian_together=True):
    set_seed(seed)  # Set the random seed for reproducibility

    optimizer.zero_grad()

    # Initialize model
    model = model_class()

    # Sample data from the model
    y_sample_TMS = model.sample((train_T, *sample_shape_MS))

    # Calculate ratio (normalized sample values along the last dimension)
    ratio_rating_TMS = y_sample_TMS / y_sample_TMS.sum(dim=-1, keepdim=True)
    ratio_rating_TS = ratio_rating_TMS.mean(dim=1)
    ratio_rating_TS.requires_grad_(True)

    def get_log_probs_baked(param_P):
        distribution = model.build_from_single_tensor(param_P)
        return distribution.log_prob(y_sample_TMS)

    if compute_jacobian_together:
        # Compute the Jacobian of log probabilities with respect to all parameters at once
        jac_TMSP = torch.autograd.functional.jacobian(
            get_log_probs_baked,
            (model.params_to_single_tensor(),),
            strategy='forward-mode',
            vectorize=True
        )
    else:
        # Compute the Jacobian with respect to each parameter separately for efficiency and clarity
        params = model.params_to_single_tensor()
        jac_TMSP = []
        for i in range(len(params)):
            def single_param_log_prob(param_scalar):
                param_copy = params.clone()
                param_copy[i] = param_scalar
                return model.build_from_single_tensor(param_copy).log_prob(y_sample_TMS)

            jac = torch.autograd.functional.jacobian(
                single_param_log_prob,
                params[i],
                strategy='forward-mode',
                vectorize=True
            )
            jac_TMSP.append(jac)
        jac_TMSP = torch.stack(jac_TMSP, dim=-1)

    # Example loss and backward pass
    loss = -ratio_rating_TS.sum()
    loss.backward()

    # Update model parameters if optimizer is provided
    optimizer.step()

    return loss.item(), jac_TMSP

# Example usage
if __name__ == "__main__":
    model_class = lambda: SimpleModel(mean=0.0, std=1.0)
    optimizer = optim.Adam(model_class().parameters(), lr=0.01)

    print("Computing Jacobian for all parameters at once:")
    loss_all, jac_all_TMSP = train_epoch(model_class, optimizer, train_T=10, sample_shape_MS=(5, 10), seed=123, compute_jacobian_together=True)
    print("Loss:", loss_all)
    print("Jacobian shape (all together):", jac_all_TMSP[0].shape)

    print("\nComputing Jacobian for each parameter separately:")
    loss_sep, jac_sep_TMSP = train_epoch(model_class, optimizer, train_T=10, sample_shape_MS=(5, 10), seed=123, compute_jacobian_together=False)
    print("Loss:", loss_sep)
    print("Jacobian shape (separate):", jac_sep_TMSP.shape)

    # assert jacs are close
    assert torch.allclose(jac_all_TMSP[0], jac_sep_TMSP, atol=1e-5)

Computing Jacobian for all parameters at once:
Loss: -9.999999046325684
Jacobian shape (all together): torch.Size([10, 5, 10, 2])

Computing Jacobian for each parameter separately:
Loss: -9.999999046325684
Jacobian shape (separate): torch.Size([10, 5, 10, 2])


torch.Size([10, 5, 10, 2])