Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: ✨ automatically replace invalid enum expressions with corresponding valid expression & import #196

Merged
merged 10 commits into from
Nov 25, 2023
22 changes: 21 additions & 1 deletion pybind11_stubgen/parser/errors.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from __future__ import annotations

from pybind11_stubgen.structs import Identifier, QualifiedName
from pybind11_stubgen.structs import Identifier, Import, QualifiedName, Value


class ParserError(Exception):
Expand Down Expand Up @@ -33,3 +33,23 @@ def __init__(self, name: QualifiedName):

def __str__(self):
return f"Can't find/import '{self.name}'"


class AmbiguousEnumError(InvalidExpressionError):
def __init__(self, repr_: str, *values_and_imports: tuple[Value, Import]):
super().__init__(repr_)
self.values_and_imports = values_and_imports

if len(self.values_and_imports) < 2:
raise ValueError(
"Expected at least 2 values_and_imports, got "
f"{len(self.values_and_imports)}"
)

def __str__(self) -> str:
origins = sorted(import_.origin for _, import_ in self.values_and_imports)
return (
f"Enum member '{self.expression}' could not be resolved; multiple "
"matching definitions found in: "
+ ", ".join(f"'{origin}'" for origin in origins)
)
4 changes: 4 additions & 0 deletions pybind11_stubgen/parser/mixins/error_handlers.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,10 @@ def handle_method(self, path: QualifiedName, class_: type) -> list[Method]:
with self.__new_layer(path):
return super().handle_method(path, class_)

def finalize(self) -> None:
with self.__new_layer(QualifiedName.from_str("finalize")):
return super().finalize()

@property
def current_path(self) -> QualifiedName:
assert len(self.stack) != 0
Expand Down
142 changes: 137 additions & 5 deletions pybind11_stubgen/parser/mixins/fix.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,10 +6,12 @@
import re
import sys
import types
from collections import defaultdict
from logging import getLogger
from typing import Any, Sequence

from pybind11_stubgen.parser.errors import (
AmbiguousEnumError,
InvalidExpressionError,
NameResolutionError,
ParserError,
Expand Down Expand Up @@ -999,13 +1001,56 @@ def parse_value_str(self, value: str) -> Value | InvalidExpression:


class RewritePybind11EnumValueRepr(IParser):
"""Reformat pybind11-generated invalid enum value reprs.

For example, pybind11 may generate a `__doc__` like this:
>>> "set_color(self, color: <ConsoleForegroundColor.Blue: 34>) -> None:\n"

Which is invalid python syntax. This parser will rewrite the generated stub to:
>>> from demo._bindings.enum import ConsoleForegroundColor
>>> def set_color(self, color: ConsoleForegroundColor.Blue) -> None:
>>> ...

Since `pybind11_stubgen` encounters the values corresponding to these reprs as it
parses the modules, it can automatically replace these invalid expressions with the
corresponding `Value` and `Import` as it encounters them. There are 3 cases for an
`Argument` whose `default` is an enum `InvalidExpression`:

1. The `InvalidExpression` repr corresponds to exactly one enum field definition.
The `InvalidExpression` is simply replaced by the corresponding `Value`.
2. The `InvalidExpression` repr corresponds to multiple enum field definitions. An
`AmbiguousEnumError` is reported.
3. The `InvalidExpression` repr corresponds to no enum field definitions. An
`InvalidExpressionError` is reported.

Attributes:
_pybind11_enum_pattern: Pattern matching pybind11 enum field reprs.
_unknown_enum_classes: Set of the str names of enum classes whose reprs were not
seen.
_invalid_default_arguments: Per module invalid arguments. Used to know which
enum imports to add to the current module.
_repr_to_value_and_import: Saved safe print values of enum field reprs and the
import to add to a module when when that repr is seen.
_repr_to_invalid_default_arguments: Groups of arguments whose default values are
`InvalidExpression`s. This is only used until the first time each repr is
seen. Left over groups will raise an error, which may be fixed using
`--enum-class-locations` or suppressed using `--ignore-invalid-expressions`.
_invalid_default_argument_to_module: Maps individual invalid default arguments
to the module containing them. Used to know which enum imports to add to
which module.
"""

_pybind11_enum_pattern = re.compile(r"<(?P<enum>\w+(\.\w+)+): (?P<value>-?\d+)>")
# _pybind11_enum_pattern = re.compile(r"<(?P<enum>\w+(\.\w+)+): (?P<value>\d+)>")
_unknown_enum_classes: set[str] = set()
_invalid_default_arguments: list[Argument] = []
_repr_to_value_and_import: dict[str, set[tuple[Value, Import]]] = defaultdict(set)
_repr_to_invalid_default_arguments: dict[str, set[Argument]] = defaultdict(set)
_invalid_default_argument_to_module: dict[Argument, Module] = {}

def __init__(self):
super().__init__()
self._pybind11_enum_locations: dict[re.Pattern, str] = {}
self._is_finalizing = False

def set_pybind11_enum_locations(self, locations: dict[re.Pattern, str]):
self._pybind11_enum_locations = locations
Expand All @@ -1024,17 +1069,104 @@ def parse_value_str(self, value: str) -> Value | InvalidExpression:
return Value(repr=f"{enum_class.name}.{entry}", is_print_safe=True)
return super().parse_value_str(value)

def handle_module(
self, path: QualifiedName, module: types.ModuleType
) -> Module | None:
# we may be handling a module within a module, so save the parent's invalid
# arguments on the stack as we handle this module
parent_module_invalid_arguments = self._invalid_default_arguments
self._invalid_default_arguments = []
result = super().handle_module(path, module)

if result is None:
self._invalid_default_arguments = parent_module_invalid_arguments
return None

# register each argument to the current module
while self._invalid_default_arguments:
arg = self._invalid_default_arguments.pop()
assert isinstance(arg.default, InvalidExpression)
repr_ = arg.default.text
self._repr_to_invalid_default_arguments[repr_].add(arg)
self._invalid_default_argument_to_module[arg] = result

self._invalid_default_arguments = parent_module_invalid_arguments
return result

def handle_function(self, path: QualifiedName, func: Any) -> list[Function]:
result = super().handle_function(path, func)

for f in result:
for arg in f.args:
if isinstance(arg.default, InvalidExpression):
# this argument will be registered to the current module
self._invalid_default_arguments.append(arg)

return result

def handle_attribute(self, path: QualifiedName, attr: Any) -> Attribute | None:
module = inspect.getmodule(attr)
repr_ = repr(attr)

if module is not None:
module_path = QualifiedName.from_str(module.__name__)
is_source_module = path[: len(module_path)] == module_path
is_alias = ( # could be an `.export_values()` alias, which we want to avoid
is_source_module
and not inspect.isclass(getattr(module, path[len(module_path)]))
)

if not is_alias and is_source_module:
# register one of the possible sources of this repr
self._repr_to_value_and_import[repr_].add(
(
Value(repr=".".join(path), is_print_safe=True),
Import(name=None, origin=module_path),
)
)

return super().handle_attribute(path, attr)

def report_error(self, error: ParserError) -> None:
if isinstance(error, InvalidExpressionError):
# defer reporting invalid enum expressions until finalization
if not self._is_finalizing and isinstance(error, InvalidExpressionError):
match = self._pybind11_enum_pattern.match(error.expression)
if match is not None:
return
super().report_error(error)

def finalize(self) -> None:
self._is_finalizing = True
for repr_, args in self._repr_to_invalid_default_arguments.items():
match = self._pybind11_enum_pattern.match(repr_)
if match is None:
pass
elif repr_ not in self._repr_to_value_and_import:
enum_qual_name = match.group("enum")
enum_class_str, entry = enum_qual_name.rsplit(".", maxsplit=1)
enum_class_str, _ = enum_qual_name.rsplit(".", maxsplit=1)
self._unknown_enum_classes.add(enum_class_str)
super().report_error(error)
self.report_error(InvalidExpressionError(repr_))
elif len(self._repr_to_value_and_import[repr_]) > 1:
self.report_error(
AmbiguousEnumError(repr_, *self._repr_to_value_and_import[repr_])
)
else:
# fix the invalid enum expressions
value, import_ = self._repr_to_value_and_import[repr_].pop()
for arg in args:
module = self._invalid_default_argument_to_module[arg]
if module.origin == import_.origin:
arg.default = Value(
repr=value.repr[len(str(module.origin)) + 1 :],
is_print_safe=True,
)
else:
arg.default = value
module.imports.add(import_)

def finalize(self):
if self._unknown_enum_classes:
# TODO: does this case still exist in practice? How would pybind11 display
# a repr for an enum field whose definition we did not see while parsing?
logger.warning(
"Enum-like str representations were found with no "
"matching mapping to the enum class location.\n"
Expand Down
2 changes: 1 addition & 1 deletion pybind11_stubgen/parser/mixins/parse.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,7 @@ def handle_class_member(
def handle_module(
self, path: QualifiedName, module: types.ModuleType
) -> Module | None:
result = Module(name=path[-1])
result = Module(name=path[-1], origin=QualifiedName.from_str(module.__name__))
for name, member in inspect.getmembers(module):
obj = self.handle_module_member(
QualifiedName([*path, Identifier(name)]), module, member
Expand Down
5 changes: 3 additions & 2 deletions pybind11_stubgen/structs.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ def parent(self) -> QualifiedName:
return QualifiedName(self[:-1])


@dataclass
@dataclass(eq=False)
class Value:
repr: str
is_print_safe: bool = False # `self.repr` is valid python and safe to print as is
Expand Down Expand Up @@ -110,7 +110,7 @@ class Attribute:
annotation: Annotation | None = field_(default=None)


@dataclass
@dataclass(eq=False)
class Argument:
name: Identifier | None
pos_only: bool = field_(default=False)
Expand Down Expand Up @@ -191,6 +191,7 @@ class Import:
@dataclass
class Module:
name: Identifier
origin: QualifiedName
doc: Docstring | None = field_(default=None)
classes: list[Class] = field_(default_factory=list)
functions: list[Function] = field_(default_factory=list)
Expand Down
3 changes: 1 addition & 2 deletions tests/check-demo-stubs-generation.sh
Original file line number Diff line number Diff line change
Expand Up @@ -43,8 +43,7 @@ run_stubgen() {
demo \
--output-dir=${STUBS_DIR} \
${NUMPY_FORMAT} \
--ignore-invalid-expressions="\(anonymous namespace\)::(Enum|Unbound)|<demo\._bindings\.flawed_bindings\..*" \
--enum-class-locations="ConsoleForegroundColor:demo._bindings.enum" \
--ignore-invalid-expressions="\(anonymous namespace\)::(Enum|Unbound)|<demo\._bindings\.flawed_bindings\..*|<ConsoleForegroundColor\\.Magenta: 35>" \
--print-safe-value-reprs="Foo\(\d+\)" \
--exit-code
}
Expand Down
8 changes: 8 additions & 0 deletions tests/demo-lib/include/demo/sublibA/ConsoleColors.h
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,14 @@ enum class ConsoleForegroundColor {
None_ = -1
};

enum class ConsoleForegroundColorDuplicate {
Green = 32,
Yellow = 33,
Blue = 34,
Magenta = 35,
None_ = -1
};

enum ConsoleBackgroundColor {
Green = 42,
Yellow = 43,
Expand Down
6 changes: 1 addition & 5 deletions tests/demo.errors.stderr.txt
Original file line number Diff line number Diff line change
@@ -1,19 +1,15 @@
pybind11_stubgen - [ ERROR] In demo._bindings.aliases.foreign_enum_default : Invalid expression '<ConsoleForegroundColor.Blue: 34>'
pybind11_stubgen - [ ERROR] In demo._bindings.eigen.dense_matrix_c : Can't find/import 'm'
pybind11_stubgen - [ ERROR] In demo._bindings.eigen.dense_matrix_c : Can't find/import 'n'
pybind11_stubgen - [ ERROR] In demo._bindings.eigen.dense_matrix_r : Can't find/import 'm'
pybind11_stubgen - [ ERROR] In demo._bindings.eigen.dense_matrix_r : Can't find/import 'n'
pybind11_stubgen - [ ERROR] In demo._bindings.eigen.four_col_matrix_r : Can't find/import 'm'
pybind11_stubgen - [ ERROR] In demo._bindings.eigen.four_row_matrix_r : Can't find/import 'n'
pybind11_stubgen - [ ERROR] In demo._bindings.enum.accept_defaulted_enum : Invalid expression '<ConsoleForegroundColor.None_: -1>'
pybind11_stubgen - [ ERROR] In demo._bindings.flawed_bindings.accept_unbound_enum : Invalid expression '(anonymous namespace)::Enum'
pybind11_stubgen - [ ERROR] In demo._bindings.flawed_bindings.accept_unbound_enum_defaulted : Invalid expression '<demo._bindings.flawed_bindings.Enum object at 0x1234abcd5678>'
pybind11_stubgen - [ ERROR] In demo._bindings.flawed_bindings.accept_unbound_type : Invalid expression '(anonymous namespace)::Unbound'
pybind11_stubgen - [ ERROR] In demo._bindings.flawed_bindings.accept_unbound_type_defaulted : Invalid expression '<demo._bindings.flawed_bindings.Unbound object at 0x1234abcd5678>'
pybind11_stubgen - [ ERROR] In demo._bindings.flawed_bindings.get_unbound_type : Invalid expression '(anonymous namespace)::Unbound'
pybind11_stubgen - [WARNING] Enum-like str representations were found with no matching mapping to the enum class location.
Use `--enum-class-locations` to specify full path to the following enum(s):
- ConsoleForegroundColor
pybind11_stubgen - [ ERROR] In finalize : Enum member '<ConsoleForegroundColor.Magenta: 35>' could not be resolved; multiple matching definitions found in: 'demo._bindings.duplicate_enum', 'demo._bindings.enum'
pybind11_stubgen - [WARNING] Raw C++ types/values were found in signatures extracted from docstrings.
Please check the corresponding sections of pybind11 documentation to avoid common mistakes in binding code:
- https://pybind11.readthedocs.io/en/latest/advanced/misc.html#avoiding-cpp-types-in-docstrings
Expand Down
1 change: 1 addition & 0 deletions tests/py-demo/bindings/src/main.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ PYBIND11_MODULE(_bindings, m) {
bind_classes_module(m.def_submodule("classes"));
bind_eigen_module(m.def_submodule("eigen"));
bind_enum_module(m.def_submodule("enum"));
bind_duplicate_enum_module(m.def_submodule("duplicate_enum"));
bind_aliases_module(m.def_submodule("aliases"));
bind_flawed_bindings_module(m.def_submodule("flawed_bindings"));
bind_functions_module(m.def_submodule("functions"));
Expand Down
1 change: 1 addition & 0 deletions tests/py-demo/bindings/src/modules.h
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ void bind_aliases_module(py::module&& m);
void bind_classes_module(py::module&& m);
void bind_eigen_module(py::module&& m);
void bind_enum_module(py::module&& m);
void bind_duplicate_enum_module(py::module&& m);
void bind_flawed_bindings_module(py::module&& m);
void bind_functions_module(py::module&& m);
void bind_issues_module(py::module&& m);
Expand Down
15 changes: 15 additions & 0 deletions tests/py-demo/bindings/src/modules/duplicate_enum.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
#include "modules.h"

#include <demo/sublibA/ConsoleColors.h>

void bind_duplicate_enum_module(py::module&&m) {

py::enum_<demo::sublibA::ConsoleForegroundColorDuplicate>(m, "ConsoleForegroundColor")
.value("Magenta", demo::sublibA::ConsoleForegroundColorDuplicate::Magenta)
.export_values();

m.def(
"accepts_ambiguous_enum",
[](const demo::sublibA::ConsoleForegroundColorDuplicate &color) {},
py::arg("color") = demo::sublibA::ConsoleForegroundColorDuplicate::Magenta);
}
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ from __future__ import annotations
from demo._bindings import (
aliases,
classes,
duplicate_enum,
eigen,
enum,
flawed_bindings,
Expand All @@ -23,6 +24,7 @@ __all__ = [
"aliases",
"classes",
"core",
"duplicate_enum",
"eigen",
"enum",
"flawed_bindings",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ from __future__ import annotations
from . import (
aliases,
classes,
duplicate_enum,
eigen,
enum,
flawed_bindings,
Expand All @@ -20,6 +21,7 @@ from . import (
__all__ = [
"aliases",
"classes",
"duplicate_enum",
"eigen",
"enum",
"flawed_bindings",
Expand Down
Loading