From 1cc6780eae615765f500c21c6acf6cfa9e9c3536 Mon Sep 17 00:00:00 2001 From: Philip Meier Date: Tue, 8 Feb 2022 09:03:24 +0100 Subject: [PATCH 01/10] only compare attributes for meta tensors [ghstack-poisoned] --- torch/testing/_comparison.py | 27 +++++---------------------- 1 file changed, 5 insertions(+), 22 deletions(-) diff --git a/torch/testing/_comparison.py b/torch/testing/_comparison.py index 993b3d1d5cb9f..80437f1f82f42 100644 --- a/torch/testing/_comparison.py +++ b/torch/testing/_comparison.py @@ -598,29 +598,12 @@ def _check_supported(self, tensor: torch.Tensor, *, id: Tuple[Any, ...]) -> None def compare(self) -> None: actual, expected = self.actual, self.expected - with self._handle_meta_tensor_data_access(): - self._compare_attributes(actual, expected) - actual, expected = self._equalize_attributes(actual, expected) - - self._compare_values(actual, expected) - - @contextlib.contextmanager - def _handle_meta_tensor_data_access(self): - """Turns a vanilla :class:`NotImplementedError` stemming from data access on a meta tensor into an expressive - :class:`ErrorMeta`. - - Although it looks like meta tensors could be handled upfront, we need to do it lazily: there are use cases - where a meta tensor wraps a data tensors and dispatches all operator calls to it. Thus, although the tensor is - a meta tensor, it behaves like a regular one. - """ - try: - yield - except NotImplementedError as error: - if "meta" not in str(error).lower(): - raise error + self._compare_attributes(actual, expected) + if any(input.device.type == "meta" for input in (actual, expected)): + return - # TODO: See https://github.com/pytorch/pytorch/issues/68592 - raise self._make_error_meta(NotImplementedError, "Comparing meta tensors is currently not supported.") + actual, expected = self._equalize_attributes(actual, expected) + self._compare_values(actual, expected) def _compare_attributes( self, From d96cadff2c6b68c2be772b3ad3edf545d2520fa4 Mon Sep 17 00:00:00 2001 From: Philip Meier Date: Tue, 8 Feb 2022 14:52:18 +0100 Subject: [PATCH 02/10] Update on "only compare attributes for meta tensors" Todo: - [ ] document this behavior - [ ] add tests [ghstack-poisoned] --- test/test_binary_ufuncs.py | 3 ++- test/test_testing.py | 5 ++--- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/test/test_binary_ufuncs.py b/test/test_binary_ufuncs.py index 0bd2a9e4d527b..ba37096016fca 100644 --- a/test/test_binary_ufuncs.py +++ b/test/test_binary_ufuncs.py @@ -20,7 +20,7 @@ from torch.testing._internal.common_device_type import ( expectedFailureMeta, instantiate_device_type_tests, onlyCUDA, onlyCPU, dtypes, dtypesIfCUDA, dtypesIfCPU, deviceCountAtLeast, precisionOverride, onlyNativeDeviceTypes, - skipCUDAIfRocm, skipIf, ops, OpDTypes) + skipCUDAIfRocm, skipIf, ops, OpDTypes, skipMeta) from torch.testing import make_tensor from torch.testing._internal.common_dtype import ( all_types_and_complex_and, integral_types_and, get_all_dtypes, get_all_int_dtypes, get_all_math_dtypes, @@ -3393,6 +3393,7 @@ def test_empty_x(sizes, dim, x, device): TypeError, 'received an invalid combination of arguments'): actual = torch.cumulative_trapezoid(torch.randn((3, 3)), x=torch.randn((3, 3)), dx=3) + @skipMeta @dtypes(torch.double) def test_pow_scalar_overloads_mem_overlap(self, device, dtype): sz = 3 diff --git a/test/test_testing.py b/test/test_testing.py index 3cfef8cee395c..1fe06a2293400 100644 --- a/test/test_testing.py +++ b/test/test_testing.py @@ -574,11 +574,10 @@ def test_unknown_layout(self): def test_meta(self): actual = torch.empty((2, 2), device="meta") - expected = actual.clone() + expected = torch.empty((2, 2), device="meta") for fn in assert_close_with_inputs(actual, expected): - with self.assertRaisesRegex(NotImplementedError, "meta"): - fn() + fn() def test_mismatching_layout(self): strided = torch.empty((2, 2)) From 93c7d028cc0c2352a55d50149b18960110f66ff9 Mon Sep 17 00:00:00 2001 From: Philip Meier Date: Tue, 8 Feb 2022 15:17:58 +0100 Subject: [PATCH 03/10] Update on "only compare attributes for meta tensors" Todo: - [ ] document this behavior - [x] add tests [ghstack-poisoned] --- test/test_binary_ufuncs.py | 1 + 1 file changed, 1 insertion(+) diff --git a/test/test_binary_ufuncs.py b/test/test_binary_ufuncs.py index ba37096016fca..b15fb21944f01 100644 --- a/test/test_binary_ufuncs.py +++ b/test/test_binary_ufuncs.py @@ -1497,6 +1497,7 @@ def test_complex_scalar_pow_tensor(self, device, dtype): self._test_pow(base, second_exp) @onlyNativeDeviceTypes + @skipMeta def test_pow_scalar_type_promotion(self, device): # Test against a scalar and non-scalar input inputs = [17, [17]] From 52622c06bc272860acf5983ac5718976fd38cedb Mon Sep 17 00:00:00 2001 From: Philip Meier Date: Tue, 8 Feb 2022 20:03:45 +0100 Subject: [PATCH 04/10] Update on "only compare attributes for meta tensors" Todo: - [ ] document this behavior - [x] add tests [ghstack-poisoned] --- test/test_modules.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/test/test_modules.py b/test/test_modules.py index 448f8f5fa7518..b3d658a5bc5dc 100644 --- a/test/test_modules.py +++ b/test/test_modules.py @@ -8,7 +8,7 @@ import torch from torch.testing._internal.common_device_type import ( - instantiate_device_type_tests, onlyCUDA, toleranceOverride, tol) + instantiate_device_type_tests, onlyCUDA, toleranceOverride, tol, skipMeta) from torch.testing._internal.common_modules import module_db, modules from torch.testing._internal.common_utils import ( TestCase, run_tests, freeze_rng_state, mock_wrapper, get_tensors_from, gradcheck, gradgradcheck) @@ -233,6 +233,7 @@ def test_pickle(self, device, dtype, module_info): @modules([module_info for module_info in module_db if 'inplace' in signature(module_info.module_cls).parameters]) + @skipMeta def test_check_inplace(self, device, dtype, module_info): # Check if the inplace variant of the module gives the same result as the out of place # variant. From 704252c0e3a422ce42d13b1c0425d3b5795b4a8b Mon Sep 17 00:00:00 2001 From: Philip Meier Date: Wed, 9 Feb 2022 09:04:24 +0100 Subject: [PATCH 05/10] Update on "only compare attributes for meta tensors" Todo: - [ ] document this behavior - [x] add tests [ghstack-poisoned] --- .../_internal/common_methods_invocations.py | 65 ++++++++++++++----- 1 file changed, 49 insertions(+), 16 deletions(-) diff --git a/torch/testing/_internal/common_methods_invocations.py b/torch/testing/_internal/common_methods_invocations.py index ebb8be16fb36c..d078b2d93d6b7 100644 --- a/torch/testing/_internal/common_methods_invocations.py +++ b/torch/testing/_internal/common_methods_invocations.py @@ -14415,9 +14415,21 @@ def ref_pairwise_distance(input1, input2): # These paths have different dtype support. Also JIT supports, # most variants but not all of them. So we split the OpInfo entries, # for `norm` based on the code-paths and JIT support. - OpInfo('norm', - sample_inputs_func=sample_inputs_norm, - dtypes=floating_and_complex_types_and(torch.float16, torch.bfloat16)), + OpInfo( + "norm", + sample_inputs_func=sample_inputs_norm, + dtypes=floating_and_complex_types_and(torch.float16, torch.bfloat16), + skips=( + # AssertionError: RuntimeError not raised : Expected RuntimeError when doing an unsafe cast from a result + # of dtype torch.float32 into an out= with dtype torch.long + DecorateInfo( + unittest.expectedFailure, + "TestCommon", + "test_out", + device_type="meta", + ), + ), + ), OpInfo('norm', variant_test_name='nuc', sample_inputs_func=sample_inputs_norm_nuc, @@ -14452,19 +14464,40 @@ def ref_pairwise_distance(input1, input2): # Arguments for call are not valid. DecorateInfo(unittest.expectedFailure, 'TestJit', 'test_variant_consistency_jit', dtypes=(torch.complex64, torch.float32,)), # noqa: B950 )), - OpInfo('norm', - variant_test_name='inf', - sample_inputs_func=sample_inputs_norm_inf, - dtypes=floating_and_complex_types_and(torch.float16, torch.bfloat16), - backward_dtypesIfCPU=floating_and_complex_types_and(torch.float16, torch.bfloat16), - skips=( - # https://github.com/pytorch/pytorch/issues/67517 - DecorateInfo(unittest.skip("Skipped!"), 'TestCommon', 'test_noncontiguous_samples'), - # following 2 tests failed intermittenly - DecorateInfo(unittest.skip("Skipped!"), 'TestGradients', 'test_fn_grad', device_type='cpu', dtypes=(torch.complex128,)), # noqa: B950 - DecorateInfo(unittest.skip("Skipped!"), 'TestGradients', 'test_fn_gradgrad', device_type='cpu', dtypes=(torch.complex128,)), # noqa: B950 - ) - ), + OpInfo( + "norm", + variant_test_name="inf", + sample_inputs_func=sample_inputs_norm_inf, + dtypes=floating_and_complex_types_and(torch.float16, torch.bfloat16), + backward_dtypesIfCPU=floating_and_complex_types_and(torch.float16, torch.bfloat16), + skips=( + # https://github.com/pytorch/pytorch/issues/67517 + DecorateInfo(unittest.skip("Skipped!"), "TestCommon", "test_noncontiguous_samples"), + # following 2 tests failed intermittenly + DecorateInfo( + unittest.skip("Skipped!"), + "TestGradients", + "test_fn_grad", + device_type="cpu", + dtypes=(torch.complex128,), + ), + DecorateInfo( + unittest.skip("Skipped!"), + "TestGradients", + "test_fn_gradgrad", + device_type="cpu", + dtypes=(torch.complex128,), + ), + # AssertionError: RuntimeError not raised : Expected RuntimeError when doing an unsafe cast from a result + # of dtype torch.float32 into an out= with dtype torch.long + DecorateInfo( + unittest.expectedFailure, + "TestCommon", + "test_out", + device_type="meta", + ), + ), + ), OpInfo('t', sample_inputs_func=sample_inputs_t, supports_out=False, From 6cce9614ebb867c97b9ac198b695c99e17335f7c Mon Sep 17 00:00:00 2001 From: Philip Meier Date: Wed, 9 Feb 2022 14:31:16 +0100 Subject: [PATCH 06/10] Update on "only compare attributes for meta tensors" Todo: - [ ] document this behavior - [x] add tests [ghstack-poisoned] --- test/test_tensor_creation_ops.py | 55 ++++++++++++++++++-------------- 1 file changed, 31 insertions(+), 24 deletions(-) diff --git a/test/test_tensor_creation_ops.py b/test/test_tensor_creation_ops.py index 62d595373b3ac..cdcd693b85838 100644 --- a/test/test_tensor_creation_ops.py +++ b/test/test_tensor_creation_ops.py @@ -14,7 +14,7 @@ from torch.testing._internal.common_utils import ( TestCase, run_tests, do_test_empty_full, TEST_WITH_ROCM, suppress_warnings, torch_to_numpy_dtype_dict, slowTest, - TEST_SCIPY, IS_MACOS, IS_PPC, IS_WINDOWS) + TEST_SCIPY, IS_MACOS, IS_PPC, IS_WINDOWS, parametrize) from torch.testing._internal.common_device_type import ( expectedFailureMeta, instantiate_device_type_tests, deviceCountAtLeast, onlyNativeDeviceTypes, onlyCPU, largeTensorTest, precisionOverride, dtypes, @@ -2786,36 +2786,43 @@ def test_tensor_ctor_device_inference(self, device): sparse_size, dtype=torch.float64) self.assertEqual(sparse_with_dtype.device, torch.device('cpu')) + def _test_signal_window_functions(self, name, dtype, device, **kwargs): + import scipy.signal as signal + + torch_method = getattr(torch, name + '_window') + if not dtype.is_floating_point: + with self.assertRaisesRegex(RuntimeError, r'floating point'): + torch_method(3, dtype=dtype) + return + for size in [0, 1, 2, 5, 10, 50, 100, 1024, 2048]: + for periodic in [True, False]: + res = torch_method(size, periodic=periodic, **kwargs, device=device, dtype=dtype) + # NB: scipy always returns a float64 result + ref = torch.from_numpy(signal.get_window((name, *(kwargs.values())), size, fftbins=periodic)) + self.assertEqual(res, ref, exact_dtype=False) + with self.assertRaisesRegex(RuntimeError, r'not implemented for sparse types'): + torch_method(3, layout=torch.sparse_coo) + self.assertTrue(torch_method(3, requires_grad=True).requires_grad) + self.assertFalse(torch_method(3).requires_grad) + @onlyNativeDeviceTypes @precisionOverride({torch.bfloat16: 5e-2, torch.half: 1e-3}) @unittest.skipIf(not TEST_SCIPY, "Scipy not found") @dtypesIfCUDA(torch.float, torch.double, torch.bfloat16, torch.half, torch.long) @dtypes(torch.float, torch.double, torch.long) - def test_signal_window_functions(self, device, dtype): - import scipy.signal as signal - - def test(name, kwargs): - torch_method = getattr(torch, name + '_window') - if not dtype.is_floating_point: - with self.assertRaisesRegex(RuntimeError, r'floating point'): - torch_method(3, dtype=dtype) - return - for size in [0, 1, 2, 5, 10, 50, 100, 1024, 2048]: - for periodic in [True, False]: - res = torch_method(size, periodic=periodic, **kwargs, device=device, dtype=dtype) - # NB: scipy always returns a float64 result - ref = torch.from_numpy(signal.get_window((name, *(kwargs.values())), size, fftbins=periodic)) - self.assertEqual(res, ref, exact_dtype=False) - with self.assertRaisesRegex(RuntimeError, r'not implemented for sparse types'): - torch_method(3, layout=torch.sparse_coo) - self.assertTrue(torch_method(3, requires_grad=True).requires_grad) - self.assertFalse(torch_method(3).requires_grad) - - for window in ['hann', 'hamming', 'bartlett', 'blackman']: - test(window, kwargs={}) + @parametrize("window", ['hann', 'hamming', 'bartlett', 'blackman']) + def test_signal_window_functions(self, device, dtype, window): + self._test_signal_window_functions(window, dtype, device) + @onlyNativeDeviceTypes + @expectedFailureMeta + @precisionOverride({torch.bfloat16: 5e-2, torch.half: 1e-3}) + @unittest.skipIf(not TEST_SCIPY, "Scipy not found") + @dtypesIfCUDA(torch.float, torch.double, torch.bfloat16, torch.half, torch.long) + @dtypes(torch.float, torch.double, torch.long) + def test_kaiser_window(self, device, dtype): for num_test in range(50): - test('kaiser', kwargs={'beta': random.random() * 30}) + self._test_signal_window_functions('kaiser', dtype, device, beta=random.random() * 30) def test_tensor_factories_empty(self, device): # ensure we can create empty tensors from each factory function From 989e81d07fc7c4146cb292361f9eef05b1467950 Mon Sep 17 00:00:00 2001 From: Philip Meier Date: Wed, 9 Feb 2022 23:27:35 +0100 Subject: [PATCH 07/10] Update on "only compare attributes for meta tensors" Todo: - [ ] document this behavior - [x] add tests [ghstack-poisoned] --- test/test_tensor_creation_ops.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/test/test_tensor_creation_ops.py b/test/test_tensor_creation_ops.py index cdcd693b85838..68ddec1471189 100644 --- a/test/test_tensor_creation_ops.py +++ b/test/test_tensor_creation_ops.py @@ -2815,7 +2815,8 @@ def test_signal_window_functions(self, device, dtype, window): self._test_signal_window_functions(window, dtype, device) @onlyNativeDeviceTypes - @expectedFailureMeta + # See https://github.com/pytorch/pytorch/issues/72630 + @skipMeta @precisionOverride({torch.bfloat16: 5e-2, torch.half: 1e-3}) @unittest.skipIf(not TEST_SCIPY, "Scipy not found") @dtypesIfCUDA(torch.float, torch.double, torch.bfloat16, torch.half, torch.long) From d7aa62d4868a00fb78ce7e66c347864ef403e34a Mon Sep 17 00:00:00 2001 From: Philip Meier Date: Thu, 10 Feb 2022 09:02:51 +0100 Subject: [PATCH 08/10] Update on "only compare attributes for meta tensors" Todo: - [ ] document this behavior - [x] add tests [ghstack-poisoned] --- test/test_torch.py | 274 ++++++++++++++++++++++----------------------- 1 file changed, 137 insertions(+), 137 deletions(-) diff --git a/test/test_torch.py b/test/test_torch.py index 164e6585f1642..3c8a74c1d1f9a 100644 --- a/test/test_torch.py +++ b/test/test_torch.py @@ -37,7 +37,7 @@ skipCUDAMemoryLeakCheckIf, BytesIOContext, noarchTest, skipIfRocm, skipIfNoSciPy, TemporaryFileName, TemporaryDirectoryName, wrapDeterministicFlagAPITest, DeterministicGuard, CudaSyncGuard, - skipIfNotRegistered, bytes_to_scalar) + skipIfNotRegistered, bytes_to_scalar, parametrize) from multiprocessing.reduction import ForkingPickler from torch.testing._internal.common_device_type import ( expectedFailureMeta, @@ -793,158 +793,158 @@ def test_is_set_to(self, device): self.assertFalse(t1.is_set_to(t2)) self.assertFalse(t2.is_set_to(t1)) - def test_broadcast(self, device): - - # all functions - fns = { - "dist", "atan2", "pow", "lerp", "add", - "sub", "mul", "div", "fmod", "remainder", - "eq", "ge", "gt", "le", "lt", "max", "min", "ne", - "addcdiv", "addcmul", "masked_scatter", "masked_select", "masked_fill", - "map", "map2", "copy" - } + # See https://github.com/pytorch/pytorch/issues/72650 + @skipMeta + @parametrize( + "fn", + [ + "dist", "atan2", "pow", "lerp", "add", "sub", "mul", "div", "fmod", "remainder", "eq", "ge", "gt", "le", + "lt", "max", "min", "ne", "addcdiv", "addcmul", "masked_scatter", "masked_select", "masked_fill", "map", + "map2", "copy", + ], + ) + def test_broadcast(self, fn, device): # functions with three tensor arguments fns_3_args = {"map2"} fns_value_kwarg = {"addcdiv", "addcmul"} - for fn in fns: - (dims_small, dims_large, dims_full) = self._select_broadcastable_dims() - full1d = torch.randn(*dims_full, device=device).flatten().float() - small = torch.randn(*dims_small, device=device).float() - large = torch.randn(*dims_large, device=device).float() - small_expanded = small.expand(*dims_full) - large_expanded = large.expand(*dims_full) - small2 = None - small2_expanded = None - if fn in fns_3_args or fn in fns_value_kwarg: - # create another smaller tensor - (dims_small2, _, _) = self._select_broadcastable_dims(dims_full) - small2 = torch.randn(*dims_small2, device=device).float() - small2_expanded = small2.expand(*dims_full) - - if small.is_cuda and fn in ['map', 'map2']: - # map and map2 are not implementd on CUDA tensors - continue - - if hasattr(large_expanded, fn): - # run through tensor versions of functions - # and verify fully expanded inputs give same results - expanded = {large: large_expanded, small: small_expanded, small2: small2_expanded} - - def tensorfn(myfn, t1, t2): - if fn == "lerp": - return myfn(t1, 0.5) - elif fn == "masked_select": - return myfn(t1 < 0) - elif fn == "masked_scatter": - return myfn(t1 < 0.5, full1d) - elif fn == "masked_fill": - return myfn(t1 < 0.5, 1.0) - elif fn in fns_3_args: - return myfn(1, t1, t2) - elif fn in fns_value_kwarg: - return myfn(t1, t2, value=1) - else: - return myfn(t1) - - # test various orders - for first, second, third in [(large, small, small2), (small, large, small2), - (small2, small, large), (small2, large, small)]: - if first is None: - break # ignore last iter when small2 is None - method_expanded = getattr(expanded[first], fn) - method = getattr(first, fn) - r1 = tensorfn(method_expanded, expanded[second], expanded[third]) - r2 = tensorfn(method, second, third) - self.assertEqual(r1, r2) - - # now for torch. versions of functions - if hasattr(torch, fn): - fntorch = getattr(torch, fn) - expanded = {large: large_expanded, small: small_expanded, small2: small2_expanded} - - def torchfn(t1, t2, t3): - if fn == "lerp": - return fntorch(t1, t2, 0.5) - elif fn == "masked_select": - return fntorch(t1, t2 < 0) - elif fn == "masked_scatter": - return fntorch(t1, t2 < 0.5, full1d) - elif fn == "masked_fill": - return fntorch(t1, t2 < 0.5, 1.0) - elif fn in fns_3_args: - return fntorch(t1, 1.0, t2, t3) - elif fn in fns_value_kwarg: - return fntorch(t1, t2, t3, value=1.0) - else: - return fntorch(t1, t2) - - # test various orders - for first, second, third in [(large, small, small2), (small, large, small2), - (small2, small, large), (small2, large, small)]: - if first is None: - break # ignore last iter when small2 is None - r1 = torchfn(expanded[first], expanded[second], expanded[third]) - r2 = torchfn(first, second, third) - self.assertEqual(r1, r2) - - # now for in place functions - # in-place tensor is not broadcastable; test only guaranteed - # to work by broadcasting other argument(s) - if not hasattr(large_expanded, fn + "_"): - continue + (dims_small, dims_large, dims_full) = self._select_broadcastable_dims() + full1d = torch.randn(*dims_full, device=device).flatten().float() + small = torch.randn(*dims_small, device=device).float() + large = torch.randn(*dims_large, device=device).float() + small_expanded = small.expand(*dims_full) + large_expanded = large.expand(*dims_full) + small2 = None + small2_expanded = None + if fn in fns_3_args or fn in fns_value_kwarg: + # create another smaller tensor + (dims_small2, _, _) = self._select_broadcastable_dims(dims_full) + small2 = torch.randn(*dims_small2, device=device).float() + small2_expanded = small2.expand(*dims_full) + + if small.is_cuda and fn in ['map', 'map2']: + # map and map2 are not implementd on CUDA tensors + return - # need to clone largeExpanded so we can reuse, since functions are in-place - large_expanded_clone = large_expanded.clone() + if hasattr(large_expanded, fn): + # run through tensor versions of functions + # and verify fully expanded inputs give same results + expanded = {large: large_expanded, small: small_expanded, small2: small2_expanded} - def tensorfn_inplace(t0, t1, t2=None): - t0_fn = getattr(t0, fn + "_") + def tensorfn(myfn, t1, t2): if fn == "lerp": - return t0_fn(t1, 0.5) + return myfn(t1, 0.5) + elif fn == "masked_select": + return myfn(t1 < 0) elif fn == "masked_scatter": - return t0_fn(t1 < 0.5, full1d) + return myfn(t1 < 0.5, full1d) elif fn == "masked_fill": - return t0_fn(t1 < 0.5, 1.0) - elif fn == "map": - return t0_fn(t1, lambda x, y: x + y) - elif fn == "map2": - return t0_fn(t1, t2, lambda x, y, z: x + y + z) + return myfn(t1 < 0.5, 1.0) elif fn in fns_3_args: - return t0_fn(1.0, t1, t2) + return myfn(1, t1, t2) elif fn in fns_value_kwarg: - return t0_fn(t1, t2, value=1.0) + return myfn(t1, t2, value=1) else: - return t0_fn(t1) - # in-place pointwise operations don't actually work if the in-place - # tensor is 0-strided (numpy has the same issue) - if (0 not in large_expanded.stride() and 0 not in large_expanded_clone.stride()): - r1 = tensorfn_inplace(large_expanded, small_expanded, small2_expanded) - r2 = tensorfn_inplace(large_expanded_clone, small, small2) + return myfn(t1) + + # test various orders + for first, second, third in [(large, small, small2), (small, large, small2), + (small2, small, large), (small2, large, small)]: + if first is None: + break # ignore last iter when small2 is None + method_expanded = getattr(expanded[first], fn) + method = getattr(first, fn) + r1 = tensorfn(method_expanded, expanded[second], expanded[third]) + r2 = tensorfn(method, second, third) self.assertEqual(r1, r2) - def broadcastable(t0, t1, t2=None): - try: - t1.expand_as(t0) - if t2 is not None: - t2.expand_as(t0) - except RuntimeError: - return False - return True - - def _test_in_place_broadcastable(t0, t1, t2=None): - if not broadcastable(t0, t1, t2): - same_size = t0.numel() == t1.numel() and (t0.numel() == t2.numel() if t2 is not None else True) - if not same_size: - self.assertRaises(RuntimeError, lambda: tensorfn_inplace(t0, t1, t2)) + # now for torch. versions of functions + if hasattr(torch, fn): + fntorch = getattr(torch, fn) + expanded = {large: large_expanded, small: small_expanded, small2: small2_expanded} + + def torchfn(t1, t2, t3): + if fn == "lerp": + return fntorch(t1, t2, 0.5) + elif fn == "masked_select": + return fntorch(t1, t2 < 0) + elif fn == "masked_scatter": + return fntorch(t1, t2 < 0.5, full1d) + elif fn == "masked_fill": + return fntorch(t1, t2 < 0.5, 1.0) + elif fn in fns_3_args: + return fntorch(t1, 1.0, t2, t3) + elif fn in fns_value_kwarg: + return fntorch(t1, t2, t3, value=1.0) else: - tensorfn_inplace(t0, t1, t2) + return fntorch(t1, t2) + + # test various orders + for first, second, third in [(large, small, small2), (small, large, small2), + (small2, small, large), (small2, large, small)]: + if first is None: + break # ignore last iter when small2 is None + r1 = torchfn(expanded[first], expanded[second], expanded[third]) + r2 = torchfn(first, second, third) + self.assertEqual(r1, r2) + + # now for in place functions + # in-place tensor is not broadcastable; test only guaranteed + # to work by broadcasting other argument(s) + if not hasattr(large_expanded, fn + "_"): + return + + # need to clone largeExpanded so we can reuse, since functions are in-place + large_expanded_clone = large_expanded.clone() + + def tensorfn_inplace(t0, t1, t2=None): + t0_fn = getattr(t0, fn + "_") + if fn == "lerp": + return t0_fn(t1, 0.5) + elif fn == "masked_scatter": + return t0_fn(t1 < 0.5, full1d) + elif fn == "masked_fill": + return t0_fn(t1 < 0.5, 1.0) + elif fn == "map": + return t0_fn(t1, lambda x, y: x + y) + elif fn == "map2": + return t0_fn(t1, t2, lambda x, y, z: x + y + z) + elif fn in fns_3_args: + return t0_fn(1.0, t1, t2) + elif fn in fns_value_kwarg: + return t0_fn(t1, t2, value=1.0) + else: + return t0_fn(t1) + # in-place pointwise operations don't actually work if the in-place + # tensor is 0-strided (numpy has the same issue) + if (0 not in large_expanded.stride() and 0 not in large_expanded_clone.stride()): + r1 = tensorfn_inplace(large_expanded, small_expanded, small2_expanded) + r2 = tensorfn_inplace(large_expanded_clone, small, small2) + self.assertEqual(r1, r2) + + def broadcastable(t0, t1, t2=None): + try: + t1.expand_as(t0) + if t2 is not None: + t2.expand_as(t0) + except RuntimeError: + return False + return True - if fn not in fns_3_args and fn not in fns_value_kwarg: - _test_in_place_broadcastable(small, large_expanded) - _test_in_place_broadcastable(small, large) + def _test_in_place_broadcastable(t0, t1, t2=None): + if not broadcastable(t0, t1, t2): + same_size = t0.numel() == t1.numel() and (t0.numel() == t2.numel() if t2 is not None else True) + if not same_size: + self.assertRaises(RuntimeError, lambda: tensorfn_inplace(t0, t1, t2)) else: - _test_in_place_broadcastable(small2, small_expanded, large_expanded) - _test_in_place_broadcastable(small2, small, large) + tensorfn_inplace(t0, t1, t2) + + if fn not in fns_3_args and fn not in fns_value_kwarg: + _test_in_place_broadcastable(small, large_expanded) + _test_in_place_broadcastable(small, large) + else: + _test_in_place_broadcastable(small2, small_expanded, large_expanded) + _test_in_place_broadcastable(small2, small, large) @unittest.skipIf(IS_FBCODE and IS_REMOTE_GPU, "cublas runtime error") @onlyCUDA @@ -2963,7 +2963,7 @@ def test_index_fill(self, device, dtype): index = torch.tensor([0], device=device) x.index_fill_(1, index, 0) self.assertEqual(x, torch.tensor([[0, 2], [0, 5]], dtype=dtype, device=device)) - if not x.is_complex(): + if not x.is_complex() and not device == "meta": with self.assertRaisesRegex(RuntimeError, r"Scalar"): x.index_fill_(1, index, 1 + 1j) # Make sure that the result stays 0-dim while applied to From f8b7f73a45f55273028d5a9c476326e104b2d03f Mon Sep 17 00:00:00 2001 From: Philip Meier Date: Thu, 10 Feb 2022 11:54:53 +0100 Subject: [PATCH 09/10] Update on "only compare attributes for meta tensors" Todo: - [ ] document this behavior - [x] add tests [ghstack-poisoned] --- test/test_type_promotion.py | 9 +++++++-- 1 file changed, 7 insertions(+), 2 deletions(-) diff --git a/test/test_type_promotion.py b/test/test_type_promotion.py index 01e96a3fe1128..f32a89933f088 100644 --- a/test/test_type_promotion.py +++ b/test/test_type_promotion.py @@ -9,7 +9,7 @@ from torch.testing._internal.common_utils import (TestCase, run_tests, load_tests, TEST_NUMPY, torch_to_numpy_dtype_dict) from torch.testing._internal.common_device_type import (instantiate_device_type_tests, onlyNativeDeviceTypes, - dtypes, dtypesIfCUDA, onlyCPU, expectedFailureMeta) + dtypes, dtypesIfCUDA, onlyCPU, expectedFailureMeta, skipMeta) from torch.testing._internal.common_dtype import ( get_all_dtypes, get_all_math_dtypes, get_all_int_dtypes, get_all_fp_dtypes ) @@ -937,7 +937,11 @@ def test_unary_op_out_casting(self, device, dtypes): elif op in real_only_ops and dtypes[0].is_complex: with self.assertRaises(RuntimeError): op(t, out=out) - elif op in float_only_ops and (not dtypes[0].is_floating_point and not dtypes[0].is_complex): + elif ( + op in float_only_ops + and (not dtypes[0].is_floating_point and not dtypes[0].is_complex) + and device != "meta" + ): with self.assertRaises(RuntimeError): op(t, out=out) else: @@ -947,6 +951,7 @@ def test_unary_op_out_casting(self, device, dtypes): # Verifies that the out= argument doesn't affect the computation, that # is, out = op(...) and op(..., out=out) produce the same result. @onlyNativeDeviceTypes + @skipMeta def test_computation_ignores_out(self, device): t = torch.tensor(33000, dtype=torch.float16, device=device) out = torch.empty(0, dtype=torch.float64, device=device) From e467f0f6fdbed92513282a6bcdf96e69f8e8a1c4 Mon Sep 17 00:00:00 2001 From: Philip Meier Date: Thu, 10 Feb 2022 17:20:05 +0100 Subject: [PATCH 10/10] Update on "only compare attributes for meta tensors" Todo: - [ ] document this behavior - [x] add tests [ghstack-poisoned] --- test/test_view_ops.py | 11 ++++++++--- 1 file changed, 8 insertions(+), 3 deletions(-) diff --git a/test/test_view_ops.py b/test/test_view_ops.py index 2678db1d74d51..37d08e39e637c 100644 --- a/test/test_view_ops.py +++ b/test/test_view_ops.py @@ -14,7 +14,7 @@ torch_to_numpy_dtype_dict, ) from torch.testing._internal.common_device_type import \ - (instantiate_device_type_tests, onlyCPU, dtypes, onlyNativeDeviceTypes) + (instantiate_device_type_tests, onlyCPU, dtypes, onlyNativeDeviceTypes, skipMeta) from torch.testing._internal.common_dtype import ( get_all_dtypes, get_all_int_dtypes, get_all_fp_dtypes, get_all_complex_dtypes ) @@ -729,6 +729,7 @@ def test_contiguous_self(self, device): s = t.contiguous() self.assertTrue(s is t) + @skipMeta def test_contiguous_nonview(self, device): t = torch.ones(5, 5, device=device) nv = t.t().contiguous() @@ -754,6 +755,7 @@ def test_reshape_as_view(self, device): v[6] = 0 self.assertEqual(t[1, 1], v[6]) + @skipMeta def test_reshape_nonview(self, device): t = torch.ones(5, 5, device=device) nv = torch.reshape(t.t(), (25,)) @@ -806,7 +808,8 @@ def assert_is_nonview(t, nv): idx_nv = (0,) * nv.ndim self.assertTrue(not nv._is_view()) nv[idx_nv] = 0 - self.assertNotEqual(t[idx_t], nv[idx_nv]) + if device != "meta": + self.assertNotEqual(t[idx_t], nv[idx_nv]) t = torch.ones(2, 3, 2, 3, device=device).transpose(2, 3) nv = t.flatten(1, 3) assert_is_nonview(t, nv) @@ -1027,7 +1030,9 @@ def test_reshape(self, device): self.assertRaises(RuntimeError, lambda: x.reshape(-1, -1)) y = torch.randn(4, 4, 4, device=device)[:, 0, :] - self.assertNotEqual(y.data_ptr(), y.reshape(-1).data_ptr()) + # .data_ptr() on meta tensors is always 0 so they are equal regardless of the reshape + if device != "meta": + self.assertNotEqual(y.data_ptr(), y.reshape(-1).data_ptr()) self.assertEqual(y.contiguous().view(-1), y.reshape(-1)) self.assertEqual(y.reshape(2, 2, 4).data_ptr(), y.data_ptr())