diff --git a/torch/autograd/gradcheck.py b/torch/autograd/gradcheck.py index 7ca1fccfce54..b2bea4570c2a 100644 --- a/torch/autograd/gradcheck.py +++ b/torch/autograd/gradcheck.py @@ -102,13 +102,11 @@ def fn_out(): d[d_idx] = grad_out.conjugate() * conj_w_d + grad_out * w_d.conj() elif ds_dx.is_complex(): # R -> C # w_d = conj_w_d = 0.5 * ds_dx - dL_dz_conj = 0.5 * (grad_out.conjugate() * ds_dx + grad_out * ds_dx.conj()) - # The above formula is derived for a C -> C function that's a part of - # bigger function with real valued output. From separate calculations, - # it can be verified that the gradient for R -> C function - # equals to real value of the result obtained from the generic formula for - # C -> C functions used above. - d[d_idx] = torch.real(dL_dz_conj) + # dL_dz_conj = 0.5 * [grad_out.conj() * ds_dx + grad_out * ds_dx.conj()] + # = 0.5 * [grad_out.conj() * ds_dx + (grad_out.conj() * ds_dx).conj()] + # = 0.5 * 2 * real(grad_out.conj() * ds_dx) + # = real(grad_out.conj() * ds_dx) + d[d_idx] = torch.real(grad_out.conjugate() * ds_dx) else: # R -> R d[d_idx] = ds_dx * grad_out