Skip to content

Commit

Permalink
Add complex backward support for torch.exp (#47194)
Browse files Browse the repository at this point in the history
Summary: Pull Request resolved: #47194

Test Plan: Imported from OSS

Reviewed By: izdeby

Differential Revision: D24683201

Pulled By: anjali411

fbshipit-source-id: c447dec51cbfe7c09d6943fbaafa94f48130d582
  • Loading branch information
anjali411 authored and facebook-github-bot committed Nov 2, 2020
1 parent c10aa44 commit da26858
Show file tree
Hide file tree
Showing 3 changed files with 5 additions and 3 deletions.
3 changes: 2 additions & 1 deletion test/test_autograd.py
Expand Up @@ -4938,7 +4938,8 @@ def run_functional_checks(test_case, test_name, name, apply_fn, run_grad_checks,
'chunk', 'split', 'split_with_sizes', 'repeat', 'expand', 'zero_',
'eq_', 'ne_', 'add', '__radd__', 'sum', 'conj', 'sin', 'cos', 'mul', 'sinh',
'cosh', '__rmul__', 'sgn', 'abs', 'dot', 'vdot', 'tensor_split', 'matmul',
'bmm', 'mv', 'ger', 'diagonal', 'atan', 'angle', 'tanh', 'fill_', 'sub'] + separate_complex_tests
'bmm', 'mv', 'ger', 'diagonal', 'atan', 'angle', 'tanh', 'fill_', 'sub',
'exp'] + separate_complex_tests

# this list corresponds to cases that are not currently implemented
skip_cuda_list = ['bmm_complex', 'matmul_4d_4d_complex']
Expand Down
2 changes: 1 addition & 1 deletion tools/autograd/derivatives.yaml
Expand Up @@ -434,7 +434,7 @@
self: 0.5 * sqrt(M_PI) * exp(self.erfinv().pow(2)) * grad

- name: exp(Tensor self) -> Tensor
self: grad * result
self: grad * result.conj()

- name: exp2(Tensor self) -> Tensor
self: grad * result * M_LN2
Expand Down
3 changes: 2 additions & 1 deletion tools/autograd/gen_variable_type.py
Expand Up @@ -166,7 +166,8 @@
'unbind', 'split', 'split_with_sizes', 'unsafe_split', 'split_with_sizes_backward',
'dot', 'vdot', 'cholesky', 'triangular_solve', 'mm', '_unsafe_view', 'mv', 'ger',
'bmm', 'diagonal', 'cholesky', 'atan', 'log', 'log10', 'log1p', 'log2', 'reciprocal',
'tan', 'pow', 'rsqrt', 'tanh', 'tanh_backward', 'asinh', 'acosh', 'take', 'fill_'
'tan', 'pow', 'rsqrt', 'tanh', 'tanh_backward', 'asinh', 'acosh', 'take', 'fill_',
'exp'
}

# Some operators invalidate the grad_accumulator. Let's reset it.
Expand Down

0 comments on commit da26858

Please sign in to comment.