Skip to content

Commit

Permalink
Fix stubgen regressions
Browse files Browse the repository at this point in the history
Fix pybind11 stubgen tests so that they report failures
  • Loading branch information
chadrik committed Nov 17, 2023
1 parent 8c8aa10 commit d54189c
Show file tree
Hide file tree
Showing 9 changed files with 88 additions and 35 deletions.
2 changes: 1 addition & 1 deletion misc/test-stubgenc.sh
Expand Up @@ -24,7 +24,7 @@ function stubgenc_test() {
# Compare generated stubs to expected ones
if ! git diff --exit-code "$STUBGEN_OUTPUT_FOLDER";
then
EXIT=$?
EXIT=1
fi
}

Expand Down
3 changes: 2 additions & 1 deletion mypy/stubdoc.py
Expand Up @@ -374,7 +374,8 @@ def infer_ret_type_sig_from_docstring(docstr: str, name: str) -> str | None:

def infer_ret_type_sig_from_anon_docstring(docstr: str) -> str | None:
"""Convert signature in form of "(self: TestClass, arg0) -> int" to their return type."""
return infer_ret_type_sig_from_docstring("stub" + docstr.strip(), "stub")
lines = ["stub" + line.strip() for line in docstr.splitlines() if line.strip().startswith("(")]
return infer_ret_type_sig_from_docstring("".join(lines), "stub")


def parse_signature(sig: str) -> tuple[str, list[str], list[str]] | None:
Expand Down
1 change: 1 addition & 0 deletions mypy/stubgen.py
Expand Up @@ -1629,6 +1629,7 @@ def generate_stubs(options: Options) -> None:
doc_dir=options.doc_dir,
include_private=options.include_private,
export_less=options.export_less,
include_docstrings=options.include_docstrings,
)
num_modules = len(all_modules)
if not options.quiet and num_modules > 0:
Expand Down
52 changes: 47 additions & 5 deletions mypy/stubgenc.py
Expand Up @@ -126,10 +126,12 @@ def get_property_type(self, default_type: str | None, ctx: FunctionContext) -> s
"""Infer property type from docstring or docstring signature."""
if ctx.docstring is not None:
inferred = infer_ret_type_sig_from_anon_docstring(ctx.docstring)
if not inferred:
inferred = infer_ret_type_sig_from_docstring(ctx.docstring, ctx.name)
if not inferred:
inferred = infer_prop_type_from_docstring(ctx.docstring)
if inferred:
return inferred
inferred = infer_ret_type_sig_from_docstring(ctx.docstring, ctx.name)
if inferred:
return inferred
inferred = infer_prop_type_from_docstring(ctx.docstring)
return inferred
else:
return None
Expand Down Expand Up @@ -237,6 +239,26 @@ def __init__(
self.resort_members = self.is_c_module
super().__init__(_all_, include_private, export_less, include_docstrings)
self.module_name = module_name
if self.is_c_module:
# Add additional implicit imports.
# C-extensions are given more lattitude since they do not import the typing module.
self.known_imports.update(
{
"typing": [
"Any",
"Callable",
"ClassVar",
"Dict",
"Iterable",
"Iterator",
"List",
"NamedTuple",
"Optional",
"Tuple",
"Union",
]
}
)

def get_default_function_sig(self, func: object, ctx: FunctionContext) -> FunctionSig:
argspec = None
Expand Down Expand Up @@ -590,9 +612,29 @@ def generate_function_stub(
if inferred[0].args and inferred[0].args[0].name == "cls":
decorators.append("@classmethod")

if docstring:
docstring = self._indent_docstring(docstring)
output.extend(self.format_func_def(inferred, decorators=decorators, docstring=docstring))
self._fix_iter(ctx, inferred, output)

def _indent_docstring(self, docstring: str) -> str:
"""Fix indentation of docstring extracted from pybind11 or other binding generators."""
lines = docstring.splitlines(keepends=True)
indent = self._indent + " "
if len(lines) > 1:
if not all(line.startswith(indent) or not line.strip() for line in lines):
# if the docstring is not indented, then indent all but the first line
for i, line in enumerate(lines[1:]):
if line.strip():
lines[i + 1] = indent + line
# if there's a trailing newline, add a final line to visually indent the quoted docstring
if lines[-1].endswith("\n"):
if len(lines) > 1:
lines.append(indent)
else:
lines[-1] = lines[-1][:-1]
return "".join(lines)

def _fix_iter(
self, ctx: FunctionContext, inferred: list[FunctionSig], output: list[str]
) -> None:
Expand Down Expand Up @@ -640,7 +682,7 @@ def generate_property_stub(
if fget:
alt_docstr = getattr(fget, "__doc__", None)
if alt_docstr and docstring:
docstring += alt_docstr
docstring += "\n" + alt_docstr
elif alt_docstr:
docstring = alt_docstr

Expand Down
18 changes: 9 additions & 9 deletions mypy/stubutil.py
Expand Up @@ -576,6 +576,14 @@ def __init__(
self.sig_generators = self.get_sig_generators()
# populated by visit_mypy_file
self.module_name: str = ""
# These are "soft" imports for objects which might appear in annotations but not have
# a corresponding import statement.
self.known_imports = {
"_typeshed": ["Incomplete"],
"typing": ["Any", "TypeVar", "NamedTuple"],
"collections.abc": ["Generator"],
"typing_extensions": ["TypedDict", "ParamSpec", "TypeVarTuple"],
}

def get_sig_generators(self) -> list[SignatureGenerator]:
return []
Expand Down Expand Up @@ -667,15 +675,7 @@ def set_defined_names(self, defined_names: set[str]) -> None:
for name in self._all_ or ():
self.import_tracker.reexport(name)

# These are "soft" imports for objects which might appear in annotations but not have
# a corresponding import statement.
known_imports = {
"_typeshed": ["Incomplete"],
"typing": ["Any", "TypeVar", "NamedTuple"],
"collections.abc": ["Generator"],
"typing_extensions": ["TypedDict", "ParamSpec", "TypeVarTuple"],
}
for pkg, imports in known_imports.items():
for pkg, imports in self.known_imports.items():
for t in imports:
# require=False means that the import won't be added unless require_name() is called
# for the object during generation.
Expand Down
11 changes: 9 additions & 2 deletions test-data/pybind11_mypy_demo/src/main.cpp
Expand Up @@ -44,6 +44,7 @@

#include <cmath>
#include <pybind11/pybind11.h>
#include <pybind11/stl.h>

namespace py = pybind11;

Expand Down Expand Up @@ -102,6 +103,11 @@ struct Point {
return distance_to(other.x, other.y);
}

std::vector<double> as_vector()
{
return std::vector<double>{x, y};
}

double x, y;
};

Expand Down Expand Up @@ -134,14 +140,15 @@ void bind_basics(py::module& basics) {
.def(py::init<double, double>(), py::arg("x"), py::arg("y"))
.def("distance_to", py::overload_cast<double, double>(&Point::distance_to, py::const_), py::arg("x"), py::arg("y"))
.def("distance_to", py::overload_cast<const Point&>(&Point::distance_to, py::const_), py::arg("other"))
.def_readwrite("x", &Point::x)
.def("as_list", &Point::as_vector)
.def_readwrite("x", &Point::x, "some docstring")
.def_property("y",
[](Point& self){ return self.y; },
[](Point& self, double value){ self.y = value; }
)
.def_property_readonly("length", &Point::length)
.def_property_readonly_static("x_axis", [](py::object cls){return Point::x_axis;})
.def_property_readonly_static("y_axis", [](py::object cls){return Point::y_axis;})
.def_property_readonly_static("y_axis", [](py::object cls){return Point::y_axis;}, "another docstring")
.def_readwrite_static("length_unit", &Point::length_unit)
.def_property_static("angle_unit",
[](py::object& /*cls*/){ return Point::angle_unit; },
Expand Down
@@ -0,0 +1 @@
from . import basics as basics
@@ -1,7 +1,7 @@
from typing import ClassVar
from typing import ClassVar, List, overload

from typing import overload
PI: float
__version__: str

class Point:
class AngleUnit:
Expand All @@ -13,8 +13,6 @@ class Point:
"""__init__(self: pybind11_mypy_demo.basics.Point.AngleUnit, value: int) -> None"""
def __eq__(self, other: object) -> bool:
"""__eq__(self: object, other: object) -> bool"""
def __getstate__(self) -> int:
"""__getstate__(self: object) -> int"""
def __hash__(self) -> int:
"""__hash__(self: object) -> int"""
def __index__(self) -> int:
Expand All @@ -23,8 +21,6 @@ class Point:
"""__int__(self: pybind11_mypy_demo.basics.Point.AngleUnit) -> int"""
def __ne__(self, other: object) -> bool:
"""__ne__(self: object, other: object) -> bool"""
def __setstate__(self, state: int) -> None:
"""__setstate__(self: pybind11_mypy_demo.basics.Point.AngleUnit, state: int) -> None"""
@property
def name(self) -> str: ...
@property
Expand All @@ -40,8 +36,6 @@ class Point:
"""__init__(self: pybind11_mypy_demo.basics.Point.LengthUnit, value: int) -> None"""
def __eq__(self, other: object) -> bool:
"""__eq__(self: object, other: object) -> bool"""
def __getstate__(self) -> int:
"""__getstate__(self: object) -> int"""
def __hash__(self) -> int:
"""__hash__(self: object) -> int"""
def __index__(self) -> int:
Expand All @@ -50,8 +44,6 @@ class Point:
"""__int__(self: pybind11_mypy_demo.basics.Point.LengthUnit) -> int"""
def __ne__(self, other: object) -> bool:
"""__ne__(self: object, other: object) -> bool"""
def __setstate__(self, state: int) -> None:
"""__setstate__(self: pybind11_mypy_demo.basics.Point.LengthUnit, state: int) -> None"""
@property
def name(self) -> str: ...
@property
Expand All @@ -70,43 +62,51 @@ class Point:
1. __init__(self: pybind11_mypy_demo.basics.Point) -> None
2. __init__(self: pybind11_mypy_demo.basics.Point, x: float, y: float) -> None"""
2. __init__(self: pybind11_mypy_demo.basics.Point, x: float, y: float) -> None
"""
@overload
def __init__(self, x: float, y: float) -> None:
"""__init__(*args, **kwargs)
Overloaded function.
1. __init__(self: pybind11_mypy_demo.basics.Point) -> None
2. __init__(self: pybind11_mypy_demo.basics.Point, x: float, y: float) -> None"""
2. __init__(self: pybind11_mypy_demo.basics.Point, x: float, y: float) -> None
"""
def as_list(self) -> List[float]:
"""as_list(self: pybind11_mypy_demo.basics.Point) -> List[float]"""
@overload
def distance_to(self, x: float, y: float) -> float:
"""distance_to(*args, **kwargs)
Overloaded function.
1. distance_to(self: pybind11_mypy_demo.basics.Point, x: float, y: float) -> float
2. distance_to(self: pybind11_mypy_demo.basics.Point, other: pybind11_mypy_demo.basics.Point) -> float"""
2. distance_to(self: pybind11_mypy_demo.basics.Point, other: pybind11_mypy_demo.basics.Point) -> float
"""
@overload
def distance_to(self, other: Point) -> float:
"""distance_to(*args, **kwargs)
Overloaded function.
1. distance_to(self: pybind11_mypy_demo.basics.Point, x: float, y: float) -> float
2. distance_to(self: pybind11_mypy_demo.basics.Point, other: pybind11_mypy_demo.basics.Point) -> float"""
2. distance_to(self: pybind11_mypy_demo.basics.Point, other: pybind11_mypy_demo.basics.Point) -> float
"""
@property
def length(self) -> float: ...

def answer() -> int:
'''answer() -> int
answer docstring, with end quote"'''
answer docstring, with end quote"
'''
def midpoint(left: float, right: float) -> float:
"""midpoint(left: float, right: float) -> float"""
def sum(arg0: int, arg1: int) -> int:
'''sum(arg0: int, arg1: int) -> int
multiline docstring test, edge case quotes """\'\'\''''
multiline docstring test, edge case quotes """\'\'\'
'''
def weighted_midpoint(left: float, right: float, alpha: float = ...) -> float:
"""weighted_midpoint(left: float, right: float, alpha: float = 0.5) -> float"""
@@ -1,4 +1,4 @@
from typing import ClassVar, overload
from typing import ClassVar, List, overload

PI: float
__version__: str
Expand Down Expand Up @@ -47,6 +47,7 @@ class Point:
def __init__(self) -> None: ...
@overload
def __init__(self, x: float, y: float) -> None: ...
def as_list(self) -> List[float]: ...
@overload
def distance_to(self, x: float, y: float) -> float: ...
@overload
Expand Down

0 comments on commit d54189c

Please sign in to comment.