Skip to content
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
48 changes: 37 additions & 11 deletions torch/testing/_internal/common_methods_invocations.py
Original file line number Diff line number Diff line change
Expand Up @@ -1993,6 +1993,26 @@ def sample_inputs_masked_select(op_info, device, dtype, requires_grad, **kwargs)
return samples


def sample_inputs_matmul(op_info, device, dtype, requires_grad):
test_cases = (((L,), (L,)),
((S, M), (M,)),
((M,), (M, S)),
((S, M), (M, S)),
((S, S, M), (M,)),
((S, S, M), (M, S)),
((M,), (S, M, S)),
((S, M), (S, M, S)),
((S, S, M, M), (S, S, M, S)),
((S, S, M, M), (M,)),
((M,), (S, S, M, S)))
sample_inputs = []
for lhs_shape, rhs_shape in test_cases:
lhs = make_tensor(lhs_shape, device, dtype, low=None, high=None, requires_grad=requires_grad)
rhs = make_tensor(rhs_shape, device, dtype, low=None, high=None, requires_grad=requires_grad)
sample_inputs.append(SampleInput(lhs, args=(rhs,)))
return tuple(sample_inputs)


def sample_inputs_polar(op_info, device, dtype, requires_grad, **kwargs):
def _make_tensor_helper(shape, low=None, high=None):
return make_tensor(shape, device, dtype, low=low, high=high, requires_grad=requires_grad)
Expand Down Expand Up @@ -3239,6 +3259,23 @@ def gradcheck_wrapper_triangular_input(op, input, *args, upper=False, **kwargs):
dtypesIfCPU=all_types_and_complex_and(torch.bool, torch.half, torch.bfloat16),
dtypesIfCUDA=all_types_and_complex_and(torch.bool, torch.half, torch.bfloat16),
sample_inputs_func=sample_inputs_masked_select),
OpInfo('matmul',
dtypes=floating_types(),
dtypesIfCPU=all_types_and_complex(),
dtypesIfCUDA=floating_types_and(torch.float16, torch.complex64, torch.complex128,
*[torch.bfloat16] if CUDA11OrLater else []),
dtypesIfROCM=floating_types_and(torch.half),
assert_autodiffed=True,
sample_inputs_func=sample_inputs_matmul,
skips=(
# matmul does not correctly warn when resizing out= inputs
SkipInfo('TestCommon', 'test_out'),
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is correct and doesn't require an issue (the comments alone are enough to track out= behavior, and many ops do not implement their out= behavior properly)

# https://github.com/pytorch/pytorch/issues/55754
SkipInfo('TestGradients', 'test_fn_grad',
device_type='cpu', dtypes=(torch.complex128,)),
# https://github.com/pytorch/pytorch/issues/55755
SkipInfo('TestOpInfo', 'test_unsupported_dtypes',
device_type='cpu', dtypes=(torch.float16,)),)),
OpInfo('max',
op=torch.max,
variant_test_name='binary',
Expand Down Expand Up @@ -4459,17 +4496,6 @@ def method_tests():
('mv', (S, M), ((M,),), '', (True,)),
('inner', (S,), ((S,),), "1d_1d", (False,)),
('inner', (), ((S, S),), "scalar_2d", (False,)),
('matmul', (L,), ((L,),), '', (True,)),
('matmul', (S, M), ((M,),), "2d_1d", (True,)),
('matmul', (M,), ((M, S),), "1d_2d", (True,)),
('matmul', (S, M), ((M, S),), "2d_2d", (True,)),
('matmul', (S, S, M), ((M,),), "3d_1d", (True,)),
('matmul', (S, S, M), ((M, S),), "3d_2d", (True,)),
('matmul', (M,), ((S, M, S),), "1d_3d", (True,)),
('matmul', (S, M), ((S, M, S),), "2d_3d", (True,)),
('matmul', (S, S, M, M), ((S, S, M, S),), "4d_4d", (True,)),
('matmul', (S, S, M, M), ((M,),), "4d_1d", (True,)),
('matmul', (M,), ((S, S, M, S),), "1d_4d", (True,)),
('matrix_exp', (S, S), NO_ARGS, "single_matrix"),
('matrix_exp', (S, S, S), NO_ARGS, "batch_of_matrices"),
('mvlgamma', torch.empty(S,).uniform_(0.5, 1), [1], "p=1"),
Expand Down