Skip to content

Commit

Permalink
Add complex autograd support for torch.mean (#47566)
Browse files Browse the repository at this point in the history
Summary: Pull Request resolved: #47566

Test Plan: Imported from OSS

Reviewed By: heitorschueroff

Differential Revision: D24817013

Pulled By: anjali411

fbshipit-source-id: f2b8411fb9abdc3e2d07c8e4fef3071b76605b12
  • Loading branch information
anjali411 authored and facebook-github-bot committed Nov 9, 2020
1 parent 3d96243 commit 8339f88
Show file tree
Hide file tree
Showing 2 changed files with 2 additions and 2 deletions.
2 changes: 1 addition & 1 deletion test/test_autograd.py
Expand Up @@ -5006,7 +5006,7 @@ def run_functional_checks(test_case, test_name, name, apply_fn, run_grad_checks,
'eq_', 'ne_', 'add', '__radd__', 'sum', 'conj', 'sin', 'cos', 'mul', 'sinh',
'cosh', '__rmul__', 'sgn', 'abs', 'dot', 'vdot', 'tensor_split', 'matmul',
'bmm', 'mv', 'ger', 'diagonal', 'atan', 'angle', 'tanh', 'fill_', 'sub',
'exp'] + separate_complex_tests
'exp', 'mean'] + separate_complex_tests

# this list corresponds to cases that are not currently implemented
skip_cuda_list = ['bmm_complex', 'matmul_4d_4d_complex']
Expand Down
2 changes: 1 addition & 1 deletion tools/autograd/gen_variable_type.py
Expand Up @@ -167,7 +167,7 @@
'dot', 'vdot', 'cholesky', 'triangular_solve', 'mm', '_unsafe_view', 'mv', 'ger',
'bmm', 'diagonal', 'alias', 'atan', 'log', 'log10', 'log1p', 'log2', 'reciprocal',
'tan', 'pow', 'rsqrt', 'tanh', 'tanh_backward', 'asinh', 'acosh', 'take', 'fill_',
'exp', 'nonzero'
'exp', 'nonzero', 'mean'
}

# Some operators invalidate the grad_accumulator. Let's reset it.
Expand Down

0 comments on commit 8339f88

Please sign in to comment.