diff --git a/examples/docstub.toml b/examples/docstub.toml index 8850bdc..a16862e 100644 --- a/examples/docstub.toml +++ b/examples/docstub.toml @@ -5,14 +5,16 @@ extend_grammar = """ """ -# A mapping of docnames to import information. Each item maps a docname on the -# left side to a dictionary on the right side, which supports the following -# fields: -# use : A string to replace the docname with, defaults to the docname. -# from : Indicate that the docname can be imported from this path. -# import : Import this object, defaults to the docname. +# Import information for type annotations, declared ahead of time. +# +# Each item maps an annotation name on the left side to a dictionary on the +# right side. +# +# Import information can be declared with the following fields: +# from : Indicate that the DocType can be imported from this path. +# import : Import this object, defaults to the DocType. # as : Use this alias for the imported object -# is_builtin : Indicate that this docname doesn't need to be imported, +# is_builtin : Indicate that this DocType doesn't need to be imported, # defaults to "false" -[tool.docstub.docnames] +[tool.docstub.known_imports] configparser = {import = "configparser"} diff --git a/examples/example_pkg-stubs/__init__.pyi b/examples/example_pkg-stubs/__init__.pyi index 5797999..f1a41c0 100644 --- a/examples/example_pkg-stubs/__init__.pyi +++ b/examples/example_pkg-stubs/__init__.pyi @@ -1,8 +1,10 @@ import _numpy as np_ -from _basic import func_comment, func_contains +from _basic import func_contains __all__ = [ - "func_comment", "func_contains", "np_", ] + +class CustomException(Exception): + pass diff --git a/examples/example_pkg-stubs/_basic.pyi b/examples/example_pkg-stubs/_basic.pyi index 1d7e2db..58801b5 100644 --- a/examples/example_pkg-stubs/_basic.pyi +++ b/examples/example_pkg-stubs/_basic.pyi @@ -3,31 +3,48 @@ import logging from collections.abc import Sequence from typing import Any, Literal, Self, Union -logger = logging.getLogger(__name__) +from . import CustomException + +logger = ... __all__ = [ "func_empty", "ExampleClass", ] -def func_empty(a1: Any, a2: Any, a3: Any) -> None: ... +def func_empty(a1, a2, a3) -> None: ... def func_contains( - self, a1: list[float], a2: dict[str, Union[int, str]], a3: Sequence[int | float], a4: frozenset[bytes], -) -> tuple[tuple[int, ...], list[int]]: ... + a5: tuple[int], + a6: list[int, str], + a7: dict[str, int], +) -> None: ... def func_literals( a1: Literal[1, 3, "foo"], a2: Literal["uno", 2, "drei", "four"] = ... ) -> None: ... +def func_use_from_elsewhere( + a1: CustomException, + a2: ExampleClass, + a3: CustomException.NestedClass, + a4: ExampleClass.NestedClass, +) -> tuple[CustomException, ExampleClass.NestedClass]: ... class ExampleClass: - def __init__(self, a1: int, a2: float | None = ...) -> None: ... + class NestedClass: + def method_in_nested_class(self, a1: complex) -> None: ... + + def __init__(self, a1: str, a2: float = ...) -> None: ... def method(self, a1: float, a2: float | None) -> list[float]: ... @staticmethod def some_staticmethod(a1: float, a2: float | None = ...) -> dict[str, Any]: ... @property def some_property(self) -> str: ... + @some_property.setter + def some_property(self, value: str) -> None: ... @classmethod def method_returning_cls(cls, config: configparser.ConfigParser) -> Self: ... + @classmethod + def method_returning_cls2(cls, config: configparser.ConfigParser) -> Self: ... diff --git a/examples/example_pkg-stubs/_numpy.pyi b/examples/example_pkg-stubs/_numpy.pyi index be3c721..ec1a695 100644 --- a/examples/example_pkg-stubs/_numpy.pyi +++ b/examples/example_pkg-stubs/_numpy.pyi @@ -2,7 +2,7 @@ import numpy as np from numpy.typing import ArrayLike, NDArray def func_object_with_numpy_objects( - a1: np.np.int8, a2: np.np.int16, a3: np.typing.DTypeLike, a4: np.typing.DTypeLike + a1: np.int8, a2: np.int16, a3: np.typing.DTypeLike, a4: np.typing.DTypeLike ) -> None: ... def func_ndarray( a1: NDArray, a2: np.NDArray, a3: NDArray[float], a4: NDArray[np.uint8] = ... diff --git a/examples/example_pkg/__init__.py b/examples/example_pkg/__init__.py index 2d40e24..ac61e3d 100644 --- a/examples/example_pkg/__init__.py +++ b/examples/example_pkg/__init__.py @@ -1,10 +1,13 @@ """Example of an init file.""" import _numpy as np_ -from _basic import func_comment, func_contains +from _basic import func_contains __all__ = [ - "func_comment", "func_contains", "np_", ] + + +class CustomException(Exception): + pass diff --git a/examples/example_pkg/_basic.py b/examples/example_pkg/_basic.py index 3933b94..cdaa5a2 100644 --- a/examples/example_pkg/_basic.py +++ b/examples/example_pkg/_basic.py @@ -26,7 +26,7 @@ def func_empty(a1, a2, a3): """ -def func_contains(self, a1, a2, a3, a4): +def func_contains(a1, a2, a3, a4, a5, a6, a7): """Dummy. Parameters @@ -35,11 +35,9 @@ def func_contains(self, a1, a2, a3, a4): a2 : dict[str, Union[int, str]] a3 : Sequence[int | float] a4 : frozenset[bytes] - - Returns - ------- - r1 : tuple of int - r2 : list of int + a5 : tuple of int + a6 : list of (int, str) + a7 : dict of {str: int} """ @@ -53,16 +51,44 @@ def func_literals(a1, a2="uno"): """ +def func_use_from_elsewhere(a1, a2, a3, a4): + """Check if types with full import names are matched. + + Parameters + ---------- + a1 : example_pkg.CustomException + a2 : ExampleClass + a3 : example_pkg.CustomException.NestedClass + a4 : ExampleClass.NestedClass + + Returns + ------- + r1 : ~.CustomException + r2 : ~.NestedClass + """ + + class ExampleClass: - # TODO also take into account class level docstring + """Dummy. - def __init__(self, a1, a2=None): - """ - Parameters - ---------- - a1 : int - a2 : float, optional - """ + Parameters + ---------- + a1 : str + a2 : float, default 0 + """ + + class NestedClass: + + def method_in_nested_class(self, a1): + """ + + Parameters + ---------- + a1 : complex + """ + + def __init__(self, a1, a2=0): + pass def method(self, a1, a2): """Dummy. @@ -101,6 +127,15 @@ def some_property(self): """ return str(self) + @some_property.setter + def some_property(self, value): + """Dummy + + Parameters + ---------- + value : str + """ + @classmethod def method_returning_cls(cls, config): """Using `Self` in context of classmethods is supported. @@ -115,3 +150,18 @@ def method_returning_cls(cls, config): out : Self New class. """ + + @classmethod + def method_returning_cls2(cls, config): + """Using `Self` in context of classmethods is supported. + + Parameters + ---------- + config : configparser.ConfigParser + Configuation. + + Returns + ------- + out : Self + New class. + """ diff --git a/examples/example_pkg/_numpy.py b/examples/example_pkg/_numpy.py index 524af84..33fa9ca 100644 --- a/examples/example_pkg/_numpy.py +++ b/examples/example_pkg/_numpy.py @@ -19,7 +19,7 @@ def func_ndarray(a1, a2, a3, a4=None): Parameters ---------- a1 : ndarray - a2 : np.ndarray + a2 : np.NDArray a3 : (N, 3) ndarray of float a4 : ndarray of shape (1,) and dtype uint8 diff --git a/pyproject.toml b/pyproject.toml index 58943e1..e258613 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -40,6 +40,7 @@ optional = [ ] dev = [ "pre-commit >=3.7", + "ipython", ] test = [ "pytest >=5.0.0", @@ -87,3 +88,8 @@ ignore = [ "RET504", # Assignment before `return` statement facilitates debugging "PTH123", # Using builtin open() instead of Path.open() is fine ] + + +[tool.docstub.docnames] +cst = {import = "libcst", as="cst"} +lark = {import = "lark"} diff --git a/src/docstub/_analysis.py b/src/docstub/_analysis.py index ee851b0..750ebba 100644 --- a/src/docstub/_analysis.py +++ b/src/docstub/_analysis.py @@ -2,59 +2,180 @@ import builtins import collections.abc +import logging +import re import typing from dataclasses import dataclass +from pathlib import Path + +import libcst as cst + +from ._utils import accumulate_qualname + +logger = logging.getLogger(__name__) + + +def _shared_leading_path(*paths): + """Identify the common leading parts between import paths. + + Parameters + ---------- + *paths : tuple[str] + + Returns + ------- + shared : str + """ + if len(paths) < 2: + raise ValueError("need more than two paths") + splits = (p.split(".") for p in paths) + shared = [] + for paths in zip(*splits, strict=False): + if all(paths[0] == p for p in paths): + shared.append(paths[0]) + else: + break + return ".".join(shared) @dataclass(slots=True, frozen=True) -class DocName: - """An atomic name (without ".") in a docstring type with import info.""" +class KnownImport: + """Import information associated with a single known type annotation. + + Parameters + ---------- + import_name : + Dotted names after "import". + import_path : + Dotted names after "from". + import_alias : + Name (without ".") after "as". + builtin_name : + Names an object that's builtin and doesn't need an import. + """ - use_name: str import_name: str = None import_path: str = None import_alias: str = None - is_builtin: bool = False + builtin_name: str = None @classmethod - def from_cfg(cls, docname: str, spec: dict): - use_name = docname - if "import" in spec: - use_name = spec["import"] - if "as" in spec: - use_name = spec["as"] - if "use" in spec: - use_name = spec["use"] - - import_name = docname - if "use" in spec: - import_name = spec["use"] - if "import" in spec: - import_name = spec["import"] - - docname = cls( - use_name=use_name, - import_name=import_name, - import_path=spec.get("from"), - import_alias=spec.get("as"), - is_builtin=spec.get("builtin", False), - ) - return docname + def one_from_config(cls, name, *, info): + """Create one KnownImport from the configuration format. + + Parameters + ---------- + name : str + info : dict[{"from", "import", "as", "is_builtin"}, str] + + Returns + ------- + TypeImport : Self + """ + assert not (info.keys() - {"from", "import", "as", "is_builtin"}) + + if info.get("is_builtin"): + known_import = cls(builtin_name=name) + else: + import_name = name + if "import" in info: + import_name = info["import"] + + known_import = cls( + import_name=import_name, + import_path=info.get("from"), + import_alias=info.get("as"), + ) + if not name.startswith(known_import.target): + raise ValueError( + f"{name!r} doesn't start with {known_import.target!r}", + ) - def format_import(self): - if self.is_builtin: + return known_import + + @classmethod + def many_from_config(cls, mapping): + """Create many KnownImports from the configuration format. + + Parameters + ---------- + mapping : dict[str, dict[{"from", "import", "as", "is_builtin"}, str]] + + Returns + ------- + known_imports : dict[str, Self] + """ + known_imports = { + name: cls.one_from_config(name, info=info) for name, info in mapping.items() + } + return known_imports + + def format_import(self, relative_to=None): + if self.builtin_name: msg = "cannot import builtin" raise RuntimeError(msg) out = f"import {self.import_name}" - if self.import_path: - out = f"from {self.import_path} {out}" + + import_path = self.import_path + if import_path: + if relative_to: + shared = _shared_leading_path(relative_to, import_path) + if shared == import_path: + import_path = "." + else: + import_path = self.import_path.replace(shared, "") + + out = f"from {import_path} {out}" if self.import_alias: out = f"{out} as {self.import_alias}" return out + @property + def target(self) -> str: + if self.import_alias: + out = self.import_alias + elif self.import_name: + out = self.import_name + elif self.builtin_name: + out = self.builtin_name + else: + raise RuntimeError("cannot determine import target") + return out + @property def has_import(self): - return not self.is_builtin + return self.builtin_name is None + + def __post_init__(self): + if self.builtin_name is not None: + if ( + self.import_name is not None + or self.import_alias is not None + or self.import_path is not None + ): + raise ValueError("builtin cannot contain import information") + elif self.import_name is None: + raise ValueError("non bultin must at least define an `import_name`") + + def __repr__(self): + if self.builtin_name: + info = f"{self.target} (builtin)" + else: + info = f"{self.format_import()!r}" + out = f"<{type(self).__name__} {info}>" + return out + + def __str__(self): + out = self.format_import() + return out + + +@dataclass(slots=True, frozen=True) +class InspectionContext: + """Currently inspected module and other information.""" + + file_path: Path + in_package_path: str def _is_type(value) -> bool: @@ -66,80 +187,233 @@ def _is_type(value) -> bool: return is_type -def _builtin_docnames(): - """Return docnames for all builtins (in the current runtime). +def _builtin_imports(): + """Return known imports for all builtins (in the current runtime). Returns ------- - docnames : dict[str, DocName] + known_imports : dict[str, KnownImport] """ known_builtins = set(dir(builtins)) - docnames = {} + known_imports = {} for name in known_builtins: if name.startswith("_"): continue value = getattr(builtins, name) if not _is_type(value): continue - docnames[name] = DocName(use_name=name, is_builtin=True) + known_imports[name] = KnownImport(builtin_name=name) - return docnames + return known_imports -def _typing_docnames(): - """Return docnames for public types in the `typing` module. +def _typing_imports(): + """Return known imports for public types in the `typing` module. Returns ------- - docnames : dict[str, DocName] + known_imports : dict[str, KnownImport] """ - docnames = {} + known_imports = {} for name in typing.__all__: if name.startswith("_"): continue value = getattr(typing, name) if not _is_type(value): continue - docnames[name] = DocName.from_cfg(name, spec={"from": "typing"}) - return docnames + known_imports[name] = KnownImport.one_from_config(name, info={"from": "typing"}) + return known_imports -def _collections_abc_docnames(): - """Return docnames for public types in the `collections.abc` module. +def _collections_abc_imports(): + """Return known imports for public types in the `collections.abc` module. Returns ------- - docnames : dict[str, DocName] + known_imports : dict[str, KnownImport] """ - docnames = {} + known_imports = {} for name in collections.abc.__all__: if name.startswith("_"): continue value = getattr(collections.abc, name) if not _is_type(value): continue - docnames[name] = DocName.from_cfg(name, spec={"from": "collections.abc"}) - return docnames + known_imports[name] = KnownImport.one_from_config( + name, info={"from": "collections.abc"} + ) + return known_imports -def common_docnames(): - """Return docnames for commonly supported types. +def common_known_imports(): + """Return known imports for commonly supported types. This includes builtin types, and types from the `typing` or `collections.abc` module. Returns ------- - docnames : dict[str, DocName] + known_imports : dict[str, KnownImport] """ - docnames = _builtin_docnames() - docnames |= _typing_docnames() - docnames |= _collections_abc_docnames() # Overrides containers from typing - return docnames + known_imports = _builtin_imports() + known_imports |= _typing_imports() + known_imports |= _collections_abc_imports() # Overrides containers from typing + return known_imports + + +class KnownImportCollector(cst.CSTVisitor): + @classmethod + def collect(cls, file, module_name): + file = Path(file) + with file.open("r") as fo: + source = fo.read() + + tree = cst.parse_module(source) + collector = cls(module_name=module_name) + tree.visit(collector) + return collector.known_imports + + def __init__(self, *, module_name): + self.module_name = module_name + self._stack = [] + self.known_imports = {} + + def visit_ClassDef(self, node): + self._stack.append(node.name.value) + + class_name = ".".join(self._stack[:1]) + qualname = f"{self.module_name}.{'.'.join(self._stack)}" + + known_import = KnownImport( + import_name=class_name, + import_path=self.module_name, + ) + self.known_imports[qualname] = known_import + + return True + + def leave_ClassDef(self, original_node): + self._stack.pop() + + def visit_FunctionDef(self, node): + self._stack.append(node.name.value) + return True + + def leave_FunctionDef(self, original_node): + self._stack.pop() + + +class StaticInspector: + """Static analysis of Python packages. + + Attributes + ---------- + current_source : ~.PackageFile | None + + Examples + -------- + >>> from docstub._analysis import StaticInspector, common_known_imports + >>> inspector = StaticInspector(known_imports=common_known_imports()) + >>> inspector.query("Any") + ('Any', ) + """ + + def __init__( + self, + *, + source_pkgs=None, + known_imports=None, + ): + """ + Parameters + ---------- + source_pkgs: list[Path], optional + known_imports: dict[str, KnownImport], optional + """ + if source_pkgs is None: + source_pkgs = [] + if known_imports is None: + known_imports = {} + + self.current_source = None + self.source_pkgs = source_pkgs + + self.known_imports = known_imports + + def query(self, search_name): + """Search for a known annotation name. + + Parameters + ---------- + search_name : str + + Returns + ------- + annotation_name : str | None + If it was found, the name of the annotation that matches the `known_import`. + known_import : KnownImport | None + If it was found, import information matching the `annotation_name`. + """ + annotation_name = None + known_import = None + + if search_name.startswith("~."): + # Sphinx like matching with abbreviated name + pattern = search_name.replace(".", r"\.") + pattern = pattern.replace("~", ".*") + pattern = re.compile(pattern + "$") + # Might be slow, but works for now + matches = { + key: value + for key, value in self.known_imports.items() + if re.match(pattern, key) + } + if len(matches) > 1: + shortest_key = sorted(matches.keys(), key=lambda x: len(x))[0] + known_import = matches[shortest_key] + annotation_name = shortest_key + logger.warning( + "%r in %s matches multiple types %r, using %r", + search_name, + self.current_source, + matches.keys(), + shortest_key, + ) + elif len(matches) == 1: + annotation_name, known_import = matches.popitem() + else: + logger.debug( + "couldn't match %r in %s", search_name, self.current_source + ) + + if known_import is None and self.current_source: + # Try scope of current module + try_qualname = f"{self.current_source.import_path}.{search_name}" + known_import = self.known_imports.get(try_qualname) + annotation_name = search_name + + if known_import is None: + # Try a subset of the qualname (first 'a.b.c', then 'a.b' and 'a') + for partial_qualname in reversed(accumulate_qualname(search_name)): + known_import = self.known_imports.get(partial_qualname) + if known_import: + annotation_name = search_name + break + + if ( + known_import is not None + and annotation_name is not None + and annotation_name != known_import.target + and annotation_name.startswith(known_import.target) + ): + # Ensure that the annotation matches the import target + annotation_name = annotation_name[ + annotation_name.find(known_import.target) : + ] + return annotation_name, known_import -def find_module_paths_using_search(): - # TODO use from mypy.stubgen import find_module_paths_using_search ? - # https://github.com/python/mypy/blob/66b48cbe97bf9c7660525766afe6d7089a984769/mypy/stubgen.py#L1526 - pass + def __repr__(self): + repr = f"{type(self).__name__}({self.source_pkgs})" + return repr diff --git a/src/docstub/_cli.py b/src/docstub/_cli.py index 1038daf..651d9f2 100644 --- a/src/docstub/_cli.py +++ b/src/docstub/_cli.py @@ -4,9 +4,14 @@ import click -from . import _config -from ._analysis import DocName, common_docnames -from ._stubs import Py2StubTransformer, walk_python_package +from ._analysis import ( + KnownImport, + KnownImportCollector, + StaticInspector, + common_known_imports, +) +from ._config import Config +from ._stubs import Py2StubTransformer, walk_source, walk_source_and_targets from ._version import __version__ logger = logging.getLogger(__name__) @@ -15,6 +20,40 @@ _VERBOSITY_LEVEL = {0: logging.WARNING, 1: logging.INFO, 2: logging.DEBUG} +def _find_configuration(source_dir, config_path): + """Find and load configuration from multiple possible sources. + + Parameters + ---------- + source_dir : Path + config_path : Path + + Returns + ------- + config : dict[str, Any] + """ + config = Config.from_toml(Config.DEFAULT_CONFIG_PATH) + + pyproject_toml = source_dir.parent / "pyproject.toml" + if pyproject_toml.is_file(): + logger.info("using %s", pyproject_toml) + add_config = Config.from_toml(pyproject_toml) + config = config.merge(add_config) + + docstub_toml = source_dir.parent / "docstub.toml" + if docstub_toml.is_file(): + logger.info("using %s", docstub_toml) + add_config = Config.from_toml(docstub_toml) + config = config.merge(add_config) + + if config_path: + logger.info("using %s", config_path) + add_config = Config.from_toml(config_path) + config = config.merge(add_config) + + return config + + @click.command() @click.version_option(__version__) @click.argument("source_dir", type=click.Path(exists=True, file_okay=False)) @@ -26,39 +65,37 @@ def main(source_dir, out_dir, config_path, verbose): verbose = min(2, max(0, verbose)) # Limit to range [0, 2] logging.basicConfig( level=_VERBOSITY_LEVEL[verbose], - format="%(levelname)s: %(filename)s, line %(lineno)d, in %(funcName)s: %(message)s", + format="%(levelname)s: %(filename)s#L%(lineno)d::%(funcName)s: %(message)s", stream=sys.stderr, ) source_dir = Path(source_dir) + config = _find_configuration(source_dir, config_path) - # Handle configuration - if config_path is None: - config_path = source_dir.parent / "docstub.toml" - else: - config_path = Path(config_path) - config = _config.default_config() - if config_path.exists(): - _user_config = _config.load_config_file(config_path) - config = _config.merge_config(config, _user_config) - - # Build docname map - docnames = common_docnames() - docnames.update( - { - name: DocName.from_cfg(docname=name, spec=spec) - for name, spec in config["docnames"].items() - } + # Build map of known imports + known_imports = common_known_imports() + for source_path in walk_source(source_dir): + logger.info("collecting types in %s", source_path) + known_imports_in_source = KnownImportCollector.collect( + source_path, module_name=source_path.import_path + ) + known_imports.update(known_imports_in_source) + known_imports.update(KnownImport.many_from_config(config.known_imports)) + + inspector = StaticInspector( + source_pkgs=[source_dir.parent.resolve()], known_imports=known_imports ) # and the stub transformer - stub_transformer = Py2StubTransformer(docnames=docnames) + stub_transformer = Py2StubTransformer( + inspector=inspector, replace_doctypes=config.replace_doctypes + ) if not out_dir: out_dir = source_dir.parent out_dir = Path(out_dir) / (source_dir.name + "-stubs") out_dir.mkdir(parents=True, exist_ok=True) - for source_path, stub_path in walk_python_package(source_dir, out_dir): + for source_path, stub_path in walk_source_and_targets(source_dir, out_dir): if source_path.suffix.lower() == ".pyi": logger.debug("using existing stub file %s", source_path) with source_path.open() as fo: @@ -68,7 +105,9 @@ def main(source_dir, out_dir, config_path, verbose): py_content = fo.read() logger.debug("creating stub from %s", source_path) try: - stub_content = stub_transformer.python_to_stub(py_content) + stub_content = stub_transformer.python_to_stub( + py_content, module_path=source_path + ) except Exception as e: logger.exception("failed creating stub for %s:\n\n%s", source_path, e) continue diff --git a/src/docstub/_config.py b/src/docstub/_config.py index 8f790ed..c1535ed 100644 --- a/src/docstub/_config.py +++ b/src/docstub/_config.py @@ -1,5 +1,7 @@ +import dataclasses import logging from pathlib import Path +from typing import ClassVar try: import tomllib @@ -10,30 +12,56 @@ logger = logging.getLogger(__name__) -DEFAULT_CONFIG_PATH = Path(__file__).parent / "default_config.toml" +@dataclasses.dataclass(frozen=True, slots=True, kw_only=True) +class Config: + DEFAULT_CONFIG_PATH: ClassVar[Path] = Path(__file__).parent / "default_config.toml" + extend_grammar: str = "" + known_imports: dict[str, dict[str, str]] = dataclasses.field(default_factory=dict) + replace_doctypes: dict[str, str] = dataclasses.field(default_factory=dict) -def load_config_file(path: Path | str) -> dict: - """Return configuration options in local TOML file if they exist.""" - with open(path, "rb") as fp: - config = tomllib.load(fp) - config = config.get("tool", {}).get("docstub", {}) - return config + _source: tuple[Path, ...] = () + @classmethod + def from_toml(cls, path: Path | str) -> "Config": + """Return configuration options in local TOML file if they exist.""" + path = Path(path) + with open(path, "rb") as fp: + raw = tomllib.load(fp) + config = cls(**raw.get("tool", {}).get("docstub", {}), _source=(path,)) + logger.debug("created Config from %s", path) + return config -def default_config(): - config = load_config_file(DEFAULT_CONFIG_PATH) - return config + @classmethod + def from_default(cls): + config = cls.from_toml(cls.DEFAULT_CONFIG_PATH) + return config + def merge(self, other): + """Merge contents with other and return a new Config instance.""" + if not isinstance(other, type(self)): + return NotImplemented + new = Config( + extend_grammar=self.extend_grammar + other.extend_grammar, + known_imports=self.known_imports | other.known_imports, + replace_doctypes=self.replace_doctypes | other.replace_doctypes, + _source=self._source + other._source, + ) + logger.debug("merged Config from %s", new._source) + return new -def merge_config(*configurations): - merged = {} - merged["extended_grammar"] = "\n".join( - cfg.get("extended_grammar", "") for cfg in configurations - ) - merged["docnames"] = {} - for cfg in configurations: - docnames = cfg.get("docnames") - if docnames and isinstance(docnames, dict): - merged["docnames"].update(docnames) - return merged + def to_dict(self): + return dataclasses.asdict(self) + + def __post_init__(self): + if not isinstance(self.extend_grammar, str): + raise TypeError("extended_grammar must be a string") + if not isinstance(self.known_imports, dict): + raise TypeError("known_imports must be a dict") + if not isinstance(self.replace_doctypes, dict): + raise TypeError("replace_doctypes must be a string") + + def __repr__(self): + sources = " | ".join(str(s) for s in self._source) + formatted = f"<{type(self).__name__}: {sources}>" + return formatted diff --git a/src/docstub/_docstrings.py b/src/docstub/_docstrings.py index 215f760..221d25c 100644 --- a/src/docstub/_docstrings.py +++ b/src/docstub/_docstrings.py @@ -1,6 +1,4 @@ -"""Transform types defined in docstrings to Python parsable types. - -""" +"""Transform types defined in docstrings to Python parsable types.""" import logging from dataclasses import dataclass, field @@ -10,7 +8,8 @@ import lark.visitors from numpydoc.docscrape import NumpyDocString -from ._analysis import DocName +from ._analysis import KnownImport +from ._utils import accumulate_qualname logger = logging.getLogger(__name__) @@ -19,6 +18,12 @@ grammar_path = here / "doctype.lark" +with grammar_path.open() as file: + _grammar = file.read() + +_lark = lark.Lark(_grammar) + + def _find_one_token(tree: lark.Tree, *, name: str) -> lark.Token: """Find token with a specific type name in tree.""" tokens = [child for child in tree.children if child.type == name] @@ -29,11 +34,11 @@ def _find_one_token(tree: lark.Tree, *, name: str) -> lark.Token: @dataclass(frozen=True, slots=True) -class PyType: - """Python-ready type with attached import information.""" +class Annotation: + """Python-ready type annotation with attached import information.""" value: str - imports: set[DocName] | frozenset[DocName] = field(default_factory=frozenset) + imports: frozenset[KnownImport] = field(default_factory=frozenset) def __post_init__(self): object.__setattr__(self, "imports", frozenset(self.imports)) @@ -42,53 +47,104 @@ def __str__(self) -> str: return self.value @classmethod - def from_concatenated(cls, pytypes): - """Concatenate multiple PyTypes in a tuple. + def as_return_tuple(cls, return_types): + """Concatenate multiple annotations and wrap in tuple if more than one. Useful to combine multiple returned types for a function into a single - PyType. + annotation. Parameters ---------- - pytypes : Iterable[PyType] + return_types : Iterable[Annotation] The types to combine. Returns ------- - concatenated : PyType + concatenated : Annotation The concatenated types. """ + values, imports = cls._aggregate_annotations(*return_types) + value = " , ".join(values) + if len(values) > 1: + value = f"tuple[{value}]" + concatenated = cls(value=value, imports=imports) + return concatenated + + @classmethod + def as_yields_generator(cls, yield_types, receive_types=()): + """Create new iterator type from yield and receive types. + + Parameters + ---------- + yield_types : Iterable[Annotation] + The types to yield. + receive_types : Iterable[Annotation], optional + The types the generator receives. + + Returns + ------- + iterator : Annotation + The yielded and received types wrapped in a generator. + """ + # TODO + raise NotImplementedError() + + @staticmethod + def _aggregate_annotations(*types): + """Aggregate values and imports of given Annotations. + + Parameters + ---------- + types : Iterable[Annotation] + + Returns + ------- + values : list[str] + imports : set[~.KnownImport] + """ values = [] imports = set() - for p in pytypes: + for p in types: values.append(p.value) imports.update(p.imports) - value = " , ".join(values) - if len(values) > 1: - value = f"tuple[{value}]" - joined = cls(value=value, imports=imports) - return joined + return values, imports + + +ErrorFallbackAnnotation = Annotation( + value="ErrorFallback", + imports=frozenset( + ( + KnownImport( + import_name="Any", + import_path="typing", + import_alias="ErrorFallback", + ), + ) + ), +) -class MatchedName(lark.Token): - pass +class KnownName(lark.Token): + """Wrapper token signaling that a type name was matched to a known import.""" @lark.visitors.v_args(tree=True) class DoctypeTransformer(lark.visitors.Transformer): - """Transformer for docstring type descriptions (doctypes). - - Parameters - ---------- - docnames : dict[str, DocName] - A dictionary mapping atomic names used in doctypes to information such - as where to import from or how to replace the name itself. - kwargs : dict[Any, Any] - Keyword arguments passed to the init of the parent class. - """ - - def __init__(self, *, docnames, **kwargs): - self.docnames = docnames + """Transformer for docstring type descriptions (doctypes).""" + + def __init__(self, *, inspector, replace_doctypes, **kwargs): + """ + Parameters + ---------- + inspector : ~.StaticInspector + A dictionary mapping atomic names used in doctypes to information such + as where to import from or how to replace the name itself. + replace_doctypes : dict[str, str] + kwargs : dict[Any, Any] + Keyword arguments passed to the init of the parent class. + """ + self.inspector = inspector + self.replace_doctypes = replace_doctypes self._collected_imports = None super().__init__(**kwargs) @@ -126,23 +182,25 @@ def transform(self, tree): Returns ------- - pytype : PyType + annotation : Annotation The doctype formatted as a stub-file compatible string with necessary imports attached. """ try: self._collected_imports = set() value = super().transform(tree=tree) - pytype = PyType(value=value, imports=frozenset(self._collected_imports)) - return pytype + annotation = Annotation( + value=value, imports=frozenset(self._collected_imports) + ) + return annotation finally: self._collected_imports = None - def doctype(self, tree): + def annotation(self, tree): out = " | ".join(tree.children) return out - def type_or(self, tree): + def types_or(self, tree): out = " | ".join(tree.children) return out @@ -150,35 +208,45 @@ def optional(self, tree): out = "None" literal = [child for child in tree.children if child.type == "LITERAL"] assert len(literal) <= 1 - if len(literal) == 1: - out = lark.Discard # Should be covered by doctype + if literal: + out = lark.Discard # Type should cover the default return out def extra_info(self, tree): logger.debug("dropping extra info") return lark.Discard - def qualname(self, tree): - matched = False - out = [] - for i, child in enumerate(tree.children): - if i != 0 and not child.startswith("["): - out.append(".") - if isinstance(child, MatchedName): - matched = True - out.append(child) - out = "".join(out) - if matched is False: - logger.warning("unmatched name %r", out) + def sphinx_ref(self, tree): + qualname = _find_one_token(tree, name="QUALNAME") + return qualname + + def container(self, tree): + _container, *_content = tree.children + _content = ", ".join(_content) + assert _content + out = f"{_container}[{_content}]" return out - def NAME(self, token): - new_token = self._match_n_record_name(token) - return new_token + def qualname(self, tree): + children = tree.children + _qualname = ".".join(children) + + for partial_qualname in accumulate_qualname(_qualname): + replacement = self.replace_doctypes.get(partial_qualname) + if replacement: + _qualname = _qualname.replace(partial_qualname, replacement) + break + + _qualname = self._find_import(_qualname) + + _qualname = lark.Token(type="QUALNAME", value=_qualname) + return _qualname def ARRAY_NAME(self, token): - new_token = self._match_n_record_name(token) - new_token.type = "ARRAY_NAME" + assert "." not in token + new_token = self.replace_doctypes.get(str(token), str(token)) + new_token = self._find_import(new_token) + new_token = lark.Token(type="ARRAY_NAME", value=new_token) return new_token def shape(self, tree): @@ -192,14 +260,6 @@ def shape_n_dtype(self, tree): name = f"{name}[{', '.join(children)}]" return name - def container_of(self, tree): - assert len(tree.children) == 2 - container_name, item_type = tree.children - if container_name == "tuple": - item_type += ", ..." - out = f"{container_name}[{item_type}]" - return out - def contains(self, tree): out = ", ".join(tree.children) out = f"[{out}]" @@ -208,104 +268,101 @@ def contains(self, tree): def literals(self, tree): out = " , ".join(tree.children) out = f"Literal[{out}]" - self._collected_imports.add(self.docnames["Literal"]) + _, known_import = self.inspector.query("Literal") + self._collected_imports.add(known_import) return out - def _match_n_record_name(self, token): + def _find_import(self, qualname): """Match type names to known imports.""" - assert "." not in token - if token in self.docnames: - docname = self.docnames[token] - token = MatchedName(token.type, value=docname.use_name) - if docname.has_import: - self._collected_imports.add(docname) - return token - - -with grammar_path.open() as file: - _grammar = file.read() - -_lark = lark.Lark(_grammar) - - -def doc2pytype(doctype, *, docnames): - """Convert a type description to a Python-ready type. - - Parameters - ---------- - doctype : str - The type description of a parameter or return value, as extracted from - a docstring. - docnames : dict[str, DocName] - A dictionary mapping atomic names used in doctypes to information such - as where to import from or how to replace the name itself. - - Returns - ------- - pytype : PyType - The transformed type, ready to be inserted into a stub file, with - necessary imports attached. - """ - try: - transformer = DoctypeTransformer(docnames=docnames) - tree = _lark.parse(doctype) - pytype = transformer.transform(tree) - return pytype - except Exception: - logger.exception("couldn't parse docstring %r:", doctype) - return PyType( - value="Any", imports={DocName.from_cfg("Any", {"from": "typing"})} - ) - - -class ReturnKey: - """Simple "singleton" key to access the return PyType in a dictionary. - - See :func:`collect_pytypes` for more. - """ - - -ReturnKey = ReturnKey() - - -def collect_pytypes(docstring, *, docnames): - """Collect PyTypes from a docstring. - - Parameters - ---------- - docstring : str - The docstring to collect from. - docnames : dict[str, DocName] - A dictionary mapping atomic names used in doctypes to information such - as where to import from or how to replace the name itself. - - Returns - ------- - pytypes : dict[str | Literal[ReturnKey], PyType] - The collected PyType for each parameter. If a return type is documented - it's saved under the special key :class:`ReturnKey`. - """ - np_docstring = NumpyDocString(docstring) - - params = {p.name: p for p in np_docstring["Parameters"]} - other = {p.name: p for p in np_docstring["Other Parameters"]} - duplicate_params = params.keys() & other.keys() - if duplicate_params: - raise ValueError(f"{duplicate_params=}") - params.update(other) + try: + qualname, known_import = self.inspector.query(qualname) + if known_import: + if known_import.has_import: + self._collected_imports.add(known_import) + else: + logger.warning( + "unknown import for %r in %s", + qualname, + self.inspector.current_source, + ) + return qualname + except Exception as error: + raise error + + +class DocstringAnnotations: + def __init__(self, docstring, *, transformer): + self.docstring = docstring + self.np_docstring = NumpyDocString(docstring) + self.transformer = transformer + + def _doctype_to_annotation(self, doctype): + """Convert a type description to a Python-ready type. - pytypes = { - name: doc2pytype(param.type, docnames=docnames) - for name, param in params.items() - if param.type - } + Parameters + ---------- + doctype : str + The type description of a parameter or return value, as extracted from + a docstring. + inspector : docstub._analysis.StaticInspector + replace_doctypes : dict[str, str] - returns = [ - doc2pytype(param.type, docnames=docnames) - for param in np_docstring["Returns"] - if param.type - ] - if returns: - pytypes[ReturnKey] = PyType.from_concatenated(returns) + Returns + ------- + annotation : Annotation + The transformed type, ready to be inserted into a stub file, with + necessary imports attached. + """ + try: + tree = _lark.parse(doctype) + annotation = self.transformer.transform(tree) + return annotation + except lark.visitors.VisitError as e: + logger.exception("couldn't parse doctype: %r", doctype, exc_info=e.orig_exc) + return ErrorFallbackAnnotation + except Exception: + logger.exception("couldn't parse doctype: %r", doctype) + return ErrorFallbackAnnotation + + @property + def parameters(self) -> dict[str, Annotation]: + def name_and_type(numpydoc_section): + name_type = { + param.name: param.type + for param in self.np_docstring[numpydoc_section] + if param.type + } + return name_type + + params = name_and_type("Parameters") + other = name_and_type("Other Parameters") + + duplicate_params = params.keys() & other.keys() + if duplicate_params: + raise ValueError(f"{duplicate_params=}") + params.update(other) + + annotations = { + name: self._doctype_to_annotation(type_) for name, type_ in params.items() + } + return annotations + + @property + def returns(self) -> Annotation | None: + out = [ + self._doctype_to_annotation(param.type) + for param in self.np_docstring["Returns"] + if param.type + ] + out = Annotation.as_return_tuple(out) if out else None + return out - return pytypes + @property + def yields(self) -> Annotation | None: + out = { + self._doctype_to_annotation(param.type) + for param in self.np_docstring["Yields"] + if param.type + } + out = Annotation.as_return_tuple(out) if out else None + return out diff --git a/src/docstub/_stubs.py b/src/docstub/_stubs.py index eb5beb7..95a45a1 100644 --- a/src/docstub/_stubs.py +++ b/src/docstub/_stubs.py @@ -1,20 +1,112 @@ -"""Transform Python source files to typed stub files. - -""" +"""Transform Python source files to typed stub files.""" +import enum import logging from dataclasses import dataclass -from typing import Literal +from pathlib import Path import libcst as cst +import libcst.matchers as cstm -from ._docstrings import ReturnKey, collect_pytypes +from ._docstrings import DocstringAnnotations, DoctypeTransformer logger = logging.getLogger(__name__) -def walk_python_package(root_dir, target_dir): - """Iterate modules in a Python package and it's target stub files. +class PackageFile(Path): + """File in a Python package.""" + + def __init__(self, *args, package_root): + """ + Parameters + ---------- + args : tuple[Any, ...] + package_root : Path + """ + self.package_root = package_root + super().__init__(*args) + if self.is_dir(): + raise ValueError("mustn't be a directory") + if not self.is_relative_to(self.package_root): + raise ValueError("path must be relative to package_root") + + @property + def import_path(self): + """ + Returns + ------- + str + """ + relative_to_root = self.relative_to(self.package_root) + parts = relative_to_root.with_suffix("").parts + parts = (self.package_root.name, *parts) + if parts[-1] == "__init__": + parts = parts[:-1] + import_name = ".".join(parts) + return import_name + + def with_segments(self, *args): + return Path(*args) + + +def _is_python_package(path): + """ + Parameters + ---------- + path : Path + + Returns + ------- + is_package : bool + """ + is_package = (path / "__init__.py").is_file() or (path / "__init__.pyi").is_file() + return is_package + + +def walk_source(root_dir): + """Iterate modules in a Python package and its target stub files. + + Parameters + ---------- + root_dir : Path + Root directory of a Python package. + target_dir : Path + Root directory in which a matching stub package will be created. + + Yields + ------ + source_path : PackageFile + Either a Python file or a stub file that takes precedence. + + Notes + ----- + Files starting with "test_" are skipped entirely for now. + """ + queue = [root_dir] + while queue: + path = queue.pop(0) + + if path.is_dir(): + if _is_python_package(path): + queue.extend(path.iterdir()) + else: + logger.debug("skipping directory %s", path) + continue + + assert path.is_file() + + suffix = path.suffix.lower() + if suffix not in {".py", ".pyi"}: + continue + if suffix == ".py" and path.with_suffix(".pyi").exists(): + continue # Stub file already exists and takes precedence + + python_file = PackageFile(path, package_root=root_dir) + yield python_file + + +def walk_source_and_targets(root_dir, target_dir): + """Iterate modules in a Python package and its target stub files. Parameters ---------- @@ -25,37 +117,19 @@ def walk_python_package(root_dir, target_dir): Returns ------- - source_path : Path + source_path : PackageFile Either a Python file or a stub file that takes precedence. - stub_path : Path + stub_path : PackageFile Target stub file. Notes ----- Files starting with "test_" are skipped entirely for now. """ - for root, _, files in root_dir.walk(top_down=True): - for name in files: - source_path = root / name - - if name.startswith("test_"): - logger.debug("skipping %s", name) - continue - - if source_path.suffix.lower() not in {".py", ".pyi"}: - continue - - if ( - source_path.suffix.lower() == ".py" - and source_path.with_suffix(".pyi").exists() - ): - # Stub file already exists and takes precedence - continue - - stub_path = target_dir / source_path.with_suffix(".pyi").relative_to( - root_dir - ) - yield source_path, stub_path + for source_path in walk_source(root_dir): + stub_path = target_dir / source_path.with_suffix(".pyi").relative_to(root_dir) + stub_path = PackageFile(stub_path, package_root=target_dir) + yield source_path, stub_path def try_format_stub(stub: str) -> str: @@ -75,31 +149,47 @@ def try_format_stub(stub: str) -> str: return stub +class FuncType(enum.Enum): + MODULE = enum.auto() + CLASS = enum.auto() + FUNC = enum.auto() + METHOD = enum.auto() + CLASSMETHOD = enum.auto() + STATICMETHOD = enum.auto() + + @dataclass(slots=True, frozen=True) class _Scope: - type: Literal["module", "class", "func", "method", "classmethod", "staticmethod"] + """""" + + type: FuncType node: cst.CSTNode = None - def __post_init__(self): - allowed_types = { - "module", - "class", - "func", - "method", - "classmethod", - "staticmethod", + @property + def has_self_or_cls(self): + return self.type in {FuncType.METHOD, FuncType.CLASSMETHOD} + + @property + def is_method(self): + return self.type in { + FuncType.METHOD, + FuncType.CLASSMETHOD, + FuncType.STATICMETHOD, } - if self.type not in allowed_types: - msg = f"type {self.type!r} is not allowed, allowed are {allowed_types!r}" - raise ValueError(msg) @property - def has_self_or_cls(self): - return self.type in {"method", "classmethod"} + def is_class_init(self): + out = self.is_method and self.node.name.value == "__init__" + return out class Py2StubTransformer(cst.CSTTransformer): - """Transform syntax tree of a Python file into the tree of a stub file.""" + """Transform syntax tree of a Python file into the tree of a stub file. + + Attributes + ---------- + inspector : ~._analysis.StaticInspector + """ # Equivalent to ` ...`, to replace the body of callables with _body_replacement = cst.SimpleStatementSuite( @@ -109,19 +199,45 @@ class Py2StubTransformer(cst.CSTTransformer): _Annotation_Any = cst.Annotation(cst.Name("Any")) _Annotation_None = cst.Annotation(cst.Name("None")) - def __init__(self, *, docnames): - self.docnames = docnames + def __init__(self, *, inspector, replace_doctypes): + """ + Parameters + ---------- + inspector : ~._analysis.StaticInspector + replace_doctypes : dict[str, str] + """ + self.inspector = inspector + self.replace_doctypes = replace_doctypes + self.transformer = DoctypeTransformer( + inspector=inspector, replace_doctypes=replace_doctypes + ) # Relevant docstring for the current context - self._scope_stack = None # Store current class or function scope - self._pytypes_stack = None # Store current parameter types + self._scope_stack = None # Entered module, class or function scopes + self._pytypes_stack = None # Collected pytypes for each stack self._required_imports = None # Collect imports for used types + self._current_module = None - def python_to_stub(self, source: str) -> str: - """Convert Python source code to stub-file ready code.""" + def python_to_stub(self, source, *, module_path=None): + """Convert Python source code to stub-file ready code. + + Parameters + ---------- + source : str + module_path : PackageFile, optional + The location of the source that is transformed into a stub file. + If given, used to enhance logging & error messages with more + context information. + + Returns + ------- + stub : str + """ try: self._scope_stack = [] self._pytypes_stack = [] self._required_imports = set() + if module_path: + self.inspector.current_source = module_path source_tree = cst.parse_module(source) stub_tree = source_tree.visit(self) @@ -129,66 +245,101 @@ def python_to_stub(self, source: str) -> str: stub = try_format_stub(stub) return stub finally: - assert len(self._scope_stack) == 0 - assert len(self._pytypes_stack) == 0 self._scope_stack = None self._pytypes_stack = None self._required_imports = None + self.inspector.current_source = None def visit_ClassDef(self, node): - self._scope_stack.append(_Scope(type="class", node=node)) + """Collect pytypes from class docstring and add scope to stack. + + Parameters + ---------- + node : cst.ClassDef + + Returns + ------- + out : Literal[True] + """ + self._scope_stack.append(_Scope(type=FuncType.CLASS, node=node)) + pytypes = self._pytypes_from_node(node) + self._pytypes_stack.append(pytypes) return True def leave_ClassDef(self, original_node, updated_node): + """Drop class scope from the stack. + + Parameters + ---------- + original_node : cst.ClassDef + updated_node : cst.ClassDef + + Returns + ------- + updated_node : cst.ClassDef + """ self._scope_stack.pop() + self._pytypes_stack.pop() return updated_node def visit_FunctionDef(self, node): - func_type = "func" - if self._scope_stack[-1].type == "class": - func_type = "method" - for decorator in node.decorators: - assert func_type in {"func", "method"} - if decorator.decorator.value == "classmethod": - func_type = "classmethod" - break - if decorator.decorator.value == "staticmethod": - func_type = "staticmethod" - break - self._scope_stack.append(_Scope(type=func_type, node=node)) + """Collect pytypes from function docstring and add scope to stack. - docstring = node.get_docstring() - pytypes = None - if docstring: - try: - pytypes = collect_pytypes(docstring, docnames=self.docnames) - except Exception as e: - logger.exception( - "error while parsing docstring of `%s`:\n\n%s", node.name.value, e - ) + Parameters + ---------- + node : cst.FunctionDef + + Returns + ------- + out : Literal[True] + """ + func_type = self._function_type(node) + self._scope_stack.append(_Scope(type=func_type, node=node)) + pytypes = self._pytypes_from_node(node) self._pytypes_stack.append(pytypes) return True def leave_FunctionDef(self, original_node, updated_node): + """Add type annotation for return to function. + + Parameters + ---------- + original_node : cst.FunctionDef + updated_node : cst.FunctionDef + + Returns + ------- + updated_node : cst.FunctionDef + """ node_changes = { "body": self._body_replacement, "returns": self._Annotation_None, } - pytypes = self._pytypes_stack.pop() - if pytypes: - return_pytype = pytypes.get(ReturnKey) - if return_pytype: - node_changes["returns"] = cst.Annotation( - cst.parse_expression(return_pytype.value) - ) - self._required_imports |= return_pytype.imports + ds_annotations = self._pytypes_stack.pop() + if ds_annotations and ds_annotations.returns: + assert ds_annotations.returns.value + node_changes["returns"] = cst.Annotation( + cst.parse_expression(ds_annotations.returns.value) + ) + self._required_imports |= ds_annotations.returns.imports updated_node = updated_node.with_changes(**node_changes) self._scope_stack.pop() return updated_node def leave_Param(self, original_node, updated_node): + """Add type annotation to parameter. + + Parameters + ---------- + original_node : cst.Param + updated_node : cst.Param + + Returns + ------- + updated_node : cst.Param + """ node_changes = {} scope = self._scope_stack[-1] @@ -199,8 +350,10 @@ def leave_Param(self, original_node, updated_node): name = original_node.name.value pytypes = self._pytypes_stack[-1] + if not pytypes and scope.is_class_init: + pytypes = self._pytypes_stack[-2] if pytypes: - pytype = pytypes.get(name) + pytype = pytypes.parameters.get(name) if pytype: annotation = cst.Annotation(cst.parse_expression(pytype.value)) node_changes["annotation"] = annotation @@ -210,7 +363,8 @@ def leave_Param(self, original_node, updated_node): # Potentially use "Any" except for first param in (class)methods elif not is_self_or_cls and updated_node.annotation is None: node_changes["annotation"] = self._Annotation_Any - self._required_imports.add(self.docnames["Any"]) + _, known_import = self.inspector.query("Any") + self._required_imports.add(known_import) if updated_node.default is not None: node_changes["default"] = cst.Ellipsis() @@ -219,34 +373,207 @@ def leave_Param(self, original_node, updated_node): updated_node = updated_node.with_changes(**node_changes) return updated_node - def leave_Expr(self, original_node, upated_node): + def leave_Expr(self, original_node, updated_node): + """Drop expression from stub file. + + Parameters + ---------- + original_node : cst.Expr + updated_node : cst.Expr + + Returns + ------- + cst.RemovalSentinel + """ return cst.RemovalSentinel.REMOVE def leave_Comment(self, original_node, updated_node): + """Drop comment from stub file. + + Parameters + ---------- + original_node : cst.Comment + updated_node : cst.Comment + + Returns + ------- + cst.RemovalSentinel + """ return cst.RemovalSentinel.REMOVE + def leave_Assign(self, original_node, updated_node): + """Drop value of assign statements from stub files. + + Parameters + ---------- + original_node : cst.Assign + updated_node : cst.Assign + + Returns + ------- + updated_node : cst.Assign + """ + targets = cstm.findall(updated_node, cstm.AssignTarget()) + names_are__all__ = [ + name + for target in targets + for name in cstm.findall(target, cst.Name(value="__all__")) + ] + if not names_are__all__: + # TODO replace with AnnAssign if possible / figure out assign type? + updated_node = updated_node.with_changes(value=self._body_replacement) + return updated_node + def visit_Module(self, node): - self._scope_stack.append(_Scope(type="module", node=node)) + """Add module scope to stack. + + Parameters + ---------- + node : cst.Module + + Returns + ------- + Literal[True] + """ + self._scope_stack.append(_Scope(type=FuncType.MODULE, node=node)) + pytypes = self._pytypes_from_node(node) + self._pytypes_stack.append(pytypes) return True def leave_Module(self, original_node, updated_node): - import_nodes = self._parse_imports(self._required_imports) + """Add required type imports to module + + Parameters + ---------- + original_node : cst.Module + updated_node : cst.Module + + Returns + ------- + updated_node : cst.Module + """ + current_module = self.inspector.current_source.import_path + required_imports = [ + imp for imp in self._required_imports if imp.import_path != current_module + ] + import_nodes = self._parse_imports( + required_imports, current_module=current_module + ) updated_node = updated_node.with_changes(body=import_nodes + updated_node.body) self._scope_stack.pop() + self._pytypes_stack.pop() return updated_node + def visit_Lambda(self, node): + """Don't visit parameters fo lambda which can't have an annotation. + + Parameters + ---------- + node : cst.Lambda + + Returns + ------- + Literal[False] + """ + return False + + def leave_Decorator(self, original_node, updated_node): + """Drop decorators except for a few out of the SDL. + + Parameters + ---------- + original_node : cst.Decorator + updated_node : cst.Decorator + + Returns + ------- + cst.Decorator | cst.RemovalSentinel + """ + names = cstm.findall(original_node, cstm.Name()) + names = ".".join(name.value for name in names) + + allowlist = ( + "classmethod", + "staticmethod", + "property", + ".setter", + "abstractmethod", + "dataclass", + "coroutine", + ) + out = cst.RemovalSentinel.REMOVE + # TODO add decorators in typing module + for allowed in allowlist: + if allowed in names: + out = updated_node + return out + @staticmethod - def _parse_imports(imports): + def _parse_imports(imports, *, current_module=None): """Create nodes to include in the module tree from given imports. Parameters ---------- - imports : set[DocName] + imports : set[~.DocName] + current_module : str, optional Returns ------- import_nodes : tuple[cst.SimpleStatementLine, ...] """ - lines = {imp.format_import() for imp in imports} + lines = {imp.format_import(relative_to=current_module) for imp in imports} import_nodes = tuple(cst.parse_statement(line) for line in lines) return import_nodes + + def _function_type(self, func_def): + """Determine if a function is a method, property, staticmethod, ... + + Parameters + ---------- + func_def : cst.FunctionDef + + Returns + ------- + func_type : FuncType + """ + func_type = FuncType.FUNC + if self._scope_stack[-1].type == FuncType.CLASS: + func_type = FuncType.METHOD + for decorator in func_def.decorators: + if not hasattr(decorator.decorator, "value"): + continue + if decorator.decorator.value == "classmethod": + func_type = FuncType.CLASSMETHOD + break + if decorator.decorator.value == "staticmethod": + assert func_type == FuncType.METHOD + func_type = FuncType.STATICMETHOD + break + return func_type + + def _pytypes_from_node(self, node): + """Extract types from function, class or module docstrings. + + Parameters + ---------- + node : cst.FunctionDef | cst.ClassDef | cst.Module + + Returns + ------- + pytypes : dict[str, ~._docstrings.PyType] + """ + pytypes = None + docstring = node.get_docstring() + if docstring: + try: + pytypes = DocstringAnnotations( + docstring, + transformer=self.transformer, + ) + except Exception as e: + logger.exception( + "error while parsing docstring of `%s`:\n\n%s", + node.name.value, + e, + ) + return pytypes diff --git a/src/docstub/_utils.py b/src/docstub/_utils.py index e69de29..bcec015 100644 --- a/src/docstub/_utils.py +++ b/src/docstub/_utils.py @@ -0,0 +1,23 @@ +import itertools + + +def accumulate_qualname(qualname, *, start_right=False): + """Return possible partial names from a fully qualified one. + + Examples + -------- + >>> accumulate_qualname("a.b.c") + ('a', 'a.b', 'a.b.c') + >>> accumulate_qualname("a.b.c", start_right=True) + ('c', 'b.c', 'a.b.c') + """ + fragments = qualname.split(".") + if start_right is True: + fragments = reversed(fragments) + template = "{1}.{0}" + else: + template = "{0}.{1}" + out = tuple( + itertools.accumulate(fragments, func=lambda x, y: template.format(x, y)) + ) + return out diff --git a/src/docstub/default_config.toml b/src/docstub/default_config.toml index 1391b2a..2c042e2 100644 --- a/src/docstub/default_config.toml +++ b/src/docstub/default_config.toml @@ -5,85 +5,84 @@ extend_grammar = """ """ -# A mapping of docnames to import information. Each item maps a docname on the -# left side to a dictionary on the right side, which supports the following -# fields: -# use : A string to replace the docname with, defaults to the docname. -# from : Indicate that the docname can be imported from this path. -# import : Import this object, defaults to the docname. +# Import information for type annotations, declared ahead of time. +# +# Each item maps an annotation name on the left side to a dictionary on the +# right side. +# +# Import information can be declared with the following fields: +# from : Indicate that the DocType can be imported from this path. +# import : Import this object, defaults to the DocType. # as : Use this alias for the imported object -# is_builtin : Indicate that this docname doesn't need to be imported, +# is_builtin : Indicate that this DocType doesn't need to be imported, # defaults to "false" -[tool.docstub.docnames] - +[tool.docstub.known_imports] Path = { from = "pathlib" } - -function = { use = "Callable", from = "typing" } -func = { use = "Callable", from = "typing" } -callable = { use = "Callable", from = "typing" } - +Callable = { from = "typing" } np = { import = "numpy", as = "np" } -numpy = { use = "np", import = "numpy", as = "np" } - -scalar = { use = "np.ScalarType", import = "numpy", as = "np" } - -integer = { use = "np.integer", import = "numpy", as = "np" } -signedinteger = { use = "np.signedinteger", import = "numpy", as = "np" } -byte = { use = "np.byte", import = "numpy", as = "np" } -short = { use = "np.short", import = "numpy", as = "np" } -intc = { use = "np.intc", import = "numpy", as = "np" } -int_ = { use = "np.int_", import = "numpy", as = "np" } -longlong = { use = "np.longlong", import = "numpy", as = "np" } -int8 = { use = "np.int8", import = "numpy", as = "np" } -int16 = { use = "np.int16", import = "numpy", as = "np" } -int32 = { use = "np.int32", import = "numpy", as = "np" } -int64 = { use = "np.int64", import = "numpy", as = "np" } -intp = { use = "np.intp", import = "numpy", as = "np" } - -unsignedinteger = { use = "np.unsignedinteger", import = "numpy", as = "np" } -ushort = { use = "np.ushort", import = "numpy", as = "np" } -uintc = { use = "np.uintc", import = "numpy", as = "np" } -uint = { use = "np.uint", import = "numpy", as = "np" } -ulonglong = { use = "np.ulonglong", import = "numpy", as = "np" } -uint8 = { use = "np.uint8", import = "numpy", as = "np" } -uint16 = { use = "np.uint16", import = "numpy", as = "np" } -uint32 = { use = "np.uint32", import = "numpy", as = "np" } -uint64 = { use = "np.uint64", import = "numpy", as = "np" } -uintp = { use = "np.uintp", import = "numpy", as = "np" } - -floating = { use = "np.floating", import = "numpy", as = "np" } -#half = { use = "np.half", import = "numpy", as = "np" } -#single = { use = "np.single", import = "numpy", as = "np" } -double = { use = "np.double", import = "numpy", as = "np" } -longdouble = { use = "np.longdouble", import = "numpy", as = "np" } -float16 = { use = "np.float16", import = "numpy", as = "np" } -float32 = { use = "np.float32", import = "numpy", as = "np" } -float64 = { use = "np.float64", import = "numpy", as = "np" } -float96 = { use = "np.float96", import = "numpy", as = "np" } -float128 = { use = "np.float128", import = "numpy", as = "np" } - -complexfloating = { use = "np.complexfloating", import = "numpy", as = "np" } -csingle = { use = "np.csingle", import = "numpy", as = "np" } -cdouble = { use = "np.cdouble", import = "numpy", as = "np" } -clongdouble = { use = "np.clongdouble", import = "numpy", as = "np" } -complex64 = { use = "np.complex64", import = "numpy", as = "np" } -complex128 = { use = "np.complex128", import = "numpy", as = "np" } -complex192 = { use = "np.complex192", import = "numpy", as = "np" } -complex256 = { use = "np.complex256", import = "numpy", as = "np" } - -bool_ = { use = "np.bool_", import = "numpy", as = "np" } -datetime64 = { use = "np.datetime64", import = "numpy", as = "np" } -timedelta64 = { use = "np.timedelta64", import = "numpy", as = "np" } -object_ = { use = "np.object_", import = "numpy", as = "np" } -#flexible = { use = "np.flexible", import = "numpy", as = "np" } -#character = { use = "np.character", import = "numpy", as = "np" } -bytes_ = { use = "np.bytes_", import = "numpy", as = "np" } -#str_ = { use = "np.str_", import = "numpy", as = "np" } -#void = { use = "np.void", import = "numpy", as = "np" } - NDArray = { from = "numpy.typing" } -ndarray = { use = "NDArray", from = "numpy.typing" } -array = { use = "NDArray", from = "numpy.typing" } ArrayLike = { from = "numpy.typing" } -array-like = { use = "ArrayLike", from = "numpy.typing" } -array_like = { use = "ArrayLike", from = "numpy.typing" } + +# Specify human-friendly aliases that can be used instead of Python-parsable +# annotations. +# TODO rename to qualname_alias or something +[tool.docstub.replace_doctypes] +function = "Callable" +func = "Callable" +callable = "Callable" + +numpy = "np" +scalar = "np.ScalarType" +integer = "np.integer" +signedinteger = "np.signedinteger" +byte ="np.byte" +short = "np.short" +intc = "np.intc" +int_ = "np.int_" +longlong = "np.longlong" +int8 = "np.int8" +int16 = "np.int16" +int32 = "np.int32" +int64 = "np.int64" +intp = "np.intp" +unsignedinteger = "np.unsignedinteger" +ushort = "np.ushort" +uintc = "np.uintc" +uint = "np.uint" +ulonglong ="np.ulonglong" +uint8 = "np.uint8" +uint16 = "np.uint16" +uint32 = "np.uint32" +uint64 = "np.uint64" +uintp = "np.uintp" +floating = "np.floating" +#half = "np.half" +#single = "np.single" +double = "np.double" +longdouble = "np.longdouble" +float16 = "np.float16" +float32 = "np.float32" +float64 = "np.float64" +float96 = "np.float96" +float128 = "np.float128" +complexfloating = "np.complexfloating" +csingle = "np.csingle" +cdouble = "np.cdouble" +clongdouble = "np.clongdouble" +complex64 = "np.complex64" +complex128 = "np.complex128" +complex192 = "np.complex192" +complex256 = "np.complex256" +bool_ = "np.bool_" +datetime64 = "np.datetime64" +timedelta64 = "np.timedelta64" +object_ = "np.object_" +#flexible = "np.flexible" +#character = "np.character" +bytes_ = "np.bytes_" +#str_ = "np.str_" +#void = "np.void" +ndarray = "NDArray" +array = "NDArray" +array-like = "ArrayLike" +array_like = "ArrayLike" diff --git a/src/docstub/doctype.lark b/src/docstub/doctype.lark index 7cc4f95..053d6a4 100644 --- a/src/docstub/doctype.lark +++ b/src/docstub/doctype.lark @@ -1,30 +1,31 @@ -?start : doctype +?start : annotation -doctype : type_or ("," optional)? ("," extra_info)? +annotation : (literals | types_or) ("," optional)? ("," extra_info)? -type_or : type (("or" | "|") type)* +literals : "{" literal ("," literal)* "}" + +types_or : type (("or" | "|") type)* ?type : qualname - | "{" literal ("," literal)* "}" -> literals - | container_of + | sphinx_ref + | container | shape_n_dtype optional : "optional" - | "default" ("=" | ":") literal + | "default" ("=" | ":")? literal extra_info : /[^\r\n]+/ -// Name with leading dot separated path -qualname : (NAME ".")* NAME contains? - +sphinx_ref : ":" (NAME ":")? NAME ":`" qualname "`" -contains: "[" type_or ("," type_or)* "]" - | "[" type_or "," PY_ELLIPSES "]" - - -// Container-of -container_of : NAME "of" type_or +container: qualname "[" types_or ("," types_or)* "]" + | qualname "[" types_or "," PY_ELLIPSES "]" + | qualname "of" type + | qualname "of" "(" types_or ("," types_or)* ")" + | qualname "of" "{" types_or ":" types_or "}" +// Name with leading dot separated path +qualname : (/~/ ".")? (NAME ".")* NAME // Array-like form with dtype or shape information shape_n_dtype : shape? ARRAY_NAME ("of" dtype)? @@ -37,7 +38,7 @@ ARRAY_NAME : "array" | "ndarray" | "array-like" | "array_like" -dtype : NAME +dtype : qualname shape : "(" dim ",)" | "(" leading_optional? dim (("," dim | insert_optional))* ")" | NUMBER "-"? "D" diff --git a/tests/test_analysis.py b/tests/test_analysis.py new file mode 100644 index 0000000..8da4d13 --- /dev/null +++ b/tests/test_analysis.py @@ -0,0 +1,62 @@ +import pytest + +from docstub._analysis import KnownImport, StaticInspector + + +class Test_StaticInspector: + known_imports = { # noqa: RUF012 + "dict": KnownImport(builtin_name="dict"), + "np": KnownImport(import_name="numpy", import_alias="np"), + "foo.bar": KnownImport(import_path="foo", import_name="bar"), + "foo.bar.Baz": KnownImport(import_path="foo.bar", import_name="Baz"), + "foo.bar.Baz.Bix": KnownImport(import_path="foo.bar", import_name="Baz"), + "foo.bar.Baz.Qux": KnownImport(import_path="foo", import_name="bar"), + } + + # fmt: off + @pytest.mark.parametrize( + ("name", "exp_annotation", "exp_import_line"), + [ + ("np", "np", "import numpy as np"), + # Finds imports whose import target matches the start of `name` + ("np.doesnt_exist", "np.doesnt_exist", "import numpy as np"), + + ("foo.bar.Baz", "Baz", "from foo.bar import Baz"), + # Finds "Baz" with abbreviated form as well + ( "~.bar.Baz", "Baz", "from foo.bar import Baz"), + ( "~.Baz", "Baz", "from foo.bar import Baz"), + + # Finds nested class "Baz.Bix" + ("foo.bar.Baz.Bix", "Baz.Bix", "from foo.bar import Baz"), + ( "~.bar.Baz.Bix", "Baz.Bix", "from foo.bar import Baz"), + ( "~.Baz.Bix", "Baz.Bix", "from foo.bar import Baz"), + ( "~.Bix", "Baz.Bix", "from foo.bar import Baz"), + + # Finds nested class "Baz.Gul" that's not explicitly defined, but + # whose import target matches "Baz" + ("foo.bar.Baz.Gul", "Baz.Gul", "from foo.bar import Baz"), + # but abbreviated form doesn't work + ( "~.bar.Baz.Gul", None, None), + ( "~.Baz.Gul", None, None), + ( "~.Gul", None, None), + + # Finds nested class "bar.Baz.Qux" (import defines module as target) + ("foo.bar.Baz.Qux", "bar.Baz.Qux", "from foo import bar"), + ( "~.bar.Baz.Qux", "bar.Baz.Qux", "from foo import bar"), + ( "~.Baz.Qux", "bar.Baz.Qux", "from foo import bar"), + ( "~.Qux", "bar.Baz.Qux", "from foo import bar"), + ] + ) + def test_query(self, name, exp_annotation, exp_import_line): + inspector = StaticInspector(known_imports=self.known_imports.copy()) + + annotation, known_import = inspector.query(name) + + if exp_annotation is None and exp_import_line is None: + assert exp_annotation is annotation + assert exp_import_line is known_import + else: + assert str(known_import) == exp_import_line + assert annotation.startswith(known_import.target) + assert annotation == exp_annotation + # fmt: on