Skip to content

Commit

Permalink
[complex] Enable complex autograd tests for diag (#51268)
Browse files Browse the repository at this point in the history
Summary:
Reference: #33152

Pull Request resolved: #51268

Reviewed By: pbelevich

Differential Revision: D26179236

Pulled By: anjali411

fbshipit-source-id: e9756136eaaced5a8692228a158965f77505e7b9
  • Loading branch information
kshitij12345 authored and facebook-github-bot committed Feb 2, 2021
1 parent 43084d7 commit c39fb97
Show file tree
Hide file tree
Showing 2 changed files with 24 additions and 12 deletions.
3 changes: 1 addition & 2 deletions tools/autograd/gen_variable_type.py
Expand Up @@ -89,8 +89,7 @@
'reflection_pad1d_backward', 'reflection_pad2d_backward',
'replication_pad1d', 'replication_pad2d', 'replication_pad3d',
'replication_pad1d_backward', 'replication_pad2d_backward', 'replication_pad3d_backward',
'masked_scatter', 'masked_select',
'index_fill',
'diag', 'masked_scatter', 'masked_select', 'index_fill'
}

# Some operators invalidate the grad_accumulator. Let's reset it.
Expand Down
33 changes: 23 additions & 10 deletions torch/testing/_internal/common_methods_invocations.py
Expand Up @@ -1114,6 +1114,23 @@ def sample_inputs_fliplr_flipud(op_info, device, dtype, requires_grad):
)
return [SampleInput(tensor) for tensor in tensors]

def sample_inputs_diag(op_info, device, dtype, requires_grad):
vec_sample = SampleInput(make_tensor((M, ), device, dtype, low=None, high=None, requires_grad=requires_grad))

tensors = (
make_tensor((M, M), device, dtype, low=None, high=None, requires_grad=requires_grad),
make_tensor((3, 5), device, dtype, low=None, high=None, requires_grad=requires_grad),
make_tensor((5, 3), device, dtype, low=None, high=None, requires_grad=requires_grad),
)

args = ((), (2,), (-2,), (1,), (2,))

samples = []
for tensor, arg in product(tensors, args):
samples.append(SampleInput(tensor, args=arg))

return samples + [vec_sample]

def sample_inputs_logit(op_info, device, dtype, requires_grad):
low, high = op_info.domain

Expand Down Expand Up @@ -1432,6 +1449,12 @@ def sample_inputs_masked_select(op_info, device, dtype, requires_grad):
),
assert_autodiffed=True,
safe_casts_outputs=True),
OpInfo('diag',
dtypes=all_types_and_complex_and(torch.bool),
dtypesIfCPU=all_types_and_complex_and(torch.bool),
dtypesIfCUDA=all_types_and_complex_and(torch.bool, torch.half),
sample_inputs_func=sample_inputs_diag,
test_inplace_grad=False),
SpectralFuncInfo('fft.fft',
aten_name='fft_fft',
ref=np.fft.fft,
Expand Down Expand Up @@ -2865,16 +2888,6 @@ def method_tests():
('dist', (), ((), 4), 'scalar_4'),
('dist', (S, S, S), ((), 4), 'scalar_4_broadcast_rhs'),
('dist', (), ((S, S, S), 4), 'scalar_4_broadcast_lhs'),
('diag', (M, M), NO_ARGS, '2d'),
('diag', (3, 5), NO_ARGS, '2d_wide'),
('diag', (3, 5), (2,), '2d_wide_pos'),
('diag', (3, 5), (-2,), '2d_wide_neg'),
('diag', (5, 3), NO_ARGS, '2d_tall'),
('diag', (5, 3), (2,), '2d_tall_pos'),
('diag', (5, 3), (-2,), '2d_tall_neg'),
('diag', (M,), NO_ARGS, '1d'),
('diag', (M, M), (1,), '2d_1'),
('diag', (M, M), (2,), '2d_2'),
('diag_embed', (S, S), NO_ARGS),
('diagonal', (M, M), NO_ARGS, '2d'),
('diagonal', (3, 5), NO_ARGS, '2d_wide'),
Expand Down

0 comments on commit c39fb97

Please sign in to comment.