Skip to content

Commit

Permalink
OpInfos for torch.atleast_{1d, 2d, 3d} (#67355)
Browse files Browse the repository at this point in the history
Summary:

Test Plan: Imported from OSS

Reviewed By: ejguan

Differential Revision: D32649416

Pulled By: anjali411

fbshipit-source-id: 1b42e86c7124427880fff52fbe490481059da967

[ghstack-poisoned]
  • Loading branch information
PaliC committed Nov 30, 2021
1 parent 4483b4f commit bac03b7
Show file tree
Hide file tree
Showing 2 changed files with 52 additions and 0 deletions.
3 changes: 3 additions & 0 deletions test/test_fx_experimental.py
Expand Up @@ -1531,6 +1531,9 @@ def test_normalize_operator_exhaustive(self, device, dtype, op):
'__ror__',
'__rxor__',
"__rmatmul__",
"atleast_1d",
"atleast_2d",
"atleast_3d",
}

# Unsupported input types
Expand Down
49 changes: 49 additions & 0 deletions torch/testing/_internal/common_methods_invocations.py
Expand Up @@ -6601,6 +6601,16 @@ def generator():

return list(generator())

def sample_inputs_atleast1d2d3d(op_info, device, dtype, requires_grad, **kwargs):
input_list = []
shapes = ((S, S, S, S), (S, S, S), (S, S), (S, ), (),)
make_tensor_partial = partial(make_tensor, dtype=dtype, device=device, requires_grad=requires_grad)
samples = []
for shape in shapes:
input_list.append(make_tensor_partial(shape))
samples.append(SampleInput(make_tensor_partial(shape)))
samples.append(SampleInput(input_list, ))
return samples

def sample_inputs_select(op_info, device, dtype, requires_grad, **kwargs):
make_arg = partial(make_tensor, dtype=dtype, device=device, requires_grad=requires_grad)
Expand Down Expand Up @@ -11678,6 +11688,45 @@ def ref_pairwise_distance(input1, input2):
supports_forward_ad=True,
sample_inputs_func=sample_inputs_view_as_reshape_as,
),
OpInfo('atleast_1d',
dtypes=all_types_and_complex_and(torch.bool, torch.float16, torch.bfloat16),
supports_out=False,
supports_forward_ad=True,
sample_inputs_func=sample_inputs_atleast1d2d3d,
skips=(
# JIT does not support variadic tensors.
# RuntimeError: input->type()->kind() == TypeKind::OptionalType
# INTERNAL ASSERT FAILED at "../torch/csrc/jit/passes/utils/check_alias_annotation.cpp":252,
# please report a bug to PyTorch.
DecorateInfo(unittest.skip("Skipped!"), 'TestJit', 'test_variant_consistency_jit', dtypes=[torch.float32]),
),
),
OpInfo('atleast_2d',
dtypes=all_types_and_complex_and(torch.bool, torch.float16, torch.bfloat16),
supports_out=False,
supports_forward_ad=True,
skips=(
# JIT does not support variadic tensors.
# RuntimeError: input->type()->kind() == TypeKind::OptionalType
# INTERNAL ASSERT FAILED at "../torch/csrc/jit/passes/utils/check_alias_annotation.cpp":252,
# please report a bug to PyTorch.
DecorateInfo(unittest.skip("Skipped!"), 'TestJit', 'test_variant_consistency_jit', dtypes=[torch.float32]),
),
sample_inputs_func=sample_inputs_atleast1d2d3d,
),
OpInfo('atleast_3d',
dtypes=all_types_and_complex_and(torch.bool, torch.float16, torch.bfloat16),
supports_out=False,
supports_forward_ad=True,
skips=(
# JIT does not support variadic tensors.
# RuntimeError: input->type()->kind() == TypeKind::OptionalType
# INTERNAL ASSERT FAILED at "../torch/csrc/jit/passes/utils/check_alias_annotation.cpp":252,
# please report a bug to PyTorch.
DecorateInfo(unittest.skip("Skipped!"), 'TestJit', 'test_variant_consistency_jit', dtypes=[torch.float32]),
),
sample_inputs_func=sample_inputs_atleast1d2d3d,
),
OpInfo('pinverse',
op=torch.pinverse,
dtypes=floating_and_complex_types(),
Expand Down

0 comments on commit bac03b7

Please sign in to comment.