From 8339f88353770f6f30f912814c8fd992e5f3f114 Mon Sep 17 00:00:00 2001 From: anjali411 Date: Mon, 9 Nov 2020 08:21:16 -0800 Subject: [PATCH] Add complex autograd support for torch.mean (#47566) Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/47566 Test Plan: Imported from OSS Reviewed By: heitorschueroff Differential Revision: D24817013 Pulled By: anjali411 fbshipit-source-id: f2b8411fb9abdc3e2d07c8e4fef3071b76605b12 --- test/test_autograd.py | 2 +- tools/autograd/gen_variable_type.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/test/test_autograd.py b/test/test_autograd.py index c7001ffab4f5..e651bfe477dd 100644 --- a/test/test_autograd.py +++ b/test/test_autograd.py @@ -5006,7 +5006,7 @@ def run_functional_checks(test_case, test_name, name, apply_fn, run_grad_checks, 'eq_', 'ne_', 'add', '__radd__', 'sum', 'conj', 'sin', 'cos', 'mul', 'sinh', 'cosh', '__rmul__', 'sgn', 'abs', 'dot', 'vdot', 'tensor_split', 'matmul', 'bmm', 'mv', 'ger', 'diagonal', 'atan', 'angle', 'tanh', 'fill_', 'sub', - 'exp'] + separate_complex_tests + 'exp', 'mean'] + separate_complex_tests # this list corresponds to cases that are not currently implemented skip_cuda_list = ['bmm_complex', 'matmul_4d_4d_complex'] diff --git a/tools/autograd/gen_variable_type.py b/tools/autograd/gen_variable_type.py index 7e2d9bbe641a..df52813e1a40 100644 --- a/tools/autograd/gen_variable_type.py +++ b/tools/autograd/gen_variable_type.py @@ -167,7 +167,7 @@ 'dot', 'vdot', 'cholesky', 'triangular_solve', 'mm', '_unsafe_view', 'mv', 'ger', 'bmm', 'diagonal', 'alias', 'atan', 'log', 'log10', 'log1p', 'log2', 'reciprocal', 'tan', 'pow', 'rsqrt', 'tanh', 'tanh_backward', 'asinh', 'acosh', 'take', 'fill_', - 'exp', 'nonzero' + 'exp', 'nonzero', 'mean' } # Some operators invalidate the grad_accumulator. Let's reset it.