Skip to content

Commit

Permalink
Support type stub generation for staticmethod (#14934)
Browse files Browse the repository at this point in the history
Fixes #13574

This PR fixes the generation of type hints for static methods of
pybind11 classes. The code changes are based on the suggestions in
#13574.

The fix introduces an additional check if the property under inspection
is of type `staticmethod`. If it is, the type information is read from
the staticmethod's `__func__` attribute, instead of the staticmethod
instance itself.

I added a test for C++ classes with static methods bound using pybind11.
Both, an overloaded and a non-overloaded static method are tested.
  • Loading branch information
WeilerMarcel committed Jan 8, 2024
1 parent fbb738a commit 35f402c
Show file tree
Hide file tree
Showing 4 changed files with 66 additions and 6 deletions.
16 changes: 10 additions & 6 deletions mypy/stubgenc.py
Original file line number Diff line number Diff line change
Expand Up @@ -530,12 +530,14 @@ def is_classmethod(self, class_info: ClassInfo, name: str, obj: object) -> bool:
return inspect.ismethod(obj)

def is_staticmethod(self, class_info: ClassInfo | None, name: str, obj: object) -> bool:
if self.is_c_module:
if class_info is None:
return False
elif self.is_c_module:
raw_lookup: Mapping[str, Any] = getattr(class_info.cls, "__dict__") # noqa: B009
raw_value = raw_lookup.get(name, obj)
return isinstance(raw_value, staticmethod)
else:
return class_info is not None and isinstance(
inspect.getattr_static(class_info.cls, name), staticmethod
)
return isinstance(inspect.getattr_static(class_info.cls, name), staticmethod)

@staticmethod
def is_abstract_method(obj: object) -> bool:
Expand Down Expand Up @@ -761,7 +763,7 @@ def generate_class_stub(self, class_name: str, cls: type, output: list[str]) ->
The result lines will be appended to 'output'. If necessary, any
required names will be added to 'imports'.
"""
raw_lookup = getattr(cls, "__dict__") # noqa: B009
raw_lookup: Mapping[str, Any] = getattr(cls, "__dict__") # noqa: B009
items = self.get_members(cls)
if self.resort_members:
items = sorted(items, key=lambda x: method_name_sort_key(x[0]))
Expand Down Expand Up @@ -793,7 +795,9 @@ def generate_class_stub(self, class_name: str, cls: type, output: list[str]) ->
continue
attr = "__init__"
# FIXME: make this nicer
if self.is_classmethod(class_info, attr, value):
if self.is_staticmethod(class_info, attr, value):
class_info.self_var = ""
elif self.is_classmethod(class_info, attr, value):
class_info.self_var = "cls"
else:
class_info.self_var = "self"
Expand Down
15 changes: 15 additions & 0 deletions test-data/pybind11_mypy_demo/src/main.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -118,6 +118,13 @@ const Point Point::y_axis = Point(0, 1);
Point::LengthUnit Point::length_unit = Point::LengthUnit::mm;
Point::AngleUnit Point::angle_unit = Point::AngleUnit::radian;

struct Foo
{
static int some_static_method(int a, int b) { return a * 42 + b; }
static int overloaded_static_method(int value) { return value * 42; }
static double overloaded_static_method(double value) { return value * 42; }
};

} // namespace: basics

void bind_basics(py::module& basics) {
Expand Down Expand Up @@ -166,6 +173,14 @@ void bind_basics(py::module& basics) {
.value("radian", Point::AngleUnit::radian)
.value("degree", Point::AngleUnit::degree);

// Static methods
py::class_<Foo> pyFoo(basics, "Foo");

pyFoo
.def_static("some_static_method", &Foo::some_static_method, R"#(None)#", py::arg("a"), py::arg("b"))
.def_static("overloaded_static_method", py::overload_cast<int>(&Foo::overloaded_static_method), py::arg("value"))
.def_static("overloaded_static_method", py::overload_cast<double>(&Foo::overloaded_static_method), py::arg("value"));

// Module-level attributes
basics.attr("PI") = std::acos(-1);
basics.attr("__version__") = "0.0.1";
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,36 @@ from typing import ClassVar, List, overload
PI: float
__version__: str

class Foo:
def __init__(self, *args, **kwargs) -> None:
"""Initialize self. See help(type(self)) for accurate signature."""
@overload
@staticmethod
def overloaded_static_method(value: int) -> int:
"""overloaded_static_method(*args, **kwargs)
Overloaded function.
1. overloaded_static_method(value: int) -> int
2. overloaded_static_method(value: float) -> float
"""
@overload
@staticmethod
def overloaded_static_method(value: float) -> float:
"""overloaded_static_method(*args, **kwargs)
Overloaded function.
1. overloaded_static_method(value: int) -> int
2. overloaded_static_method(value: float) -> float
"""
@staticmethod
def some_static_method(a: int, b: int) -> int:
"""some_static_method(a: int, b: int) -> int
None
"""

class Point:
class AngleUnit:
__members__: ClassVar[dict] = ... # read-only
Expand Down
11 changes: 11 additions & 0 deletions test-data/pybind11_mypy_demo/stubgen/pybind11_mypy_demo/basics.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,17 @@ from typing import ClassVar, List, overload
PI: float
__version__: str

class Foo:
def __init__(self, *args, **kwargs) -> None: ...
@overload
@staticmethod
def overloaded_static_method(value: int) -> int: ...
@overload
@staticmethod
def overloaded_static_method(value: float) -> float: ...
@staticmethod
def some_static_method(a: int, b: int) -> int: ...

class Point:
class AngleUnit:
__members__: ClassVar[dict] = ... # read-only
Expand Down

0 comments on commit 35f402c

Please sign in to comment.