Skip to content

Commit

Permalink
take sqrt first
Browse files Browse the repository at this point in the history
  • Loading branch information
neerajprad committed Dec 2, 2020
1 parent 187019a commit 06efaf9
Showing 1 changed file with 1 addition and 1 deletion.
2 changes: 1 addition & 1 deletion torch/distributions/transforms.py
Expand Up @@ -590,7 +590,7 @@ def _call(self, x):
# Note that y = sign(r) * sqrt(z * z1m_cumprod)
# = (sign(r) * sqrt(z)) * sqrt(z1m_cumprod) = r * sqrt(z1m_cumprod)
z = r ** 2
z1m_cumprod_sqrt = (1 - z).cumprod(-1).sqrt()
z1m_cumprod_sqrt = (1 - z).sqrt().cumprod(-1)
# Diagonal elements must be 1.
r = r + torch.eye(r.shape[-1], dtype=r.dtype, device=r.device)
y = r * pad(z1m_cumprod_sqrt[..., :-1], [1, 0], value=1)
Expand Down

0 comments on commit 06efaf9

Please sign in to comment.