Skip to content

Commit

Permalink
add docstring
Browse files Browse the repository at this point in the history
  • Loading branch information
justjhong committed Sep 23, 2021
1 parent 9517ae7 commit 3ccbabb
Showing 1 changed file with 7 additions and 1 deletion.
8 changes: 7 additions & 1 deletion scvi/module/_amortizedlda.py
Expand Up @@ -30,9 +30,15 @@ def log_prob(self, value):
def logistic_normal_approximation(
alpha: torch.Tensor,
) -> Tuple[torch.Tensor, torch.Tensor]:
"""
Returns the mean and standard deviation of the Logistic Normal approximation to the Dirichlet.
Uses the Laplace approximation of the Logistic Normal distribution to the Dirichlet distribution
as described in Srivastava et al. https://arxiv.org/pdf/1703.01488.pdf.
"""
K = alpha.shape[0]
mu = torch.log(alpha) - torch.log(alpha).sum() / K
sigma = (1 - 2 / K) / alpha + torch.sum(1 / alpha) / K ** 2
sigma = torch.sqrt((1 - 2 / K) / alpha + torch.sum(1 / alpha) / K ** 2)
return mu, sigma


Expand Down

0 comments on commit 3ccbabb

Please sign in to comment.