From 0436ea125b3557e6855fb2f36de0aeb8623555b2 Mon Sep 17 00:00:00 2001 From: Peter Bell Date: Fri, 22 Jan 2021 09:30:18 -0800 Subject: [PATCH] OpInfo: Remove promotes_integers_to_float and infer it instead (#50279) Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/50279 This allows different sample inputs to have different behavior for the same operator. For example, `div(..., rounding_mode='true')` will promote but other rounding modes don't. The current boolean flag is too restrictive to allow this. Test Plan: Imported from OSS Reviewed By: ngimel Differential Revision: D25950011 Pulled By: mruberry fbshipit-source-id: 7e82b82bedc626b2b6970d92d5b25676183ec384 --- test/test_ops.py | 3 +- test/test_unary_ufuncs.py | 56 ++++---------- .../_internal/common_methods_invocations.py | 74 +++++++++++-------- torch/testing/_internal/common_utils.py | 2 +- 4 files changed, 60 insertions(+), 75 deletions(-) diff --git a/test/test_ops.py b/test/test_ops.py index 0c57cae8721f..26a3ee69f95a 100644 --- a/test/test_ops.py +++ b/test/test_ops.py @@ -221,8 +221,7 @@ def test_variant_consistency_eager(self, device, dtype, op): for variant in variants: # Verifies that inplace operations that promote int->float fail # on tensors with integer dtypes. - if (variant is inplace and op.promotes_integers_to_float and - dtype in (torch.bool, torch.uint8, torch.int8, torch.int16, torch.int32, torch.int64)): + if (variant is inplace and not torch.can_cast(expected_forward.dtype, dtype)): try: variant_forward = variant(*(clone_input_helper(input) for input in sample.input), *sample.args, diff --git a/test/test_unary_ufuncs.py b/test/test_unary_ufuncs.py index 639e32427060..365c33179206 100644 --- a/test/test_unary_ufuncs.py +++ b/test/test_unary_ufuncs.py @@ -10,8 +10,8 @@ from torch._six import inf, nan from torch.testing._internal.common_utils import ( - TestCase, run_tests, torch_to_numpy_dtype_dict, suppress_warnings, - IS_MACOS, make_tensor, TEST_SCIPY, slowTest, skipIfNoSciPy) + TestCase, run_tests, torch_to_numpy_dtype_dict, numpy_to_torch_dtype_dict, + suppress_warnings, IS_MACOS, make_tensor, TEST_SCIPY, slowTest, skipIfNoSciPy) from torch.testing._internal.common_methods_invocations import ( unary_ufuncs) from torch.testing._internal.common_device_type import ( @@ -19,7 +19,7 @@ onlyCUDA, dtypesIfCUDA, precisionOverride, skipCUDAIfRocm, dtypesIfCPU, OpDTypes) from torch.testing import ( - floating_types_and, integral_types, all_types_and_complex_and, floating_types) + floating_types_and, all_types_and_complex_and, floating_types) if TEST_SCIPY: import scipy @@ -214,7 +214,7 @@ def _fn(t): if alt is None: continue - if inplace and op.promotes_integers_to_float and dtype in integral_types() + (torch.bool,): + if inplace and not torch.can_cast(expected.dtype, dtype): # Assert that RuntimeError is raised # for inplace variant of Operators that # promote integer input to floating dtype. @@ -285,7 +285,7 @@ def test_reference_numerics(self, device, dtype, op): msg = None exact_dtype = True - if op.promotes_integers_to_float and dtype in integral_types() + (torch.bool,): + if not torch.can_cast(numpy_to_torch_dtype_dict[expected.dtype.type], dtype): exact_dtype = False if dtype in [torch.uint8, torch.int8, torch.bool]: @@ -402,53 +402,29 @@ def test_batch_vs_slicing(self, device, dtype, op): self.assertEqual(actual, expected) - def _test_out_arg(self, op, input, output): - dtype = input.dtype - out_dtype = output.dtype - if dtype is out_dtype: - expected = op(input) - op(input, out=output) - self.assertEqual(output, expected) + def _test_out_arg(self, op, input, output, expected): + if op.safe_casts_outputs: + expect_fail = not torch.can_cast(expected.dtype, output.dtype) else: - with self.assertRaises(RuntimeError): - op(input, out=output) + expect_fail = output.dtype != expected.dtype - def _test_out_promote_int_to_float_op(self, op, input, output): - def compare_out(op, input, out): - out_dtype = out.dtype - expected = op(input) - op(input, out=out) - self.assertEqual(out, expected.to(out_dtype)) - - dtype = input.dtype - out_dtype = output.dtype - if out_dtype.is_floating_point and not dtype.is_complex: - compare_out(op, input, output) - elif out_dtype.is_floating_point and dtype.is_complex: - if op.supports_complex_to_float: - compare_out(op, input, output) - else: - # Can't cast complex to float - with self.assertRaises(RuntimeError): - op(input, out=output) - elif out_dtype.is_complex: - compare_out(op, input, output) - else: - # Can't cast to Integral types + if expect_fail: with self.assertRaises(RuntimeError): op(input, out=output) + else: + res = op(input, out=output) + self.assertTrue(res is output) + self.assertEqual(output, expected.to(output.dtype)) @ops(unary_ufuncs, dtypes=OpDTypes.supported) def test_out_arg_all_dtypes(self, device, dtype, op): input = make_tensor((64, 64), dtype=dtype, device=device, low=op.domain[0], high=op.domain[1]) + expected = op(input) for out_dtype in all_types_and_complex_and(torch.bool, torch.half): out = torch.empty_like(input, dtype=out_dtype) - if op.promotes_integers_to_float: - self._test_out_promote_int_to_float_op(op, input, out) - else: - self._test_out_arg(op, input, out) + self._test_out_arg(op, input, out, expected) @dtypes(*(torch.testing.get_all_int_dtypes() + [torch.bool] + torch.testing.get_all_fp_dtypes(include_bfloat16=False))) diff --git a/torch/testing/_internal/common_methods_invocations.py b/torch/testing/_internal/common_methods_invocations.py index 2097ff49c72b..39308bc58e4f 100644 --- a/torch/testing/_internal/common_methods_invocations.py +++ b/torch/testing/_internal/common_methods_invocations.py @@ -60,6 +60,16 @@ def __init__(self, input, *, args=tuple(), kwargs=None, output_process_fn_grad=N self.kwargs = kwargs if kwargs is not None else {} self.output_process_fn_grad = output_process_fn_grad + def __repr__(self): + arguments = [ + f'input[{len(self.input)}]', + f'args={self.args}' if len(self.args) > 0 else None, + f'kwargs={self.kwargs}' if len(self.kwargs) > 0 else None, + (f'output_process_fn_grad={self.output_process_fn_grad}' + if self.output_process_fn_grad is not None else None)] + + return f'SampleInput({", ".join(a for a in arguments if a is not None)})' + _NOTHING = object() # Unique value to distinguish default from anything else @@ -107,7 +117,7 @@ def __init__(self, supports_tensor_out=True, # whether the op supports the out kwarg, returning a Tensor skips=tuple(), # information about which tests to skip decorators=None, # decorators to apply to generated tests - promotes_integers_to_float=False, # whether op promotes unary output to float or not + safe_casts_outputs=False, # whether op allows safe casting when writing to out arguments sample_inputs_func=None, # function to generate sample inputs aten_name=None, # name of the corresponding aten:: operator variant_test_name='', # additional string to include in the test name @@ -141,7 +151,7 @@ def __init__(self, self.test_inplace_grad = test_inplace_grad self.test_complex_grad = test_complex_grad self.supports_tensor_out = supports_tensor_out - self.promotes_integers_to_float = promotes_integers_to_float + self.safe_casts_outputs = safe_casts_outputs self.skips = skips self.decorators = decorators @@ -959,7 +969,7 @@ def sample_inputs_fliplr_flipud(op_info, device, dtype, requires_grad): decorators=(precisionOverride({torch.float16: 1e-2, torch.bfloat16: 1e-1, torch.complex64: 1e-2}),), - promotes_integers_to_float=True, + safe_casts_outputs=True, skips=( SkipInfo('TestUnaryUfuncs', 'test_reference_numerics', device_type='cpu', dtypes=[torch.cfloat, torch.cdouble]), @@ -979,7 +989,7 @@ def sample_inputs_fliplr_flipud(op_info, device, dtype, requires_grad): dtypes=all_types_and_complex_and(torch.bool), dtypesIfCPU=all_types_and_complex_and(torch.bool), dtypesIfCUDA=all_types_and_complex_and(torch.bool, torch.half, torch.bfloat16), - promotes_integers_to_float=True, + safe_casts_outputs=True, decorators=(precisionOverride({torch.bfloat16: 5e-2}),), test_inplace_grad=False, skips=( @@ -1026,7 +1036,7 @@ def sample_inputs_fliplr_flipud(op_info, device, dtype, requires_grad): domain=(-1, 1), supports_sparse=True, decorators=(precisionOverride({torch.bfloat16: 1e-2}),), - promotes_integers_to_float=True, + safe_casts_outputs=True, dtypes=all_types_and_complex_and(torch.bool), dtypesIfCPU=all_types_and_complex_and(torch.bool, torch.bfloat16), dtypesIfCUDA=all_types_and_complex_and(torch.bool, torch.half), @@ -1045,7 +1055,7 @@ def sample_inputs_fliplr_flipud(op_info, device, dtype, requires_grad): dtypes=all_types_and_complex_and(torch.bool), dtypesIfCPU=all_types_and_complex_and(torch.bool), dtypesIfCUDA=all_types_and_complex_and(torch.bool, torch.half, torch.bfloat16), - promotes_integers_to_float=True, + safe_casts_outputs=True, decorators=(precisionOverride({torch.bfloat16: 5e-2}),), test_inplace_grad=False, skips=( @@ -1066,7 +1076,7 @@ def sample_inputs_fliplr_flipud(op_info, device, dtype, requires_grad): assert_autodiffed=True, skip_bfloat16_grad=True, decorators=(precisionOverride({torch.bfloat16: 1e-2}),), - promotes_integers_to_float=True, + safe_casts_outputs=True, skips=( SkipInfo('TestUnaryUfuncs', 'test_reference_numerics', device_type='cpu', dtypes=[torch.cfloat, torch.cdouble]), @@ -1080,7 +1090,7 @@ def sample_inputs_fliplr_flipud(op_info, device, dtype, requires_grad): dtypes=all_types_and_complex_and(torch.bool), dtypesIfCPU=all_types_and_complex_and(torch.bool), dtypesIfCUDA=all_types_and_complex_and(torch.bool, torch.half, torch.bfloat16), - promotes_integers_to_float=True, + safe_casts_outputs=True, decorators=(precisionOverride({torch.bfloat16: 1e-2}),), test_inplace_grad=False, skips=( @@ -1103,7 +1113,7 @@ def sample_inputs_fliplr_flipud(op_info, device, dtype, requires_grad): assert_autodiffed=True, skip_bfloat16_grad=True, handles_large_floats=False, - promotes_integers_to_float=True, + safe_casts_outputs=True, decorators=(precisionOverride({torch.bfloat16: 1e-2}),), skips=( SkipInfo('TestUnaryUfuncs', 'test_reference_numerics', @@ -1115,7 +1125,7 @@ def sample_inputs_fliplr_flipud(op_info, device, dtype, requires_grad): ref=np_unary_ufunc_integer_promotion_wrapper(np.cosh), dtypesIfCPU=all_types_and_complex_and(torch.bool), dtypesIfCUDA=all_types_and_complex_and(torch.bool, torch.half), - promotes_integers_to_float=True, + safe_casts_outputs=True, assert_autodiffed=True, skips=( # Reference: https://github.com/pytorch/pytorch/issues/48641 @@ -1141,7 +1151,7 @@ def sample_inputs_fliplr_flipud(op_info, device, dtype, requires_grad): device_type='cpu', dtypes=[torch.cfloat, torch.cdouble]), ), assert_autodiffed=True, - promotes_integers_to_float=True), + safe_casts_outputs=True), SpectralFuncInfo('fft.fft', aten_name='fft_fft', ref=np.fft.fft, @@ -1288,7 +1298,7 @@ def sample_inputs_fliplr_flipud(op_info, device, dtype, requires_grad): dtypesIfCUDA=all_types_and_complex_and(torch.bool, torch.half, torch.bfloat16), assert_autodiffed=True, skip_bfloat16_grad=True, - promotes_integers_to_float=True, + safe_casts_outputs=True, decorators=(precisionOverride({torch.bfloat16: 5e-2}),), skips=( SkipInfo('TestUnaryUfuncs', 'test_reference_numerics', @@ -1308,7 +1318,7 @@ def sample_inputs_fliplr_flipud(op_info, device, dtype, requires_grad): assert_autodiffed=True, skip_bfloat16_grad=True, dtypesIfCUDA=all_types_and_complex_and(torch.bool, torch.half, torch.bfloat16), - promotes_integers_to_float=True, + safe_casts_outputs=True, skips=( SkipInfo('TestUnaryUfuncs', 'test_reference_numerics', device_type='cuda', dtypes=[torch.cfloat, torch.cdouble]), @@ -1322,7 +1332,7 @@ def sample_inputs_fliplr_flipud(op_info, device, dtype, requires_grad): dtypesIfCPU=all_types_and(torch.bool, torch.bfloat16), dtypesIfCUDA=all_types_and(torch.bool, torch.half, torch.bfloat16), decorators=(precisionOverride({torch.bfloat16: 1e-1}),), - promotes_integers_to_float=True, + safe_casts_outputs=True, assert_autodiffed=True, skip_bfloat16_grad=True), UnaryUfuncInfo('log2', @@ -1333,7 +1343,7 @@ def sample_inputs_fliplr_flipud(op_info, device, dtype, requires_grad): dtypesIfCUDA=all_types_and_complex_and(torch.bool, torch.half, torch.bfloat16), assert_autodiffed=True, skip_bfloat16_grad=True, - promotes_integers_to_float=True, + safe_casts_outputs=True, decorators=(precisionOverride({torch.bfloat16: 1e-1}),), skips=( SkipInfo('TestUnaryUfuncs', 'test_reference_numerics', @@ -1357,7 +1367,7 @@ def sample_inputs_fliplr_flipud(op_info, device, dtype, requires_grad): skip_bfloat16_grad=True, handles_large_floats=False, handles_complex_extremals=False, - promotes_integers_to_float=True, + safe_casts_outputs=True, decorators=(precisionOverride({torch.bfloat16: 1e-2}),), skips=( SkipInfo('TestUnaryUfuncs', 'test_reference_numerics', @@ -1371,7 +1381,7 @@ def sample_inputs_fliplr_flipud(op_info, device, dtype, requires_grad): skip_bfloat16_grad=True, handles_large_floats=False, handles_complex_extremals=False, - promotes_integers_to_float=True, + safe_casts_outputs=True, decorators=(precisionOverride({torch.bfloat16: 1e-2, torch.float16: 1e-2}),), skips=( @@ -1385,7 +1395,7 @@ def sample_inputs_fliplr_flipud(op_info, device, dtype, requires_grad): ref=np_unary_ufunc_integer_promotion_wrapper(np.sinh), dtypesIfCPU=all_types_and_complex_and(torch.bool), dtypesIfCUDA=all_types_and_complex_and(torch.bool, torch.half), - promotes_integers_to_float=True, + safe_casts_outputs=True, assert_autodiffed=True, decorators=(precisionOverride({torch.float16: 1e-2}),), skips=( @@ -1408,7 +1418,7 @@ def sample_inputs_fliplr_flipud(op_info, device, dtype, requires_grad): dtypesIfCUDA=all_types_and_complex_and(torch.bool, torch.half), assert_autodiffed=True, skip_bfloat16_grad=True, - promotes_integers_to_float=True, + safe_casts_outputs=True, skips=( SkipInfo('TestUnaryUfuncs', 'test_reference_numerics', device_type='cuda', dtypes=[torch.cfloat, torch.cdouble]), @@ -1429,7 +1439,7 @@ def sample_inputs_fliplr_flipud(op_info, device, dtype, requires_grad): dtypesIfCUDA=all_types_and_complex_and(torch.bool, torch.half, torch.bfloat16), assert_autodiffed=True, skip_bfloat16_grad=True, - promotes_integers_to_float=True, + safe_casts_outputs=True, skips=( SkipInfo('TestUnaryUfuncs', 'test_reference_numerics', device_type='cuda', dtypes=[torch.cfloat, torch.cdouble]), @@ -1449,13 +1459,13 @@ def sample_inputs_fliplr_flipud(op_info, device, dtype, requires_grad): dtypes=all_types_and(torch.bool, torch.half), dtypesIfCPU=all_types_and(torch.bool, torch.half), dtypesIfCUDA=all_types_and(torch.bool, torch.half), - promotes_integers_to_float=True), + safe_casts_outputs=True), UnaryUfuncInfo('expm1', ref=np_unary_ufunc_integer_promotion_wrapper(np.expm1), dtypes=all_types_and(torch.bool, torch.half), dtypesIfCPU=all_types_and(torch.bool, torch.bfloat16), dtypesIfCUDA=all_types_and(torch.bool, torch.half), - promotes_integers_to_float=True, + safe_casts_outputs=True, assert_autodiffed=True, skips=( # Reference: https://github.com/pytorch/pytorch/pull/48926#issuecomment-739734774 @@ -1474,7 +1484,7 @@ def sample_inputs_fliplr_flipud(op_info, device, dtype, requires_grad): dtypesIfCUDA=None, assert_autodiffed=True, skip_bfloat16_grad=True, - promotes_integers_to_float=True, + safe_casts_outputs=True, skips=( # Reference: https://github.com/pytorch/pytorch/issues/45690 SkipInfo('TestUnaryUfuncs', 'test_reference_numerics', @@ -1490,7 +1500,7 @@ def sample_inputs_fliplr_flipud(op_info, device, dtype, requires_grad): dtypesIfCPU=all_types_and_complex_and(torch.bool), dtypesIfCUDA=all_types_and_complex_and(torch.bool, torch.half), decorators=(precisionOverride({torch.half: 5e-2}),), - promotes_integers_to_float=True, + safe_casts_outputs=True, assert_autodiffed=True, handles_complex_extremals=False), UnaryUfuncInfo('sqrt', @@ -1511,7 +1521,7 @@ def sample_inputs_fliplr_flipud(op_info, device, dtype, requires_grad): # Reference: https://github.com/pytorch/pytorch/pull/47293#issuecomment-721774436 SkipInfo('TestUnaryUfuncs', 'test_reference_numerics', dtypes=[torch.bfloat16])), - promotes_integers_to_float=True, + safe_casts_outputs=True, handles_complex_extremals=False), OpInfo('linalg.inv', aten_name='linalg_inv', @@ -1530,7 +1540,7 @@ def sample_inputs_fliplr_flipud(op_info, device, dtype, requires_grad): dtypesIfROCM=all_types_and_complex_and(torch.bool), decorators=(precisionOverride({torch.float16: 1e-2, torch.bfloat16: 1e-2}),), - promotes_integers_to_float=True, + safe_casts_outputs=True, supports_complex_to_float=True, test_inplace_grad=False), OpInfo('linalg.solve', @@ -1709,7 +1719,7 @@ def reference_sigmoid(x): dtypes=all_types_and_complex_and(torch.bool, torch.bfloat16), dtypesIfCPU=all_types_and_complex_and(torch.bool, torch.bfloat16), dtypesIfCUDA=all_types_and(torch.bool, torch.half, torch.bfloat16), - promotes_integers_to_float=True, + safe_casts_outputs=True, assert_autodiffed=True, test_complex_grad=False), # Reference: https://github.com/pytorch/pytorch/issues/48552 UnaryUfuncInfo('digamma', @@ -1724,7 +1734,7 @@ def reference_sigmoid(x): # in float16 and NaN's can't be tested for equality. SkipInfo('TestCommon', 'test_variant_consistency_jit', device_type='cuda', dtypes=[torch.float16]),), - promotes_integers_to_float=True), + safe_casts_outputs=True), UnaryUfuncInfo('erf', ref=scipy.special.erf, decorators=(precisionOverride({torch.float16: 1e-2, @@ -1737,7 +1747,7 @@ def reference_sigmoid(x): SkipInfo('TestCommon', 'test_variant_consistency_jit', dtypes=[torch.bfloat16]),), assert_autodiffed=True, - promotes_integers_to_float=True), + safe_casts_outputs=True), UnaryUfuncInfo('erfc', ref=scipy.special.erfc, decorators=(precisionOverride({torch.float16: 1e-2, @@ -1750,7 +1760,7 @@ def reference_sigmoid(x): SkipInfo('TestCommon', 'test_variant_consistency_jit', dtypes=[torch.bfloat16]),), assert_autodiffed=True, - promotes_integers_to_float=True), + safe_casts_outputs=True), UnaryUfuncInfo('erfinv', ref=scipy.special.erfinv, decorators=(precisionOverride({torch.float16: 1e-2, @@ -1759,7 +1769,7 @@ def reference_sigmoid(x): dtypes=all_types_and(torch.bool), dtypesIfCPU=all_types_and(torch.bool, torch.bfloat16), dtypesIfCUDA=all_types_and(torch.bool, torch.half), - promotes_integers_to_float=True, + safe_casts_outputs=True, domain=(-1, 1), skips=( # Reference: https://github.com/pytorch/pytorch/pull/49155#issuecomment-742664611 @@ -1776,7 +1786,7 @@ def reference_sigmoid(x): dtypesIfCUDA=all_types_and(torch.bool, torch.half, torch.bfloat16), test_inplace_grad=True, supports_tensor_out=True, - promotes_integers_to_float=True, + safe_casts_outputs=True, sample_inputs_func=sample_inputs_xlogy), ] op_db = op_db + op_db_scipy_reference diff --git a/torch/testing/_internal/common_utils.py b/torch/testing/_internal/common_utils.py index 239355a0bdda..a4ae8b77fe87 100644 --- a/torch/testing/_internal/common_utils.py +++ b/torch/testing/_internal/common_utils.py @@ -382,7 +382,7 @@ def _check_module_exists(name): # Dict of NumPy dtype -> torch dtype (when the correspondence exists) numpy_to_torch_dtype_dict = { - np.bool : torch.bool, + np.bool_ : torch.bool, np.uint8 : torch.uint8, np.int8 : torch.int8, np.int16 : torch.int16,