Skip to content

Commit

Permalink
Merge pull request opencv#24022 from VadimLevin:dev/vlevin/python-typ…
Browse files Browse the repository at this point in the history
…ing-cuda

Fix python typing stubs generation for CUDA modules opencv#24022

resolves opencv#23946
resolves opencv#23945
resolves opencv/opencv-python#871

### Pull Request Readiness Checklist

See details at https://github.com/opencv/opencv/wiki/How_to_contribute#making-a-good-pull-request

- [x] I agree to contribute to the project under Apache 2 License.
- [x] To the best of my knowledge, the proposed patch is not based on a code under GPL or another license that is incompatible with OpenCV
- [x] The PR is proposed to the proper branch
- [x] There is a reference to the original bug report and related work
- [ ] There is accuracy test, performance test and test data in opencv_extra repository, if applicable
      Patch to opencv_extra has the same branch name.
- [x] The feature is well documented and sample code can be built with the project CMake
  • Loading branch information
VadimLevin authored and thewoz committed Jan 4, 2024
1 parent c46a2f7 commit de3dfc5
Show file tree
Hide file tree
Showing 4 changed files with 112 additions and 42 deletions.
68 changes: 62 additions & 6 deletions modules/python/src2/typing_stubs_generation/api_refinement.py
Expand Up @@ -2,14 +2,17 @@
"apply_manual_api_refinement"
]

from typing import Sequence, Callable
from typing import cast, Sequence, Callable, Iterable

from .nodes import (NamespaceNode, FunctionNode, OptionalTypeNode,
ClassProperty, PrimitiveTypeNode)
from .ast_utils import find_function_node, SymbolName
from .nodes import (NamespaceNode, FunctionNode, OptionalTypeNode, TypeNode,
ClassProperty, PrimitiveTypeNode, ASTNodeTypeNode,
AggregatedTypeNode)
from .ast_utils import (find_function_node, SymbolName,
for_each_function_overload)


def apply_manual_api_refinement(root: NamespaceNode) -> None:
refine_cuda_module(root)
export_matrix_type_constants(root)
# Export OpenCV exception class
builtin_exception = root.add_class("Exception")
Expand Down Expand Up @@ -57,13 +60,65 @@ def _make_optional_arg(root_node: NamespaceNode,
continue

overload.arguments[arg_idx].type_node = OptionalTypeNode(
overload.arguments[arg_idx].type_node
cast(TypeNode, overload.arguments[arg_idx].type_node)
)

return _make_optional_arg


def _find_argument_index(arguments: Sequence[FunctionNode.Arg], name: str) -> int:
def refine_cuda_module(root: NamespaceNode) -> None:
def fix_cudaoptflow_enums_names() -> None:
for class_name in ("NvidiaOpticalFlow_1_0", "NvidiaOpticalFlow_2_0"):
if class_name not in cuda_root.classes:
continue
opt_flow_class = cuda_root.classes[class_name]
_trim_class_name_from_argument_types(
for_each_function_overload(opt_flow_class), class_name
)

def fix_namespace_usage_scope(cuda_ns: NamespaceNode) -> None:
USED_TYPES = ("GpuMat", "Stream")

def fix_type_usage(type_node: TypeNode) -> None:
if isinstance(type_node, AggregatedTypeNode):
for item in type_node.items:
fix_type_usage(item)
if isinstance(type_node, ASTNodeTypeNode):
if type_node._typename in USED_TYPES:
type_node._typename = f"cuda_{type_node._typename}"

for overload in for_each_function_overload(cuda_ns):
if overload.return_type is not None:
fix_type_usage(overload.return_type.type_node)
for type_node in [arg.type_node for arg in overload.arguments
if arg.type_node is not None]:
fix_type_usage(type_node)

if "cuda" not in root.namespaces:
return
cuda_root = root.namespaces["cuda"]
fix_cudaoptflow_enums_names()
for ns in [ns for ns_name, ns in root.namespaces.items()
if ns_name.startswith("cuda")]:
fix_namespace_usage_scope(ns)


def _trim_class_name_from_argument_types(
overloads: Iterable[FunctionNode.Overload],
class_name: str
) -> None:
separator = f"{class_name}_"
for overload in overloads:
for arg in [arg for arg in overload.arguments
if arg.type_node is not None]:
ast_node = cast(ASTNodeTypeNode, arg.type_node)
if class_name in ast_node.ctype_name:
fixed_name = ast_node._typename.split(separator)[-1]
ast_node._typename = fixed_name


def _find_argument_index(arguments: Sequence[FunctionNode.Arg],
name: str) -> int:
for i, arg in enumerate(arguments):
if arg.name == name:
return i
Expand All @@ -76,6 +131,7 @@ def _find_argument_index(arguments: Sequence[FunctionNode.Arg], name: str) -> in
SymbolName(("cv", ), (), "resize"): make_optional_arg("dsize"),
SymbolName(("cv", ), (), "calcHist"): make_optional_arg("mask"),
}

ERROR_CLASS_PROPERTIES = (
ClassProperty("code", PrimitiveTypeNode.int_(), False),
ClassProperty("err", PrimitiveTypeNode.str_(), False),
Expand Down
29 changes: 28 additions & 1 deletion modules/python/src2/typing_stubs_generation/ast_utils.py
@@ -1,5 +1,5 @@
from typing import (NamedTuple, Sequence, Tuple, Union, List,
Dict, Callable, Optional)
Dict, Callable, Optional, Generator)
import keyword

from .nodes import (ASTNode, NamespaceNode, ClassNode, FunctionNode,
Expand Down Expand Up @@ -404,6 +404,33 @@ def update_full_export_name(class_node: ClassNode) -> None:
return enum_export_name, namespace_node.full_export_name


def for_each_class(
node: Union[NamespaceNode, ClassNode]
) -> Generator[ClassNode, None, None]:
for cls in node.classes.values():
yield cls
if len(cls.classes):
yield from for_each_class(cls)


def for_each_function(
node: Union[NamespaceNode, ClassNode],
traverse_class_nodes: bool = True
) -> Generator[FunctionNode, None, None]:
yield from node.functions.values()
if traverse_class_nodes:
for cls in for_each_class(node):
yield from for_each_function(cls)


def for_each_function_overload(
node: Union[NamespaceNode, ClassNode],
traverse_class_nodes: bool = True
) -> Generator[FunctionNode.Overload, None, None]:
for func in for_each_function(node, traverse_class_nodes):
yield from func.overloads


if __name__ == '__main__':
import doctest
doctest.testmod()
53 changes: 18 additions & 35 deletions modules/python/src2/typing_stubs_generation/generation.py
Expand Up @@ -3,17 +3,20 @@
from io import StringIO
from pathlib import Path
import re
from typing import (Generator, Type, Callable, NamedTuple, Union, Set, Dict,
from typing import (Type, Callable, NamedTuple, Union, Set, Dict,
Collection, Tuple, List)
import warnings

from .ast_utils import get_enclosing_namespace, get_enum_module_and_export_name
from .ast_utils import (get_enclosing_namespace,
get_enum_module_and_export_name,
for_each_function_overload,
for_each_class)

from .predefined_types import PREDEFINED_TYPES
from .api_refinement import apply_manual_api_refinement

from .nodes import (ASTNode, ASTNodeType, NamespaceNode, ClassNode, FunctionNode,
EnumerationNode, ConstantNode)
from .nodes import (ASTNode, ASTNodeType, NamespaceNode, ClassNode,
FunctionNode, EnumerationNode, ConstantNode)

from .nodes.type_node import (TypeNode, AliasTypeNode, AliasRefTypeNode,
AggregatedTypeNode, ASTNodeTypeNode,
Expand Down Expand Up @@ -105,8 +108,9 @@ def _generate_typing_stubs(root: NamespaceNode, output_path: Path) -> None:

# NOTE: Enumerations require special handling, because all enumeration
# constants are exposed as module attributes
has_enums = _generate_section_stub(StubSection("# Enumerations", EnumerationNode),
root, output_stream, 0)
has_enums = _generate_section_stub(
StubSection("# Enumerations", EnumerationNode), root, output_stream, 0
)
# Collect all enums from class level and export them to module level
for class_node in root.classes.values():
if _generate_enums_from_classes_tree(class_node, output_stream, indent=0):
Expand Down Expand Up @@ -536,30 +540,6 @@ def check_overload_presence(node: Union[NamespaceNode, ClassNode]) -> bool:
return True
return False


def _for_each_class(node: Union[NamespaceNode, ClassNode]) \
-> Generator[ClassNode, None, None]:
for cls in node.classes.values():
yield cls
if len(cls.classes):
yield from _for_each_class(cls)


def _for_each_function(node: Union[NamespaceNode, ClassNode]) \
-> Generator[FunctionNode, None, None]:
for func in node.functions.values():
yield func
for cls in node.classes.values():
yield from _for_each_function(cls)


def _for_each_function_overload(node: Union[NamespaceNode, ClassNode]) \
-> Generator[FunctionNode.Overload, None, None]:
for func in _for_each_function(node):
for overload in func.overloads:
yield overload


def _collect_required_imports(root: NamespaceNode) -> Set[str]:
"""Collects all imports required for classes and functions typing stubs
declarations.
Expand All @@ -582,7 +562,7 @@ def _add_required_usage_imports(type_node: TypeNode, imports: Set[str]):
has_overload = check_overload_presence(root)
# if there is no module-level functions with overload, check its presence
# during class traversing, including their inner-classes
for cls in _for_each_class(root):
for cls in for_each_class(root):
if not has_overload and check_overload_presence(cls):
has_overload = True
required_imports.add("import typing")
Expand All @@ -600,8 +580,9 @@ def _add_required_usage_imports(type_node: TypeNode, imports: Set[str]):
if has_overload:
required_imports.add("import typing")
# Importing modules required to resolve functions arguments
for overload in _for_each_function_overload(root):
for arg in filter(lambda a: a.type_node is not None, overload.arguments):
for overload in for_each_function_overload(root):
for arg in filter(lambda a: a.type_node is not None,
overload.arguments):
_add_required_usage_imports(arg.type_node, required_imports) # type: ignore
if overload.return_type is not None:
_add_required_usage_imports(overload.return_type.type_node,
Expand All @@ -625,11 +606,13 @@ def _reexport_submodule(ns: NamespaceNode) -> None:

_reexport_submodule(root)

# Special cases, symbols defined in possible pure Python submodules should be
# Special cases, symbols defined in possible pure Python submodules
# should be
root.reexported_submodules_symbols["mat_wrapper"].append("Mat")


def _write_reexported_symbols_section(module: NamespaceNode, output_stream: StringIO) -> None:
def _write_reexported_symbols_section(module: NamespaceNode,
output_stream: StringIO) -> None:
"""Write re-export section for the given module.
Re-export statements have from `from module_name import smth as smth`.
Expand Down
Expand Up @@ -22,6 +22,10 @@
PrimitiveTypeNode.int_("uchar"),
PrimitiveTypeNode.int_("unsigned"),
PrimitiveTypeNode.int_("int64"),
PrimitiveTypeNode.int_("uint8_t"),
PrimitiveTypeNode.int_("int8_t"),
PrimitiveTypeNode.int_("int32_t"),
PrimitiveTypeNode.int_("uint32_t"),
PrimitiveTypeNode.int_("size_t"),
PrimitiveTypeNode.float_("float"),
PrimitiveTypeNode.float_("double"),
Expand Down

0 comments on commit de3dfc5

Please sign in to comment.