From 94b701afe24962b1da8fbdc3a354063e9b3a0464 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Lars=20Gr=C3=BCter?= Date: Wed, 21 May 2025 14:11:39 +0200 Subject: [PATCH 1/2] Add TODO: Uncouple docstrings & analysis module --- src/docstub/_docstrings.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/src/docstub/_docstrings.py b/src/docstub/_docstrings.py index 9b32e54..4c15586 100644 --- a/src/docstub/_docstrings.py +++ b/src/docstub/_docstrings.py @@ -11,6 +11,10 @@ import lark.visitors import numpydoc.docscrape as npds +# TODO Uncouple docstrings & analysis module +# It should be possible to transform docstrings without matching to valid +# types and imports. I think that could very well be done at a higher level, +# e.g. in the stubs module. from ._analysis import KnownImport, TypesDatabase from ._utils import DocstubError, ErrorReporter, accumulate_qualname, escape_qualname From 8304852b870d5ce25f289478e0686c1e9044d1ff Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Lars=20Gr=C3=BCter?= Date: Fri, 23 May 2025 09:20:48 +0200 Subject: [PATCH 2/2] Refactor & simplify names in configuration Hopefully this will be easier to explain. The previous names and structure weren't at all intuitive. In the process also simplify the architecture somewhat. It still doesn't feel completely clean but "baby steps". --- examples/docstub.toml | 28 +++---- examples/example_pkg/_basic.py | 2 +- pyproject.toml | 8 +- src/docstub/_analysis.py | 129 +++++++++++++++----------------- src/docstub/_cli.py | 56 ++++++++------ src/docstub/_config.py | 29 ++++--- src/docstub/_docstrings.py | 39 +++------- src/docstub/_stubs.py | 22 ++---- src/docstub/default_config.toml | 69 ++++++++++------- tests/test_analysis.py | 68 +++++++++++------ tests/test_config.py | 9 +++ 11 files changed, 243 insertions(+), 216 deletions(-) create mode 100644 tests/test_config.py diff --git a/examples/docstub.toml b/examples/docstub.toml index a16862e..5234a0d 100644 --- a/examples/docstub.toml +++ b/examples/docstub.toml @@ -1,20 +1,16 @@ [tool.docstub] -# TODO not implemented and used yet -extend_grammar = """ - -""" - -# Import information for type annotations, declared ahead of time. +# Prefixes for external modules to match types in docstrings. +# Docstub can't yet automatically discover where to import types from other +# packages from. Instead, you can provide this information explicitly. +# Any type in a docstring whose prefix matches the name given on the left side, +# will be associated with the given "module" on the right side. # -# Each item maps an annotation name on the left side to a dictionary on the -# right side. +# Examples: +# np = "numpy" +# Will match `np.uint8` and `np.typing.NDarray` and use "import numpy as np". # -# Import information can be declared with the following fields: -# from : Indicate that the DocType can be imported from this path. -# import : Import this object, defaults to the DocType. -# as : Use this alias for the imported object -# is_builtin : Indicate that this DocType doesn't need to be imported, -# defaults to "false" -[tool.docstub.known_imports] -configparser = {import = "configparser"} +# plt = "matplotlib.pyplot +# Will match `plt.Figure` use `import matplotlib.pyplot as plt`. +[tool.docstub.type_prefixes] +configparser = "configparser" diff --git a/examples/example_pkg/_basic.py b/examples/example_pkg/_basic.py index 4c98b60..b4e0152 100644 --- a/examples/example_pkg/_basic.py +++ b/examples/example_pkg/_basic.py @@ -58,7 +58,7 @@ def func_use_from_elsewhere(a1, a2, a3, a4): ---------- a1 : example_pkg.CustomException a2 : ExampleClass - a3 : example_pkg.CustomException.NestedClass + a3 : example_pkg._basic.ExampleClass.NestedClass a4 : ExampleClass.NestedClass Returns diff --git a/pyproject.toml b/pyproject.toml index e562caf..2efef8b 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -105,10 +105,10 @@ testpaths = [ [tool.coverage] run.source = ["docstub"] -[tool.docstub.known_imports] -cst = {import = "libcst", as="cst"} -lark = {import = "lark"} -numpydoc = {import = "numpydoc"} +[tool.docstub.type_prefixes] +cst = "libcst" +lark = "lark" +numpydoc = "numpydoc" [tool.mypy] diff --git a/src/docstub/_analysis.py b/src/docstub/_analysis.py index c39dae2..f196624 100644 --- a/src/docstub/_analysis.py +++ b/src/docstub/_analysis.py @@ -410,69 +410,63 @@ def _collect_type_annotation(self, stack): self.known_imports[qualname] = known_import -class TypesDatabase: - """A static database of collected types usable as an annotation. +class TypeMatcher: + """Match strings to collected type information. Attributes ---------- - current_source : Path | None - source_pkgs : list[Path] - known_imports : dict[str, KnownImport] - stats : dict[str, Any] + types : dict[str, KnownImport] + prefixes : dict[str, KnownImport] + aliases : dict[str, str] + successful_queries : int + unknown_qualnames : list + current_module : Path | None Examples -------- - >>> from docstub._analysis import TypesDatabase, common_known_imports - >>> db = TypesDatabase(known_imports=common_known_imports()) - >>> db.query("Any") + >>> from docstub._analysis import TypeMatcher, common_known_imports + >>> db = TypeMatcher() + >>> db.match("Any") ('Any', ) """ def __init__( self, *, - source_pkgs=None, - known_imports=None, + types=None, + prefixes=None, + aliases=None, ): """ Parameters ---------- - source_pkgs : list[Path], optional - known_imports : dict[str, KnownImport], optional - If not provided, defaults to imports returned by - :func:`common_known_imports`. + types : dict[str, KnownImport] + prefixes : dict[str, KnownImport] + aliases : dict[str, str] """ - if source_pkgs is None: - source_pkgs = [] - if known_imports is None: - known_imports = common_known_imports() - - self.current_source = None - self.source_pkgs = source_pkgs + self.types = types or common_known_imports() + self.prefixes = prefixes or {} + self.aliases = aliases or {} + self.successful_queries = 0 + self.unknown_qualnames = [] - self.known_imports = known_imports + self.current_module = None - self.stats = { - "successful_queries": 0, - "unknown_doctypes": [], - } - - def query(self, search_name): + def match(self, search_name): """Search for a known annotation name. Parameters ---------- search_name : str + current_module : Path, optional Returns ------- - annotation_name : str | None - If it was found, the name of the annotation that matches the `known_import`. - known_import : KnownImport | None - If it was found, import information matching the `annotation_name`. + type_name : str | None + type_origin : KnownImport | None """ - annotation_name = None - known_import = None + type_name = None + type_origin = None if search_name.startswith("~."): # Sphinx like matching with abbreviated name @@ -481,63 +475,64 @@ def query(self, search_name): regex = re.compile(pattern + "$") # Might be slow, but works for now matches = { - key: value - for key, value in self.known_imports.items() - if regex.match(key) + key: value for key, value in self.types.items() if regex.match(key) } if len(matches) > 1: shortest_key = sorted(matches.keys(), key=lambda x: len(x))[0] - known_import = matches[shortest_key] - annotation_name = shortest_key + type_origin = matches[shortest_key] + type_name = shortest_key logger.warning( "%r in %s matches multiple types %r, using %r", search_name, - self.current_source, + self.current_module or "", matches.keys(), shortest_key, ) elif len(matches) == 1: - annotation_name, known_import = matches.popitem() + type_name, type_origin = matches.popitem() else: search_name = search_name[2:] logger.debug( - "couldn't match %r in %s", search_name, self.current_source + "couldn't match %r in %s", + search_name, + self.current_module or "", ) - if known_import is None and self.current_source: + # Replace alias + search_name = self.aliases.get(search_name, search_name) + + if type_origin is None and self.current_module: # Try scope of current module - module_name = module_name_from_path(self.current_source) + module_name = module_name_from_path(self.current_module) try_qualname = f"{module_name}.{search_name}" - known_import = self.known_imports.get(try_qualname) - if known_import: - annotation_name = search_name + type_origin = self.types.get(try_qualname) + if type_origin: + type_name = search_name + + if type_origin is None and search_name in self.types: + type_name = search_name + type_origin = self.types[search_name] - if known_import is None: + if type_origin is None: # Try a subset of the qualname (first 'a.b.c', then 'a.b' and 'a') for partial_qualname in reversed(accumulate_qualname(search_name)): - known_import = self.known_imports.get(partial_qualname) - if known_import: - annotation_name = search_name + type_origin = self.prefixes.get(partial_qualname) + if type_origin: + type_name = search_name break if ( - known_import is not None - and annotation_name is not None - and annotation_name != known_import.target - and not annotation_name.startswith(known_import.target) + type_origin is not None + and type_name is not None + and type_name != type_origin.target + and not type_name.startswith(type_origin.target) ): # Ensure that the annotation matches the import target - annotation_name = annotation_name[ - annotation_name.find(known_import.target) : - ] + type_name = type_name[type_name.find(type_origin.target) :] - if annotation_name is not None: - self.stats["successful_queries"] += 1 + if type_name is not None: + self.successful_queries += 1 else: - self.stats["unknown_doctypes"].append(search_name) + self.unknown_qualnames.append(search_name) - return annotation_name, known_import - - def __repr__(self) -> str: - repr = f"{type(self).__name__}({self.source_pkgs})" - return repr + return type_name, type_origin diff --git a/src/docstub/_cli.py b/src/docstub/_cli.py index c22820e..7af74a7 100644 --- a/src/docstub/_cli.py +++ b/src/docstub/_cli.py @@ -10,7 +10,7 @@ from ._analysis import ( KnownImport, TypeCollector, - TypesDatabase, + TypeMatcher, common_known_imports, ) from ._cache import FileCache @@ -76,19 +76,18 @@ def _setup_logging(*, verbose): ) -def _build_import_map(config, root_path): - """Build a map of known imports. +def _collect_types(root_path): + """Collect types. Parameters ---------- - config : ~.Config root_path : Path Returns ------- - imports : dict[str, ~.KnownImport] + types : dict[str, ~.KnownImport] """ - known_imports = common_known_imports() + types = common_known_imports() collect_cached_types = FileCache( func=TypeCollector.collect, @@ -99,12 +98,10 @@ def _build_import_map(config, root_path): if root_path.is_dir(): for source_path in walk_python_package(root_path): logger.info("collecting types in %s", source_path) - known_imports_in_source = collect_cached_types(source_path) - known_imports.update(known_imports_in_source) - - known_imports.update(KnownImport.many_from_config(config.known_imports)) + types_in_source = collect_cached_types(source_path) + types.update(types_in_source) - return known_imports + return types @contextmanager @@ -195,15 +192,26 @@ def main(root_path, out_dir, config_path, group_errors, allow_errors, verbose): ) config = _load_configuration(config_path) - known_imports = _build_import_map(config, root_path) + + types = common_known_imports() + types |= _collect_types(root_path) + types |= { + type_name: KnownImport(import_path=module, import_name=type_name) + for type_name, module in config.types.items() + } + + prefixes = { + prefix: ( + KnownImport(import_name=module, import_alias=prefix) + if module != prefix + else KnownImport(import_name=prefix) + ) + for prefix, module in config.type_prefixes.items() + } reporter = GroupedErrorReporter() if group_errors else ErrorReporter() - types_db = TypesDatabase( - source_pkgs=[root_path.parent.resolve()], known_imports=known_imports - ) - stub_transformer = Py2StubTransformer( - types_db=types_db, replace_doctypes=config.replace_doctypes, reporter=reporter - ) + matcher = TypeMatcher(types=types, prefixes=prefixes, aliases=config.type_aliases) + stub_transformer = Py2StubTransformer(matcher=matcher, reporter=reporter) if not out_dir: if root_path.is_file(): @@ -246,22 +254,22 @@ def main(root_path, out_dir, config_path, group_errors, allow_errors, verbose): reporter.print_grouped() # Report basic statistics - successful_queries = types_db.stats["successful_queries"] + successful_queries = matcher.successful_queries click.secho(f"{successful_queries} matched annotations", fg="green") syntax_error_count = stub_transformer.transformer.stats["syntax_errors"] if syntax_error_count: click.secho(f"{syntax_error_count} syntax errors", fg="red") - unknown_doctypes = types_db.stats["unknown_doctypes"] - if unknown_doctypes: - click.secho(f"{len(unknown_doctypes)} unknown doctypes", fg="red") - counter = Counter(unknown_doctypes) + unknown_qualnames = matcher.unknown_qualnames + if unknown_qualnames: + click.secho(f"{len(unknown_qualnames)} unknown type names", fg="red") + counter = Counter(unknown_qualnames) sorted_item_counts = sorted(counter.items(), key=lambda x: x[1], reverse=True) for item, count in sorted_item_counts: click.echo(f" {item} (x{count})") - total_errors = len(unknown_doctypes) + syntax_error_count + total_errors = len(unknown_qualnames) + syntax_error_count total_msg = f"{total_errors} total errors" if allow_errors: total_msg = f"{total_msg} (allowed {allow_errors})" diff --git a/src/docstub/_config.py b/src/docstub/_config.py index c9dff3a..af100eb 100644 --- a/src/docstub/_config.py +++ b/src/docstub/_config.py @@ -11,9 +11,9 @@ class Config: DEFAULT_CONFIG_PATH: ClassVar[Path] = Path(__file__).parent / "default_config.toml" - extend_grammar: str = "" - known_imports: dict[str, dict[str, str]] = dataclasses.field(default_factory=dict) - replace_doctypes: dict[str, str] = dataclasses.field(default_factory=dict) + types: dict[str, str] = dataclasses.field(default_factory=dict) + type_prefixes: dict[str, str] = dataclasses.field(default_factory=dict) + type_aliases: dict[str, str] = dataclasses.field(default_factory=dict) _source: tuple[Path, ...] = () @@ -61,9 +61,9 @@ def merge(self, other): if not isinstance(other, type(self)): return NotImplemented new = Config( - extend_grammar=self.extend_grammar + other.extend_grammar, - known_imports=self.known_imports | other.known_imports, - replace_doctypes=self.replace_doctypes | other.replace_doctypes, + types=self.types | other.types, + type_prefixes=self.type_prefixes | other.type_prefixes, + type_aliases=self.type_aliases | other.type_aliases, _source=self._source + other._source, ) logger.debug("merged Config from %s", new._source) @@ -73,14 +73,19 @@ def to_dict(self): return dataclasses.asdict(self) def __post_init__(self): - if not isinstance(self.extend_grammar, str): - raise TypeError("extended_grammar must be a string") - if not isinstance(self.known_imports, dict): - raise TypeError("known_imports must be a dict") - if not isinstance(self.replace_doctypes, dict): - raise TypeError("replace_doctypes must be a string") + self.validate(self.to_dict()) def __repr__(self) -> str: sources = " | ".join(str(s) for s in self._source) formatted = f"<{type(self).__name__}: {sources}>" return formatted + + @staticmethod + def validate(mapping): + for name in ["types", "type_prefixes", "type_aliases"]: + table = mapping[name] + if not isinstance(table, dict): + raise TypeError(f"{name} must be a dict") + for key, value in table.items(): + if not isinstance(key, str) or not isinstance(value, str): + raise TypeError(f"`{key} = {value}` in {name} must both be a str") diff --git a/src/docstub/_docstrings.py b/src/docstub/_docstrings.py index 4c15586..003cef5 100644 --- a/src/docstub/_docstrings.py +++ b/src/docstub/_docstrings.py @@ -15,8 +15,8 @@ # It should be possible to transform docstrings without matching to valid # types and imports. I think that could very well be done at a higher level, # e.g. in the stubs module. -from ._analysis import KnownImport, TypesDatabase -from ._utils import DocstubError, ErrorReporter, accumulate_qualname, escape_qualname +from ._analysis import KnownImport, TypeMatcher +from ._utils import DocstubError, ErrorReporter, escape_qualname logger = logging.getLogger(__name__) @@ -175,8 +175,7 @@ class DoctypeTransformer(lark.visitors.Transformer): Attributes ---------- - types_db : ~.TypesDatabase - replace_doctypes : dict[str, str] + matcher : ~.TypeMatcher stats : dict[str, Any] blacklisted_qualnames : ClassVar[frozenset[str]] All Python keywords [1]_ are blacklisted from use in qualnames except for ``True`` @@ -235,26 +234,18 @@ class DoctypeTransformer(lark.visitors.Transformer): } ) - def __init__(self, *, types_db=None, replace_doctypes=None, **kwargs): + def __init__(self, *, matcher=None, **kwargs): """ Parameters ---------- - types_db : ~.TypesDatabase, optional - A static database of collected types usable as an annotation. If - not given, defaults to a database with common types from the - standard library (see :func:`~.common_known_imports`). - replace_doctypes : dict[str, str], optional - Replacements for human-friendly aliases. + matcher : ~.TypeMatcher, optional kwargs : dict[Any, Any], optional Keyword arguments passed to the init of the parent class. """ - if replace_doctypes is None: - replace_doctypes = {} - if types_db is None: - types_db = TypesDatabase() + if matcher is None: + matcher = TypeMatcher() - self.types_db = types_db - self.replace_doctypes = replace_doctypes + self.matcher = matcher self._collected_imports = None self._unknown_qualnames = None @@ -312,12 +303,6 @@ def qualname(self, tree): children = tree.children _qualname = ".".join(children) - for partial_qualname in accumulate_qualname(_qualname): - replacement = self.replace_doctypes.get(partial_qualname) - if replacement: - _qualname = _qualname.replace(partial_qualname, replacement) - break - _qualname = self._match_import(_qualname, meta=tree.meta) if _qualname in self.blacklisted_qualnames: @@ -393,8 +378,8 @@ def natlang_literal(self, tree): out, ) - if self.types_db is not None: - _, known_import = self.types_db.query("Literal") + if self.matcher is not None: + _, known_import = self.matcher.match("Literal") if known_import: self._collected_imports.add(known_import) return out @@ -523,8 +508,8 @@ def _match_import(self, qualname, *, meta): matched_qualname : str Possibly modified or normalized qualname. """ - if self.types_db is not None: - annotation_name, known_import = self.types_db.query(qualname) + if self.matcher is not None: + annotation_name, known_import = self.matcher.match(qualname) else: annotation_name = None known_import = None diff --git a/src/docstub/_stubs.py b/src/docstub/_stubs.py index 46854cf..801ffac 100644 --- a/src/docstub/_stubs.py +++ b/src/docstub/_stubs.py @@ -350,8 +350,6 @@ class Py2StubTransformer(cst.CSTTransformer): Attributes ---------- - types_db : ~.TypesDatabase | None - replace_doctypes : dict[str, str] | None transformer : ~.DoctypeTransformer References @@ -381,22 +379,17 @@ def print_upper(x: Incomplete) -> None: ... ) _Annotation_None: ClassVar[cst.Annotation] = cst.Annotation(cst.Name("None")) - def __init__(self, *, types_db=None, replace_doctypes=None, reporter=None): + def __init__(self, *, matcher=None, reporter=None): """ Parameters ---------- - types_db : ~.TypesDatabase - replace_doctypes : dict[str, str] + matcher : ~.TypeMatcher reporter : ~.ErrorReporter """ if reporter is None: reporter = ErrorReporter() - self.types_db = types_db - self.replace_doctypes = replace_doctypes - self.transformer = DoctypeTransformer( - types_db=types_db, replace_doctypes=replace_doctypes - ) + self.transformer = DoctypeTransformer(matcher=matcher) self.reporter = reporter # Relevant docstring for the current context self._scope_stack = None # Entered module, class or function scopes @@ -423,8 +416,10 @@ def current_source(self, value): value : Path """ self._current_source = value - if self.types_db is not None: - self.types_db.current_source = value + # TODO pass current_source directly when using the transformer / matcher + # instead of assigning it here! + if self.transformer is not None and self.transformer.matcher is not None: + self.transformer.matcher.current_module = value @property def is_inside_function_def(self): @@ -585,8 +580,7 @@ def leave_FunctionDef(self, original_node, updated_node): ) replaced = _inline_node_as_code(original_node.returns.annotation) details = ( - f"{replaced}\n" - f"{reporter.underline(replaced)} -> {annotation_value}" + f"{replaced}\n{reporter.underline(replaced)} -> {annotation_value}" ) reporter.message( short="Replacing existing inline return annotation", diff --git a/src/docstub/default_config.toml b/src/docstub/default_config.toml index c6b0f6c..5f94ff1 100644 --- a/src/docstub/default_config.toml +++ b/src/docstub/default_config.toml @@ -1,32 +1,43 @@ [tool.docstub] -# TODO not implemented and used yet -extend_grammar = """ +# Types and their external modules to use in docstrings. +# Docstub can't yet automatically discover where to import types from other +# packages from. Instead, you can provide this information explicitly. +# Any type on the left side will be associated with the given "module" on the +# right side. +# +# Examples: +# Path = "pathlib" +# Will allow using "Path" and use "from pathlib import Path". +# +# NDArray = "numpy.typing" +# Will allow "NDarray" and use "from numpy.typing import NDArray". +[tool.docstub.types] +Path = "pathlib" +NDArray = "numpy.typing" +ArrayLike = "numpy.typing" -""" -# Import information for type annotations, declared ahead of time. +# Prefixes for external modules to match types in docstrings. +# Docstub can't yet automatically discover where to import types from other +# packages from. Instead, you can provide this information explicitly. +# Any type in a docstring whose prefix matches the name given on the left side, +# will be associated with the given "module" on the right side. # -# Each item maps an annotation name on the left side to a dictionary on the -# right side. +# Examples: +# np = "numpy" +# Will match `np.uint8` and `np.typing.NDarray` and use "import numpy as np". # -# Import information can be declared with the following fields: -# from : Indicate that the DocType can be imported from this path. -# import : Import this object, defaults to the DocType. -# as : Use this alias for the imported object -# is_builtin : Indicate that this DocType doesn't need to be imported, -# defaults to "false" -[tool.docstub.known_imports] -Path = { from = "pathlib" } -Callable = { from = "collections.abc" } -np = { import = "numpy", as = "np" } -NDArray = { from = "numpy.typing" } -ArrayLike = { from = "numpy.typing" } +# plt = "matplotlib.pyplot +# Will match `plt.Figure` use `import matplotlib.pyplot as plt`. +[tool.docstub.type_prefixes] +np = "numpy" +numpy = "numpy" + -# Specify human-friendly aliases that can be used instead of Python-parsable -# annotations. -# TODO rename to qualname_alias or something -[tool.docstub.replace_doctypes] +# Specify human-friendly aliases that can be used in docstrings to describe +# valid Python types or annotations. +[tool.docstub.type_aliases] iterable = "Iterable" callable = "Callable" function = "Callable" @@ -34,14 +45,14 @@ func = "Callable" sequence = "Sequence" mapping = "Mapping" -numpy = "np" +# NumPy scalar = "np.ScalarType" integer = "np.integer" signedinteger = "np.signedinteger" -byte ="np.byte" -short = "np.short" +#byte ="np.byte" +#short = "np.short" intc = "np.intc" -int_ = "np.int_" +#int_ = "np.int_" longlong = "np.longlong" int8 = "np.int8" int16 = "np.int16" @@ -76,13 +87,13 @@ complex64 = "np.complex64" complex128 = "np.complex128" complex192 = "np.complex192" complex256 = "np.complex256" -bool_ = "np.bool_" +#bool_ = "np.bool_" datetime64 = "np.datetime64" timedelta64 = "np.timedelta64" -object_ = "np.object_" +#object_ = "np.object_" #flexible = "np.flexible" #character = "np.character" -bytes_ = "np.bytes_" +#bytes_ = "np.bytes_" #str_ = "np.str_" #void = "np.void" ndarray = "NDArray" diff --git a/tests/test_analysis.py b/tests/test_analysis.py index d55a8e7..91a6f09 100644 --- a/tests/test_analysis.py +++ b/tests/test_analysis.py @@ -2,7 +2,7 @@ import pytest -from docstub._analysis import KnownImport, TypeCollector, TypesDatabase +from docstub._analysis import KnownImport, TypeCollector, TypeMatcher @pytest.fixture @@ -91,10 +91,14 @@ def test_ignores_assigns(self, module_factory, src): assert len(imports) == 0 -class Test_TypesDatabase: - known_imports = { # noqa: RUF012 - "dict": KnownImport(builtin_name="dict"), +class Test_TypeMatcher: + type_prefixes = { # noqa: RUF012 "np": KnownImport(import_name="numpy", import_alias="np"), + "foo.bar.Baz": KnownImport(import_path="foo.bar", import_name="Baz"), + } + + types = { # noqa: RUF012 + "dict": KnownImport(builtin_name="dict"), "foo.bar": KnownImport(import_path="foo", import_name="bar"), "foo.bar.Baz": KnownImport(import_path="foo.bar", import_name="Baz"), "foo.bar.Baz.Bix": KnownImport(import_path="foo.bar", import_name="Baz"), @@ -103,12 +107,8 @@ class Test_TypesDatabase: # fmt: off @pytest.mark.parametrize( - ("name", "exp_annotation", "exp_import_line"), + ("search_name", "expected_name", "expected_origin"), [ - ("np", "np", "import numpy as np"), - # Finds imports whose import target matches the start of `name` - ("np.doesnt_exist", "np.doesnt_exist", "import numpy as np"), - ("foo.bar.Baz", "Baz", "from foo.bar import Baz"), # Finds "Baz" with abbreviated form as well ( "~.bar.Baz", "Baz", "from foo.bar import Baz"), @@ -120,10 +120,8 @@ class Test_TypesDatabase: ( "~.Baz.Bix", "Baz.Bix", "from foo.bar import Baz"), ( "~.Bix", "Baz.Bix", "from foo.bar import Baz"), - # Finds nested class "Baz.Gul" that's not explicitly defined, but - # whose import target matches "Baz" - ("foo.bar.Baz.Gul", "Baz.Gul", "from foo.bar import Baz"), - # but abbreviated form doesn't work + # Abbreviated form with not explicitly defined class "Baz.Gul" + # never matches ( "~.bar.Baz.Gul", None, None), ( "~.Baz.Gul", None, None), ( "~.Gul", None, None), @@ -135,16 +133,42 @@ class Test_TypesDatabase: ( "~.Qux", "bar.Baz.Qux", "from foo import bar"), ] ) - def test_query(self, name, exp_annotation, exp_import_line): - db = TypesDatabase(known_imports=self.known_imports.copy()) + def test_query_types(self, search_name, expected_name, expected_origin): + db = TypeMatcher(types=self.types.copy()) + + type_name, type_origin = db.match(search_name) + + if expected_name is None and expected_origin is None: + assert expected_name is type_name + assert expected_origin is type_origin + else: + assert str(type_origin) == expected_origin + assert type_name.startswith(type_origin.target) + assert type_name == expected_name + # fmt: on + + # fmt: off + @pytest.mark.parametrize( + ("search_name", "expected_name", "expected_origin"), + [ + ("np", "np", "import numpy as np"), + # Finds imports whose import target matches the start of `name` + ("np.doesnt_exist", "np.doesnt_exist", "import numpy as np"), + # Finds nested class "Baz.Gul" that's not explicitly defined, but + # whose import target matches "Baz" + ("foo.bar.Baz.Gul", "Baz.Gul", "from foo.bar import Baz"), + ] + ) + def test_query_prefix(self, search_name, expected_name, expected_origin): + db = TypeMatcher(prefixes=self.type_prefixes.copy()) - annotation, known_import = db.query(name) + type_name, type_origin = db.match(search_name) - if exp_annotation is None and exp_import_line is None: - assert exp_annotation is annotation - assert exp_import_line is known_import + if expected_name is None and expected_origin is None: + assert expected_name is type_name + assert expected_origin is type_origin else: - assert str(known_import) == exp_import_line - assert annotation.startswith(known_import.target) - assert annotation == exp_annotation + assert str(type_origin) == expected_origin + assert type_name.startswith(type_origin.target) + assert type_name == expected_name # fmt: on diff --git a/tests/test_config.py b/tests/test_config.py new file mode 100644 index 0000000..1cb0222 --- /dev/null +++ b/tests/test_config.py @@ -0,0 +1,9 @@ +from docstub._config import Config + + +class Test_Config: + def test_from_default(self): + config = Config.from_default() + assert len(config.types) > 0 + assert len(config.type_prefixes) > 0 + assert len(config.type_aliases) > 0