diff --git a/test/typing/pass/creation_ops.py b/test/typing/pass/creation_ops.py index c524d56f1971..f866d3a1628f 100644 --- a/test/typing/pass/creation_ops.py +++ b/test/typing/pass/creation_ops.py @@ -2,6 +2,10 @@ # flake8: noqa import torch from torch.testing._internal.common_utils import TEST_NUMPY + +from typing_extensions import assert_type + + if TEST_NUMPY: import numpy as np @@ -117,3 +121,7 @@ inp = torch.tensor([-1.5, 0, 2.0]) values = torch.tensor([0.5]) torch.heaviside(inp, values) + +# Parameter +p = torch.nn.Parameter(torch.empty(1)) +assert_type(p, torch.nn.Parameter) diff --git a/tools/pyi/gen_pyi.py b/tools/pyi/gen_pyi.py index f0b9044c6fe9..59498f41f3ef 100644 --- a/tools/pyi/gen_pyi.py +++ b/tools/pyi/gen_pyi.py @@ -1064,14 +1064,14 @@ def replace_special_case(hint: str) -> str: "new_tensor": [ f"def new_tensor(self, data: Any, {FACTORY_PARAMS}) -> Tensor: ..." ], - "__new__": ["def __new__(self, *args, **kwargs) -> Tensor: ..."], + "__new__": ["def __new__(cls, *args, **kwargs) -> Self: ..."], # new and __init__ have the same signatures differ only in return type # Adapted from legacy_tensor_ctor and legacy_tensor_new "new": [ - f"def new(self, *args: Any, {DEVICE_PARAM}) -> Tensor: ...", - "def new(self, storage: Storage) -> Tensor: ...", - "def new(self, other: Tensor) -> Tensor: ...", - f"def new(self, size: _size, *, {DEVICE_PARAM}) -> Tensor: ...", + f"def new(cls, *args: Any, {DEVICE_PARAM}) -> Self: ...", + "def new(cls, storage: Storage) -> Self: ...", + "def new(cls, other: Tensor) -> Self: ...", + f"def new(cls, size: _size, *, {DEVICE_PARAM}) -> Self: ...", ], "__init__": [ f"def __init__(self, *args: Any, {DEVICE_PARAM}) -> None: ...", diff --git a/torch/_C/__init__.pyi.in b/torch/_C/__init__.pyi.in index 34e49e15d850..5e20dd31cd85 100644 --- a/torch/_C/__init__.pyi.in +++ b/torch/_C/__init__.pyi.in @@ -29,7 +29,7 @@ from typing import ( overload, runtime_checkable, ) -from typing_extensions import ParamSpec +from typing_extensions import ParamSpec, Self import numpy