diff --git a/scvi/module/_amortizedlda.py b/scvi/module/_amortizedlda.py index d1389f564d..5f3f63b513 100644 --- a/scvi/module/_amortizedlda.py +++ b/scvi/module/_amortizedlda.py @@ -30,9 +30,15 @@ def log_prob(self, value): def logistic_normal_approximation( alpha: torch.Tensor, ) -> Tuple[torch.Tensor, torch.Tensor]: + """ + Returns the mean and standard deviation of the Logistic Normal approximation to the Dirichlet. + + Uses the Laplace approximation of the Logistic Normal distribution to the Dirichlet distribution + as described in Srivastava et al. https://arxiv.org/pdf/1703.01488.pdf. + """ K = alpha.shape[0] mu = torch.log(alpha) - torch.log(alpha).sum() / K - sigma = (1 - 2 / K) / alpha + torch.sum(1 / alpha) / K ** 2 + sigma = torch.sqrt((1 - 2 / K) / alpha + torch.sum(1 / alpha) / K ** 2) return mu, sigma