diff --git a/tools/autograd/gen_variable_type.py b/tools/autograd/gen_variable_type.py index 43874a31e8cf..fd21960350b0 100644 --- a/tools/autograd/gen_variable_type.py +++ b/tools/autograd/gen_variable_type.py @@ -89,8 +89,7 @@ 'reflection_pad1d_backward', 'reflection_pad2d_backward', 'replication_pad1d', 'replication_pad2d', 'replication_pad3d', 'replication_pad1d_backward', 'replication_pad2d_backward', 'replication_pad3d_backward', - 'masked_scatter', 'masked_select', - 'index_fill', + 'diag', 'masked_scatter', 'masked_select', 'index_fill' } # Some operators invalidate the grad_accumulator. Let's reset it. diff --git a/torch/testing/_internal/common_methods_invocations.py b/torch/testing/_internal/common_methods_invocations.py index 53c317934105..5e4d557877be 100644 --- a/torch/testing/_internal/common_methods_invocations.py +++ b/torch/testing/_internal/common_methods_invocations.py @@ -1114,6 +1114,23 @@ def sample_inputs_fliplr_flipud(op_info, device, dtype, requires_grad): ) return [SampleInput(tensor) for tensor in tensors] +def sample_inputs_diag(op_info, device, dtype, requires_grad): + vec_sample = SampleInput(make_tensor((M, ), device, dtype, low=None, high=None, requires_grad=requires_grad)) + + tensors = ( + make_tensor((M, M), device, dtype, low=None, high=None, requires_grad=requires_grad), + make_tensor((3, 5), device, dtype, low=None, high=None, requires_grad=requires_grad), + make_tensor((5, 3), device, dtype, low=None, high=None, requires_grad=requires_grad), + ) + + args = ((), (2,), (-2,), (1,), (2,)) + + samples = [] + for tensor, arg in product(tensors, args): + samples.append(SampleInput(tensor, args=arg)) + + return samples + [vec_sample] + def sample_inputs_logit(op_info, device, dtype, requires_grad): low, high = op_info.domain @@ -1432,6 +1449,12 @@ def sample_inputs_masked_select(op_info, device, dtype, requires_grad): ), assert_autodiffed=True, safe_casts_outputs=True), + OpInfo('diag', + dtypes=all_types_and_complex_and(torch.bool), + dtypesIfCPU=all_types_and_complex_and(torch.bool), + dtypesIfCUDA=all_types_and_complex_and(torch.bool, torch.half), + sample_inputs_func=sample_inputs_diag, + test_inplace_grad=False), SpectralFuncInfo('fft.fft', aten_name='fft_fft', ref=np.fft.fft, @@ -2865,16 +2888,6 @@ def method_tests(): ('dist', (), ((), 4), 'scalar_4'), ('dist', (S, S, S), ((), 4), 'scalar_4_broadcast_rhs'), ('dist', (), ((S, S, S), 4), 'scalar_4_broadcast_lhs'), - ('diag', (M, M), NO_ARGS, '2d'), - ('diag', (3, 5), NO_ARGS, '2d_wide'), - ('diag', (3, 5), (2,), '2d_wide_pos'), - ('diag', (3, 5), (-2,), '2d_wide_neg'), - ('diag', (5, 3), NO_ARGS, '2d_tall'), - ('diag', (5, 3), (2,), '2d_tall_pos'), - ('diag', (5, 3), (-2,), '2d_tall_neg'), - ('diag', (M,), NO_ARGS, '1d'), - ('diag', (M, M), (1,), '2d_1'), - ('diag', (M, M), (2,), '2d_2'), ('diag_embed', (S, S), NO_ARGS), ('diagonal', (M, M), NO_ARGS, '2d'), ('diagonal', (3, 5), NO_ARGS, '2d_wide'),