From 66ad46a915a2838fd9a37e8051248d6a4fa3b9ae Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Lars=20Gr=C3=BCter?= Date: Sat, 1 Jun 2024 13:41:11 +0200 Subject: [PATCH 01/16] Add StaticInspector --- src/docstub/_analysis.py | 122 +++++++++++++++++++++++++++++++++++-- src/docstub/_cli.py | 32 ++++++---- src/docstub/_docstrings.py | 35 +++++------ src/docstub/_stubs.py | 13 ++-- 4 files changed, 163 insertions(+), 39 deletions(-) diff --git a/src/docstub/_analysis.py b/src/docstub/_analysis.py index ee851b0..ff73e21 100644 --- a/src/docstub/_analysis.py +++ b/src/docstub/_analysis.py @@ -2,8 +2,12 @@ import builtins import collections.abc +import itertools import typing from dataclasses import dataclass +from pathlib import Path + +import libcst as cst @dataclass(slots=True, frozen=True) @@ -56,6 +60,20 @@ def format_import(self): def has_import(self): return not self.is_builtin + def __repr__(self): + classname = type(self).__name__ + if self.has_import: + info = f"{self.import_name}" + if self.import_path: + info = f"{self.import_path}.{info}" + if self.import_alias: + info = f"{info} as {self.import_alias}" + if self.use_name not in info: + info = f"{info}; {self.use_name}" + else: + info = f"{self.use_name} (builtin)" + return f"{classname}: {info}" + def _is_type(value) -> bool: """Check if value is a type.""" @@ -139,7 +157,103 @@ def common_docnames(): return docnames -def find_module_paths_using_search(): - # TODO use from mypy.stubgen import find_module_paths_using_search ? - # https://github.com/python/mypy/blob/66b48cbe97bf9c7660525766afe6d7089a984769/mypy/stubgen.py#L1526 - pass +class DocNameCollector(cst.CSTVisitor): + + @classmethod + def collect(cls, file, module_name): + file = Path(file) + with file.open("r") as fo: + source = fo.read() + + tree = cst.parse_module(source) + collector = cls(module_name=module_name) + tree.visit(collector) + return collector.docnames + + def __init__(self, *, module_name): + self.module_name = module_name + self._stack = [] + self.docnames = {} + + def visit_ClassDef(self, node): + self._stack.append(node.name.value) + + use_name = ".".join(self._stack[:1]) + qualname = f"{self.module_name}.{'.'.join(self._stack)}" + docname = DocName( + use_name=use_name, import_name=use_name, import_path=self.module_name + ) + self.docnames[qualname] = docname + + return True + + def leave_ClassDef(self, original_node): + self._stack.pop() + + def visit_FunctionDef(self, node): + self._stack.append(node.name.value) + return True + + def leave_FunctionDef(self, original_node): + self._stack.pop() + + +class StaticInspector: + """Try to find docnames when requested.""" + + def __init__(self, *, source_pkgs=None, docnames=None): + if source_pkgs is None: + source_pkgs = [] + if docnames is None: + docnames = {} + + self.source_pkgs: list[Path] = source_pkgs + self.docnames = docnames + self._inspected = {} + + @staticmethod + def _accumulate_module_name(qualname): + fragments = qualname.split(".") + yield from itertools.accumulate(fragments, lambda x, y: f"{x}.{y}") + + def _find_modules(self, qualname): + for source in self.source_pkgs: + for module_name in self._accumulate_module_name(qualname): + module_path = module_name.replace(".", "/") + # Return PYI files last, so their content overwrites + files = [ + source / f"{module_path}.py", + source / f"{module_path}.pyi", + source / f"{module_path}/__init__.py", + source / f"{module_path}/__init__.pyi", + ] + for file in files: + if file.is_file(): + yield file, module_name + + def inspect_module(self, file, module_name): + """Collect docnames from the given file. + + Parameters + ---------- + file : Path + + Returns + ------- + collected : set[DocName] + """ + if file in self._inspected: + return self._inspected[file] + + docnames = DocNameCollector.collect(file, module_name) + self._inspected[file] = docnames + self.docnames.update(docnames) + return docnames + + def query(self, qualname): + out = self.docnames.get(qualname) + if out is None: + for file, module_name in self._find_modules(qualname): + self.inspect_module(file, module_name) + out = self.docnames.get(qualname) + return out diff --git a/src/docstub/_cli.py b/src/docstub/_cli.py index 1038daf..1dcc478 100644 --- a/src/docstub/_cli.py +++ b/src/docstub/_cli.py @@ -5,8 +5,8 @@ import click from . import _config -from ._analysis import DocName, common_docnames -from ._stubs import Py2StubTransformer, walk_python_package +from ._analysis import DocName, StaticInspector, common_docnames +from ._stubs import Py2StubTransformer, walk_source_and_targets from ._version import __version__ logger = logging.getLogger(__name__) @@ -33,14 +33,21 @@ def main(source_dir, out_dir, config_path, verbose): source_dir = Path(source_dir) # Handle configuration - if config_path is None: - config_path = source_dir.parent / "docstub.toml" - else: - config_path = Path(config_path) config = _config.default_config() - if config_path.exists(): - _user_config = _config.load_config_file(config_path) - config = _config.merge_config(config, _user_config) + pyproject_toml = source_dir.parent / "pyproject.toml" + docstub_toml = source_dir.parent / "docstub.toml" + if pyproject_toml.is_file(): + logger.info("using %s", pyproject_toml) + add_config = _config.load_config_file(pyproject_toml) + config = _config.merge_config(config, add_config) + if docstub_toml.is_file(): + logger.info("using %s", docstub_toml) + add_config = _config.load_config_file(docstub_toml) + config = _config.merge_config(config, add_config) + if config_path: + logger.info("using %s", config_path) + add_config = _config.load_config_file(config_path) + config = _config.merge_config(config, add_config) # Build docname map docnames = common_docnames() @@ -50,15 +57,18 @@ def main(source_dir, out_dir, config_path, verbose): for name, spec in config["docnames"].items() } ) + inspector = StaticInspector( + source_pkgs=[source_dir.parent.resolve()], docnames=docnames + ) # and the stub transformer - stub_transformer = Py2StubTransformer(docnames=docnames) + stub_transformer = Py2StubTransformer(inspector=inspector) if not out_dir: out_dir = source_dir.parent out_dir = Path(out_dir) / (source_dir.name + "-stubs") out_dir.mkdir(parents=True, exist_ok=True) - for source_path, stub_path in walk_python_package(source_dir, out_dir): + for source_path, stub_path in walk_source_and_targets(source_dir, out_dir): if source_path.suffix.lower() == ".pyi": logger.debug("using existing stub file %s", source_path) with source_path.open() as fo: diff --git a/src/docstub/_docstrings.py b/src/docstub/_docstrings.py index 215f760..d935dc5 100644 --- a/src/docstub/_docstrings.py +++ b/src/docstub/_docstrings.py @@ -87,8 +87,8 @@ class DoctypeTransformer(lark.visitors.Transformer): Keyword arguments passed to the init of the parent class. """ - def __init__(self, *, docnames, **kwargs): - self.docnames = docnames + def __init__(self, *, inspector, **kwargs): + self.inspector = inspector self._collected_imports = None super().__init__(**kwargs) @@ -169,7 +169,12 @@ def qualname(self, tree): out.append(child) out = "".join(out) if matched is False: - logger.warning("unmatched name %r", out) + docname = self.inspector.query(out) + if docname: + out = docname.use_name + self._collected_imports.add(docname) + else: + logger.warning("unmatched name %r", out) return out def NAME(self, token): @@ -208,14 +213,14 @@ def contains(self, tree): def literals(self, tree): out = " , ".join(tree.children) out = f"Literal[{out}]" - self._collected_imports.add(self.docnames["Literal"]) + self._collected_imports.add(self.inspector.query("Literal")) return out def _match_n_record_name(self, token): """Match type names to known imports.""" assert "." not in token - if token in self.docnames: - docname = self.docnames[token] + docname = self.inspector.query(token) + if docname: token = MatchedName(token.type, value=docname.use_name) if docname.has_import: self._collected_imports.add(docname) @@ -228,7 +233,7 @@ def _match_n_record_name(self, token): _lark = lark.Lark(_grammar) -def doc2pytype(doctype, *, docnames): +def doc2pytype(doctype, *, inspector): """Convert a type description to a Python-ready type. Parameters @@ -236,9 +241,7 @@ def doc2pytype(doctype, *, docnames): doctype : str The type description of a parameter or return value, as extracted from a docstring. - docnames : dict[str, DocName] - A dictionary mapping atomic names used in doctypes to information such - as where to import from or how to replace the name itself. + inspector : docstub._analysis.StaticInspector Returns ------- @@ -247,7 +250,7 @@ def doc2pytype(doctype, *, docnames): necessary imports attached. """ try: - transformer = DoctypeTransformer(docnames=docnames) + transformer = DoctypeTransformer(inspector=inspector) tree = _lark.parse(doctype) pytype = transformer.transform(tree) return pytype @@ -268,16 +271,14 @@ class ReturnKey: ReturnKey = ReturnKey() -def collect_pytypes(docstring, *, docnames): +def collect_pytypes(docstring, *, inspector): """Collect PyTypes from a docstring. Parameters ---------- docstring : str The docstring to collect from. - docnames : dict[str, DocName] - A dictionary mapping atomic names used in doctypes to information such - as where to import from or how to replace the name itself. + inspector : docstub._analysis.StaticInspector Returns ------- @@ -295,13 +296,13 @@ def collect_pytypes(docstring, *, docnames): params.update(other) pytypes = { - name: doc2pytype(param.type, docnames=docnames) + name: doc2pytype(param.type, inspector=inspector) for name, param in params.items() if param.type } returns = [ - doc2pytype(param.type, docnames=docnames) + doc2pytype(param.type, inspector=inspector) for param in np_docstring["Returns"] if param.type ] diff --git a/src/docstub/_stubs.py b/src/docstub/_stubs.py index eb5beb7..ae7019b 100644 --- a/src/docstub/_stubs.py +++ b/src/docstub/_stubs.py @@ -13,8 +13,8 @@ logger = logging.getLogger(__name__) -def walk_python_package(root_dir, target_dir): - """Iterate modules in a Python package and it's target stub files. +def walk_source_and_targets(root_dir, target_dir): + """Iterate modules in a Python package and its target stub files. Parameters ---------- @@ -51,7 +51,6 @@ def walk_python_package(root_dir, target_dir): ): # Stub file already exists and takes precedence continue - stub_path = target_dir / source_path.with_suffix(".pyi").relative_to( root_dir ) @@ -109,8 +108,8 @@ class Py2StubTransformer(cst.CSTTransformer): _Annotation_Any = cst.Annotation(cst.Name("Any")) _Annotation_None = cst.Annotation(cst.Name("None")) - def __init__(self, *, docnames): - self.docnames = docnames + def __init__(self, *, inspector): + self.inspector = inspector # Relevant docstring for the current context self._scope_stack = None # Store current class or function scope self._pytypes_stack = None # Store current parameter types @@ -161,7 +160,7 @@ def visit_FunctionDef(self, node): pytypes = None if docstring: try: - pytypes = collect_pytypes(docstring, docnames=self.docnames) + pytypes = collect_pytypes(docstring, inspector=self.inspector) except Exception as e: logger.exception( "error while parsing docstring of `%s`:\n\n%s", node.name.value, e @@ -210,7 +209,7 @@ def leave_Param(self, original_node, updated_node): # Potentially use "Any" except for first param in (class)methods elif not is_self_or_cls and updated_node.annotation is None: node_changes["annotation"] = self._Annotation_Any - self._required_imports.add(self.docnames["Any"]) + self._required_imports.add(self.inspector.query("Any")) if updated_node.default is not None: node_changes["default"] = cst.Ellipsis() From 637cd5f8cdd2e83de66449ef358f360619ecaa4a Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Lars=20Gr=C3=BCter?= Date: Mon, 3 Jun 2024 20:38:01 +0200 Subject: [PATCH 02/16] WIP --- examples/example_pkg-stubs/__init__.pyi | 6 +- examples/example_pkg-stubs/_basic.pyi | 8 + examples/example_pkg/__init__.py | 7 +- examples/example_pkg/_basic.py | 34 +++++ pyproject.toml | 5 + src/docstub/_analysis.py | 86 ++++++++--- src/docstub/_cli.py | 52 ++++--- src/docstub/_docstrings.py | 87 ++++++++--- src/docstub/_stubs.py | 190 +++++++++++++++++------- 9 files changed, 358 insertions(+), 117 deletions(-) diff --git a/examples/example_pkg-stubs/__init__.pyi b/examples/example_pkg-stubs/__init__.pyi index 5797999..f1a41c0 100644 --- a/examples/example_pkg-stubs/__init__.pyi +++ b/examples/example_pkg-stubs/__init__.pyi @@ -1,8 +1,10 @@ import _numpy as np_ -from _basic import func_comment, func_contains +from _basic import func_contains __all__ = [ - "func_comment", "func_contains", "np_", ] + +class CustomException(Exception): + pass diff --git a/examples/example_pkg-stubs/_basic.pyi b/examples/example_pkg-stubs/_basic.pyi index 1d7e2db..06534ca 100644 --- a/examples/example_pkg-stubs/_basic.pyi +++ b/examples/example_pkg-stubs/_basic.pyi @@ -3,6 +3,9 @@ import logging from collections.abc import Sequence from typing import Any, Literal, Self, Union +from example_pkg import CustomException +from example_pkg._basic import ExampleClass + logger = logging.getLogger(__name__) __all__ = [ @@ -21,6 +24,7 @@ def func_contains( def func_literals( a1: Literal[1, 3, "foo"], a2: Literal["uno", 2, "drei", "four"] = ... ) -> None: ... +def func_use_from_elsewhere(a1: CustomException, a2: ExampleClass) -> None: ... class ExampleClass: def __init__(self, a1: int, a2: float | None = ...) -> None: ... @@ -29,5 +33,9 @@ class ExampleClass: def some_staticmethod(a1: float, a2: float | None = ...) -> dict[str, Any]: ... @property def some_property(self) -> str: ... + @some_property.setter + def some_property(self, value: str) -> None: ... @classmethod def method_returning_cls(cls, config: configparser.ConfigParser) -> Self: ... + @classmethod() + def method_returning_cls2(cls, config: configparser.ConfigParser) -> Self: ... diff --git a/examples/example_pkg/__init__.py b/examples/example_pkg/__init__.py index 2d40e24..ac61e3d 100644 --- a/examples/example_pkg/__init__.py +++ b/examples/example_pkg/__init__.py @@ -1,10 +1,13 @@ """Example of an init file.""" import _numpy as np_ -from _basic import func_comment, func_contains +from _basic import func_contains __all__ = [ - "func_comment", "func_contains", "np_", ] + + +class CustomException(Exception): + pass diff --git a/examples/example_pkg/_basic.py b/examples/example_pkg/_basic.py index 3933b94..b1bb434 100644 --- a/examples/example_pkg/_basic.py +++ b/examples/example_pkg/_basic.py @@ -53,6 +53,16 @@ def func_literals(a1, a2="uno"): """ +def func_use_from_elsewhere(a1, a2): + """Check if types with full import names are matched. + + Parameters + ---------- + a1 : example_pkg.CustomException + a2 : example_pkg._basic.ExampleClass + """ + + class ExampleClass: # TODO also take into account class level docstring @@ -101,6 +111,15 @@ def some_property(self): """ return str(self) + @some_property.setter + def some_property(self, value): + """Dummy + + Parameters + ---------- + value : str + """ + @classmethod def method_returning_cls(cls, config): """Using `Self` in context of classmethods is supported. @@ -115,3 +134,18 @@ def method_returning_cls(cls, config): out : Self New class. """ + + @classmethod() + def method_returning_cls2(cls, config): + """Using `Self` in context of classmethods is supported. + + Parameters + ---------- + config : configparser.ConfigParser + Configuation. + + Returns + ------- + out : Self + New class. + """ diff --git a/pyproject.toml b/pyproject.toml index 58943e1..04e824a 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -87,3 +87,8 @@ ignore = [ "RET504", # Assignment before `return` statement facilitates debugging "PTH123", # Using builtin open() instead of Path.open() is fine ] + + +[tool.docstub.docnames] +cst = {import = "cst"} +lark = {import = "lark"} diff --git a/src/docstub/_analysis.py b/src/docstub/_analysis.py index ff73e21..72b5cec 100644 --- a/src/docstub/_analysis.py +++ b/src/docstub/_analysis.py @@ -21,30 +21,59 @@ class DocName: is_builtin: bool = False @classmethod - def from_cfg(cls, docname: str, spec: dict): + def one_from_config(cls, docname, *, info): + """Create one DocName from the configuration format. + + Parameters + ---------- + docname : str + info : dict[{"use", "from", "import", "as", "is_builtin"}, str] + + Returns + ------- + docname : Self + """ use_name = docname - if "import" in spec: - use_name = spec["import"] - if "as" in spec: - use_name = spec["as"] - if "use" in spec: - use_name = spec["use"] + if "import" in info: + use_name = info["import"] + if "as" in info: + use_name = info["as"] + if "use" in info: + use_name = info["use"] import_name = docname - if "use" in spec: - import_name = spec["use"] - if "import" in spec: - import_name = spec["import"] + if "use" in info: + import_name = info["use"] + if "import" in info: + import_name = info["import"] docname = cls( use_name=use_name, import_name=import_name, - import_path=spec.get("from"), - import_alias=spec.get("as"), - is_builtin=spec.get("builtin", False), + import_path=info.get("from"), + import_alias=info.get("as"), + is_builtin=info.get("builtin", False), ) return docname + @classmethod + def many_from_config(cls, mapping): + """Create many DocNames from the configuration format. + + Parameters + ---------- + mapping : dict[str, dict[{"use", "from", "import", "as", "is_builtin"}, str]] + + Returns + ------- + docnames : dict[str, Self] + """ + docnames = { + docname: cls.one_from_config(docname, info=info) + for docname, info in mapping.items() + } + return docnames + def format_import(self): if self.is_builtin: msg = "cannot import builtin" @@ -75,6 +104,14 @@ def __repr__(self): return f"{classname}: {info}" +@dataclass(slots=True, frozen=True) +class InspectionContext: + """Currently inspected module and other information.""" + + file_path: Path + in_package_path: str + + def _is_type(value) -> bool: """Check if value is a type.""" # Checking for isinstance(..., type) isn't enough, some types such as @@ -119,7 +156,7 @@ def _typing_docnames(): value = getattr(typing, name) if not _is_type(value): continue - docnames[name] = DocName.from_cfg(name, spec={"from": "typing"}) + docnames[name] = DocName.one_from_config(name, info={"from": "typing"}) return docnames @@ -137,7 +174,7 @@ def _collections_abc_docnames(): value = getattr(collections.abc, name) if not _is_type(value): continue - docnames[name] = DocName.from_cfg(name, spec={"from": "collections.abc"}) + docnames[name] = DocName.one_from_config(name, info={"from": "collections.abc"}) return docnames @@ -199,15 +236,26 @@ def leave_FunctionDef(self, original_node): class StaticInspector: - """Try to find docnames when requested.""" + """Static analysis of Python packages. + + Parameters + ---------- + source_pkgs: list[Path] + docnames: dict[str, DocName] + """ - def __init__(self, *, source_pkgs=None, docnames=None): + def __init__( + self, + *, + source_pkgs=None, + docnames=None, + ): if source_pkgs is None: source_pkgs = [] if docnames is None: docnames = {} - self.source_pkgs: list[Path] = source_pkgs + self.source_pkgs = source_pkgs self.docnames = docnames self._inspected = {} diff --git a/src/docstub/_cli.py b/src/docstub/_cli.py index 1dcc478..b57aef7 100644 --- a/src/docstub/_cli.py +++ b/src/docstub/_cli.py @@ -15,23 +15,18 @@ _VERBOSITY_LEVEL = {0: logging.WARNING, 1: logging.INFO, 2: logging.DEBUG} -@click.command() -@click.version_option(__version__) -@click.argument("source_dir", type=click.Path(exists=True, file_okay=False)) -@click.option("-o", "--out-dir", type=click.Path(file_okay=False)) -@click.option("--config", "config_path", type=click.Path(exists=True, dir_okay=False)) -@click.option("-v", "--verbose", count=True, help="Log more details") -@click.help_option("-h", "--help") -def main(source_dir, out_dir, config_path, verbose): - verbose = min(2, max(0, verbose)) # Limit to range [0, 2] - logging.basicConfig( - level=_VERBOSITY_LEVEL[verbose], - format="%(levelname)s: %(filename)s, line %(lineno)d, in %(funcName)s: %(message)s", - stream=sys.stderr, - ) +def _find_configuration(source_dir, config_path): + """Find and load configuration from multiple possible sources. - source_dir = Path(source_dir) + Parameters + ---------- + source_dir : Path + config_path : Path + Returns + ------- + config : dict[str, Any] + """ # Handle configuration config = _config.default_config() pyproject_toml = source_dir.parent / "pyproject.toml" @@ -48,15 +43,30 @@ def main(source_dir, out_dir, config_path, verbose): logger.info("using %s", config_path) add_config = _config.load_config_file(config_path) config = _config.merge_config(config, add_config) + return config + + +@click.command() +@click.version_option(__version__) +@click.argument("source_dir", type=click.Path(exists=True, file_okay=False)) +@click.option("-o", "--out-dir", type=click.Path(file_okay=False)) +@click.option("--config", "config_path", type=click.Path(exists=True, dir_okay=False)) +@click.option("-v", "--verbose", count=True, help="Log more details") +@click.help_option("-h", "--help") +def main(source_dir, out_dir, config_path, verbose): + verbose = min(2, max(0, verbose)) # Limit to range [0, 2] + logging.basicConfig( + level=_VERBOSITY_LEVEL[verbose], + format="%(levelname)s: %(filename)s#L%(lineno)d::%(funcName)s: %(message)s", + stream=sys.stderr, + ) + + source_dir = Path(source_dir) + config = _find_configuration(source_dir, config_path) # Build docname map docnames = common_docnames() - docnames.update( - { - name: DocName.from_cfg(docname=name, spec=spec) - for name, spec in config["docnames"].items() - } - ) + docnames.update(DocName.many_from_config(config["docnames"])) inspector = StaticInspector( source_pkgs=[source_dir.parent.resolve()], docnames=docnames ) diff --git a/src/docstub/_docstrings.py b/src/docstub/_docstrings.py index d935dc5..1cd9920 100644 --- a/src/docstub/_docstrings.py +++ b/src/docstub/_docstrings.py @@ -2,6 +2,7 @@ """ +import enum import logging from dataclasses import dataclass, field from pathlib import Path @@ -42,15 +43,15 @@ def __str__(self) -> str: return self.value @classmethod - def from_concatenated(cls, pytypes): - """Concatenate multiple PyTypes in a tuple. + def as_return_tuple(cls, return_types): + """Concatenate multiple PyTypes and wrap in tuple if more than one. Useful to combine multiple returned types for a function into a single PyType. Parameters ---------- - pytypes : Iterable[PyType] + return_types : Iterable[PyType] The types to combine. Returns @@ -58,16 +59,51 @@ def from_concatenated(cls, pytypes): concatenated : PyType The concatenated types. """ + values, imports = cls._aggregate_pytypes(*return_types) + value = " , ".join(values) + if len(values) > 1: + value = f"tuple[{value}]" + concatenated = cls(value=value, imports=imports) + return concatenated + + @classmethod + def as_yields_generator(cls, yield_types, receive_types=()): + """Create new iterator type from yield and receive types. + + Parameters + ---------- + yield_types : Iterable[PyType] + The types to yield. + receive_types : Iterable[PyType], optional + The types the generator receives. + + Returns + ------- + iterator : PyType + The yielded and received types wrapped in a generator. + """ + # TODO + raise NotImplementedError() + + @staticmethod + def _aggregate_pytypes(*types): + """Aggregate values and imports of given PyTypes. + + Parameters + ---------- + types : Iterable[PyType] + + Returns + ------- + values : list[str] + imports : set[DocName] + """ values = [] imports = set() - for p in pytypes: + for p in types: values.append(p.value) imports.update(p.imports) - value = " , ".join(values) - if len(values) > 1: - value = f"tuple[{value}]" - joined = cls(value=value, imports=imports) - return joined + return values, imports class MatchedName(lark.Token): @@ -261,14 +297,9 @@ def doc2pytype(doctype, *, inspector): ) -class ReturnKey: - """Simple "singleton" key to access the return PyType in a dictionary. - - See :func:`collect_pytypes` for more. - """ - - -ReturnKey = ReturnKey() +class NPDocSection(enum.Enum): + RETURNS = enum.auto() + YIELDS = enum.auto() def collect_pytypes(docstring, *, inspector): @@ -290,6 +321,7 @@ def collect_pytypes(docstring, *, inspector): params = {p.name: p for p in np_docstring["Parameters"]} other = {p.name: p for p in np_docstring["Other Parameters"]} + duplicate_params = params.keys() & other.keys() if duplicate_params: raise ValueError(f"{duplicate_params=}") @@ -306,7 +338,26 @@ def collect_pytypes(docstring, *, inspector): for param in np_docstring["Returns"] if param.type ] + yields = [ + doc2pytype(param.type, inspector=inspector) + for param in np_docstring["Yields"] + if param.type + ] + receives = [ + doc2pytype(param.type, inspector=inspector) + for param in np_docstring["Receives"] + if param.type + ] + if returns and yields: + logger.warning( + "found 'Returns' and 'Yields' section in docstring, ignoring 'Yields'" + ) + if receives and not yields: + logger.warning("found 'Receives' section in docstring without 'Yields' section") + if returns: - pytypes[ReturnKey] = PyType.from_concatenated(returns) + pytypes[NPDocSection.RETURNS] = PyType.as_return_tuple(returns) + elif yields: + logger.error("yields is not supported yet, ignoring") return pytypes diff --git a/src/docstub/_stubs.py b/src/docstub/_stubs.py index ae7019b..b9f5e04 100644 --- a/src/docstub/_stubs.py +++ b/src/docstub/_stubs.py @@ -2,17 +2,97 @@ """ +import enum import logging from dataclasses import dataclass -from typing import Literal +from pathlib import Path import libcst as cst -from ._docstrings import ReturnKey, collect_pytypes +from ._docstrings import NPDocSection, collect_pytypes logger = logging.getLogger(__name__) +class PythonFile(Path): + + def __init__(self, *args, package_root): + self.package_root = package_root + super().__init__(*args) + if not self.is_file(): + raise ValueError("must be a file") + if not self.is_relative_to(self.package_root): + raise ValueError("path must be relative to package_root") + + @property + def import_name(self): + relative_to_root = self.relative_to(self.package_root) + parts = relative_to_root.with_suffix("").parts + if parts[-1] == "__init__": + parts = parts[:-1] + import_name = ".".join(parts) + return import_name + + def with_segments(self, *args): + return Path(*args) + + +def _is_python_package(path): + """ + Parameters + ---------- + path : Path + + Returns + ------- + is_package : bool + """ + is_package = (path / "__init__.py").is_file() or (path / "__init__.pyi").is_file() + return is_package + + +def walk_source(root_dir): + """Iterate modules in a Python package and its target stub files. + + Parameters + ---------- + root_dir : Path + Root directory of a Python package. + target_dir : Path + Root directory in which a matching stub package will be created. + + Yields + ------ + source_path : PythonFile + Either a Python file or a stub file that takes precedence. + + Notes + ----- + Files starting with "test_" are skipped entirely for now. + """ + queue = [root_dir] + while queue: + path = queue.pop(0) + + if path.is_dir(): + if _is_python_package(path): + queue.extend(path.iterdir()) + else: + logger.debug("skipping directory %s", path) + continue + + assert path.is_file() + + suffix = path.suffix.lower() + if suffix not in {".py", ".pyi"}: + continue + if suffix == ".py" and path.with_suffix(".pyi").exists(): + continue # Stub file already exists and takes precedence + + python_file = PythonFile(path, package_root=root_dir) + yield python_file + + def walk_source_and_targets(root_dir, target_dir): """Iterate modules in a Python package and its target stub files. @@ -25,36 +105,19 @@ def walk_source_and_targets(root_dir, target_dir): Returns ------- - source_path : Path + source_path : PythonFile Either a Python file or a stub file that takes precedence. - stub_path : Path + stub_path : PythonFile Target stub file. Notes ----- Files starting with "test_" are skipped entirely for now. """ - for root, _, files in root_dir.walk(top_down=True): - for name in files: - source_path = root / name - - if name.startswith("test_"): - logger.debug("skipping %s", name) - continue - - if source_path.suffix.lower() not in {".py", ".pyi"}: - continue - - if ( - source_path.suffix.lower() == ".py" - and source_path.with_suffix(".pyi").exists() - ): - # Stub file already exists and takes precedence - continue - stub_path = target_dir / source_path.with_suffix(".pyi").relative_to( - root_dir - ) - yield source_path, stub_path + for source_path in walk_source(root_dir): + stub_path = target_dir / source_path.with_suffix(".pyi").relative_to(root_dir) + stub_path = PythonFile(stub_path, package_root=target_dir) + yield source_path, stub_path def try_format_stub(stub: str) -> str: @@ -74,27 +137,23 @@ def try_format_stub(stub: str) -> str: return stub +class FuncType(enum.Enum): + MODULE = enum.auto() + CLASS = enum.auto() + FUNC = enum.auto() + METHOD = enum.auto() + CLASSMETHOD = enum.auto() + STATICMETHOD = enum.auto() + + @dataclass(slots=True, frozen=True) class _Scope: - type: Literal["module", "class", "func", "method", "classmethod", "staticmethod"] + type: FuncType node: cst.CSTNode = None - def __post_init__(self): - allowed_types = { - "module", - "class", - "func", - "method", - "classmethod", - "staticmethod", - } - if self.type not in allowed_types: - msg = f"type {self.type!r} is not allowed, allowed are {allowed_types!r}" - raise ValueError(msg) - @property def has_self_or_cls(self): - return self.type in {"method", "classmethod"} + return self.type in {FuncType.METHOD, FuncType.CLASSMETHOD} class Py2StubTransformer(cst.CSTTransformer): @@ -115,12 +174,14 @@ def __init__(self, *, inspector): self._pytypes_stack = None # Store current parameter types self._required_imports = None # Collect imports for used types - def python_to_stub(self, source: str) -> str: + def python_to_stub(self, source: str, module_path=None) -> str: """Convert Python source code to stub-file ready code.""" try: self._scope_stack = [] self._pytypes_stack = [] self._required_imports = set() + if module_path: + self.inspector.set_current_module(module_path) source_tree = cst.parse_module(source) stub_tree = source_tree.visit(self) @@ -128,8 +189,6 @@ def python_to_stub(self, source: str) -> str: stub = try_format_stub(stub) return stub finally: - assert len(self._scope_stack) == 0 - assert len(self._pytypes_stack) == 0 self._scope_stack = None self._pytypes_stack = None self._required_imports = None @@ -143,17 +202,7 @@ def leave_ClassDef(self, original_node, updated_node): return updated_node def visit_FunctionDef(self, node): - func_type = "func" - if self._scope_stack[-1].type == "class": - func_type = "method" - for decorator in node.decorators: - assert func_type in {"func", "method"} - if decorator.decorator.value == "classmethod": - func_type = "classmethod" - break - if decorator.decorator.value == "staticmethod": - func_type = "staticmethod" - break + func_type = self._function_type(node) self._scope_stack.append(_Scope(type=func_type, node=node)) docstring = node.get_docstring() @@ -176,8 +225,9 @@ def leave_FunctionDef(self, original_node, updated_node): pytypes = self._pytypes_stack.pop() if pytypes: - return_pytype = pytypes.get(ReturnKey) + return_pytype = pytypes.get(NPDocSection.RETURNS) if return_pytype: + assert return_pytype.value node_changes["returns"] = cst.Annotation( cst.parse_expression(return_pytype.value) ) @@ -234,6 +284,10 @@ def leave_Module(self, original_node, updated_node): self._scope_stack.pop() return updated_node + def visit_Lambda(self, node): + # Skip visiting parameters of lambda which can't have an annotation. + return False + @staticmethod def _parse_imports(imports): """Create nodes to include in the module tree from given imports. @@ -249,3 +303,29 @@ def _parse_imports(imports): lines = {imp.format_import() for imp in imports} import_nodes = tuple(cst.parse_statement(line) for line in lines) return import_nodes + + def _function_type(self, func_def): + """Determine if a function is a method, property, staticmethod, ... + + Parameters + ---------- + func_def : cst.FunctionDef + + Returns + ------- + func_type : FuncType + """ + func_type = FuncType.FUNC + if self._scope_stack[-1].type == FuncType.CLASS: + func_type = FuncType.METHOD + for decorator in func_def.decorators: + if not hasattr(decorator.decorator, "value"): + continue + if decorator.decorator.value == "classmethod": + func_type = FuncType.CLASSMETHOD + break + if decorator.decorator.value == "staticmethod": + assert func_type == FuncType.METHOD + func_type = FuncType.STATICMETHOD + break + return func_type From f5e750b63cde4dc08dd4ca8a29e89bb019272599 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Lars=20Gr=C3=BCter?= Date: Wed, 12 Jun 2024 10:09:56 -0400 Subject: [PATCH 03/16] WIP --- src/docstub/_analysis.py | 38 ++++++++++++++++++++++++++++++++++++-- src/docstub/_cli.py | 4 +++- src/docstub/_docstrings.py | 3 ++- src/docstub/_stubs.py | 26 +++++++++++++++++++++----- 4 files changed, 62 insertions(+), 9 deletions(-) diff --git a/src/docstub/_analysis.py b/src/docstub/_analysis.py index 72b5cec..298c3b9 100644 --- a/src/docstub/_analysis.py +++ b/src/docstub/_analysis.py @@ -242,6 +242,13 @@ class StaticInspector: ---------- source_pkgs: list[Path] docnames: dict[str, DocName] + + Examples + -------- + >>> from docstub._analysis import StaticInspector, common_docnames + >>> inspector = StaticInspector(docnames=common_docnames()) + >>> inspector.query("Any") + """ def __init__( @@ -255,9 +262,9 @@ def __init__( if docnames is None: docnames = {} + self.current_module = None self.source_pkgs = source_pkgs - self.docnames = docnames - self._inspected = {} + self._inspected = {"initial": docnames} @staticmethod def _accumulate_module_name(qualname): @@ -299,9 +306,36 @@ def inspect_module(self, file, module_name): return docnames def query(self, qualname): + """ + Parameters + ---------- + qualname + + Returns + ------- + + """ out = self.docnames.get(qualname) if out is None: for file, module_name in self._find_modules(qualname): self.inspect_module(file, module_name) out = self.docnames.get(qualname) + + *prefix, name = qualname.split(".") + if out is None and not prefix and self.current_module: + out = self.query(f"{self.current_module.import_name}.{qualname}") + return out + + @property + def docnames(self): + current_docnames = {} + + for _, docnames in self._inspected.items(): + current_docnames.update(docnames) + + return current_docnames + + def __repr__(self): + repr = f"{type(self).__name__}({self.source_pkgs})" + return repr diff --git a/src/docstub/_cli.py b/src/docstub/_cli.py index b57aef7..383ddf3 100644 --- a/src/docstub/_cli.py +++ b/src/docstub/_cli.py @@ -88,7 +88,9 @@ def main(source_dir, out_dir, config_path, verbose): py_content = fo.read() logger.debug("creating stub from %s", source_path) try: - stub_content = stub_transformer.python_to_stub(py_content) + stub_content = stub_transformer.python_to_stub( + py_content, module_path=source_path + ) except Exception as e: logger.exception("failed creating stub for %s:\n\n%s", source_path, e) continue diff --git a/src/docstub/_docstrings.py b/src/docstub/_docstrings.py index 1cd9920..52903c0 100644 --- a/src/docstub/_docstrings.py +++ b/src/docstub/_docstrings.py @@ -293,7 +293,8 @@ def doc2pytype(doctype, *, inspector): except Exception: logger.exception("couldn't parse docstring %r:", doctype) return PyType( - value="Any", imports={DocName.from_cfg("Any", {"from": "typing"})} + value="Any", + imports={DocName.one_from_config("Any", info={"from": "typing"})}, ) diff --git a/src/docstub/_stubs.py b/src/docstub/_stubs.py index b9f5e04..03947da 100644 --- a/src/docstub/_stubs.py +++ b/src/docstub/_stubs.py @@ -19,8 +19,8 @@ class PythonFile(Path): def __init__(self, *args, package_root): self.package_root = package_root super().__init__(*args) - if not self.is_file(): - raise ValueError("must be a file") + if self.is_dir(): + raise ValueError("mustn't be a directory") if not self.is_relative_to(self.package_root): raise ValueError("path must be relative to package_root") @@ -28,6 +28,7 @@ def __init__(self, *args, package_root): def import_name(self): relative_to_root = self.relative_to(self.package_root) parts = relative_to_root.with_suffix("").parts + parts = (self.package_root.name, *parts) if parts[-1] == "__init__": parts = parts[:-1] import_name = ".".join(parts) @@ -173,15 +174,29 @@ def __init__(self, *, inspector): self._scope_stack = None # Store current class or function scope self._pytypes_stack = None # Store current parameter types self._required_imports = None # Collect imports for used types + self._current_module = None - def python_to_stub(self, source: str, module_path=None) -> str: - """Convert Python source code to stub-file ready code.""" + def python_to_stub(self, source, *, module_path=None): + """Convert Python source code to stub-file ready code. + + Parameters + ---------- + source : str + module_path : PythonFile, optional + The location of the source that is transformed into a stub file. + If given, used to enhance logging & error messages with more + context information. + + Returns + ------- + stub : str + """ try: self._scope_stack = [] self._pytypes_stack = [] self._required_imports = set() if module_path: - self.inspector.set_current_module(module_path) + self.inspector.current_module = module_path source_tree = cst.parse_module(source) stub_tree = source_tree.visit(self) @@ -192,6 +207,7 @@ def python_to_stub(self, source: str, module_path=None) -> str: self._scope_stack = None self._pytypes_stack = None self._required_imports = None + self.inspector.current_module = None def visit_ClassDef(self, node): self._scope_stack.append(_Scope(type="class", node=node)) From d038ea8a035831f729860f343a5c2a565eeae6d5 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Lars=20Gr=C3=BCter?= Date: Wed, 12 Jun 2024 12:05:58 -0400 Subject: [PATCH 04/16] WIP --- examples/example_pkg-stubs/_basic.pyi | 5 ++-- examples/example_pkg/_basic.py | 6 ++++- src/docstub/_analysis.py | 38 +++++++++++++++++++++------ src/docstub/_cli.py | 10 +++++-- src/docstub/_docstrings.py | 10 ++++--- src/docstub/_stubs.py | 9 +++++-- src/docstub/doctype.lark | 2 +- 7 files changed, 60 insertions(+), 20 deletions(-) diff --git a/examples/example_pkg-stubs/_basic.pyi b/examples/example_pkg-stubs/_basic.pyi index 06534ca..15be338 100644 --- a/examples/example_pkg-stubs/_basic.pyi +++ b/examples/example_pkg-stubs/_basic.pyi @@ -4,7 +4,6 @@ from collections.abc import Sequence from typing import Any, Literal, Self, Union from example_pkg import CustomException -from example_pkg._basic import ExampleClass logger = logging.getLogger(__name__) @@ -24,7 +23,9 @@ def func_contains( def func_literals( a1: Literal[1, 3, "foo"], a2: Literal["uno", 2, "drei", "four"] = ... ) -> None: ... -def func_use_from_elsewhere(a1: CustomException, a2: ExampleClass) -> None: ... +def func_use_from_elsewhere( + a1: CustomException, a2: ExampleClass +) -> CustomException: ... class ExampleClass: def __init__(self, a1: int, a2: float | None = ...) -> None: ... diff --git a/examples/example_pkg/_basic.py b/examples/example_pkg/_basic.py index b1bb434..20caf26 100644 --- a/examples/example_pkg/_basic.py +++ b/examples/example_pkg/_basic.py @@ -59,7 +59,11 @@ def func_use_from_elsewhere(a1, a2): Parameters ---------- a1 : example_pkg.CustomException - a2 : example_pkg._basic.ExampleClass + a2 : ExampleClass + + Returns + ------- + r1 : ~.CustomException """ diff --git a/src/docstub/_analysis.py b/src/docstub/_analysis.py index 298c3b9..7bf6fa4 100644 --- a/src/docstub/_analysis.py +++ b/src/docstub/_analysis.py @@ -3,12 +3,16 @@ import builtins import collections.abc import itertools +import logging +import re import typing from dataclasses import dataclass from pathlib import Path import libcst as cst +logger = logging.getLogger(__name__) + @dataclass(slots=True, frozen=True) class DocName: @@ -309,21 +313,39 @@ def query(self, qualname): """ Parameters ---------- - qualname + qualname : str Returns ------- - + out : DocName | None """ out = self.docnames.get(qualname) - if out is None: - for file, module_name in self._find_modules(qualname): - self.inspect_module(file, module_name) - out = self.docnames.get(qualname) *prefix, name = qualname.split(".") - if out is None and not prefix and self.current_module: - out = self.query(f"{self.current_module.import_name}.{qualname}") + if not out and "~" in prefix: + pattern = qualname.replace(".", r"\.") + pattern = pattern.replace("~", ".*") + pattern = re.compile(pattern + "$") + matches = { + key: value + for key, value in self.docnames.items() + if re.match(pattern, key) + } + if len(matches) > 1: + shortest_key = sorted(matches.keys(), key=lambda x: len(x))[0] + out = matches[shortest_key] + logger.warning( + "%s matches multiple types %s, using %s", + qualname, + matches.keys(), + shortest_key, + ) + elif len(matches) == 1: + _, out = matches.popitem() + + elif not out and self.current_module: + try_qualname = f"{self.current_module.import_name}.{qualname}" + out = self.docnames.get(try_qualname) return out diff --git a/src/docstub/_cli.py b/src/docstub/_cli.py index 383ddf3..5c448fc 100644 --- a/src/docstub/_cli.py +++ b/src/docstub/_cli.py @@ -5,8 +5,8 @@ import click from . import _config -from ._analysis import DocName, StaticInspector, common_docnames -from ._stubs import Py2StubTransformer, walk_source_and_targets +from ._analysis import DocName, DocNameCollector, StaticInspector, common_docnames +from ._stubs import Py2StubTransformer, walk_source, walk_source_and_targets from ._version import __version__ logger = logging.getLogger(__name__) @@ -66,7 +66,13 @@ def main(source_dir, out_dir, config_path, verbose): # Build docname map docnames = common_docnames() + for source_path in walk_source(source_dir): + docnames_in_source = DocNameCollector.collect( + source_path, module_name=source_path.import_name + ) + docnames.update(docnames_in_source) docnames.update(DocName.many_from_config(config["docnames"])) + inspector = StaticInspector( source_pkgs=[source_dir.parent.resolve()], docnames=docnames ) diff --git a/src/docstub/_docstrings.py b/src/docstub/_docstrings.py index 52903c0..73f131d 100644 --- a/src/docstub/_docstrings.py +++ b/src/docstub/_docstrings.py @@ -96,7 +96,7 @@ def _aggregate_pytypes(*types): Returns ------- values : list[str] - imports : set[DocName] + imports : set[~.DocName] """ values = [] imports = set() @@ -116,7 +116,7 @@ class DoctypeTransformer(lark.visitors.Transformer): Parameters ---------- - docnames : dict[str, DocName] + docnames : dict[str, ~.DocName] A dictionary mapping atomic names used in doctypes to information such as where to import from or how to replace the name itself. kwargs : dict[Any, Any] @@ -210,7 +210,9 @@ def qualname(self, tree): out = docname.use_name self._collected_imports.add(docname) else: - logger.warning("unmatched name %r", out) + logger.warning( + "unmatched name %r in %s", out, self.inspector.current_module + ) return out def NAME(self, token): @@ -314,7 +316,7 @@ def collect_pytypes(docstring, *, inspector): Returns ------- - pytypes : dict[str | Literal[ReturnKey], PyType] + pytypes : dict[str | NPDocSection, PyType] The collected PyType for each parameter. If a return type is documented it's saved under the special key :class:`ReturnKey`. """ diff --git a/src/docstub/_stubs.py b/src/docstub/_stubs.py index 03947da..3b7605c 100644 --- a/src/docstub/_stubs.py +++ b/src/docstub/_stubs.py @@ -295,7 +295,12 @@ def visit_Module(self, node): return True def leave_Module(self, original_node, updated_node): - import_nodes = self._parse_imports(self._required_imports) + required_imports = [ + imp + for imp in self._required_imports + if imp.import_path != self.inspector.current_module.import_name + ] + import_nodes = self._parse_imports(required_imports) updated_node = updated_node.with_changes(body=import_nodes + updated_node.body) self._scope_stack.pop() return updated_node @@ -310,7 +315,7 @@ def _parse_imports(imports): Parameters ---------- - imports : set[DocName] + imports : set[~.DocName] Returns ------- diff --git a/src/docstub/doctype.lark b/src/docstub/doctype.lark index 7cc4f95..59f3adb 100644 --- a/src/docstub/doctype.lark +++ b/src/docstub/doctype.lark @@ -15,7 +15,7 @@ optional : "optional" extra_info : /[^\r\n]+/ // Name with leading dot separated path -qualname : (NAME ".")* NAME contains? +qualname : (/~/ ".")? (NAME ".")* NAME contains? contains: "[" type_or ("," type_or)* "]" From 599d262824dda6ce1bcd6d638092f3e531e3ee26 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Lars=20Gr=C3=BCter?= Date: Wed, 12 Jun 2024 13:26:30 -0400 Subject: [PATCH 05/16] WIP --- src/docstub/_cli.py | 1 + src/docstub/_docstrings.py | 25 +++++++++++++++---------- src/docstub/doctype.lark | 3 +++ 3 files changed, 19 insertions(+), 10 deletions(-) diff --git a/src/docstub/_cli.py b/src/docstub/_cli.py index 5c448fc..bf2e4fe 100644 --- a/src/docstub/_cli.py +++ b/src/docstub/_cli.py @@ -67,6 +67,7 @@ def main(source_dir, out_dir, config_path, verbose): # Build docname map docnames = common_docnames() for source_path in walk_source(source_dir): + logger.info("collecting types in %s", source_path) docnames_in_source = DocNameCollector.collect( source_path, module_name=source_path.import_name ) diff --git a/src/docstub/_docstrings.py b/src/docstub/_docstrings.py index 73f131d..4dcf31b 100644 --- a/src/docstub/_docstrings.py +++ b/src/docstub/_docstrings.py @@ -112,18 +112,18 @@ class MatchedName(lark.Token): @lark.visitors.v_args(tree=True) class DoctypeTransformer(lark.visitors.Transformer): - """Transformer for docstring type descriptions (doctypes). - - Parameters - ---------- - docnames : dict[str, ~.DocName] - A dictionary mapping atomic names used in doctypes to information such - as where to import from or how to replace the name itself. - kwargs : dict[Any, Any] - Keyword arguments passed to the init of the parent class. - """ + """Transformer for docstring type descriptions (doctypes).""" def __init__(self, *, inspector, **kwargs): + """ + Parameters + ---------- + inspector : ~.StaticInspector + A dictionary mapping atomic names used in doctypes to information such + as where to import from or how to replace the name itself. + kwargs : dict[Any, Any] + Keyword arguments passed to the init of the parent class. + """ self.inspector = inspector self._collected_imports = None super().__init__(**kwargs) @@ -194,6 +194,10 @@ def extra_info(self, tree): logger.debug("dropping extra info") return lark.Discard + def sphinx_ref(self, tree): + qualname = _find_one_token(tree, name="QUALNAME") + return qualname + def qualname(self, tree): matched = False out = [] @@ -213,6 +217,7 @@ def qualname(self, tree): logger.warning( "unmatched name %r in %s", out, self.inspector.current_module ) + out = lark.Token("QUALNAME", out) return out def NAME(self, token): diff --git a/src/docstub/doctype.lark b/src/docstub/doctype.lark index 59f3adb..209d18e 100644 --- a/src/docstub/doctype.lark +++ b/src/docstub/doctype.lark @@ -5,6 +5,7 @@ doctype : type_or ("," optional)? ("," extra_info)? type_or : type (("or" | "|") type)* ?type : qualname + | sphinx_ref | "{" literal ("," literal)* "}" -> literals | container_of | shape_n_dtype @@ -14,6 +15,8 @@ optional : "optional" extra_info : /[^\r\n]+/ +sphinx_ref : ":" (NAME ":")? NAME ":`" qualname "`" + // Name with leading dot separated path qualname : (/~/ ".")? (NAME ".")* NAME contains? From 324754ea57db8cda0b3ecb3dd16891fcc48e19b9 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Lars=20Gr=C3=BCter?= Date: Wed, 12 Jun 2024 14:37:23 -0400 Subject: [PATCH 06/16] Support relative imports for types in stubs --- examples/example_pkg-stubs/_basic.pyi | 2 +- pyproject.toml | 2 +- src/docstub/_analysis.py | 38 ++++++++++++++++++++++++--- src/docstub/_stubs.py | 9 ++++--- 4 files changed, 43 insertions(+), 8 deletions(-) diff --git a/examples/example_pkg-stubs/_basic.pyi b/examples/example_pkg-stubs/_basic.pyi index 15be338..95c81f6 100644 --- a/examples/example_pkg-stubs/_basic.pyi +++ b/examples/example_pkg-stubs/_basic.pyi @@ -3,7 +3,7 @@ import logging from collections.abc import Sequence from typing import Any, Literal, Self, Union -from example_pkg import CustomException +from . import CustomException logger = logging.getLogger(__name__) diff --git a/pyproject.toml b/pyproject.toml index 04e824a..754f815 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -90,5 +90,5 @@ ignore = [ [tool.docstub.docnames] -cst = {import = "cst"} +cst = {import = "libcst", as="cst"} lark = {import = "lark"} diff --git a/src/docstub/_analysis.py b/src/docstub/_analysis.py index 7bf6fa4..40fc41d 100644 --- a/src/docstub/_analysis.py +++ b/src/docstub/_analysis.py @@ -14,6 +14,29 @@ logger = logging.getLogger(__name__) +def _shared_leading_path(*paths): + """Identify the common leading parts between import paths. + + Parameters + ---------- + *paths : tuple[str] + + Returns + ------- + shared : str + """ + if len(paths) < 2: + raise ValueError("need more than two paths") + splits = (p.split(".") for p in paths) + shared = [] + for paths in zip(*splits, strict=False): + if all(paths[0] == p for p in paths): + shared.append(paths[0]) + else: + break + return ".".join(shared) + + @dataclass(slots=True, frozen=True) class DocName: """An atomic name (without ".") in a docstring type with import info.""" @@ -78,13 +101,22 @@ def many_from_config(cls, mapping): } return docnames - def format_import(self): + def format_import(self, relative_to=None): if self.is_builtin: msg = "cannot import builtin" raise RuntimeError(msg) out = f"import {self.import_name}" - if self.import_path: - out = f"from {self.import_path} {out}" + + import_path = self.import_path + if import_path: + if relative_to: + shared = _shared_leading_path(relative_to, import_path) + if shared == import_path: + import_path = "." + else: + import_path = self.import_path.replace(shared, "") + + out = f"from {import_path} {out}" if self.import_alias: out = f"{out} as {self.import_alias}" return out diff --git a/src/docstub/_stubs.py b/src/docstub/_stubs.py index 3b7605c..5126666 100644 --- a/src/docstub/_stubs.py +++ b/src/docstub/_stubs.py @@ -300,7 +300,9 @@ def leave_Module(self, original_node, updated_node): for imp in self._required_imports if imp.import_path != self.inspector.current_module.import_name ] - import_nodes = self._parse_imports(required_imports) + import_nodes = self._parse_imports( + required_imports, current_module=self.inspector.current_module.import_name + ) updated_node = updated_node.with_changes(body=import_nodes + updated_node.body) self._scope_stack.pop() return updated_node @@ -310,18 +312,19 @@ def visit_Lambda(self, node): return False @staticmethod - def _parse_imports(imports): + def _parse_imports(imports, *, current_module=None): """Create nodes to include in the module tree from given imports. Parameters ---------- imports : set[~.DocName] + current_module : str, optional Returns ------- import_nodes : tuple[cst.SimpleStatementLine, ...] """ - lines = {imp.format_import() for imp in imports} + lines = {imp.format_import(relative_to=current_module) for imp in imports} import_nodes = tuple(cst.parse_statement(line) for line in lines) return import_nodes From e300cf6c1431a5b9729e502a752da11cc834bf74 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Lars=20Gr=C3=BCter?= Date: Fri, 14 Jun 2024 11:58:40 -0400 Subject: [PATCH 07/16] WIP --- examples/example_pkg-stubs/_basic.pyi | 2 +- examples/example_pkg/_basic.py | 15 +- src/docstub/_analysis.py | 18 +- src/docstub/_cli.py | 2 +- src/docstub/_docstrings.py | 2 +- src/docstub/_stubs.py | 232 ++++++++++++++++++++++---- 6 files changed, 226 insertions(+), 45 deletions(-) diff --git a/examples/example_pkg-stubs/_basic.pyi b/examples/example_pkg-stubs/_basic.pyi index 95c81f6..ce8132a 100644 --- a/examples/example_pkg-stubs/_basic.pyi +++ b/examples/example_pkg-stubs/_basic.pyi @@ -28,7 +28,7 @@ def func_use_from_elsewhere( ) -> CustomException: ... class ExampleClass: - def __init__(self, a1: int, a2: float | None = ...) -> None: ... + def __init__(self, a1: str, a2: bool | None = ...) -> None: ... def method(self, a1: float, a2: float | None) -> list[float]: ... @staticmethod def some_staticmethod(a1: float, a2: float | None = ...) -> dict[str, Any]: ... diff --git a/examples/example_pkg/_basic.py b/examples/example_pkg/_basic.py index 20caf26..90fa25f 100644 --- a/examples/example_pkg/_basic.py +++ b/examples/example_pkg/_basic.py @@ -68,15 +68,16 @@ def func_use_from_elsewhere(a1, a2): class ExampleClass: - # TODO also take into account class level docstring + """Dummy. + + Parameters + ---------- + a1 : str + a2 : float, optional + """ def __init__(self, a1, a2=None): - """ - Parameters - ---------- - a1 : int - a2 : float, optional - """ + pass def method(self, a1, a2): """Dummy. diff --git a/src/docstub/_analysis.py b/src/docstub/_analysis.py index 40fc41d..aa80f3b 100644 --- a/src/docstub/_analysis.py +++ b/src/docstub/_analysis.py @@ -274,17 +274,15 @@ def leave_FunctionDef(self, original_node): class StaticInspector: """Static analysis of Python packages. - Parameters + Attributes ---------- - source_pkgs: list[Path] - docnames: dict[str, DocName] + current_source : ~.PackageFile | None Examples -------- >>> from docstub._analysis import StaticInspector, common_docnames >>> inspector = StaticInspector(docnames=common_docnames()) >>> inspector.query("Any") - """ def __init__( @@ -293,12 +291,18 @@ def __init__( source_pkgs=None, docnames=None, ): + """ + Parameters + ---------- + source_pkgs: list[Path], optional + docnames: dict[str, DocName], optional + """ if source_pkgs is None: source_pkgs = [] if docnames is None: docnames = {} - self.current_module = None + self.current_source = None self.source_pkgs = source_pkgs self._inspected = {"initial": docnames} @@ -375,8 +379,8 @@ def query(self, qualname): elif len(matches) == 1: _, out = matches.popitem() - elif not out and self.current_module: - try_qualname = f"{self.current_module.import_name}.{qualname}" + elif not out and self.current_source: + try_qualname = f"{self.current_source.import_path}.{qualname}" out = self.docnames.get(try_qualname) return out diff --git a/src/docstub/_cli.py b/src/docstub/_cli.py index bf2e4fe..bc21cd7 100644 --- a/src/docstub/_cli.py +++ b/src/docstub/_cli.py @@ -69,7 +69,7 @@ def main(source_dir, out_dir, config_path, verbose): for source_path in walk_source(source_dir): logger.info("collecting types in %s", source_path) docnames_in_source = DocNameCollector.collect( - source_path, module_name=source_path.import_name + source_path, module_name=source_path.import_path ) docnames.update(docnames_in_source) docnames.update(DocName.many_from_config(config["docnames"])) diff --git a/src/docstub/_docstrings.py b/src/docstub/_docstrings.py index 4dcf31b..ca81c56 100644 --- a/src/docstub/_docstrings.py +++ b/src/docstub/_docstrings.py @@ -215,7 +215,7 @@ def qualname(self, tree): self._collected_imports.add(docname) else: logger.warning( - "unmatched name %r in %s", out, self.inspector.current_module + "unmatched name %r in %s", out, self.inspector.current_source ) out = lark.Token("QUALNAME", out) return out diff --git a/src/docstub/_stubs.py b/src/docstub/_stubs.py index 5126666..b92c65c 100644 --- a/src/docstub/_stubs.py +++ b/src/docstub/_stubs.py @@ -14,9 +14,16 @@ logger = logging.getLogger(__name__) -class PythonFile(Path): +class PackageFile(Path): + """File in a Python package.""" def __init__(self, *args, package_root): + """ + Parameters + ---------- + args : tuple[Any, ...] + package_root : Path + """ self.package_root = package_root super().__init__(*args) if self.is_dir(): @@ -25,7 +32,12 @@ def __init__(self, *args, package_root): raise ValueError("path must be relative to package_root") @property - def import_name(self): + def import_path(self): + """ + Returns + ------- + str + """ relative_to_root = self.relative_to(self.package_root) parts = relative_to_root.with_suffix("").parts parts = (self.package_root.name, *parts) @@ -64,7 +76,7 @@ def walk_source(root_dir): Yields ------ - source_path : PythonFile + source_path : PackageFile Either a Python file or a stub file that takes precedence. Notes @@ -90,7 +102,7 @@ def walk_source(root_dir): if suffix == ".py" and path.with_suffix(".pyi").exists(): continue # Stub file already exists and takes precedence - python_file = PythonFile(path, package_root=root_dir) + python_file = PackageFile(path, package_root=root_dir) yield python_file @@ -106,9 +118,9 @@ def walk_source_and_targets(root_dir, target_dir): Returns ------- - source_path : PythonFile + source_path : PackageFile Either a Python file or a stub file that takes precedence. - stub_path : PythonFile + stub_path : PackageFile Target stub file. Notes @@ -117,7 +129,7 @@ def walk_source_and_targets(root_dir, target_dir): """ for source_path in walk_source(root_dir): stub_path = target_dir / source_path.with_suffix(".pyi").relative_to(root_dir) - stub_path = PythonFile(stub_path, package_root=target_dir) + stub_path = PackageFile(stub_path, package_root=target_dir) yield source_path, stub_path @@ -149,6 +161,8 @@ class FuncType(enum.Enum): @dataclass(slots=True, frozen=True) class _Scope: + """""" + type: FuncType node: cst.CSTNode = None @@ -156,9 +170,27 @@ class _Scope: def has_self_or_cls(self): return self.type in {FuncType.METHOD, FuncType.CLASSMETHOD} + @property + def is_method(self): + return self.type in { + FuncType.METHOD, + FuncType.CLASSMETHOD, + FuncType.STATICMETHOD, + } + + @property + def is_class_init(self): + out = self.is_method and self.node.name.value == "__init__" + return out + class Py2StubTransformer(cst.CSTTransformer): - """Transform syntax tree of a Python file into the tree of a stub file.""" + """Transform syntax tree of a Python file into the tree of a stub file. + + Attributes + ---------- + inspector : ~._analysis.StaticInspector + """ # Equivalent to ` ...`, to replace the body of callables with _body_replacement = cst.SimpleStatementSuite( @@ -169,6 +201,11 @@ class Py2StubTransformer(cst.CSTTransformer): _Annotation_None = cst.Annotation(cst.Name("None")) def __init__(self, *, inspector): + """ + Parameters + ---------- + inspector : ~._analysis.StaticInspector + """ self.inspector = inspector # Relevant docstring for the current context self._scope_stack = None # Store current class or function scope @@ -182,7 +219,7 @@ def python_to_stub(self, source, *, module_path=None): Parameters ---------- source : str - module_path : PythonFile, optional + module_path : PackageFile, optional The location of the source that is transformed into a stub file. If given, used to enhance logging & error messages with more context information. @@ -196,7 +233,7 @@ def python_to_stub(self, source, *, module_path=None): self._pytypes_stack = [] self._required_imports = set() if module_path: - self.inspector.current_module = module_path + self.inspector.current_source = module_path source_tree = cst.parse_module(source) stub_tree = source_tree.visit(self) @@ -207,33 +244,68 @@ def python_to_stub(self, source, *, module_path=None): self._scope_stack = None self._pytypes_stack = None self._required_imports = None - self.inspector.current_module = None + self.inspector.current_source = None def visit_ClassDef(self, node): - self._scope_stack.append(_Scope(type="class", node=node)) + """Collect pytypes from class docstring and add scope to stack. + + Parameters + ---------- + node : cst.ClassDef + + Returns + ------- + out : Literal[True] + """ + self._scope_stack.append(_Scope(type=FuncType.CLASS, node=node)) + pytypes = self._pytypes_from_func(node) + self._pytypes_stack.append(pytypes) return True def leave_ClassDef(self, original_node, updated_node): + """Drop class scope from the stack. + + Parameters + ---------- + original_node : cst.ClassDef + updated_node : cst.ClassDef + + Returns + ------- + updated_node : cst.ClassDef + """ self._scope_stack.pop() return updated_node def visit_FunctionDef(self, node): + """Collect pytypes from function docstring and add scope to stack. + + Parameters + ---------- + node : cst.FunctionDef + + Returns + ------- + out : Literal[True] + """ func_type = self._function_type(node) self._scope_stack.append(_Scope(type=func_type, node=node)) - - docstring = node.get_docstring() - pytypes = None - if docstring: - try: - pytypes = collect_pytypes(docstring, inspector=self.inspector) - except Exception as e: - logger.exception( - "error while parsing docstring of `%s`:\n\n%s", node.name.value, e - ) + pytypes = self._pytypes_from_func(node) self._pytypes_stack.append(pytypes) return True def leave_FunctionDef(self, original_node, updated_node): + """Add type annotation for return to function. + + Parameters + ---------- + original_node : cst.FunctionDef + updated_node : cst.FunctionDef + + Returns + ------- + updated_node : cst.FunctionDef + """ node_changes = { "body": self._body_replacement, "returns": self._Annotation_None, @@ -254,6 +326,17 @@ def leave_FunctionDef(self, original_node, updated_node): return updated_node def leave_Param(self, original_node, updated_node): + """Add type annotation to parameter. + + Parameters + ---------- + original_node : cst.Param + updated_node : cst.Param + + Returns + ------- + updated_node : cst.Param + """ node_changes = {} scope = self._scope_stack[-1] @@ -264,6 +347,8 @@ def leave_Param(self, original_node, updated_node): name = original_node.name.value pytypes = self._pytypes_stack[-1] + if not pytypes and scope.is_class_init: + pytypes = self._pytypes_stack[-2] if pytypes: pytype = pytypes.get(name) if pytype: @@ -284,31 +369,98 @@ def leave_Param(self, original_node, updated_node): updated_node = updated_node.with_changes(**node_changes) return updated_node - def leave_Expr(self, original_node, upated_node): + def leave_Expr(self, original_node, updated_node): + """Drop expression from stub file. + + Parameters + ---------- + original_node : cst.Expr + updated_node : cst.Expr + + Returns + ------- + cst.RemovalSentinel + """ return cst.RemovalSentinel.REMOVE def leave_Comment(self, original_node, updated_node): + """Drop comment from stub file. + + Parameters + ---------- + original_node : cst.Comment + updated_node : cst.Comment + + Returns + ------- + cst.RemovalSentinel + """ return cst.RemovalSentinel.REMOVE + def leave_Assign(self, original_node, updated_node): + """Drop value of assign statements from stub files. + + Parameters + ---------- + original_node : cst.Assign + updated_node : cst.Assign + + Returns + ------- + updated_node : cst.Assign + """ + # TODO replace with AnnAssign if possible / figure out assign type? + updated_node = updated_node.with_changes(value=self._body_replacement) + return updated_node + def visit_Module(self, node): + """Add module scope to stack. + + Parameters + ---------- + node : cst.Module + + Returns + ------- + Literal[True] + """ self._scope_stack.append(_Scope(type="module", node=node)) return True def leave_Module(self, original_node, updated_node): + """Add required type imports to module + + Parameters + ---------- + original_node : cst.Module + updated_node : cst.Module + + Returns + ------- + updated_node : cst.Module + """ + current_module = self.inspector.current_source.import_path required_imports = [ - imp - for imp in self._required_imports - if imp.import_path != self.inspector.current_module.import_name + imp for imp in self._required_imports if imp.import_path != current_module ] import_nodes = self._parse_imports( - required_imports, current_module=self.inspector.current_module.import_name + required_imports, current_module=current_module ) updated_node = updated_node.with_changes(body=import_nodes + updated_node.body) self._scope_stack.pop() return updated_node def visit_Lambda(self, node): - # Skip visiting parameters of lambda which can't have an annotation. + """Don't visit parameters fo lambda which can't have an annotation. + + Parameters + ---------- + node : cst.Lambda + + Returns + ------- + Literal[False] + """ return False @staticmethod @@ -353,3 +505,27 @@ def _function_type(self, func_def): func_type = FuncType.STATICMETHOD break return func_type + + def _pytypes_from_func(self, node): + """Extract types from function or class docstrings. + + Parameters + ---------- + node : cst.FunctionDef | cst.ClassDef + + Returns + ------- + pytypes : dict[str, ~._docstrings.PyType] + """ + pytypes = None + docstring = node.get_docstring() + if docstring: + try: + pytypes = collect_pytypes(docstring, inspector=self.inspector) + except Exception as e: + logger.exception( + "error while parsing docstring of `%s`:\n\n%s", + node.name.value, + e, + ) + return pytypes From 5c1e629f96b164ae89aae62d2011432f86358573 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Lars=20Gr=C3=BCter?= Date: Sun, 16 Jun 2024 02:32:19 -0400 Subject: [PATCH 08/16] WIP --- examples/example_pkg-stubs/_basic.pyi | 6 +- src/docstub/_docstrings.py | 39 ++++++++----- src/docstub/_stubs.py | 81 ++++++++++++++++++++------- 3 files changed, 89 insertions(+), 37 deletions(-) diff --git a/examples/example_pkg-stubs/_basic.pyi b/examples/example_pkg-stubs/_basic.pyi index ce8132a..509c8ba 100644 --- a/examples/example_pkg-stubs/_basic.pyi +++ b/examples/example_pkg-stubs/_basic.pyi @@ -5,14 +5,14 @@ from typing import Any, Literal, Self, Union from . import CustomException -logger = logging.getLogger(__name__) +logger = ... __all__ = [ "func_empty", "ExampleClass", ] -def func_empty(a1: Any, a2: Any, a3: Any) -> None: ... +def func_empty(a1, a2, a3) -> None: ... def func_contains( self, a1: list[float], @@ -28,7 +28,7 @@ def func_use_from_elsewhere( ) -> CustomException: ... class ExampleClass: - def __init__(self, a1: str, a2: bool | None = ...) -> None: ... + def __init__(self, a1: str, a2: float | None = ...) -> None: ... def method(self, a1: float, a2: float | None) -> list[float]: ... @staticmethod def some_staticmethod(a1: float, a2: float | None = ...) -> dict[str, Any]: ... diff --git a/src/docstub/_docstrings.py b/src/docstub/_docstrings.py index ca81c56..f218a5b 100644 --- a/src/docstub/_docstrings.py +++ b/src/docstub/_docstrings.py @@ -2,7 +2,6 @@ """ -import enum import logging from dataclasses import dataclass, field from pathlib import Path @@ -305,9 +304,14 @@ def doc2pytype(doctype, *, inspector): ) -class NPDocSection(enum.Enum): - RETURNS = enum.auto() - YIELDS = enum.auto() +@dataclass(frozen=True, slots=True) +class DocstringPyTypes: + """Groups Pytypes in a docstring.""" + + parameters: dict[str, PyType] + attributes: dict[str, PyType] + returns: PyType | None + yields: PyType | None def collect_pytypes(docstring, *, inspector): @@ -321,9 +325,9 @@ def collect_pytypes(docstring, *, inspector): Returns ------- - pytypes : dict[str | NPDocSection, PyType] - The collected PyType for each parameter. If a return type is documented - it's saved under the special key :class:`ReturnKey`. + pytypes : DocstringPyTypes + The collected PyTypes grouped by parameters, attributes, returns, and + yields. """ np_docstring = NumpyDocString(docstring) @@ -335,7 +339,7 @@ def collect_pytypes(docstring, *, inspector): raise ValueError(f"{duplicate_params=}") params.update(other) - pytypes = { + parameters = { name: doc2pytype(param.type, inspector=inspector) for name, param in params.items() if param.type @@ -346,6 +350,8 @@ def collect_pytypes(docstring, *, inspector): for param in np_docstring["Returns"] if param.type ] + returns = PyType.as_return_tuple(returns) if returns else None + yields = [ doc2pytype(param.type, inspector=inspector) for param in np_docstring["Yields"] @@ -356,16 +362,21 @@ def collect_pytypes(docstring, *, inspector): for param in np_docstring["Receives"] if param.type ] + attributes = [ + doc2pytype(param.type, inspector=inspector) + for param in np_docstring["Attributes"] + if param.type + ] if returns and yields: logger.warning( "found 'Returns' and 'Yields' section in docstring, ignoring 'Yields'" ) if receives and not yields: logger.warning("found 'Receives' section in docstring without 'Yields' section") + if yields: + logger.warning("yields is not supported yet") - if returns: - pytypes[NPDocSection.RETURNS] = PyType.as_return_tuple(returns) - elif yields: - logger.error("yields is not supported yet, ignoring") - - return pytypes + ds_pytypes = DocstringPyTypes( + parameters=parameters, attributes=attributes, returns=returns, yields=None + ) + return ds_pytypes diff --git a/src/docstub/_stubs.py b/src/docstub/_stubs.py index b92c65c..97fb50a 100644 --- a/src/docstub/_stubs.py +++ b/src/docstub/_stubs.py @@ -8,8 +8,9 @@ from pathlib import Path import libcst as cst +import libcst.matchers as cstm -from ._docstrings import NPDocSection, collect_pytypes +from ._docstrings import collect_pytypes logger = logging.getLogger(__name__) @@ -208,8 +209,8 @@ def __init__(self, *, inspector): """ self.inspector = inspector # Relevant docstring for the current context - self._scope_stack = None # Store current class or function scope - self._pytypes_stack = None # Store current parameter types + self._scope_stack = None # Entered module, class or function scopes + self._pytypes_stack = None # Collected pytypes for each stack self._required_imports = None # Collect imports for used types self._current_module = None @@ -258,7 +259,7 @@ def visit_ClassDef(self, node): out : Literal[True] """ self._scope_stack.append(_Scope(type=FuncType.CLASS, node=node)) - pytypes = self._pytypes_from_func(node) + pytypes = self._pytypes_from_node(node) self._pytypes_stack.append(pytypes) return True @@ -275,6 +276,7 @@ def leave_ClassDef(self, original_node, updated_node): updated_node : cst.ClassDef """ self._scope_stack.pop() + self._pytypes_stack.pop() return updated_node def visit_FunctionDef(self, node): @@ -290,7 +292,7 @@ def visit_FunctionDef(self, node): """ func_type = self._function_type(node) self._scope_stack.append(_Scope(type=func_type, node=node)) - pytypes = self._pytypes_from_func(node) + pytypes = self._pytypes_from_node(node) self._pytypes_stack.append(pytypes) return True @@ -312,14 +314,12 @@ def leave_FunctionDef(self, original_node, updated_node): } pytypes = self._pytypes_stack.pop() - if pytypes: - return_pytype = pytypes.get(NPDocSection.RETURNS) - if return_pytype: - assert return_pytype.value - node_changes["returns"] = cst.Annotation( - cst.parse_expression(return_pytype.value) - ) - self._required_imports |= return_pytype.imports + if pytypes and pytypes.returns: + assert pytypes.returns.value + node_changes["returns"] = cst.Annotation( + cst.parse_expression(pytypes.returns.value) + ) + self._required_imports |= pytypes.returns.imports updated_node = updated_node.with_changes(**node_changes) self._scope_stack.pop() @@ -350,7 +350,7 @@ def leave_Param(self, original_node, updated_node): if not pytypes and scope.is_class_init: pytypes = self._pytypes_stack[-2] if pytypes: - pytype = pytypes.get(name) + pytype = pytypes.parameters.get(name) if pytype: annotation = cst.Annotation(cst.parse_expression(pytype.value)) node_changes["annotation"] = annotation @@ -409,8 +409,15 @@ def leave_Assign(self, original_node, updated_node): ------- updated_node : cst.Assign """ - # TODO replace with AnnAssign if possible / figure out assign type? - updated_node = updated_node.with_changes(value=self._body_replacement) + targets = cstm.findall(updated_node, cstm.AssignTarget()) + names_are__all__ = [ + name + for target in targets + for name in cstm.findall(target, cst.Name(value="__all__")) + ] + if not names_are__all__: + # TODO replace with AnnAssign if possible / figure out assign type? + updated_node = updated_node.with_changes(value=self._body_replacement) return updated_node def visit_Module(self, node): @@ -424,7 +431,9 @@ def visit_Module(self, node): ------- Literal[True] """ - self._scope_stack.append(_Scope(type="module", node=node)) + self._scope_stack.append(_Scope(type=FuncType.MODULE, node=node)) + pytypes = self._pytypes_from_node(node) + self._pytypes_stack.append(pytypes) return True def leave_Module(self, original_node, updated_node): @@ -448,6 +457,7 @@ def leave_Module(self, original_node, updated_node): ) updated_node = updated_node.with_changes(body=import_nodes + updated_node.body) self._scope_stack.pop() + self._pytypes_stack.pop() return updated_node def visit_Lambda(self, node): @@ -463,6 +473,37 @@ def visit_Lambda(self, node): """ return False + def leave_Decorator(self, original_node, updated_node): + """Drop decorators except for a few out of the SDL. + + Parameters + ---------- + original_node : cst.Decorator + updated_node : cst.Decorator + + Returns + ------- + cst.Decorator | cst.RemovalSentinel + """ + names = cstm.findall(original_node, cstm.Name()) + names = ".".join(name.value for name in names) + + allowlist = ( + "classmethod", + "staticmethod", + "property", + ".setter", + "abstractmethod", + "dataclass", + "coroutine", + ) + out = cst.RemovalSentinel.REMOVE + # TODO add decorators in typing module + for allowed in allowlist: + if allowed in names: + out = updated_node + return out + @staticmethod def _parse_imports(imports, *, current_module=None): """Create nodes to include in the module tree from given imports. @@ -506,12 +547,12 @@ def _function_type(self, func_def): break return func_type - def _pytypes_from_func(self, node): - """Extract types from function or class docstrings. + def _pytypes_from_node(self, node): + """Extract types from function, class or module docstrings. Parameters ---------- - node : cst.FunctionDef | cst.ClassDef + node : cst.FunctionDef | cst.ClassDef | cst.Module Returns ------- From 3908f3fcae381ba70bc9c3efd71daa77ea6a71ed Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Lars=20Gr=C3=BCter?= Date: Sun, 23 Jun 2024 09:26:45 +0200 Subject: [PATCH 09/16] Remove container_of and support "of {key: value}" syntax This syntax is used by Pandas [1]. I'm not entirely sold yet, but let's add it for now. To avoid confusion, literals should probably be made to only be accepted on the top-level. Otherwise something like `dict of {{"a", "b"}: int}` becomes possible. Co-authored-by: Oriol Abril-Pla --- examples/example_pkg-stubs/_basic.pyi | 6 ++++-- examples/example_pkg/_basic.py | 10 ++++------ src/docstub/_docstrings.py | 8 -------- src/docstub/doctype.lark | 12 ++++-------- 4 files changed, 12 insertions(+), 24 deletions(-) diff --git a/examples/example_pkg-stubs/_basic.pyi b/examples/example_pkg-stubs/_basic.pyi index 509c8ba..a2fd34d 100644 --- a/examples/example_pkg-stubs/_basic.pyi +++ b/examples/example_pkg-stubs/_basic.pyi @@ -14,12 +14,14 @@ __all__ = [ def func_empty(a1, a2, a3) -> None: ... def func_contains( - self, a1: list[float], a2: dict[str, Union[int, str]], a3: Sequence[int | float], a4: frozenset[bytes], -) -> tuple[tuple[int, ...], list[int]]: ... + a5: tuple[int], + a6: list[int, str], + a7: dict[str, int], +) -> None: ... def func_literals( a1: Literal[1, 3, "foo"], a2: Literal["uno", 2, "drei", "four"] = ... ) -> None: ... diff --git a/examples/example_pkg/_basic.py b/examples/example_pkg/_basic.py index 90fa25f..659cf9e 100644 --- a/examples/example_pkg/_basic.py +++ b/examples/example_pkg/_basic.py @@ -26,7 +26,7 @@ def func_empty(a1, a2, a3): """ -def func_contains(self, a1, a2, a3, a4): +def func_contains(a1, a2, a3, a4, a5, a6, a7): """Dummy. Parameters @@ -35,11 +35,9 @@ def func_contains(self, a1, a2, a3, a4): a2 : dict[str, Union[int, str]] a3 : Sequence[int | float] a4 : frozenset[bytes] - - Returns - ------- - r1 : tuple of int - r2 : list of int + a5 : tuple of int + a6 : list of (int, str) + a7 : dict of {str: int} """ diff --git a/src/docstub/_docstrings.py b/src/docstub/_docstrings.py index f218a5b..c859188 100644 --- a/src/docstub/_docstrings.py +++ b/src/docstub/_docstrings.py @@ -239,14 +239,6 @@ def shape_n_dtype(self, tree): name = f"{name}[{', '.join(children)}]" return name - def container_of(self, tree): - assert len(tree.children) == 2 - container_name, item_type = tree.children - if container_name == "tuple": - item_type += ", ..." - out = f"{container_name}[{item_type}]" - return out - def contains(self, tree): out = ", ".join(tree.children) out = f"[{out}]" diff --git a/src/docstub/doctype.lark b/src/docstub/doctype.lark index 209d18e..b51c849 100644 --- a/src/docstub/doctype.lark +++ b/src/docstub/doctype.lark @@ -7,7 +7,6 @@ type_or : type (("or" | "|") type)* ?type : qualname | sphinx_ref | "{" literal ("," literal)* "}" -> literals - | container_of | shape_n_dtype optional : "optional" @@ -20,14 +19,11 @@ sphinx_ref : ":" (NAME ":")? NAME ":`" qualname "`" // Name with leading dot separated path qualname : (/~/ ".")? (NAME ".")* NAME contains? - contains: "[" type_or ("," type_or)* "]" | "[" type_or "," PY_ELLIPSES "]" - - -// Container-of -container_of : NAME "of" type_or - + | "of" type + | "of" "(" type_or ("," type_or)* ")" + | "of" "{" type_or ":" type_or "}" // Array-like form with dtype or shape information shape_n_dtype : shape? ARRAY_NAME ("of" dtype)? @@ -40,7 +36,7 @@ ARRAY_NAME : "array" | "ndarray" | "array-like" | "array_like" -dtype : NAME +dtype : qualname shape : "(" dim ",)" | "(" leading_optional? dim (("," dim | insert_optional))* ")" | NUMBER "-"? "D" From 5a288281acd715ce7c1d92acb4ba935fd928d57e Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Lars=20Gr=C3=BCter?= Date: Sun, 23 Jun 2024 12:03:00 +0200 Subject: [PATCH 10/16] Only allow literals on top-level which avoids potentially confusing constructs like `dict of {{"a", "b"}: int}` Co-authored-by: Oriol Abril-Pla --- src/docstub/_docstrings.py | 8 ++++---- src/docstub/doctype.lark | 17 +++++++++-------- 2 files changed, 13 insertions(+), 12 deletions(-) diff --git a/src/docstub/_docstrings.py b/src/docstub/_docstrings.py index c859188..1b1e8a8 100644 --- a/src/docstub/_docstrings.py +++ b/src/docstub/_docstrings.py @@ -173,11 +173,11 @@ def transform(self, tree): finally: self._collected_imports = None - def doctype(self, tree): + def annotation(self, tree): out = " | ".join(tree.children) return out - def type_or(self, tree): + def types_or(self, tree): out = " | ".join(tree.children) return out @@ -185,8 +185,8 @@ def optional(self, tree): out = "None" literal = [child for child in tree.children if child.type == "LITERAL"] assert len(literal) <= 1 - if len(literal) == 1: - out = lark.Discard # Should be covered by doctype + if literal: + out = lark.Discard # Type should cover the default return out def extra_info(self, tree): diff --git a/src/docstub/doctype.lark b/src/docstub/doctype.lark index b51c849..154741e 100644 --- a/src/docstub/doctype.lark +++ b/src/docstub/doctype.lark @@ -1,12 +1,13 @@ -?start : doctype +?start : annotation -doctype : type_or ("," optional)? ("," extra_info)? +annotation : (literals | types_or) ("," optional)? ("," extra_info)? -type_or : type (("or" | "|") type)* +literals : "{" literal ("," literal)* "}" + +types_or : type (("or" | "|") type)* ?type : qualname | sphinx_ref - | "{" literal ("," literal)* "}" -> literals | shape_n_dtype optional : "optional" @@ -19,11 +20,11 @@ sphinx_ref : ":" (NAME ":")? NAME ":`" qualname "`" // Name with leading dot separated path qualname : (/~/ ".")? (NAME ".")* NAME contains? -contains: "[" type_or ("," type_or)* "]" - | "[" type_or "," PY_ELLIPSES "]" +contains: "[" types_or ("," types_or)* "]" + | "[" types_or "," PY_ELLIPSES "]" | "of" type - | "of" "(" type_or ("," type_or)* ")" - | "of" "{" type_or ":" type_or "}" + | "of" "(" types_or ("," types_or)* ")" + | "of" "{" types_or ":" types_or "}" // Array-like form with dtype or shape information shape_n_dtype : shape? ARRAY_NAME ("of" dtype)? From deefbd9fd8caf9e70461436e6b9f159e7b556087 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Lars=20Gr=C3=BCter?= Date: Sun, 23 Jun 2024 16:51:36 +0200 Subject: [PATCH 11/16] Make "= | :" optional in default syntax since this is what NumPyDoc recommends as well [1]. [1] https://numpydoc.readthedocs.io/en/latest/format.html#parameters Co-authored-by: Oriol Abril-Pla --- examples/example_pkg-stubs/_basic.pyi | 2 +- examples/example_pkg/_basic.py | 4 ++-- src/docstub/doctype.lark | 2 +- 3 files changed, 4 insertions(+), 4 deletions(-) diff --git a/examples/example_pkg-stubs/_basic.pyi b/examples/example_pkg-stubs/_basic.pyi index a2fd34d..0de6e89 100644 --- a/examples/example_pkg-stubs/_basic.pyi +++ b/examples/example_pkg-stubs/_basic.pyi @@ -30,7 +30,7 @@ def func_use_from_elsewhere( ) -> CustomException: ... class ExampleClass: - def __init__(self, a1: str, a2: float | None = ...) -> None: ... + def __init__(self, a1: str, a2: float = ...) -> None: ... def method(self, a1: float, a2: float | None) -> list[float]: ... @staticmethod def some_staticmethod(a1: float, a2: float | None = ...) -> dict[str, Any]: ... diff --git a/examples/example_pkg/_basic.py b/examples/example_pkg/_basic.py index 659cf9e..1004bb8 100644 --- a/examples/example_pkg/_basic.py +++ b/examples/example_pkg/_basic.py @@ -71,10 +71,10 @@ class ExampleClass: Parameters ---------- a1 : str - a2 : float, optional + a2 : float, default 0 """ - def __init__(self, a1, a2=None): + def __init__(self, a1, a2=0): pass def method(self, a1, a2): diff --git a/src/docstub/doctype.lark b/src/docstub/doctype.lark index 154741e..8ae49ba 100644 --- a/src/docstub/doctype.lark +++ b/src/docstub/doctype.lark @@ -11,7 +11,7 @@ types_or : type (("or" | "|") type)* | shape_n_dtype optional : "optional" - | "default" ("=" | ":") literal + | "default" ("=" | ":")? literal extra_info : /[^\r\n]+/ From ad1e9e240872e2e0dc4ee836f1c6c720779c7a3f Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Lars=20Gr=C3=BCter?= Date: Thu, 22 Aug 2024 11:06:46 +0200 Subject: [PATCH 12/16] Add ipython as dev dependency --- pyproject.toml | 1 + 1 file changed, 1 insertion(+) diff --git a/pyproject.toml b/pyproject.toml index 754f815..e258613 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -40,6 +40,7 @@ optional = [ ] dev = [ "pre-commit >=3.7", + "ipython", ] test = [ "pytest >=5.0.0", From 587400ca0797d7921b422274b6e81019ee22d77a Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Lars=20Gr=C3=BCter?= Date: Thu, 22 Aug 2024 11:50:52 +0200 Subject: [PATCH 13/16] Move matching to imports to qualname level Doing this on the NAME level meant that stuff like "np.int16" would be turned into "np.np.int16" because both "np" and "int16" where matched. --- examples/example_pkg-stubs/_numpy.pyi | 2 +- examples/example_pkg/_numpy.py | 2 +- src/docstub/_docstrings.py | 22 ++++++++++++---------- 3 files changed, 14 insertions(+), 12 deletions(-) diff --git a/examples/example_pkg-stubs/_numpy.pyi b/examples/example_pkg-stubs/_numpy.pyi index be3c721..ec1a695 100644 --- a/examples/example_pkg-stubs/_numpy.pyi +++ b/examples/example_pkg-stubs/_numpy.pyi @@ -2,7 +2,7 @@ import numpy as np from numpy.typing import ArrayLike, NDArray def func_object_with_numpy_objects( - a1: np.np.int8, a2: np.np.int16, a3: np.typing.DTypeLike, a4: np.typing.DTypeLike + a1: np.int8, a2: np.int16, a3: np.typing.DTypeLike, a4: np.typing.DTypeLike ) -> None: ... def func_ndarray( a1: NDArray, a2: np.NDArray, a3: NDArray[float], a4: NDArray[np.uint8] = ... diff --git a/examples/example_pkg/_numpy.py b/examples/example_pkg/_numpy.py index 524af84..33fa9ca 100644 --- a/examples/example_pkg/_numpy.py +++ b/examples/example_pkg/_numpy.py @@ -19,7 +19,7 @@ def func_ndarray(a1, a2, a3, a4=None): Parameters ---------- a1 : ndarray - a2 : np.ndarray + a2 : np.NDArray a3 : (N, 3) ndarray of float a4 : ndarray of shape (1,) and dtype uint8 diff --git a/src/docstub/_docstrings.py b/src/docstub/_docstrings.py index 1b1e8a8..c7465c7 100644 --- a/src/docstub/_docstrings.py +++ b/src/docstub/_docstrings.py @@ -198,14 +198,19 @@ def sphinx_ref(self, tree): return qualname def qualname(self, tree): - matched = False - out = [] - for i, child in enumerate(tree.children): - if i != 0 and not child.startswith("["): + children = tree.children + + # Try to match only first child to known imports + children[0] = self._match_n_record_name(children[0]) + matched = isinstance(children[0], MatchedName) + + # Insert dots except for containers (when child starts with "[") + out = [children[0]] + for child in children[1:]: + if not child.startswith("["): out.append(".") - if isinstance(child, MatchedName): - matched = True out.append(child) + out = "".join(out) if matched is False: docname = self.inspector.query(out) @@ -216,13 +221,10 @@ def qualname(self, tree): logger.warning( "unmatched name %r in %s", out, self.inspector.current_source ) + out = lark.Token("QUALNAME", out) return out - def NAME(self, token): - new_token = self._match_n_record_name(token) - return new_token - def ARRAY_NAME(self, token): new_token = self._match_n_record_name(token) new_token.type = "ARRAY_NAME" From 748b69338defbc7618d5492b5e1385ca20f5716f Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Lars=20Gr=C3=BCter?= Date: Sun, 25 Aug 2024 11:20:06 +0200 Subject: [PATCH 14/16] Split docname into KnownImport and replace Reduce responsibility of the former DocName class. Replacing docstring specific type description should be handled separately. --- src/docstub/_analysis.py | 197 ++++++++++++++++---------------- src/docstub/_cli.py | 42 ++++--- src/docstub/_config.py | 70 ++++++++---- src/docstub/default_config.toml | 148 ++++++++++++------------ 4 files changed, 248 insertions(+), 209 deletions(-) diff --git a/src/docstub/_analysis.py b/src/docstub/_analysis.py index aa80f3b..f911c98 100644 --- a/src/docstub/_analysis.py +++ b/src/docstub/_analysis.py @@ -38,68 +38,71 @@ def _shared_leading_path(*paths): @dataclass(slots=True, frozen=True) -class DocName: - """An atomic name (without ".") in a docstring type with import info.""" +class KnownImport: + """Import information associated with a single known type annotation. - use_name: str + Parameters + ---------- + annotation_name : str + Name of the type annotation + import_name : + Dotted names after "import". + import_path : + Dotted names after "from". + import_alias : + Name (without ".") after "as". + """ + + annotation_name: str import_name: str = None import_path: str = None import_alias: str = None is_builtin: bool = False @classmethod - def one_from_config(cls, docname, *, info): - """Create one DocName from the configuration format. + def one_from_config(cls, name, *, info): + """Create one KnownImport from the configuration format. Parameters ---------- - docname : str - info : dict[{"use", "from", "import", "as", "is_builtin"}, str] + name : str + info : dict[{"from", "import", "as", "is_builtin"}, str] Returns ------- - docname : Self + TypeImport : Self """ - use_name = docname - if "import" in info: - use_name = info["import"] - if "as" in info: - use_name = info["as"] - if "use" in info: - use_name = info["use"] - - import_name = docname - if "use" in info: - import_name = info["use"] + assert not (info.keys() - {"from", "import", "as", "is_builtin"}) + + import_name = name if "import" in info: import_name = info["import"] - docname = cls( - use_name=use_name, + known_import = cls( + annotation_name=name, import_name=import_name, import_path=info.get("from"), import_alias=info.get("as"), - is_builtin=info.get("builtin", False), + is_builtin=info.get("is_builtin", False), ) - return docname + return known_import @classmethod def many_from_config(cls, mapping): - """Create many DocNames from the configuration format. + """Create many KnownImports from the configuration format. Parameters ---------- - mapping : dict[str, dict[{"use", "from", "import", "as", "is_builtin"}, str]] + mapping : dict[str, dict[{"from", "import", "as", "is_builtin"}, str]] Returns ------- - docnames : dict[str, Self] + known_imports : dict[str, Self] """ - docnames = { - docname: cls.one_from_config(docname, info=info) - for docname, info in mapping.items() + known_imports = { + name: cls.one_from_config(name, info=info) for name, info in mapping.items() } - return docnames + return known_imports def format_import(self, relative_to=None): if self.is_builtin: @@ -125,19 +128,20 @@ def format_import(self, relative_to=None): def has_import(self): return not self.is_builtin - def __repr__(self): - classname = type(self).__name__ - if self.has_import: - info = f"{self.import_name}" - if self.import_path: - info = f"{self.import_path}.{info}" - if self.import_alias: - info = f"{info} as {self.import_alias}" - if self.use_name not in info: - info = f"{info}; {self.use_name}" - else: - info = f"{self.use_name} (builtin)" - return f"{classname}: {info}" + def __post_init__(self): + if "." in self.annotation_name: + raise ValueError("'.' in the annotation name aren't yet supported") + + if self.import_alias and self.import_alias != self.annotation_name: + raise ValueError( + f"annotation name must match given import alias: " + f"{self.annotation_name} != {self.import_alias}" + ) + elif self.import_name != self.annotation_name: + raise ValueError( + f"annotation name must match import name if no alias is given: " + f"{self.annotation_name} != {self.import_name}" + ) @dataclass(slots=True, frozen=True) @@ -157,81 +161,82 @@ def _is_type(value) -> bool: return is_type -def _builtin_docnames(): - """Return docnames for all builtins (in the current runtime). +def _builtin_imports(): + """Return known imports for all builtins (in the current runtime). Returns ------- - docnames : dict[str, DocName] + known_imports : dict[str, KnownImport] """ known_builtins = set(dir(builtins)) - docnames = {} + known_imports = {} for name in known_builtins: if name.startswith("_"): continue value = getattr(builtins, name) if not _is_type(value): continue - docnames[name] = DocName(use_name=name, is_builtin=True) + known_imports[name] = KnownImport(annotation_name=name, is_builtin=True) - return docnames + return known_imports -def _typing_docnames(): - """Return docnames for public types in the `typing` module. +def _typing_imports(): + """Return known imports for public types in the `typing` module. Returns ------- - docnames : dict[str, DocName] + known_imports : dict[str, KnownImport] """ - docnames = {} + known_imports = {} for name in typing.__all__: if name.startswith("_"): continue value = getattr(typing, name) if not _is_type(value): continue - docnames[name] = DocName.one_from_config(name, info={"from": "typing"}) - return docnames + known_imports[name] = KnownImport.one_from_config(name, info={"from": "typing"}) + return known_imports -def _collections_abc_docnames(): - """Return docnames for public types in the `collections.abc` module. +def _collections_abc_imports(): + """Return known imports for public types in the `collections.abc` module. Returns ------- - docnames : dict[str, DocName] + known_imports : dict[str, KnownImport] """ - docnames = {} + known_imports = {} for name in collections.abc.__all__: if name.startswith("_"): continue value = getattr(collections.abc, name) if not _is_type(value): continue - docnames[name] = DocName.one_from_config(name, info={"from": "collections.abc"}) - return docnames + known_imports[name] = KnownImport.one_from_config( + name, info={"from": "collections.abc"} + ) + return known_imports -def common_docnames(): - """Return docnames for commonly supported types. +def common_known_imports(): + """Return known imports for commonly supported types. This includes builtin types, and types from the `typing` or `collections.abc` module. Returns ------- - docnames : dict[str, DocName] + known_imports : dict[str, KnownImport] """ - docnames = _builtin_docnames() - docnames |= _typing_docnames() - docnames |= _collections_abc_docnames() # Overrides containers from typing - return docnames - + known_imports = _builtin_imports() + known_imports |= _typing_imports() + known_imports |= _collections_abc_imports() # Overrides containers from typing + return known_imports -class DocNameCollector(cst.CSTVisitor): +class KnownImportCollector(cst.CSTVisitor): @classmethod def collect(cls, file, module_name): file = Path(file) @@ -241,22 +246,22 @@ def collect(cls, file, module_name): tree = cst.parse_module(source) collector = cls(module_name=module_name) tree.visit(collector) - return collector.docnames + return collector.known_imports def __init__(self, *, module_name): self.module_name = module_name self._stack = [] - self.docnames = {} + self.known_imports = {} def visit_ClassDef(self, node): self._stack.append(node.name.value) use_name = ".".join(self._stack[:1]) qualname = f"{self.module_name}.{'.'.join(self._stack)}" - docname = DocName( + known_import = KnownImport( use_name=use_name, import_name=use_name, import_path=self.module_name ) - self.docnames[qualname] = docname + self.known_imports[qualname] = known_import return True @@ -280,8 +285,8 @@ class StaticInspector: Examples -------- - >>> from docstub._analysis import StaticInspector, common_docnames - >>> inspector = StaticInspector(docnames=common_docnames()) + >>> from docstub._analysis import StaticInspector, common_known_imports + >>> inspector = StaticInspector(known_imports=common_known_imports()) >>> inspector.query("Any") """ @@ -289,22 +294,22 @@ def __init__( self, *, source_pkgs=None, - docnames=None, + known_imports=None, ): """ Parameters ---------- source_pkgs: list[Path], optional - docnames: dict[str, DocName], optional + known_imports: dict[str, KnownImport], optional """ if source_pkgs is None: source_pkgs = [] - if docnames is None: - docnames = {} + if known_imports is None: + known_imports = {} self.current_source = None self.source_pkgs = source_pkgs - self._inspected = {"initial": docnames} + self._inspected = {"initial": known_imports} @staticmethod def _accumulate_module_name(qualname): @@ -327,7 +332,7 @@ def _find_modules(self, qualname): yield file, module_name def inspect_module(self, file, module_name): - """Collect docnames from the given file. + """Collect known imports from the given file. Parameters ---------- @@ -335,15 +340,15 @@ def inspect_module(self, file, module_name): Returns ------- - collected : set[DocName] + collected : set[KnownImport] """ if file in self._inspected: return self._inspected[file] - docnames = DocNameCollector.collect(file, module_name) - self._inspected[file] = docnames - self.docnames.update(docnames) - return docnames + known_imports = KnownImportCollector.collect(file, module_name) + self._inspected[file] = known_imports + self.known_imports.update(known_imports) + return known_imports def query(self, qualname): """ @@ -353,9 +358,9 @@ def query(self, qualname): Returns ------- - out : DocName | None + out : KnownImport | None """ - out = self.docnames.get(qualname) + out = self.known_imports.get(qualname) *prefix, name = qualname.split(".") if not out and "~" in prefix: @@ -364,7 +369,7 @@ def query(self, qualname): pattern = re.compile(pattern + "$") matches = { key: value - for key, value in self.docnames.items() + for key, value in self.known_imports.items() if re.match(pattern, key) } if len(matches) > 1: @@ -381,18 +386,18 @@ def query(self, qualname): elif not out and self.current_source: try_qualname = f"{self.current_source.import_path}.{qualname}" - out = self.docnames.get(try_qualname) + out = self.known_imports.get(try_qualname) return out @property - def docnames(self): - current_docnames = {} + def known_imports(self): + current_known_imports = {} - for _, docnames in self._inspected.items(): - current_docnames.update(docnames) + for _, known_imports in self._inspected.items(): + current_known_imports.update(known_imports) - return current_docnames + return current_known_imports def __repr__(self): repr = f"{type(self).__name__}({self.source_pkgs})" diff --git a/src/docstub/_cli.py b/src/docstub/_cli.py index bc21cd7..ea29373 100644 --- a/src/docstub/_cli.py +++ b/src/docstub/_cli.py @@ -4,8 +4,13 @@ import click -from . import _config -from ._analysis import DocName, DocNameCollector, StaticInspector, common_docnames +from ._config import Config +from ._analysis import ( + KnownImport, + KnownImportCollector, + StaticInspector, + common_known_imports, +) from ._stubs import Py2StubTransformer, walk_source, walk_source_and_targets from ._version import __version__ @@ -27,22 +32,25 @@ def _find_configuration(source_dir, config_path): ------- config : dict[str, Any] """ - # Handle configuration - config = _config.default_config() + config = Config.from_toml(Config.DEFAULT_CONFIG_PATH) + pyproject_toml = source_dir.parent / "pyproject.toml" - docstub_toml = source_dir.parent / "docstub.toml" if pyproject_toml.is_file(): logger.info("using %s", pyproject_toml) - add_config = _config.load_config_file(pyproject_toml) - config = _config.merge_config(config, add_config) + add_config = Config.from_toml(pyproject_toml) + config = config.merge(add_config) + + docstub_toml = source_dir.parent / "docstub.toml" if docstub_toml.is_file(): logger.info("using %s", docstub_toml) - add_config = _config.load_config_file(docstub_toml) - config = _config.merge_config(config, add_config) + add_config = Config.from_toml(docstub_toml) + config = config.merge(add_config) + if config_path: logger.info("using %s", config_path) - add_config = _config.load_config_file(config_path) - config = _config.merge_config(config, add_config) + add_config = Config.from_toml(config_path) + config = config.merge(add_config) + return config @@ -64,18 +72,18 @@ def main(source_dir, out_dir, config_path, verbose): source_dir = Path(source_dir) config = _find_configuration(source_dir, config_path) - # Build docname map - docnames = common_docnames() + # Build map of known imports + known_imports = common_known_imports() for source_path in walk_source(source_dir): logger.info("collecting types in %s", source_path) - docnames_in_source = DocNameCollector.collect( + known_imports_in_source = KnownImportCollector.collect( source_path, module_name=source_path.import_path ) - docnames.update(docnames_in_source) - docnames.update(DocName.many_from_config(config["docnames"])) + known_imports.update(known_imports_in_source) + known_imports.update(KnownImport.many_from_config(config["known_imports"])) inspector = StaticInspector( - source_pkgs=[source_dir.parent.resolve()], docnames=docnames + source_pkgs=[source_dir.parent.resolve()], known_imports=known_imports ) # and the stub transformer stub_transformer = Py2StubTransformer(inspector=inspector) diff --git a/src/docstub/_config.py b/src/docstub/_config.py index 8f790ed..676c6df 100644 --- a/src/docstub/_config.py +++ b/src/docstub/_config.py @@ -1,5 +1,7 @@ import logging +import dataclasses from pathlib import Path +from typing import ClassVar try: import tomllib @@ -10,30 +12,56 @@ logger = logging.getLogger(__name__) -DEFAULT_CONFIG_PATH = Path(__file__).parent / "default_config.toml" +@dataclasses.dataclass(frozen=True, slots=True, kw_only=True) +class Config: + DEFAULT_CONFIG_PATH: ClassVar[Path] = Path(__file__).parent / "default_config.toml" + extend_grammar: str = "" + known_imports: dict[str, dict[str, str]] = dataclasses.field(default_factory=dict) + replace: dict[str, str] = dataclasses.field(default_factory=dict) -def load_config_file(path: Path | str) -> dict: - """Return configuration options in local TOML file if they exist.""" - with open(path, "rb") as fp: - config = tomllib.load(fp) - config = config.get("tool", {}).get("docstub", {}) - return config + _source: tuple[Path, ...] = tuple() + @classmethod + def from_toml(cls, path: Path | str) -> "Config": + """Return configuration options in local TOML file if they exist.""" + path = Path(path) + with open(path, "rb") as fp: + raw = tomllib.load(fp) + config = cls(**raw.get("tool", {}).get("docstub", {}), _source=(path,)) + logger.debug("created Config from %s", path) + return config -def default_config(): - config = load_config_file(DEFAULT_CONFIG_PATH) - return config + @classmethod + def from_default(cls): + config = cls.from_toml(cls.DEFAULT_CONFIG_PATH) + return config + def merge(self, other): + """Merge contents with other and return a new Config instance.""" + if not isinstance(other, type(self)): + return NotImplemented + new = Config( + extend_grammar=self.extend_grammar + other.extend_grammar, + known_imports=self.known_imports | other.known_imports, + replace=self.replace | other.replace, + _source=self._source + other._source, + ) + logger.debug("merged Config from %s", new._source) + return new -def merge_config(*configurations): - merged = {} - merged["extended_grammar"] = "\n".join( - cfg.get("extended_grammar", "") for cfg in configurations - ) - merged["docnames"] = {} - for cfg in configurations: - docnames = cfg.get("docnames") - if docnames and isinstance(docnames, dict): - merged["docnames"].update(docnames) - return merged + def to_dict(self): + return dataclasses.asdict(self) + + def __post_init__(self): + if not isinstance(self.extend_grammar, str): + raise TypeError("extended_grammar must be a string") + if not isinstance(self.known_imports, dict): + raise TypeError("known_imports must be a dict") + if not isinstance(self.replace, dict): + raise TypeError("replace must be a string") + + def __repr__(self): + sources = " | ".join(str(s) for s in self._source) + formatted = f"<{type(self).__name__}: {sources}>" + return formatted diff --git a/src/docstub/default_config.toml b/src/docstub/default_config.toml index 1391b2a..95e3023 100644 --- a/src/docstub/default_config.toml +++ b/src/docstub/default_config.toml @@ -5,85 +5,83 @@ extend_grammar = """ """ -# A mapping of docnames to import information. Each item maps a docname on the -# left side to a dictionary on the right side, which supports the following -# fields: -# use : A string to replace the docname with, defaults to the docname. -# from : Indicate that the docname can be imported from this path. -# import : Import this object, defaults to the docname. +# Import information for type annotations, declared ahead of time. +# +# Each item maps an annotation name on the left side to a dictionary on the +# right side. +# +# Import information can be declared with the following fields: +# from : Indicate that the DocType can be imported from this path. +# import : Import this object, defaults to the DocType. # as : Use this alias for the imported object -# is_builtin : Indicate that this docname doesn't need to be imported, +# is_builtin : Indicate that this DocType doesn't need to be imported, # defaults to "false" -[tool.docstub.docnames] - +[tool.docstub.known_imports] Path = { from = "pathlib" } - -function = { use = "Callable", from = "typing" } -func = { use = "Callable", from = "typing" } -callable = { use = "Callable", from = "typing" } - +Callable = { from = "typing" } np = { import = "numpy", as = "np" } -numpy = { use = "np", import = "numpy", as = "np" } - -scalar = { use = "np.ScalarType", import = "numpy", as = "np" } - -integer = { use = "np.integer", import = "numpy", as = "np" } -signedinteger = { use = "np.signedinteger", import = "numpy", as = "np" } -byte = { use = "np.byte", import = "numpy", as = "np" } -short = { use = "np.short", import = "numpy", as = "np" } -intc = { use = "np.intc", import = "numpy", as = "np" } -int_ = { use = "np.int_", import = "numpy", as = "np" } -longlong = { use = "np.longlong", import = "numpy", as = "np" } -int8 = { use = "np.int8", import = "numpy", as = "np" } -int16 = { use = "np.int16", import = "numpy", as = "np" } -int32 = { use = "np.int32", import = "numpy", as = "np" } -int64 = { use = "np.int64", import = "numpy", as = "np" } -intp = { use = "np.intp", import = "numpy", as = "np" } - -unsignedinteger = { use = "np.unsignedinteger", import = "numpy", as = "np" } -ushort = { use = "np.ushort", import = "numpy", as = "np" } -uintc = { use = "np.uintc", import = "numpy", as = "np" } -uint = { use = "np.uint", import = "numpy", as = "np" } -ulonglong = { use = "np.ulonglong", import = "numpy", as = "np" } -uint8 = { use = "np.uint8", import = "numpy", as = "np" } -uint16 = { use = "np.uint16", import = "numpy", as = "np" } -uint32 = { use = "np.uint32", import = "numpy", as = "np" } -uint64 = { use = "np.uint64", import = "numpy", as = "np" } -uintp = { use = "np.uintp", import = "numpy", as = "np" } - -floating = { use = "np.floating", import = "numpy", as = "np" } -#half = { use = "np.half", import = "numpy", as = "np" } -#single = { use = "np.single", import = "numpy", as = "np" } -double = { use = "np.double", import = "numpy", as = "np" } -longdouble = { use = "np.longdouble", import = "numpy", as = "np" } -float16 = { use = "np.float16", import = "numpy", as = "np" } -float32 = { use = "np.float32", import = "numpy", as = "np" } -float64 = { use = "np.float64", import = "numpy", as = "np" } -float96 = { use = "np.float96", import = "numpy", as = "np" } -float128 = { use = "np.float128", import = "numpy", as = "np" } +NDArray = { from = "numpy.typing" } +ArrayLike = { from = "numpy.typing" } -complexfloating = { use = "np.complexfloating", import = "numpy", as = "np" } -csingle = { use = "np.csingle", import = "numpy", as = "np" } -cdouble = { use = "np.cdouble", import = "numpy", as = "np" } -clongdouble = { use = "np.clongdouble", import = "numpy", as = "np" } -complex64 = { use = "np.complex64", import = "numpy", as = "np" } -complex128 = { use = "np.complex128", import = "numpy", as = "np" } -complex192 = { use = "np.complex192", import = "numpy", as = "np" } -complex256 = { use = "np.complex256", import = "numpy", as = "np" } -bool_ = { use = "np.bool_", import = "numpy", as = "np" } -datetime64 = { use = "np.datetime64", import = "numpy", as = "np" } -timedelta64 = { use = "np.timedelta64", import = "numpy", as = "np" } -object_ = { use = "np.object_", import = "numpy", as = "np" } -#flexible = { use = "np.flexible", import = "numpy", as = "np" } -#character = { use = "np.character", import = "numpy", as = "np" } -bytes_ = { use = "np.bytes_", import = "numpy", as = "np" } -#str_ = { use = "np.str_", import = "numpy", as = "np" } -#void = { use = "np.void", import = "numpy", as = "np" } +# Replace human-readable expressions with actual types. +[tool.docstub.replace] +function = "Callable" +func = "Callable" +callable = "Callable" -NDArray = { from = "numpy.typing" } -ndarray = { use = "NDArray", from = "numpy.typing" } -array = { use = "NDArray", from = "numpy.typing" } -ArrayLike = { from = "numpy.typing" } -array-like = { use = "ArrayLike", from = "numpy.typing" } -array_like = { use = "ArrayLike", from = "numpy.typing" } +numpy = "np" +scalar = "np.ScalarType" +integer = "np.integer" +signedinteger = "np.signedinteger" +byte ="np.byte" +short = "np.short" +intc = "np.intc" +int_ = "np.int_" +longlong = "np.longlong" +int8 = "np.int8" +int16 = "np.int16" +int32 = "np.int32" +int64 = "np.int64" +intp = "np.intp" +unsignedinteger = "np.unsignedinteger" +ushort = "np.ushort" +uintc = "np.uintc" +uint = "np.uint" +ulonglong ="np.ulonglong" +uint8 = "np.uint8" +uint16 = "np.uint16" +uint32 = "np.uint32" +uint64 = "np.uint64" +uintp = "np.uintp" +floating = "np.floating" +#half = "np.half" +#single = "np.single" +double = "np.double" +longdouble = "np.longdouble" +float16 = "np.float16" +float32 = "np.float32" +float64 = "np.float64" +float96 = "np.float96" +float128 = "np.float128" +complexfloating = "np.complexfloating" +csingle = "np.csingle" +cdouble = "np.cdouble" +clongdouble = "np.clongdouble" +complex64 = "np.complex64" +complex128 = "np.complex128" +complex192 = "np.complex192" +complex256 = "np.complex256" +bool_ = "np.bool_" +datetime64 = "np.datetime64" +timedelta64 = "np.timedelta64" +object_ = "np.object_" +#flexible = "np.flexible" +#character = "np.character" +bytes_ = "np.bytes_" +#str_ = "np.str_" +#void = "np.void" +ndarray = "NDArray" +array = "NDArray" +array-like = "ArrayLike" +array_like = "ArrayLike" From a981d2fa1a9612e26b8674222229d3eccb65bd06 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Lars=20Gr=C3=BCter?= Date: Wed, 4 Sep 2024 12:26:19 +0200 Subject: [PATCH 15/16] Test correctness when nesting classes in classes --- examples/example_pkg-stubs/_basic.pyi | 12 +++++++++--- examples/example_pkg/_basic.py | 17 +++++++++++++++-- 2 files changed, 24 insertions(+), 5 deletions(-) diff --git a/examples/example_pkg-stubs/_basic.pyi b/examples/example_pkg-stubs/_basic.pyi index 0de6e89..58801b5 100644 --- a/examples/example_pkg-stubs/_basic.pyi +++ b/examples/example_pkg-stubs/_basic.pyi @@ -26,10 +26,16 @@ def func_literals( a1: Literal[1, 3, "foo"], a2: Literal["uno", 2, "drei", "four"] = ... ) -> None: ... def func_use_from_elsewhere( - a1: CustomException, a2: ExampleClass -) -> CustomException: ... + a1: CustomException, + a2: ExampleClass, + a3: CustomException.NestedClass, + a4: ExampleClass.NestedClass, +) -> tuple[CustomException, ExampleClass.NestedClass]: ... class ExampleClass: + class NestedClass: + def method_in_nested_class(self, a1: complex) -> None: ... + def __init__(self, a1: str, a2: float = ...) -> None: ... def method(self, a1: float, a2: float | None) -> list[float]: ... @staticmethod @@ -40,5 +46,5 @@ class ExampleClass: def some_property(self, value: str) -> None: ... @classmethod def method_returning_cls(cls, config: configparser.ConfigParser) -> Self: ... - @classmethod() + @classmethod def method_returning_cls2(cls, config: configparser.ConfigParser) -> Self: ... diff --git a/examples/example_pkg/_basic.py b/examples/example_pkg/_basic.py index 1004bb8..cdaa5a2 100644 --- a/examples/example_pkg/_basic.py +++ b/examples/example_pkg/_basic.py @@ -51,17 +51,20 @@ def func_literals(a1, a2="uno"): """ -def func_use_from_elsewhere(a1, a2): +def func_use_from_elsewhere(a1, a2, a3, a4): """Check if types with full import names are matched. Parameters ---------- a1 : example_pkg.CustomException a2 : ExampleClass + a3 : example_pkg.CustomException.NestedClass + a4 : ExampleClass.NestedClass Returns ------- r1 : ~.CustomException + r2 : ~.NestedClass """ @@ -74,6 +77,16 @@ class ExampleClass: a2 : float, default 0 """ + class NestedClass: + + def method_in_nested_class(self, a1): + """ + + Parameters + ---------- + a1 : complex + """ + def __init__(self, a1, a2=0): pass @@ -138,7 +151,7 @@ def method_returning_cls(cls, config): New class. """ - @classmethod() + @classmethod def method_returning_cls2(cls, config): """Using `Self` in context of classmethods is supported. From 8588ebb241673b506813a0005bd092b9c914d668 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Lars=20Gr=C3=BCter?= Date: Wed, 4 Sep 2024 15:08:49 +0200 Subject: [PATCH 16/16] Rework analysis and other major changes While not perfect or anywhere near to finished the refactored code uses clearer separation of responsiblities and is one step closer to an architecture that feels right. :) --- examples/docstub.toml | 18 +- src/docstub/_analysis.py | 207 ++++++++++---------- src/docstub/_cli.py | 8 +- src/docstub/_config.py | 12 +- src/docstub/_docstrings.py | 324 ++++++++++++++++---------------- src/docstub/_stubs.py | 31 +-- src/docstub/_utils.py | 23 +++ src/docstub/default_config.toml | 7 +- src/docstub/doctype.lark | 15 +- tests/test_analysis.py | 62 ++++++ 10 files changed, 406 insertions(+), 301 deletions(-) create mode 100644 tests/test_analysis.py diff --git a/examples/docstub.toml b/examples/docstub.toml index 8850bdc..a16862e 100644 --- a/examples/docstub.toml +++ b/examples/docstub.toml @@ -5,14 +5,16 @@ extend_grammar = """ """ -# A mapping of docnames to import information. Each item maps a docname on the -# left side to a dictionary on the right side, which supports the following -# fields: -# use : A string to replace the docname with, defaults to the docname. -# from : Indicate that the docname can be imported from this path. -# import : Import this object, defaults to the docname. +# Import information for type annotations, declared ahead of time. +# +# Each item maps an annotation name on the left side to a dictionary on the +# right side. +# +# Import information can be declared with the following fields: +# from : Indicate that the DocType can be imported from this path. +# import : Import this object, defaults to the DocType. # as : Use this alias for the imported object -# is_builtin : Indicate that this docname doesn't need to be imported, +# is_builtin : Indicate that this DocType doesn't need to be imported, # defaults to "false" -[tool.docstub.docnames] +[tool.docstub.known_imports] configparser = {import = "configparser"} diff --git a/src/docstub/_analysis.py b/src/docstub/_analysis.py index f911c98..750ebba 100644 --- a/src/docstub/_analysis.py +++ b/src/docstub/_analysis.py @@ -2,7 +2,6 @@ import builtins import collections.abc -import itertools import logging import re import typing @@ -11,6 +10,8 @@ import libcst as cst +from ._utils import accumulate_qualname + logger = logging.getLogger(__name__) @@ -43,21 +44,20 @@ class KnownImport: Parameters ---------- - annotation_name : str - Name of the type annotation import_name : Dotted names after "import". import_path : Dotted names after "from". import_alias : Name (without ".") after "as". + builtin_name : + Names an object that's builtin and doesn't need an import. """ - annotation_name: str import_name: str = None import_path: str = None import_alias: str = None - is_builtin: bool = False + builtin_name: str = None @classmethod def one_from_config(cls, name, *, info): @@ -74,17 +74,23 @@ def one_from_config(cls, name, *, info): """ assert not (info.keys() - {"from", "import", "as", "is_builtin"}) - import_name = name - if "import" in info: - import_name = info["import"] + if info.get("is_builtin"): + known_import = cls(builtin_name=name) + else: + import_name = name + if "import" in info: + import_name = info["import"] + + known_import = cls( + import_name=import_name, + import_path=info.get("from"), + import_alias=info.get("as"), + ) + if not name.startswith(known_import.target): + raise ValueError( + f"{name!r} doesn't start with {known_import.target!r}", + ) - known_import = cls( - annotation_name=name, - import_name=import_name, - import_path=info.get("from"), - import_alias=info.get("as"), - is_builtin=info.get("is_builtin", False), - ) return known_import @classmethod @@ -105,7 +111,7 @@ def many_from_config(cls, mapping): return known_imports def format_import(self, relative_to=None): - if self.is_builtin: + if self.builtin_name: msg = "cannot import builtin" raise RuntimeError(msg) out = f"import {self.import_name}" @@ -124,24 +130,44 @@ def format_import(self, relative_to=None): out = f"{out} as {self.import_alias}" return out + @property + def target(self) -> str: + if self.import_alias: + out = self.import_alias + elif self.import_name: + out = self.import_name + elif self.builtin_name: + out = self.builtin_name + else: + raise RuntimeError("cannot determine import target") + return out + @property def has_import(self): - return not self.is_builtin + return self.builtin_name is None def __post_init__(self): - if "." in self.annotation_name: - raise ValueError("'.' in the annotation name aren't yet supported") + if self.builtin_name is not None: + if ( + self.import_name is not None + or self.import_alias is not None + or self.import_path is not None + ): + raise ValueError("builtin cannot contain import information") + elif self.import_name is None: + raise ValueError("non bultin must at least define an `import_name`") - if self.import_alias and self.import_alias != self.annotation_name: - raise ValueError( - f"annotation name must match given import alias: " - f"{self.annotation_name} != {self.import_alias}" - ) - elif self.import_name != self.annotation_name: - raise ValueError( - f"annotation name must match import name if no alias is given: " - f"{self.annotation_name} != {self.import_name}" - ) + def __repr__(self): + if self.builtin_name: + info = f"{self.target} (builtin)" + else: + info = f"{self.format_import()!r}" + out = f"<{type(self).__name__} {info}>" + return out + + def __str__(self): + out = self.format_import() + return out @dataclass(slots=True, frozen=True) @@ -177,7 +203,7 @@ def _builtin_imports(): value = getattr(builtins, name) if not _is_type(value): continue - known_imports[name] = KnownImport(annotation_name=name, is_builtin=True) + known_imports[name] = KnownImport(builtin_name=name) return known_imports @@ -256,10 +282,12 @@ def __init__(self, *, module_name): def visit_ClassDef(self, node): self._stack.append(node.name.value) - use_name = ".".join(self._stack[:1]) + class_name = ".".join(self._stack[:1]) qualname = f"{self.module_name}.{'.'.join(self._stack)}" + known_import = KnownImport( - use_name=use_name, import_name=use_name, import_path=self.module_name + import_name=class_name, + import_path=self.module_name, ) self.known_imports[qualname] = known_import @@ -288,6 +316,7 @@ class StaticInspector: >>> from docstub._analysis import StaticInspector, common_known_imports >>> inspector = StaticInspector(known_imports=common_known_imports()) >>> inspector.query("Any") + ('Any', ) """ def __init__( @@ -309,64 +338,32 @@ def __init__( self.current_source = None self.source_pkgs = source_pkgs - self._inspected = {"initial": known_imports} - - @staticmethod - def _accumulate_module_name(qualname): - fragments = qualname.split(".") - yield from itertools.accumulate(fragments, lambda x, y: f"{x}.{y}") - - def _find_modules(self, qualname): - for source in self.source_pkgs: - for module_name in self._accumulate_module_name(qualname): - module_path = module_name.replace(".", "/") - # Return PYI files last, so their content overwrites - files = [ - source / f"{module_path}.py", - source / f"{module_path}.pyi", - source / f"{module_path}/__init__.py", - source / f"{module_path}/__init__.pyi", - ] - for file in files: - if file.is_file(): - yield file, module_name - - def inspect_module(self, file, module_name): - """Collect known imports from the given file. - - Parameters - ---------- - file : Path - Returns - ------- - collected : set[KnownImport] - """ - if file in self._inspected: - return self._inspected[file] + self.known_imports = known_imports - known_imports = KnownImportCollector.collect(file, module_name) - self._inspected[file] = known_imports - self.known_imports.update(known_imports) - return known_imports + def query(self, search_name): + """Search for a known annotation name. - def query(self, qualname): - """ Parameters ---------- - qualname : str + search_name : str Returns ------- - out : KnownImport | None + annotation_name : str | None + If it was found, the name of the annotation that matches the `known_import`. + known_import : KnownImport | None + If it was found, import information matching the `annotation_name`. """ - out = self.known_imports.get(qualname) + annotation_name = None + known_import = None - *prefix, name = qualname.split(".") - if not out and "~" in prefix: - pattern = qualname.replace(".", r"\.") + if search_name.startswith("~."): + # Sphinx like matching with abbreviated name + pattern = search_name.replace(".", r"\.") pattern = pattern.replace("~", ".*") pattern = re.compile(pattern + "$") + # Might be slow, but works for now matches = { key: value for key, value in self.known_imports.items() @@ -374,30 +371,48 @@ def query(self, qualname): } if len(matches) > 1: shortest_key = sorted(matches.keys(), key=lambda x: len(x))[0] - out = matches[shortest_key] + known_import = matches[shortest_key] + annotation_name = shortest_key logger.warning( - "%s matches multiple types %s, using %s", - qualname, + "%r in %s matches multiple types %r, using %r", + search_name, + self.current_source, matches.keys(), shortest_key, ) elif len(matches) == 1: - _, out = matches.popitem() - - elif not out and self.current_source: - try_qualname = f"{self.current_source.import_path}.{qualname}" - out = self.known_imports.get(try_qualname) - - return out - - @property - def known_imports(self): - current_known_imports = {} - - for _, known_imports in self._inspected.items(): - current_known_imports.update(known_imports) + annotation_name, known_import = matches.popitem() + else: + logger.debug( + "couldn't match %r in %s", search_name, self.current_source + ) - return current_known_imports + if known_import is None and self.current_source: + # Try scope of current module + try_qualname = f"{self.current_source.import_path}.{search_name}" + known_import = self.known_imports.get(try_qualname) + annotation_name = search_name + + if known_import is None: + # Try a subset of the qualname (first 'a.b.c', then 'a.b' and 'a') + for partial_qualname in reversed(accumulate_qualname(search_name)): + known_import = self.known_imports.get(partial_qualname) + if known_import: + annotation_name = search_name + break + + if ( + known_import is not None + and annotation_name is not None + and annotation_name != known_import.target + and annotation_name.startswith(known_import.target) + ): + # Ensure that the annotation matches the import target + annotation_name = annotation_name[ + annotation_name.find(known_import.target) : + ] + + return annotation_name, known_import def __repr__(self): repr = f"{type(self).__name__}({self.source_pkgs})" diff --git a/src/docstub/_cli.py b/src/docstub/_cli.py index ea29373..651d9f2 100644 --- a/src/docstub/_cli.py +++ b/src/docstub/_cli.py @@ -4,13 +4,13 @@ import click -from ._config import Config from ._analysis import ( KnownImport, KnownImportCollector, StaticInspector, common_known_imports, ) +from ._config import Config from ._stubs import Py2StubTransformer, walk_source, walk_source_and_targets from ._version import __version__ @@ -80,13 +80,15 @@ def main(source_dir, out_dir, config_path, verbose): source_path, module_name=source_path.import_path ) known_imports.update(known_imports_in_source) - known_imports.update(KnownImport.many_from_config(config["known_imports"])) + known_imports.update(KnownImport.many_from_config(config.known_imports)) inspector = StaticInspector( source_pkgs=[source_dir.parent.resolve()], known_imports=known_imports ) # and the stub transformer - stub_transformer = Py2StubTransformer(inspector=inspector) + stub_transformer = Py2StubTransformer( + inspector=inspector, replace_doctypes=config.replace_doctypes + ) if not out_dir: out_dir = source_dir.parent diff --git a/src/docstub/_config.py b/src/docstub/_config.py index 676c6df..c1535ed 100644 --- a/src/docstub/_config.py +++ b/src/docstub/_config.py @@ -1,5 +1,5 @@ -import logging import dataclasses +import logging from pathlib import Path from typing import ClassVar @@ -18,9 +18,9 @@ class Config: extend_grammar: str = "" known_imports: dict[str, dict[str, str]] = dataclasses.field(default_factory=dict) - replace: dict[str, str] = dataclasses.field(default_factory=dict) + replace_doctypes: dict[str, str] = dataclasses.field(default_factory=dict) - _source: tuple[Path, ...] = tuple() + _source: tuple[Path, ...] = () @classmethod def from_toml(cls, path: Path | str) -> "Config": @@ -44,7 +44,7 @@ def merge(self, other): new = Config( extend_grammar=self.extend_grammar + other.extend_grammar, known_imports=self.known_imports | other.known_imports, - replace=self.replace | other.replace, + replace_doctypes=self.replace_doctypes | other.replace_doctypes, _source=self._source + other._source, ) logger.debug("merged Config from %s", new._source) @@ -58,8 +58,8 @@ def __post_init__(self): raise TypeError("extended_grammar must be a string") if not isinstance(self.known_imports, dict): raise TypeError("known_imports must be a dict") - if not isinstance(self.replace, dict): - raise TypeError("replace must be a string") + if not isinstance(self.replace_doctypes, dict): + raise TypeError("replace_doctypes must be a string") def __repr__(self): sources = " | ".join(str(s) for s in self._source) diff --git a/src/docstub/_docstrings.py b/src/docstub/_docstrings.py index c7465c7..221d25c 100644 --- a/src/docstub/_docstrings.py +++ b/src/docstub/_docstrings.py @@ -1,6 +1,4 @@ -"""Transform types defined in docstrings to Python parsable types. - -""" +"""Transform types defined in docstrings to Python parsable types.""" import logging from dataclasses import dataclass, field @@ -10,7 +8,8 @@ import lark.visitors from numpydoc.docscrape import NumpyDocString -from ._analysis import DocName +from ._analysis import KnownImport +from ._utils import accumulate_qualname logger = logging.getLogger(__name__) @@ -19,6 +18,12 @@ grammar_path = here / "doctype.lark" +with grammar_path.open() as file: + _grammar = file.read() + +_lark = lark.Lark(_grammar) + + def _find_one_token(tree: lark.Tree, *, name: str) -> lark.Token: """Find token with a specific type name in tree.""" tokens = [child for child in tree.children if child.type == name] @@ -29,11 +34,11 @@ def _find_one_token(tree: lark.Tree, *, name: str) -> lark.Token: @dataclass(frozen=True, slots=True) -class PyType: - """Python-ready type with attached import information.""" +class Annotation: + """Python-ready type annotation with attached import information.""" value: str - imports: set[DocName] | frozenset[DocName] = field(default_factory=frozenset) + imports: frozenset[KnownImport] = field(default_factory=frozenset) def __post_init__(self): object.__setattr__(self, "imports", frozenset(self.imports)) @@ -43,22 +48,22 @@ def __str__(self) -> str: @classmethod def as_return_tuple(cls, return_types): - """Concatenate multiple PyTypes and wrap in tuple if more than one. + """Concatenate multiple annotations and wrap in tuple if more than one. Useful to combine multiple returned types for a function into a single - PyType. + annotation. Parameters ---------- - return_types : Iterable[PyType] + return_types : Iterable[Annotation] The types to combine. Returns ------- - concatenated : PyType + concatenated : Annotation The concatenated types. """ - values, imports = cls._aggregate_pytypes(*return_types) + values, imports = cls._aggregate_annotations(*return_types) value = " , ".join(values) if len(values) > 1: value = f"tuple[{value}]" @@ -71,31 +76,31 @@ def as_yields_generator(cls, yield_types, receive_types=()): Parameters ---------- - yield_types : Iterable[PyType] + yield_types : Iterable[Annotation] The types to yield. - receive_types : Iterable[PyType], optional + receive_types : Iterable[Annotation], optional The types the generator receives. Returns ------- - iterator : PyType + iterator : Annotation The yielded and received types wrapped in a generator. """ # TODO raise NotImplementedError() @staticmethod - def _aggregate_pytypes(*types): - """Aggregate values and imports of given PyTypes. + def _aggregate_annotations(*types): + """Aggregate values and imports of given Annotations. Parameters ---------- - types : Iterable[PyType] + types : Iterable[Annotation] Returns ------- values : list[str] - imports : set[~.DocName] + imports : set[~.KnownImport] """ values = [] imports = set() @@ -105,25 +110,41 @@ def _aggregate_pytypes(*types): return values, imports -class MatchedName(lark.Token): - pass +ErrorFallbackAnnotation = Annotation( + value="ErrorFallback", + imports=frozenset( + ( + KnownImport( + import_name="Any", + import_path="typing", + import_alias="ErrorFallback", + ), + ) + ), +) + + +class KnownName(lark.Token): + """Wrapper token signaling that a type name was matched to a known import.""" @lark.visitors.v_args(tree=True) class DoctypeTransformer(lark.visitors.Transformer): """Transformer for docstring type descriptions (doctypes).""" - def __init__(self, *, inspector, **kwargs): + def __init__(self, *, inspector, replace_doctypes, **kwargs): """ Parameters ---------- inspector : ~.StaticInspector A dictionary mapping atomic names used in doctypes to information such as where to import from or how to replace the name itself. + replace_doctypes : dict[str, str] kwargs : dict[Any, Any] Keyword arguments passed to the init of the parent class. """ self.inspector = inspector + self.replace_doctypes = replace_doctypes self._collected_imports = None super().__init__(**kwargs) @@ -161,15 +182,17 @@ def transform(self, tree): Returns ------- - pytype : PyType + annotation : Annotation The doctype formatted as a stub-file compatible string with necessary imports attached. """ try: self._collected_imports = set() value = super().transform(tree=tree) - pytype = PyType(value=value, imports=frozenset(self._collected_imports)) - return pytype + annotation = Annotation( + value=value, imports=frozenset(self._collected_imports) + ) + return annotation finally: self._collected_imports = None @@ -197,37 +220,33 @@ def sphinx_ref(self, tree): qualname = _find_one_token(tree, name="QUALNAME") return qualname + def container(self, tree): + _container, *_content = tree.children + _content = ", ".join(_content) + assert _content + out = f"{_container}[{_content}]" + return out + def qualname(self, tree): children = tree.children + _qualname = ".".join(children) - # Try to match only first child to known imports - children[0] = self._match_n_record_name(children[0]) - matched = isinstance(children[0], MatchedName) - - # Insert dots except for containers (when child starts with "[") - out = [children[0]] - for child in children[1:]: - if not child.startswith("["): - out.append(".") - out.append(child) - - out = "".join(out) - if matched is False: - docname = self.inspector.query(out) - if docname: - out = docname.use_name - self._collected_imports.add(docname) - else: - logger.warning( - "unmatched name %r in %s", out, self.inspector.current_source - ) + for partial_qualname in accumulate_qualname(_qualname): + replacement = self.replace_doctypes.get(partial_qualname) + if replacement: + _qualname = _qualname.replace(partial_qualname, replacement) + break - out = lark.Token("QUALNAME", out) - return out + _qualname = self._find_import(_qualname) + + _qualname = lark.Token(type="QUALNAME", value=_qualname) + return _qualname def ARRAY_NAME(self, token): - new_token = self._match_n_record_name(token) - new_token.type = "ARRAY_NAME" + assert "." not in token + new_token = self.replace_doctypes.get(str(token), str(token)) + new_token = self._find_import(new_token) + new_token = lark.Token(type="ARRAY_NAME", value=new_token) return new_token def shape(self, tree): @@ -249,128 +268,101 @@ def contains(self, tree): def literals(self, tree): out = " , ".join(tree.children) out = f"Literal[{out}]" - self._collected_imports.add(self.inspector.query("Literal")) + _, known_import = self.inspector.query("Literal") + self._collected_imports.add(known_import) return out - def _match_n_record_name(self, token): + def _find_import(self, qualname): """Match type names to known imports.""" - assert "." not in token - docname = self.inspector.query(token) - if docname: - token = MatchedName(token.type, value=docname.use_name) - if docname.has_import: - self._collected_imports.add(docname) - return token - + try: + qualname, known_import = self.inspector.query(qualname) + if known_import: + if known_import.has_import: + self._collected_imports.add(known_import) + else: + logger.warning( + "unknown import for %r in %s", + qualname, + self.inspector.current_source, + ) + return qualname + except Exception as error: + raise error -with grammar_path.open() as file: - _grammar = file.read() -_lark = lark.Lark(_grammar) +class DocstringAnnotations: + def __init__(self, docstring, *, transformer): + self.docstring = docstring + self.np_docstring = NumpyDocString(docstring) + self.transformer = transformer + def _doctype_to_annotation(self, doctype): + """Convert a type description to a Python-ready type. -def doc2pytype(doctype, *, inspector): - """Convert a type description to a Python-ready type. - - Parameters - ---------- - doctype : str - The type description of a parameter or return value, as extracted from - a docstring. - inspector : docstub._analysis.StaticInspector - - Returns - ------- - pytype : PyType - The transformed type, ready to be inserted into a stub file, with - necessary imports attached. - """ - try: - transformer = DoctypeTransformer(inspector=inspector) - tree = _lark.parse(doctype) - pytype = transformer.transform(tree) - return pytype - except Exception: - logger.exception("couldn't parse docstring %r:", doctype) - return PyType( - value="Any", - imports={DocName.one_from_config("Any", info={"from": "typing"})}, - ) + Parameters + ---------- + doctype : str + The type description of a parameter or return value, as extracted from + a docstring. + inspector : docstub._analysis.StaticInspector + replace_doctypes : dict[str, str] + Returns + ------- + annotation : Annotation + The transformed type, ready to be inserted into a stub file, with + necessary imports attached. + """ + try: + tree = _lark.parse(doctype) + annotation = self.transformer.transform(tree) + return annotation + except lark.visitors.VisitError as e: + logger.exception("couldn't parse doctype: %r", doctype, exc_info=e.orig_exc) + return ErrorFallbackAnnotation + except Exception: + logger.exception("couldn't parse doctype: %r", doctype) + return ErrorFallbackAnnotation + + @property + def parameters(self) -> dict[str, Annotation]: + def name_and_type(numpydoc_section): + name_type = { + param.name: param.type + for param in self.np_docstring[numpydoc_section] + if param.type + } + return name_type + + params = name_and_type("Parameters") + other = name_and_type("Other Parameters") + + duplicate_params = params.keys() & other.keys() + if duplicate_params: + raise ValueError(f"{duplicate_params=}") + params.update(other) + + annotations = { + name: self._doctype_to_annotation(type_) for name, type_ in params.items() + } + return annotations + + @property + def returns(self) -> Annotation | None: + out = [ + self._doctype_to_annotation(param.type) + for param in self.np_docstring["Returns"] + if param.type + ] + out = Annotation.as_return_tuple(out) if out else None + return out -@dataclass(frozen=True, slots=True) -class DocstringPyTypes: - """Groups Pytypes in a docstring.""" - - parameters: dict[str, PyType] - attributes: dict[str, PyType] - returns: PyType | None - yields: PyType | None - - -def collect_pytypes(docstring, *, inspector): - """Collect PyTypes from a docstring. - - Parameters - ---------- - docstring : str - The docstring to collect from. - inspector : docstub._analysis.StaticInspector - - Returns - ------- - pytypes : DocstringPyTypes - The collected PyTypes grouped by parameters, attributes, returns, and - yields. - """ - np_docstring = NumpyDocString(docstring) - - params = {p.name: p for p in np_docstring["Parameters"]} - other = {p.name: p for p in np_docstring["Other Parameters"]} - - duplicate_params = params.keys() & other.keys() - if duplicate_params: - raise ValueError(f"{duplicate_params=}") - params.update(other) - - parameters = { - name: doc2pytype(param.type, inspector=inspector) - for name, param in params.items() - if param.type - } - - returns = [ - doc2pytype(param.type, inspector=inspector) - for param in np_docstring["Returns"] - if param.type - ] - returns = PyType.as_return_tuple(returns) if returns else None - - yields = [ - doc2pytype(param.type, inspector=inspector) - for param in np_docstring["Yields"] - if param.type - ] - receives = [ - doc2pytype(param.type, inspector=inspector) - for param in np_docstring["Receives"] - if param.type - ] - attributes = [ - doc2pytype(param.type, inspector=inspector) - for param in np_docstring["Attributes"] - if param.type - ] - if returns and yields: - logger.warning( - "found 'Returns' and 'Yields' section in docstring, ignoring 'Yields'" - ) - if receives and not yields: - logger.warning("found 'Receives' section in docstring without 'Yields' section") - if yields: - logger.warning("yields is not supported yet") - - ds_pytypes = DocstringPyTypes( - parameters=parameters, attributes=attributes, returns=returns, yields=None - ) - return ds_pytypes + @property + def yields(self) -> Annotation | None: + out = { + self._doctype_to_annotation(param.type) + for param in self.np_docstring["Yields"] + if param.type + } + out = Annotation.as_return_tuple(out) if out else None + return out diff --git a/src/docstub/_stubs.py b/src/docstub/_stubs.py index 97fb50a..95a45a1 100644 --- a/src/docstub/_stubs.py +++ b/src/docstub/_stubs.py @@ -1,6 +1,4 @@ -"""Transform Python source files to typed stub files. - -""" +"""Transform Python source files to typed stub files.""" import enum import logging @@ -10,7 +8,7 @@ import libcst as cst import libcst.matchers as cstm -from ._docstrings import collect_pytypes +from ._docstrings import DocstringAnnotations, DoctypeTransformer logger = logging.getLogger(__name__) @@ -201,13 +199,18 @@ class Py2StubTransformer(cst.CSTTransformer): _Annotation_Any = cst.Annotation(cst.Name("Any")) _Annotation_None = cst.Annotation(cst.Name("None")) - def __init__(self, *, inspector): + def __init__(self, *, inspector, replace_doctypes): """ Parameters ---------- inspector : ~._analysis.StaticInspector + replace_doctypes : dict[str, str] """ self.inspector = inspector + self.replace_doctypes = replace_doctypes + self.transformer = DoctypeTransformer( + inspector=inspector, replace_doctypes=replace_doctypes + ) # Relevant docstring for the current context self._scope_stack = None # Entered module, class or function scopes self._pytypes_stack = None # Collected pytypes for each stack @@ -313,13 +316,13 @@ def leave_FunctionDef(self, original_node, updated_node): "returns": self._Annotation_None, } - pytypes = self._pytypes_stack.pop() - if pytypes and pytypes.returns: - assert pytypes.returns.value + ds_annotations = self._pytypes_stack.pop() + if ds_annotations and ds_annotations.returns: + assert ds_annotations.returns.value node_changes["returns"] = cst.Annotation( - cst.parse_expression(pytypes.returns.value) + cst.parse_expression(ds_annotations.returns.value) ) - self._required_imports |= pytypes.returns.imports + self._required_imports |= ds_annotations.returns.imports updated_node = updated_node.with_changes(**node_changes) self._scope_stack.pop() @@ -360,7 +363,8 @@ def leave_Param(self, original_node, updated_node): # Potentially use "Any" except for first param in (class)methods elif not is_self_or_cls and updated_node.annotation is None: node_changes["annotation"] = self._Annotation_Any - self._required_imports.add(self.inspector.query("Any")) + _, known_import = self.inspector.query("Any") + self._required_imports.add(known_import) if updated_node.default is not None: node_changes["default"] = cst.Ellipsis() @@ -562,7 +566,10 @@ def _pytypes_from_node(self, node): docstring = node.get_docstring() if docstring: try: - pytypes = collect_pytypes(docstring, inspector=self.inspector) + pytypes = DocstringAnnotations( + docstring, + transformer=self.transformer, + ) except Exception as e: logger.exception( "error while parsing docstring of `%s`:\n\n%s", diff --git a/src/docstub/_utils.py b/src/docstub/_utils.py index e69de29..bcec015 100644 --- a/src/docstub/_utils.py +++ b/src/docstub/_utils.py @@ -0,0 +1,23 @@ +import itertools + + +def accumulate_qualname(qualname, *, start_right=False): + """Return possible partial names from a fully qualified one. + + Examples + -------- + >>> accumulate_qualname("a.b.c") + ('a', 'a.b', 'a.b.c') + >>> accumulate_qualname("a.b.c", start_right=True) + ('c', 'b.c', 'a.b.c') + """ + fragments = qualname.split(".") + if start_right is True: + fragments = reversed(fragments) + template = "{1}.{0}" + else: + template = "{0}.{1}" + out = tuple( + itertools.accumulate(fragments, func=lambda x, y: template.format(x, y)) + ) + return out diff --git a/src/docstub/default_config.toml b/src/docstub/default_config.toml index 95e3023..2c042e2 100644 --- a/src/docstub/default_config.toml +++ b/src/docstub/default_config.toml @@ -23,9 +23,10 @@ np = { import = "numpy", as = "np" } NDArray = { from = "numpy.typing" } ArrayLike = { from = "numpy.typing" } - -# Replace human-readable expressions with actual types. -[tool.docstub.replace] +# Specify human-friendly aliases that can be used instead of Python-parsable +# annotations. +# TODO rename to qualname_alias or something +[tool.docstub.replace_doctypes] function = "Callable" func = "Callable" callable = "Callable" diff --git a/src/docstub/doctype.lark b/src/docstub/doctype.lark index 8ae49ba..053d6a4 100644 --- a/src/docstub/doctype.lark +++ b/src/docstub/doctype.lark @@ -8,6 +8,7 @@ types_or : type (("or" | "|") type)* ?type : qualname | sphinx_ref + | container | shape_n_dtype optional : "optional" @@ -17,14 +18,14 @@ extra_info : /[^\r\n]+/ sphinx_ref : ":" (NAME ":")? NAME ":`" qualname "`" -// Name with leading dot separated path -qualname : (/~/ ".")? (NAME ".")* NAME contains? +container: qualname "[" types_or ("," types_or)* "]" + | qualname "[" types_or "," PY_ELLIPSES "]" + | qualname "of" type + | qualname "of" "(" types_or ("," types_or)* ")" + | qualname "of" "{" types_or ":" types_or "}" -contains: "[" types_or ("," types_or)* "]" - | "[" types_or "," PY_ELLIPSES "]" - | "of" type - | "of" "(" types_or ("," types_or)* ")" - | "of" "{" types_or ":" types_or "}" +// Name with leading dot separated path +qualname : (/~/ ".")? (NAME ".")* NAME // Array-like form with dtype or shape information shape_n_dtype : shape? ARRAY_NAME ("of" dtype)? diff --git a/tests/test_analysis.py b/tests/test_analysis.py new file mode 100644 index 0000000..8da4d13 --- /dev/null +++ b/tests/test_analysis.py @@ -0,0 +1,62 @@ +import pytest + +from docstub._analysis import KnownImport, StaticInspector + + +class Test_StaticInspector: + known_imports = { # noqa: RUF012 + "dict": KnownImport(builtin_name="dict"), + "np": KnownImport(import_name="numpy", import_alias="np"), + "foo.bar": KnownImport(import_path="foo", import_name="bar"), + "foo.bar.Baz": KnownImport(import_path="foo.bar", import_name="Baz"), + "foo.bar.Baz.Bix": KnownImport(import_path="foo.bar", import_name="Baz"), + "foo.bar.Baz.Qux": KnownImport(import_path="foo", import_name="bar"), + } + + # fmt: off + @pytest.mark.parametrize( + ("name", "exp_annotation", "exp_import_line"), + [ + ("np", "np", "import numpy as np"), + # Finds imports whose import target matches the start of `name` + ("np.doesnt_exist", "np.doesnt_exist", "import numpy as np"), + + ("foo.bar.Baz", "Baz", "from foo.bar import Baz"), + # Finds "Baz" with abbreviated form as well + ( "~.bar.Baz", "Baz", "from foo.bar import Baz"), + ( "~.Baz", "Baz", "from foo.bar import Baz"), + + # Finds nested class "Baz.Bix" + ("foo.bar.Baz.Bix", "Baz.Bix", "from foo.bar import Baz"), + ( "~.bar.Baz.Bix", "Baz.Bix", "from foo.bar import Baz"), + ( "~.Baz.Bix", "Baz.Bix", "from foo.bar import Baz"), + ( "~.Bix", "Baz.Bix", "from foo.bar import Baz"), + + # Finds nested class "Baz.Gul" that's not explicitly defined, but + # whose import target matches "Baz" + ("foo.bar.Baz.Gul", "Baz.Gul", "from foo.bar import Baz"), + # but abbreviated form doesn't work + ( "~.bar.Baz.Gul", None, None), + ( "~.Baz.Gul", None, None), + ( "~.Gul", None, None), + + # Finds nested class "bar.Baz.Qux" (import defines module as target) + ("foo.bar.Baz.Qux", "bar.Baz.Qux", "from foo import bar"), + ( "~.bar.Baz.Qux", "bar.Baz.Qux", "from foo import bar"), + ( "~.Baz.Qux", "bar.Baz.Qux", "from foo import bar"), + ( "~.Qux", "bar.Baz.Qux", "from foo import bar"), + ] + ) + def test_query(self, name, exp_annotation, exp_import_line): + inspector = StaticInspector(known_imports=self.known_imports.copy()) + + annotation, known_import = inspector.query(name) + + if exp_annotation is None and exp_import_line is None: + assert exp_annotation is annotation + assert exp_import_line is known_import + else: + assert str(known_import) == exp_import_line + assert annotation.startswith(known_import.target) + assert annotation == exp_annotation + # fmt: on