Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

OpInfos for torch.atleast_{1d, 2d, 3d} #67355

Closed
wants to merge 6 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
3 changes: 3 additions & 0 deletions test/test_fx_experimental.py
Expand Up @@ -1531,6 +1531,9 @@ def test_normalize_operator_exhaustive(self, device, dtype, op):
'__ror__',
'__rxor__',
"__rmatmul__",
"atleast_1d",
"atleast_2d",
"atleast_3d",
}

# Unsupported input types
Expand Down
49 changes: 49 additions & 0 deletions torch/testing/_internal/common_methods_invocations.py
Expand Up @@ -6601,6 +6601,16 @@ def generator():

return list(generator())

def sample_inputs_atleast1d2d3d(op_info, device, dtype, requires_grad, **kwargs):
input_list = []
shapes = ((S, S, S, S), (S, S, S), (S, S), (S, ), (),)
make_tensor_partial = partial(make_tensor, dtype=dtype, device=device, requires_grad=requires_grad)
samples = []
for shape in shapes:
input_list.append(make_tensor_partial(shape))
samples.append(SampleInput(make_tensor_partial(shape)))
samples.append(SampleInput(input_list, ))
return samples

def sample_inputs_select(op_info, device, dtype, requires_grad, **kwargs):
make_arg = partial(make_tensor, dtype=dtype, device=device, requires_grad=requires_grad)
Expand Down Expand Up @@ -11678,6 +11688,45 @@ def ref_pairwise_distance(input1, input2):
supports_forward_ad=True,
sample_inputs_func=sample_inputs_view_as_reshape_as,
),
OpInfo('atleast_1d',
dtypes=all_types_and_complex_and(torch.bool, torch.float16, torch.bfloat16),
supports_out=False,
supports_forward_ad=True,
sample_inputs_func=sample_inputs_atleast1d2d3d,
skips=(
# JIT does not support variadic tensors.
# RuntimeError: input->type()->kind() == TypeKind::OptionalType
# INTERNAL ASSERT FAILED at "../torch/csrc/jit/passes/utils/check_alias_annotation.cpp":252,
# please report a bug to PyTorch.
DecorateInfo(unittest.skip("Skipped!"), 'TestJit', 'test_variant_consistency_jit', dtypes=[torch.float32]),
),
),
OpInfo('atleast_2d',
dtypes=all_types_and_complex_and(torch.bool, torch.float16, torch.bfloat16),
supports_out=False,
supports_forward_ad=True,
skips=(
# JIT does not support variadic tensors.
# RuntimeError: input->type()->kind() == TypeKind::OptionalType
# INTERNAL ASSERT FAILED at "../torch/csrc/jit/passes/utils/check_alias_annotation.cpp":252,
# please report a bug to PyTorch.
DecorateInfo(unittest.skip("Skipped!"), 'TestJit', 'test_variant_consistency_jit', dtypes=[torch.float32]),
),
sample_inputs_func=sample_inputs_atleast1d2d3d,
),
OpInfo('atleast_3d',
dtypes=all_types_and_complex_and(torch.bool, torch.float16, torch.bfloat16),
supports_out=False,
supports_forward_ad=True,
skips=(
# JIT does not support variadic tensors.
# RuntimeError: input->type()->kind() == TypeKind::OptionalType
# INTERNAL ASSERT FAILED at "../torch/csrc/jit/passes/utils/check_alias_annotation.cpp":252,
# please report a bug to PyTorch.
DecorateInfo(unittest.skip("Skipped!"), 'TestJit', 'test_variant_consistency_jit', dtypes=[torch.float32]),
),
sample_inputs_func=sample_inputs_atleast1d2d3d,
),
OpInfo('pinverse',
op=torch.pinverse,
dtypes=floating_and_complex_types(),
Expand Down