From 33600d49495a1653d0e7cc31388dceda75042fdb Mon Sep 17 00:00:00 2001 From: Soof Golan <83900570+soof-golan@users.noreply.github.com> Date: Sun, 28 Aug 2022 11:55:31 +0300 Subject: [PATCH 1/3] Fix `tensor.stride()` type hint `tensor.stride()` now hints at tuple of variable length instead of tuple with constant length of 1 --- tools/pyi/gen_pyi.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tools/pyi/gen_pyi.py b/tools/pyi/gen_pyi.py index 79f97c4e9f30..417d73f829a6 100644 --- a/tools/pyi/gen_pyi.py +++ b/tools/pyi/gen_pyi.py @@ -597,7 +597,7 @@ def gen_pyi( "def size(self, dim: _int) -> _int: ...", ], "stride": [ - "def stride(self) -> Tuple[_int]: ...", + "def stride(self) -> Tuple[_int, ...]: ...", "def stride(self, _int) -> _int: ...", ], "new_ones": [ From d618aa32cff079dee960ebcd415e4785d95550a4 Mon Sep 17 00:00:00 2001 From: Soof Golan <83900570+soof-golan@users.noreply.github.com> Date: Mon, 29 Aug 2022 23:09:36 +0300 Subject: [PATCH 2/3] Fix stride type hints in shape_props.py --- torch/fx/passes/shape_prop.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/torch/fx/passes/shape_prop.py b/torch/fx/passes/shape_prop.py index 9c3a036e90bf..2be996f714ce 100644 --- a/torch/fx/passes/shape_prop.py +++ b/torch/fx/passes/shape_prop.py @@ -17,7 +17,7 @@ class TensorMetadata(NamedTuple): shape : torch.Size dtype : torch.dtype requires_grad : bool - stride : Tuple[int] + stride : Tuple[int, ...] memory_format : Optional[torch.memory_format] # Quantization metadata From 3d7bd8d307392683521f424301c6990ac2ea18f5 Mon Sep 17 00:00:00 2001 From: Soof Golan <83900570+soof-golan@users.noreply.github.com> Date: Tue, 30 Aug 2022 08:49:02 +0300 Subject: [PATCH 3/3] Ignore assignment type --- torch/_prims/__init__.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/torch/_prims/__init__.py b/torch/_prims/__init__.py index d724ac50e283..eae38612a223 100644 --- a/torch/_prims/__init__.py +++ b/torch/_prims/__init__.py @@ -1273,7 +1273,7 @@ def _collapse_view_helper( strides = (1,) else: shape = a.shape # type: ignore[assignment] - strides = a.stride() + strides = a.stride() # type: ignore[assignment] utils.validate_idx(len(shape), start) utils.validate_exclusive_idx(len(shape), end)