diff --git a/test/test_meta.py b/test/test_meta.py index 99c2049473c6d..6569775fdadbe 100644 --- a/test/test_meta.py +++ b/test/test_meta.py @@ -288,6 +288,10 @@ def test_tensor_outlives_converter(self): torch.Tensor.__getitem__, } +CHECK_ALL_STRIDES = { + aten.unsqueeze.default +} + CHECK_STRIDES_SKIPS = { aten._conj_physical.default, aten._fft_c2c.default, @@ -319,22 +323,29 @@ def test_tensor_outlives_converter(self): # aten.view.default, # repro with test_dispatch_symbolic_meta_outplace_all_strides_unflatten_cuda_float32 } +class CheckStrides(Enum): + NONE = 0 + SIGNIFICANT = 1 + ALL = 2 + def should_check_strides(func): + if func in CHECK_ALL_STRIDES: + return CheckStrides.ALL if func in CHECK_STRIDES: - return True + return CheckStrides.SIGNIFICANT if func in CHECK_STRIDES_SKIPS: - return False + return CheckStrides.NONE if not isinstance(func, torch._ops.OpOverload): - return False + return CheckStrides.NONE # Prims are expected to model strides correctly if func.namespace == "prims": - return True + return CheckStrides.SIGNIFICANT # Check if it's a view, by testing if any of the returns have # a non-empty alias set if any(r.alias_info.before_set for r in func._schema.returns if r.alias_info): - return True + return CheckStrides.SIGNIFICANT # TODO: check for TensorIterator - return True + return CheckStrides.SIGNIFICANT def assert_ref_meta_equal(test_case, func, meta_rs, rs, msg_callable): flat_meta_rs, _ = tree_flatten(meta_rs) @@ -350,7 +361,10 @@ def test_assert(cond, msg): test_assert(meta_r.dtype == r.dtype, f"but real dtype was {r.dtype}") test_assert(meta_r.shape == r.shape, f"but real shape was {r.shape}") # See https://github.com/pytorch/pytorch/issues/78050 - if should_check_strides(func): + if should_check_strides(func) == CheckStrides.ALL: + same_strides, _ = torch._prims_common.check_all_strides(meta_r, r) + test_assert(same_strides, f"but real stride was {r.stride()}") + elif should_check_strides(func) == CheckStrides.SIGNIFICANT: same_strides, _ = torch._prims_common.check_significant_strides(meta_r, r) test_assert(same_strides, f"but real stride was {r.stride()}") test_assert( diff --git a/torch/_prims/__init__.py b/torch/_prims/__init__.py index e764229b8bb5c..4a0f3438b6efb 100644 --- a/torch/_prims/__init__.py +++ b/torch/_prims/__init__.py @@ -1232,7 +1232,12 @@ def _greater_than_reduce(acc, x): new_strides.append(a.stride()[original_idx]) original_idx = original_idx + 1 else: - new_strides.append(0) + if shape[idx] != 1: + new_strides.append(0) + elif original_idx == a.ndim: + new_strides.append(1) + else: + new_strides.append(a.stride()[original_idx] * a.size()[original_idx]) return a.as_strided(shape, new_strides, a.storage_offset()) diff --git a/torch/_prims_common/__init__.py b/torch/_prims_common/__init__.py index 647b0e66729e2..d4def60261d5d 100644 --- a/torch/_prims_common/__init__.py +++ b/torch/_prims_common/__init__.py @@ -149,8 +149,8 @@ def compare_tensor_meta(a: TensorLikeType, b: TensorLikeType, check_strides=Fals raise RuntimeError(msg) -def check_significant_strides( - a: TensorLikeType, b: TensorLikeType, *, only_cuda=True +def _check_strides_helper( + a: TensorLikeType, b: TensorLikeType, *, only_cuda=True, significant_only=True ) -> Tuple[bool, Optional[int]]: # NOTE: only on CUDA because CPU elementwise strides are incorrect in PyTorch # See https://github.com/pytorch/pytorch/issues/77553 @@ -158,11 +158,22 @@ def check_significant_strides( # and for tensors with more than one element if (not only_cuda or a.device.type == "cuda" or b.device.type == "cuda") and a.numel() > 0: for idx in range(a.ndim): - if a.stride()[idx] != b.stride()[idx] and a.shape[idx] > 1: + check = not significant_only or a.shape[idx] > 1 + if a.stride()[idx] != b.stride()[idx] and check: return False, idx return True, None +def check_significant_strides( + a: TensorLikeType, b: TensorLikeType, *, only_cuda=True +) -> Tuple[bool, Optional[int]]: + return _check_strides_helper(a, b, only_cuda=only_cuda, significant_only=True) + +def check_all_strides( + a: TensorLikeType, b: TensorLikeType, *, only_cuda=True +) -> Tuple[bool, Optional[int]]: + return _check_strides_helper(a, b, only_cuda=only_cuda, significant_only=False) + # This function is equivalent to compute_contiguous() from TensorImpl.cpp def is_contiguous(a: TensorLikeType) -> bool: