diff --git a/torch/distributions/transforms.py b/torch/distributions/transforms.py index 6eb549248854..a0412d52df0d 100644 --- a/torch/distributions/transforms.py +++ b/torch/distributions/transforms.py @@ -590,7 +590,7 @@ def _call(self, x): # Note that y = sign(r) * sqrt(z * z1m_cumprod) # = (sign(r) * sqrt(z)) * sqrt(z1m_cumprod) = r * sqrt(z1m_cumprod) z = r ** 2 - z1m_cumprod_sqrt = (1 - z).cumprod(-1).sqrt() + z1m_cumprod_sqrt = (1 - z).sqrt().cumprod(-1) # Diagonal elements must be 1. r = r + torch.eye(r.shape[-1], dtype=r.dtype, device=r.device) y = r * pad(z1m_cumprod_sqrt[..., :-1], [1, 0], value=1)