From da26858c9c96528c3149506f1b594e8f43f2156d Mon Sep 17 00:00:00 2001 From: anjali411 Date: Mon, 2 Nov 2020 09:36:49 -0800 Subject: [PATCH] Add complex backward support for torch.exp (#47194) Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/47194 Test Plan: Imported from OSS Reviewed By: izdeby Differential Revision: D24683201 Pulled By: anjali411 fbshipit-source-id: c447dec51cbfe7c09d6943fbaafa94f48130d582 --- test/test_autograd.py | 3 ++- tools/autograd/derivatives.yaml | 2 +- tools/autograd/gen_variable_type.py | 3 ++- 3 files changed, 5 insertions(+), 3 deletions(-) diff --git a/test/test_autograd.py b/test/test_autograd.py index 5dc5c94c3e53..2b7b064bf483 100644 --- a/test/test_autograd.py +++ b/test/test_autograd.py @@ -4938,7 +4938,8 @@ def run_functional_checks(test_case, test_name, name, apply_fn, run_grad_checks, 'chunk', 'split', 'split_with_sizes', 'repeat', 'expand', 'zero_', 'eq_', 'ne_', 'add', '__radd__', 'sum', 'conj', 'sin', 'cos', 'mul', 'sinh', 'cosh', '__rmul__', 'sgn', 'abs', 'dot', 'vdot', 'tensor_split', 'matmul', - 'bmm', 'mv', 'ger', 'diagonal', 'atan', 'angle', 'tanh', 'fill_', 'sub'] + separate_complex_tests + 'bmm', 'mv', 'ger', 'diagonal', 'atan', 'angle', 'tanh', 'fill_', 'sub', + 'exp'] + separate_complex_tests # this list corresponds to cases that are not currently implemented skip_cuda_list = ['bmm_complex', 'matmul_4d_4d_complex'] diff --git a/tools/autograd/derivatives.yaml b/tools/autograd/derivatives.yaml index 8cbcab35685e..da80b1a7a124 100644 --- a/tools/autograd/derivatives.yaml +++ b/tools/autograd/derivatives.yaml @@ -434,7 +434,7 @@ self: 0.5 * sqrt(M_PI) * exp(self.erfinv().pow(2)) * grad - name: exp(Tensor self) -> Tensor - self: grad * result + self: grad * result.conj() - name: exp2(Tensor self) -> Tensor self: grad * result * M_LN2 diff --git a/tools/autograd/gen_variable_type.py b/tools/autograd/gen_variable_type.py index 04b8d144cd79..5baf38536ff7 100644 --- a/tools/autograd/gen_variable_type.py +++ b/tools/autograd/gen_variable_type.py @@ -166,7 +166,8 @@ 'unbind', 'split', 'split_with_sizes', 'unsafe_split', 'split_with_sizes_backward', 'dot', 'vdot', 'cholesky', 'triangular_solve', 'mm', '_unsafe_view', 'mv', 'ger', 'bmm', 'diagonal', 'cholesky', 'atan', 'log', 'log10', 'log1p', 'log2', 'reciprocal', - 'tan', 'pow', 'rsqrt', 'tanh', 'tanh_backward', 'asinh', 'acosh', 'take', 'fill_' + 'tan', 'pow', 'rsqrt', 'tanh', 'tanh_backward', 'asinh', 'acosh', 'take', 'fill_', + 'exp' } # Some operators invalidate the grad_accumulator. Let's reset it.