From 25de8baccd4ac5c71a10b0b94d57404c736e7dde Mon Sep 17 00:00:00 2001 From: kshitij12345 Date: Tue, 2 Feb 2021 00:03:40 -0600 Subject: [PATCH] enable complex autograd and jit tests for trace --- tools/autograd/gen_variable_type.py | 2 +- torch/testing/_internal/common_methods_invocations.py | 9 +-------- 2 files changed, 2 insertions(+), 9 deletions(-) diff --git a/tools/autograd/gen_variable_type.py b/tools/autograd/gen_variable_type.py index 43874a31e8cf..f79ce86d0890 100644 --- a/tools/autograd/gen_variable_type.py +++ b/tools/autograd/gen_variable_type.py @@ -90,7 +90,7 @@ 'replication_pad1d', 'replication_pad2d', 'replication_pad3d', 'replication_pad1d_backward', 'replication_pad2d_backward', 'replication_pad3d_backward', 'masked_scatter', 'masked_select', - 'index_fill', + 'index_fill', 'trace' } # Some operators invalidate the grad_accumulator. Let's reset it. diff --git a/torch/testing/_internal/common_methods_invocations.py b/torch/testing/_internal/common_methods_invocations.py index 62b46e03af2f..b32a78421b49 100644 --- a/torch/testing/_internal/common_methods_invocations.py +++ b/torch/testing/_internal/common_methods_invocations.py @@ -2121,14 +2121,7 @@ def reference_lgamma(x): dtypesIfCUDA=all_types_and_complex_and(torch.bool, torch.half), test_inplace_grad=False, supports_tensor_out=False, - # Reference: https://github.com/pytorch/pytorch/issues/50381 - test_complex_grad=False, - sample_inputs_func=sample_inputs_trace, - skips=( - SkipInfo('TestCommon', 'test_variant_consistency_jit', - dtypes=[torch.complex64, torch.complex128]), - SkipInfo('TestCommon', 'test_variant_consistency_eager', - dtypes=[torch.complex64, torch.complex128]))), + sample_inputs_func=sample_inputs_trace) ] op_db = op_db + op_db_scipy_reference