diff --git a/examples/docstub.toml b/examples/docstub.toml index a16862e..5234a0d 100644 --- a/examples/docstub.toml +++ b/examples/docstub.toml @@ -1,20 +1,16 @@ [tool.docstub] -# TODO not implemented and used yet -extend_grammar = """ - -""" - -# Import information for type annotations, declared ahead of time. +# Prefixes for external modules to match types in docstrings. +# Docstub can't yet automatically discover where to import types from other +# packages from. Instead, you can provide this information explicitly. +# Any type in a docstring whose prefix matches the name given on the left side, +# will be associated with the given "module" on the right side. # -# Each item maps an annotation name on the left side to a dictionary on the -# right side. +# Examples: +# np = "numpy" +# Will match `np.uint8` and `np.typing.NDarray` and use "import numpy as np". # -# Import information can be declared with the following fields: -# from : Indicate that the DocType can be imported from this path. -# import : Import this object, defaults to the DocType. -# as : Use this alias for the imported object -# is_builtin : Indicate that this DocType doesn't need to be imported, -# defaults to "false" -[tool.docstub.known_imports] -configparser = {import = "configparser"} +# plt = "matplotlib.pyplot +# Will match `plt.Figure` use `import matplotlib.pyplot as plt`. +[tool.docstub.type_prefixes] +configparser = "configparser" diff --git a/examples/example_pkg/_basic.py b/examples/example_pkg/_basic.py index 4c98b60..b4e0152 100644 --- a/examples/example_pkg/_basic.py +++ b/examples/example_pkg/_basic.py @@ -58,7 +58,7 @@ def func_use_from_elsewhere(a1, a2, a3, a4): ---------- a1 : example_pkg.CustomException a2 : ExampleClass - a3 : example_pkg.CustomException.NestedClass + a3 : example_pkg._basic.ExampleClass.NestedClass a4 : ExampleClass.NestedClass Returns diff --git a/pyproject.toml b/pyproject.toml index e562caf..2efef8b 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -105,10 +105,10 @@ testpaths = [ [tool.coverage] run.source = ["docstub"] -[tool.docstub.known_imports] -cst = {import = "libcst", as="cst"} -lark = {import = "lark"} -numpydoc = {import = "numpydoc"} +[tool.docstub.type_prefixes] +cst = "libcst" +lark = "lark" +numpydoc = "numpydoc" [tool.mypy] diff --git a/src/docstub/_analysis.py b/src/docstub/_analysis.py index c39dae2..f196624 100644 --- a/src/docstub/_analysis.py +++ b/src/docstub/_analysis.py @@ -410,69 +410,63 @@ def _collect_type_annotation(self, stack): self.known_imports[qualname] = known_import -class TypesDatabase: - """A static database of collected types usable as an annotation. +class TypeMatcher: + """Match strings to collected type information. Attributes ---------- - current_source : Path | None - source_pkgs : list[Path] - known_imports : dict[str, KnownImport] - stats : dict[str, Any] + types : dict[str, KnownImport] + prefixes : dict[str, KnownImport] + aliases : dict[str, str] + successful_queries : int + unknown_qualnames : list + current_module : Path | None Examples -------- - >>> from docstub._analysis import TypesDatabase, common_known_imports - >>> db = TypesDatabase(known_imports=common_known_imports()) - >>> db.query("Any") + >>> from docstub._analysis import TypeMatcher, common_known_imports + >>> db = TypeMatcher() + >>> db.match("Any") ('Any', ) """ def __init__( self, *, - source_pkgs=None, - known_imports=None, + types=None, + prefixes=None, + aliases=None, ): """ Parameters ---------- - source_pkgs : list[Path], optional - known_imports : dict[str, KnownImport], optional - If not provided, defaults to imports returned by - :func:`common_known_imports`. + types : dict[str, KnownImport] + prefixes : dict[str, KnownImport] + aliases : dict[str, str] """ - if source_pkgs is None: - source_pkgs = [] - if known_imports is None: - known_imports = common_known_imports() - - self.current_source = None - self.source_pkgs = source_pkgs + self.types = types or common_known_imports() + self.prefixes = prefixes or {} + self.aliases = aliases or {} + self.successful_queries = 0 + self.unknown_qualnames = [] - self.known_imports = known_imports + self.current_module = None - self.stats = { - "successful_queries": 0, - "unknown_doctypes": [], - } - - def query(self, search_name): + def match(self, search_name): """Search for a known annotation name. Parameters ---------- search_name : str + current_module : Path, optional Returns ------- - annotation_name : str | None - If it was found, the name of the annotation that matches the `known_import`. - known_import : KnownImport | None - If it was found, import information matching the `annotation_name`. + type_name : str | None + type_origin : KnownImport | None """ - annotation_name = None - known_import = None + type_name = None + type_origin = None if search_name.startswith("~."): # Sphinx like matching with abbreviated name @@ -481,63 +475,64 @@ def query(self, search_name): regex = re.compile(pattern + "$") # Might be slow, but works for now matches = { - key: value - for key, value in self.known_imports.items() - if regex.match(key) + key: value for key, value in self.types.items() if regex.match(key) } if len(matches) > 1: shortest_key = sorted(matches.keys(), key=lambda x: len(x))[0] - known_import = matches[shortest_key] - annotation_name = shortest_key + type_origin = matches[shortest_key] + type_name = shortest_key logger.warning( "%r in %s matches multiple types %r, using %r", search_name, - self.current_source, + self.current_module or "", matches.keys(), shortest_key, ) elif len(matches) == 1: - annotation_name, known_import = matches.popitem() + type_name, type_origin = matches.popitem() else: search_name = search_name[2:] logger.debug( - "couldn't match %r in %s", search_name, self.current_source + "couldn't match %r in %s", + search_name, + self.current_module or "", ) - if known_import is None and self.current_source: + # Replace alias + search_name = self.aliases.get(search_name, search_name) + + if type_origin is None and self.current_module: # Try scope of current module - module_name = module_name_from_path(self.current_source) + module_name = module_name_from_path(self.current_module) try_qualname = f"{module_name}.{search_name}" - known_import = self.known_imports.get(try_qualname) - if known_import: - annotation_name = search_name + type_origin = self.types.get(try_qualname) + if type_origin: + type_name = search_name + + if type_origin is None and search_name in self.types: + type_name = search_name + type_origin = self.types[search_name] - if known_import is None: + if type_origin is None: # Try a subset of the qualname (first 'a.b.c', then 'a.b' and 'a') for partial_qualname in reversed(accumulate_qualname(search_name)): - known_import = self.known_imports.get(partial_qualname) - if known_import: - annotation_name = search_name + type_origin = self.prefixes.get(partial_qualname) + if type_origin: + type_name = search_name break if ( - known_import is not None - and annotation_name is not None - and annotation_name != known_import.target - and not annotation_name.startswith(known_import.target) + type_origin is not None + and type_name is not None + and type_name != type_origin.target + and not type_name.startswith(type_origin.target) ): # Ensure that the annotation matches the import target - annotation_name = annotation_name[ - annotation_name.find(known_import.target) : - ] + type_name = type_name[type_name.find(type_origin.target) :] - if annotation_name is not None: - self.stats["successful_queries"] += 1 + if type_name is not None: + self.successful_queries += 1 else: - self.stats["unknown_doctypes"].append(search_name) + self.unknown_qualnames.append(search_name) - return annotation_name, known_import - - def __repr__(self) -> str: - repr = f"{type(self).__name__}({self.source_pkgs})" - return repr + return type_name, type_origin diff --git a/src/docstub/_cli.py b/src/docstub/_cli.py index c22820e..7af74a7 100644 --- a/src/docstub/_cli.py +++ b/src/docstub/_cli.py @@ -10,7 +10,7 @@ from ._analysis import ( KnownImport, TypeCollector, - TypesDatabase, + TypeMatcher, common_known_imports, ) from ._cache import FileCache @@ -76,19 +76,18 @@ def _setup_logging(*, verbose): ) -def _build_import_map(config, root_path): - """Build a map of known imports. +def _collect_types(root_path): + """Collect types. Parameters ---------- - config : ~.Config root_path : Path Returns ------- - imports : dict[str, ~.KnownImport] + types : dict[str, ~.KnownImport] """ - known_imports = common_known_imports() + types = common_known_imports() collect_cached_types = FileCache( func=TypeCollector.collect, @@ -99,12 +98,10 @@ def _build_import_map(config, root_path): if root_path.is_dir(): for source_path in walk_python_package(root_path): logger.info("collecting types in %s", source_path) - known_imports_in_source = collect_cached_types(source_path) - known_imports.update(known_imports_in_source) - - known_imports.update(KnownImport.many_from_config(config.known_imports)) + types_in_source = collect_cached_types(source_path) + types.update(types_in_source) - return known_imports + return types @contextmanager @@ -195,15 +192,26 @@ def main(root_path, out_dir, config_path, group_errors, allow_errors, verbose): ) config = _load_configuration(config_path) - known_imports = _build_import_map(config, root_path) + + types = common_known_imports() + types |= _collect_types(root_path) + types |= { + type_name: KnownImport(import_path=module, import_name=type_name) + for type_name, module in config.types.items() + } + + prefixes = { + prefix: ( + KnownImport(import_name=module, import_alias=prefix) + if module != prefix + else KnownImport(import_name=prefix) + ) + for prefix, module in config.type_prefixes.items() + } reporter = GroupedErrorReporter() if group_errors else ErrorReporter() - types_db = TypesDatabase( - source_pkgs=[root_path.parent.resolve()], known_imports=known_imports - ) - stub_transformer = Py2StubTransformer( - types_db=types_db, replace_doctypes=config.replace_doctypes, reporter=reporter - ) + matcher = TypeMatcher(types=types, prefixes=prefixes, aliases=config.type_aliases) + stub_transformer = Py2StubTransformer(matcher=matcher, reporter=reporter) if not out_dir: if root_path.is_file(): @@ -246,22 +254,22 @@ def main(root_path, out_dir, config_path, group_errors, allow_errors, verbose): reporter.print_grouped() # Report basic statistics - successful_queries = types_db.stats["successful_queries"] + successful_queries = matcher.successful_queries click.secho(f"{successful_queries} matched annotations", fg="green") syntax_error_count = stub_transformer.transformer.stats["syntax_errors"] if syntax_error_count: click.secho(f"{syntax_error_count} syntax errors", fg="red") - unknown_doctypes = types_db.stats["unknown_doctypes"] - if unknown_doctypes: - click.secho(f"{len(unknown_doctypes)} unknown doctypes", fg="red") - counter = Counter(unknown_doctypes) + unknown_qualnames = matcher.unknown_qualnames + if unknown_qualnames: + click.secho(f"{len(unknown_qualnames)} unknown type names", fg="red") + counter = Counter(unknown_qualnames) sorted_item_counts = sorted(counter.items(), key=lambda x: x[1], reverse=True) for item, count in sorted_item_counts: click.echo(f" {item} (x{count})") - total_errors = len(unknown_doctypes) + syntax_error_count + total_errors = len(unknown_qualnames) + syntax_error_count total_msg = f"{total_errors} total errors" if allow_errors: total_msg = f"{total_msg} (allowed {allow_errors})" diff --git a/src/docstub/_config.py b/src/docstub/_config.py index c9dff3a..af100eb 100644 --- a/src/docstub/_config.py +++ b/src/docstub/_config.py @@ -11,9 +11,9 @@ class Config: DEFAULT_CONFIG_PATH: ClassVar[Path] = Path(__file__).parent / "default_config.toml" - extend_grammar: str = "" - known_imports: dict[str, dict[str, str]] = dataclasses.field(default_factory=dict) - replace_doctypes: dict[str, str] = dataclasses.field(default_factory=dict) + types: dict[str, str] = dataclasses.field(default_factory=dict) + type_prefixes: dict[str, str] = dataclasses.field(default_factory=dict) + type_aliases: dict[str, str] = dataclasses.field(default_factory=dict) _source: tuple[Path, ...] = () @@ -61,9 +61,9 @@ def merge(self, other): if not isinstance(other, type(self)): return NotImplemented new = Config( - extend_grammar=self.extend_grammar + other.extend_grammar, - known_imports=self.known_imports | other.known_imports, - replace_doctypes=self.replace_doctypes | other.replace_doctypes, + types=self.types | other.types, + type_prefixes=self.type_prefixes | other.type_prefixes, + type_aliases=self.type_aliases | other.type_aliases, _source=self._source + other._source, ) logger.debug("merged Config from %s", new._source) @@ -73,14 +73,19 @@ def to_dict(self): return dataclasses.asdict(self) def __post_init__(self): - if not isinstance(self.extend_grammar, str): - raise TypeError("extended_grammar must be a string") - if not isinstance(self.known_imports, dict): - raise TypeError("known_imports must be a dict") - if not isinstance(self.replace_doctypes, dict): - raise TypeError("replace_doctypes must be a string") + self.validate(self.to_dict()) def __repr__(self) -> str: sources = " | ".join(str(s) for s in self._source) formatted = f"<{type(self).__name__}: {sources}>" return formatted + + @staticmethod + def validate(mapping): + for name in ["types", "type_prefixes", "type_aliases"]: + table = mapping[name] + if not isinstance(table, dict): + raise TypeError(f"{name} must be a dict") + for key, value in table.items(): + if not isinstance(key, str) or not isinstance(value, str): + raise TypeError(f"`{key} = {value}` in {name} must both be a str") diff --git a/src/docstub/_docstrings.py b/src/docstub/_docstrings.py index 9b32e54..003cef5 100644 --- a/src/docstub/_docstrings.py +++ b/src/docstub/_docstrings.py @@ -11,8 +11,12 @@ import lark.visitors import numpydoc.docscrape as npds -from ._analysis import KnownImport, TypesDatabase -from ._utils import DocstubError, ErrorReporter, accumulate_qualname, escape_qualname +# TODO Uncouple docstrings & analysis module +# It should be possible to transform docstrings without matching to valid +# types and imports. I think that could very well be done at a higher level, +# e.g. in the stubs module. +from ._analysis import KnownImport, TypeMatcher +from ._utils import DocstubError, ErrorReporter, escape_qualname logger = logging.getLogger(__name__) @@ -171,8 +175,7 @@ class DoctypeTransformer(lark.visitors.Transformer): Attributes ---------- - types_db : ~.TypesDatabase - replace_doctypes : dict[str, str] + matcher : ~.TypeMatcher stats : dict[str, Any] blacklisted_qualnames : ClassVar[frozenset[str]] All Python keywords [1]_ are blacklisted from use in qualnames except for ``True`` @@ -231,26 +234,18 @@ class DoctypeTransformer(lark.visitors.Transformer): } ) - def __init__(self, *, types_db=None, replace_doctypes=None, **kwargs): + def __init__(self, *, matcher=None, **kwargs): """ Parameters ---------- - types_db : ~.TypesDatabase, optional - A static database of collected types usable as an annotation. If - not given, defaults to a database with common types from the - standard library (see :func:`~.common_known_imports`). - replace_doctypes : dict[str, str], optional - Replacements for human-friendly aliases. + matcher : ~.TypeMatcher, optional kwargs : dict[Any, Any], optional Keyword arguments passed to the init of the parent class. """ - if replace_doctypes is None: - replace_doctypes = {} - if types_db is None: - types_db = TypesDatabase() + if matcher is None: + matcher = TypeMatcher() - self.types_db = types_db - self.replace_doctypes = replace_doctypes + self.matcher = matcher self._collected_imports = None self._unknown_qualnames = None @@ -308,12 +303,6 @@ def qualname(self, tree): children = tree.children _qualname = ".".join(children) - for partial_qualname in accumulate_qualname(_qualname): - replacement = self.replace_doctypes.get(partial_qualname) - if replacement: - _qualname = _qualname.replace(partial_qualname, replacement) - break - _qualname = self._match_import(_qualname, meta=tree.meta) if _qualname in self.blacklisted_qualnames: @@ -389,8 +378,8 @@ def natlang_literal(self, tree): out, ) - if self.types_db is not None: - _, known_import = self.types_db.query("Literal") + if self.matcher is not None: + _, known_import = self.matcher.match("Literal") if known_import: self._collected_imports.add(known_import) return out @@ -519,8 +508,8 @@ def _match_import(self, qualname, *, meta): matched_qualname : str Possibly modified or normalized qualname. """ - if self.types_db is not None: - annotation_name, known_import = self.types_db.query(qualname) + if self.matcher is not None: + annotation_name, known_import = self.matcher.match(qualname) else: annotation_name = None known_import = None diff --git a/src/docstub/_stubs.py b/src/docstub/_stubs.py index 46854cf..801ffac 100644 --- a/src/docstub/_stubs.py +++ b/src/docstub/_stubs.py @@ -350,8 +350,6 @@ class Py2StubTransformer(cst.CSTTransformer): Attributes ---------- - types_db : ~.TypesDatabase | None - replace_doctypes : dict[str, str] | None transformer : ~.DoctypeTransformer References @@ -381,22 +379,17 @@ def print_upper(x: Incomplete) -> None: ... ) _Annotation_None: ClassVar[cst.Annotation] = cst.Annotation(cst.Name("None")) - def __init__(self, *, types_db=None, replace_doctypes=None, reporter=None): + def __init__(self, *, matcher=None, reporter=None): """ Parameters ---------- - types_db : ~.TypesDatabase - replace_doctypes : dict[str, str] + matcher : ~.TypeMatcher reporter : ~.ErrorReporter """ if reporter is None: reporter = ErrorReporter() - self.types_db = types_db - self.replace_doctypes = replace_doctypes - self.transformer = DoctypeTransformer( - types_db=types_db, replace_doctypes=replace_doctypes - ) + self.transformer = DoctypeTransformer(matcher=matcher) self.reporter = reporter # Relevant docstring for the current context self._scope_stack = None # Entered module, class or function scopes @@ -423,8 +416,10 @@ def current_source(self, value): value : Path """ self._current_source = value - if self.types_db is not None: - self.types_db.current_source = value + # TODO pass current_source directly when using the transformer / matcher + # instead of assigning it here! + if self.transformer is not None and self.transformer.matcher is not None: + self.transformer.matcher.current_module = value @property def is_inside_function_def(self): @@ -585,8 +580,7 @@ def leave_FunctionDef(self, original_node, updated_node): ) replaced = _inline_node_as_code(original_node.returns.annotation) details = ( - f"{replaced}\n" - f"{reporter.underline(replaced)} -> {annotation_value}" + f"{replaced}\n{reporter.underline(replaced)} -> {annotation_value}" ) reporter.message( short="Replacing existing inline return annotation", diff --git a/src/docstub/default_config.toml b/src/docstub/default_config.toml index c6b0f6c..5f94ff1 100644 --- a/src/docstub/default_config.toml +++ b/src/docstub/default_config.toml @@ -1,32 +1,43 @@ [tool.docstub] -# TODO not implemented and used yet -extend_grammar = """ +# Types and their external modules to use in docstrings. +# Docstub can't yet automatically discover where to import types from other +# packages from. Instead, you can provide this information explicitly. +# Any type on the left side will be associated with the given "module" on the +# right side. +# +# Examples: +# Path = "pathlib" +# Will allow using "Path" and use "from pathlib import Path". +# +# NDArray = "numpy.typing" +# Will allow "NDarray" and use "from numpy.typing import NDArray". +[tool.docstub.types] +Path = "pathlib" +NDArray = "numpy.typing" +ArrayLike = "numpy.typing" -""" -# Import information for type annotations, declared ahead of time. +# Prefixes for external modules to match types in docstrings. +# Docstub can't yet automatically discover where to import types from other +# packages from. Instead, you can provide this information explicitly. +# Any type in a docstring whose prefix matches the name given on the left side, +# will be associated with the given "module" on the right side. # -# Each item maps an annotation name on the left side to a dictionary on the -# right side. +# Examples: +# np = "numpy" +# Will match `np.uint8` and `np.typing.NDarray` and use "import numpy as np". # -# Import information can be declared with the following fields: -# from : Indicate that the DocType can be imported from this path. -# import : Import this object, defaults to the DocType. -# as : Use this alias for the imported object -# is_builtin : Indicate that this DocType doesn't need to be imported, -# defaults to "false" -[tool.docstub.known_imports] -Path = { from = "pathlib" } -Callable = { from = "collections.abc" } -np = { import = "numpy", as = "np" } -NDArray = { from = "numpy.typing" } -ArrayLike = { from = "numpy.typing" } +# plt = "matplotlib.pyplot +# Will match `plt.Figure` use `import matplotlib.pyplot as plt`. +[tool.docstub.type_prefixes] +np = "numpy" +numpy = "numpy" + -# Specify human-friendly aliases that can be used instead of Python-parsable -# annotations. -# TODO rename to qualname_alias or something -[tool.docstub.replace_doctypes] +# Specify human-friendly aliases that can be used in docstrings to describe +# valid Python types or annotations. +[tool.docstub.type_aliases] iterable = "Iterable" callable = "Callable" function = "Callable" @@ -34,14 +45,14 @@ func = "Callable" sequence = "Sequence" mapping = "Mapping" -numpy = "np" +# NumPy scalar = "np.ScalarType" integer = "np.integer" signedinteger = "np.signedinteger" -byte ="np.byte" -short = "np.short" +#byte ="np.byte" +#short = "np.short" intc = "np.intc" -int_ = "np.int_" +#int_ = "np.int_" longlong = "np.longlong" int8 = "np.int8" int16 = "np.int16" @@ -76,13 +87,13 @@ complex64 = "np.complex64" complex128 = "np.complex128" complex192 = "np.complex192" complex256 = "np.complex256" -bool_ = "np.bool_" +#bool_ = "np.bool_" datetime64 = "np.datetime64" timedelta64 = "np.timedelta64" -object_ = "np.object_" +#object_ = "np.object_" #flexible = "np.flexible" #character = "np.character" -bytes_ = "np.bytes_" +#bytes_ = "np.bytes_" #str_ = "np.str_" #void = "np.void" ndarray = "NDArray" diff --git a/tests/test_analysis.py b/tests/test_analysis.py index d55a8e7..91a6f09 100644 --- a/tests/test_analysis.py +++ b/tests/test_analysis.py @@ -2,7 +2,7 @@ import pytest -from docstub._analysis import KnownImport, TypeCollector, TypesDatabase +from docstub._analysis import KnownImport, TypeCollector, TypeMatcher @pytest.fixture @@ -91,10 +91,14 @@ def test_ignores_assigns(self, module_factory, src): assert len(imports) == 0 -class Test_TypesDatabase: - known_imports = { # noqa: RUF012 - "dict": KnownImport(builtin_name="dict"), +class Test_TypeMatcher: + type_prefixes = { # noqa: RUF012 "np": KnownImport(import_name="numpy", import_alias="np"), + "foo.bar.Baz": KnownImport(import_path="foo.bar", import_name="Baz"), + } + + types = { # noqa: RUF012 + "dict": KnownImport(builtin_name="dict"), "foo.bar": KnownImport(import_path="foo", import_name="bar"), "foo.bar.Baz": KnownImport(import_path="foo.bar", import_name="Baz"), "foo.bar.Baz.Bix": KnownImport(import_path="foo.bar", import_name="Baz"), @@ -103,12 +107,8 @@ class Test_TypesDatabase: # fmt: off @pytest.mark.parametrize( - ("name", "exp_annotation", "exp_import_line"), + ("search_name", "expected_name", "expected_origin"), [ - ("np", "np", "import numpy as np"), - # Finds imports whose import target matches the start of `name` - ("np.doesnt_exist", "np.doesnt_exist", "import numpy as np"), - ("foo.bar.Baz", "Baz", "from foo.bar import Baz"), # Finds "Baz" with abbreviated form as well ( "~.bar.Baz", "Baz", "from foo.bar import Baz"), @@ -120,10 +120,8 @@ class Test_TypesDatabase: ( "~.Baz.Bix", "Baz.Bix", "from foo.bar import Baz"), ( "~.Bix", "Baz.Bix", "from foo.bar import Baz"), - # Finds nested class "Baz.Gul" that's not explicitly defined, but - # whose import target matches "Baz" - ("foo.bar.Baz.Gul", "Baz.Gul", "from foo.bar import Baz"), - # but abbreviated form doesn't work + # Abbreviated form with not explicitly defined class "Baz.Gul" + # never matches ( "~.bar.Baz.Gul", None, None), ( "~.Baz.Gul", None, None), ( "~.Gul", None, None), @@ -135,16 +133,42 @@ class Test_TypesDatabase: ( "~.Qux", "bar.Baz.Qux", "from foo import bar"), ] ) - def test_query(self, name, exp_annotation, exp_import_line): - db = TypesDatabase(known_imports=self.known_imports.copy()) + def test_query_types(self, search_name, expected_name, expected_origin): + db = TypeMatcher(types=self.types.copy()) + + type_name, type_origin = db.match(search_name) + + if expected_name is None and expected_origin is None: + assert expected_name is type_name + assert expected_origin is type_origin + else: + assert str(type_origin) == expected_origin + assert type_name.startswith(type_origin.target) + assert type_name == expected_name + # fmt: on + + # fmt: off + @pytest.mark.parametrize( + ("search_name", "expected_name", "expected_origin"), + [ + ("np", "np", "import numpy as np"), + # Finds imports whose import target matches the start of `name` + ("np.doesnt_exist", "np.doesnt_exist", "import numpy as np"), + # Finds nested class "Baz.Gul" that's not explicitly defined, but + # whose import target matches "Baz" + ("foo.bar.Baz.Gul", "Baz.Gul", "from foo.bar import Baz"), + ] + ) + def test_query_prefix(self, search_name, expected_name, expected_origin): + db = TypeMatcher(prefixes=self.type_prefixes.copy()) - annotation, known_import = db.query(name) + type_name, type_origin = db.match(search_name) - if exp_annotation is None and exp_import_line is None: - assert exp_annotation is annotation - assert exp_import_line is known_import + if expected_name is None and expected_origin is None: + assert expected_name is type_name + assert expected_origin is type_origin else: - assert str(known_import) == exp_import_line - assert annotation.startswith(known_import.target) - assert annotation == exp_annotation + assert str(type_origin) == expected_origin + assert type_name.startswith(type_origin.target) + assert type_name == expected_name # fmt: on diff --git a/tests/test_config.py b/tests/test_config.py new file mode 100644 index 0000000..1cb0222 --- /dev/null +++ b/tests/test_config.py @@ -0,0 +1,9 @@ +from docstub._config import Config + + +class Test_Config: + def test_from_default(self): + config = Config.from_default() + assert len(config.types) > 0 + assert len(config.type_prefixes) > 0 + assert len(config.type_aliases) > 0