From 738284c2304fa613885823c16fb989afedc9bf6a Mon Sep 17 00:00:00 2001 From: Randolf Scholz Date: Mon, 29 Apr 2024 23:25:23 +0000 Subject: [PATCH] Fix: `nn.Parameter` return type identified as `Tensor` instead of `nn.Parameter` (#125106) Fixes #125105 Pull Request resolved: https://github.com/pytorch/pytorch/pull/125106 Approved by: https://github.com/ezyang, https://github.com/albanD --- test/typing/pass/creation_ops.py | 8 ++++++++ tools/pyi/gen_pyi.py | 10 +++++----- torch/_C/__init__.pyi.in | 2 +- 3 files changed, 14 insertions(+), 6 deletions(-) diff --git a/test/typing/pass/creation_ops.py b/test/typing/pass/creation_ops.py index c524d56f1971..f866d3a1628f 100644 --- a/test/typing/pass/creation_ops.py +++ b/test/typing/pass/creation_ops.py @@ -2,6 +2,10 @@ # flake8: noqa import torch from torch.testing._internal.common_utils import TEST_NUMPY + +from typing_extensions import assert_type + + if TEST_NUMPY: import numpy as np @@ -117,3 +121,7 @@ inp = torch.tensor([-1.5, 0, 2.0]) values = torch.tensor([0.5]) torch.heaviside(inp, values) + +# Parameter +p = torch.nn.Parameter(torch.empty(1)) +assert_type(p, torch.nn.Parameter) diff --git a/tools/pyi/gen_pyi.py b/tools/pyi/gen_pyi.py index f0b9044c6fe9..59498f41f3ef 100644 --- a/tools/pyi/gen_pyi.py +++ b/tools/pyi/gen_pyi.py @@ -1064,14 +1064,14 @@ def replace_special_case(hint: str) -> str: "new_tensor": [ f"def new_tensor(self, data: Any, {FACTORY_PARAMS}) -> Tensor: ..." ], - "__new__": ["def __new__(self, *args, **kwargs) -> Tensor: ..."], + "__new__": ["def __new__(cls, *args, **kwargs) -> Self: ..."], # new and __init__ have the same signatures differ only in return type # Adapted from legacy_tensor_ctor and legacy_tensor_new "new": [ - f"def new(self, *args: Any, {DEVICE_PARAM}) -> Tensor: ...", - "def new(self, storage: Storage) -> Tensor: ...", - "def new(self, other: Tensor) -> Tensor: ...", - f"def new(self, size: _size, *, {DEVICE_PARAM}) -> Tensor: ...", + f"def new(cls, *args: Any, {DEVICE_PARAM}) -> Self: ...", + "def new(cls, storage: Storage) -> Self: ...", + "def new(cls, other: Tensor) -> Self: ...", + f"def new(cls, size: _size, *, {DEVICE_PARAM}) -> Self: ...", ], "__init__": [ f"def __init__(self, *args: Any, {DEVICE_PARAM}) -> None: ...", diff --git a/torch/_C/__init__.pyi.in b/torch/_C/__init__.pyi.in index 34e49e15d850..5e20dd31cd85 100644 --- a/torch/_C/__init__.pyi.in +++ b/torch/_C/__init__.pyi.in @@ -29,7 +29,7 @@ from typing import ( overload, runtime_checkable, ) -from typing_extensions import ParamSpec +from typing_extensions import ParamSpec, Self import numpy