From 276113a2b1375ad402edd45f5da1bb4d98b763e5 Mon Sep 17 00:00:00 2001 From: kshitij12345 Date: Thu, 28 Jan 2021 00:00:14 -0600 Subject: [PATCH 1/3] enable diag complex autograd tests --- test/test_autograd.py | 2 +- tools/autograd/gen_variable_type.py | 1 + 2 files changed, 2 insertions(+), 1 deletion(-) diff --git a/test/test_autograd.py b/test/test_autograd.py index df588f97701a..ca8eb4b7d1d3 100644 --- a/test/test_autograd.py +++ b/test/test_autograd.py @@ -5037,7 +5037,7 @@ def run_functional_checks(test_case, test_name, name, apply_fn, run_grad_checks, 'mean', 'inverse', 'triangular_solve', 'solve', 'addcmul', 'addcdiv', 'linalg.tensorinv', 'matrix_exp', 'qr', 'narrow', 'swapaxes', 'swapdims', 'tensor_split', 'tile', - 'baddbmm', 'addbmm', 'addmv'] + separate_complex_tests + 'baddbmm', 'addbmm', 'addmv', 'diag'] + separate_complex_tests # deny list for batched grad computation EXCLUDE_BATCHED_GRAD_TESTS = set([ diff --git a/tools/autograd/gen_variable_type.py b/tools/autograd/gen_variable_type.py index c2e776d8c0f4..cc657fdfc080 100644 --- a/tools/autograd/gen_variable_type.py +++ b/tools/autograd/gen_variable_type.py @@ -89,6 +89,7 @@ 'reflection_pad1d_backward', 'reflection_pad2d_backward', 'replication_pad1d', 'replication_pad2d', 'replication_pad3d', 'replication_pad1d_backward', 'replication_pad2d_backward', 'replication_pad3d_backward', + 'diag' } # Some operators invalidate the grad_accumulator. Let's reset it. From 843a546d9fa4250f94c48ea059d063b56eaee3c7 Mon Sep 17 00:00:00 2001 From: kshitij12345 Date: Thu, 28 Jan 2021 09:43:28 -0600 Subject: [PATCH 2/3] add OpInfo entry for diag --- .../_internal/common_methods_invocations.py | 33 +++++++++++++------ 1 file changed, 23 insertions(+), 10 deletions(-) diff --git a/torch/testing/_internal/common_methods_invocations.py b/torch/testing/_internal/common_methods_invocations.py index 324a9346c45b..97b6b3eaf395 100644 --- a/torch/testing/_internal/common_methods_invocations.py +++ b/torch/testing/_internal/common_methods_invocations.py @@ -993,6 +993,23 @@ def sample_inputs_fliplr_flipud(op_info, device, dtype, requires_grad): ) return [SampleInput(tensor) for tensor in tensors] +def sample_inputs_diag(op_info, device, dtype, requires_grad): + vec_sample = SampleInput(make_tensor((M, ), device, dtype, low=None, high=None, requires_grad=requires_grad)) + + tensors = ( + make_tensor((M, M), device, dtype, low=None, high=None, requires_grad=requires_grad), + make_tensor((3, 5), device, dtype, low=None, high=None, requires_grad=requires_grad), + make_tensor((5, 3), device, dtype, low=None, high=None, requires_grad=requires_grad), + ) + + args = ((), (2,), (-2,), (1,), (2,)) + + samples = [] + for tensor, arg in product(tensors, args): + samples.append(SampleInput(tensor, args=arg)) + + return samples + [vec_sample] + # Operator database (sorted alphabetically) op_db: List[OpInfo] = [ # NOTE: CPU complex acos produces incorrect outputs (https://github.com/pytorch/pytorch/issues/42952) @@ -1193,6 +1210,12 @@ def sample_inputs_fliplr_flipud(op_info, device, dtype, requires_grad): ), assert_autodiffed=True, safe_casts_outputs=True), + OpInfo('diag', + dtypes=all_types_and_complex_and(torch.bool), + dtypesIfCPU=all_types_and_complex_and(torch.bool), + dtypesIfCUDA=all_types_and_complex_and(torch.bool, torch.half), + sample_inputs_func=sample_inputs_diag, + test_inplace_grad=False), SpectralFuncInfo('fft.fft', aten_name='fft_fft', ref=np.fft.fft, @@ -2575,16 +2598,6 @@ def method_tests(): ('dist', (), ((), 4), 'scalar_4'), ('dist', (S, S, S), ((), 4), 'scalar_4_broadcast_rhs'), ('dist', (), ((S, S, S), 4), 'scalar_4_broadcast_lhs'), - ('diag', (M, M), NO_ARGS, '2d'), - ('diag', (3, 5), NO_ARGS, '2d_wide'), - ('diag', (3, 5), (2,), '2d_wide_pos'), - ('diag', (3, 5), (-2,), '2d_wide_neg'), - ('diag', (5, 3), NO_ARGS, '2d_tall'), - ('diag', (5, 3), (2,), '2d_tall_pos'), - ('diag', (5, 3), (-2,), '2d_tall_neg'), - ('diag', (M,), NO_ARGS, '1d'), - ('diag', (M, M), (1,), '2d_1'), - ('diag', (M, M), (2,), '2d_2'), ('diag_embed', (S, S), NO_ARGS), ('diagonal', (M, M), NO_ARGS, '2d'), ('diagonal', (3, 5), NO_ARGS, '2d_wide'), From f031d10f7eed799b59f4f2f2a42d6405d7b30fbe Mon Sep 17 00:00:00 2001 From: kshitij12345 Date: Thu, 28 Jan 2021 09:45:17 -0600 Subject: [PATCH 3/3] remove entry from test_autograd --- test/test_autograd.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/test_autograd.py b/test/test_autograd.py index ca8eb4b7d1d3..df588f97701a 100644 --- a/test/test_autograd.py +++ b/test/test_autograd.py @@ -5037,7 +5037,7 @@ def run_functional_checks(test_case, test_name, name, apply_fn, run_grad_checks, 'mean', 'inverse', 'triangular_solve', 'solve', 'addcmul', 'addcdiv', 'linalg.tensorinv', 'matrix_exp', 'qr', 'narrow', 'swapaxes', 'swapdims', 'tensor_split', 'tile', - 'baddbmm', 'addbmm', 'addmv', 'diag'] + separate_complex_tests + 'baddbmm', 'addbmm', 'addmv'] + separate_complex_tests # deny list for batched grad computation EXCLUDE_BATCHED_GRAD_TESTS = set([