From 06efaf92c939d4791a5a03cc51d2ec30b9ca98e9 Mon Sep 17 00:00:00 2001 From: neerajprad Date: Tue, 1 Dec 2020 18:14:18 -0800 Subject: [PATCH] take sqrt first --- torch/distributions/transforms.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/torch/distributions/transforms.py b/torch/distributions/transforms.py index 6eb549248854..a0412d52df0d 100644 --- a/torch/distributions/transforms.py +++ b/torch/distributions/transforms.py @@ -590,7 +590,7 @@ def _call(self, x): # Note that y = sign(r) * sqrt(z * z1m_cumprod) # = (sign(r) * sqrt(z)) * sqrt(z1m_cumprod) = r * sqrt(z1m_cumprod) z = r ** 2 - z1m_cumprod_sqrt = (1 - z).cumprod(-1).sqrt() + z1m_cumprod_sqrt = (1 - z).sqrt().cumprod(-1) # Diagonal elements must be 1. r = r + torch.eye(r.shape[-1], dtype=r.dtype, device=r.device) y = r * pad(z1m_cumprod_sqrt[..., :-1], [1, 0], value=1)