From a3662fa78c42bc2ae6b70fe6f024fb73fed59bcc Mon Sep 17 00:00:00 2001 From: anjali411 Date: Tue, 6 Oct 2020 13:56:47 -0700 Subject: [PATCH] Minor gradcheck update to reduce computations (#45757) Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/45757 Test Plan: Imported from OSS Reviewed By: glaringlee Differential Revision: D24137143 Pulled By: anjali411 fbshipit-source-id: e0174ec03d93b1fedf27baa72c3542dac0b70058 --- torch/autograd/gradcheck.py | 12 +++++------- 1 file changed, 5 insertions(+), 7 deletions(-) diff --git a/torch/autograd/gradcheck.py b/torch/autograd/gradcheck.py index 7ca1fccfce54..b2bea4570c2a 100644 --- a/torch/autograd/gradcheck.py +++ b/torch/autograd/gradcheck.py @@ -102,13 +102,11 @@ def fn_out(): d[d_idx] = grad_out.conjugate() * conj_w_d + grad_out * w_d.conj() elif ds_dx.is_complex(): # R -> C # w_d = conj_w_d = 0.5 * ds_dx - dL_dz_conj = 0.5 * (grad_out.conjugate() * ds_dx + grad_out * ds_dx.conj()) - # The above formula is derived for a C -> C function that's a part of - # bigger function with real valued output. From separate calculations, - # it can be verified that the gradient for R -> C function - # equals to real value of the result obtained from the generic formula for - # C -> C functions used above. - d[d_idx] = torch.real(dL_dz_conj) + # dL_dz_conj = 0.5 * [grad_out.conj() * ds_dx + grad_out * ds_dx.conj()] + # = 0.5 * [grad_out.conj() * ds_dx + (grad_out.conj() * ds_dx).conj()] + # = 0.5 * 2 * real(grad_out.conj() * ds_dx) + # = real(grad_out.conj() * ds_dx) + d[d_idx] = torch.real(grad_out.conjugate() * ds_dx) else: # R -> R d[d_idx] = ds_dx * grad_out