Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
28 changes: 12 additions & 16 deletions examples/docstub.toml
Original file line number Diff line number Diff line change
@@ -1,20 +1,16 @@
[tool.docstub]

# TODO not implemented and used yet
extend_grammar = """

"""

# Import information for type annotations, declared ahead of time.
# Prefixes for external modules to match types in docstrings.
# Docstub can't yet automatically discover where to import types from other
# packages from. Instead, you can provide this information explicitly.
# Any type in a docstring whose prefix matches the name given on the left side,
# will be associated with the given "module" on the right side.
#
# Each item maps an annotation name on the left side to a dictionary on the
# right side.
# Examples:
# np = "numpy"
# Will match `np.uint8` and `np.typing.NDarray` and use "import numpy as np".
#
# Import information can be declared with the following fields:
# from : Indicate that the DocType can be imported from this path.
# import : Import this object, defaults to the DocType.
# as : Use this alias for the imported object
# is_builtin : Indicate that this DocType doesn't need to be imported,
# defaults to "false"
[tool.docstub.known_imports]
configparser = {import = "configparser"}
# plt = "matplotlib.pyplot
# Will match `plt.Figure` use `import matplotlib.pyplot as plt`.
[tool.docstub.type_prefixes]
configparser = "configparser"
2 changes: 1 addition & 1 deletion examples/example_pkg/_basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ def func_use_from_elsewhere(a1, a2, a3, a4):
----------
a1 : example_pkg.CustomException
a2 : ExampleClass
a3 : example_pkg.CustomException.NestedClass
a3 : example_pkg._basic.ExampleClass.NestedClass
a4 : ExampleClass.NestedClass

Returns
Expand Down
8 changes: 4 additions & 4 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -105,10 +105,10 @@ testpaths = [
[tool.coverage]
run.source = ["docstub"]

[tool.docstub.known_imports]
cst = {import = "libcst", as="cst"}
lark = {import = "lark"}
numpydoc = {import = "numpydoc"}
[tool.docstub.type_prefixes]
cst = "libcst"
lark = "lark"
numpydoc = "numpydoc"


[tool.mypy]
Expand Down
129 changes: 62 additions & 67 deletions src/docstub/_analysis.py
Original file line number Diff line number Diff line change
Expand Up @@ -410,69 +410,63 @@ def _collect_type_annotation(self, stack):
self.known_imports[qualname] = known_import


class TypesDatabase:
"""A static database of collected types usable as an annotation.
class TypeMatcher:
"""Match strings to collected type information.

Attributes
----------
current_source : Path | None
source_pkgs : list[Path]
known_imports : dict[str, KnownImport]
stats : dict[str, Any]
types : dict[str, KnownImport]
prefixes : dict[str, KnownImport]
aliases : dict[str, str]
successful_queries : int
unknown_qualnames : list
current_module : Path | None

Examples
--------
>>> from docstub._analysis import TypesDatabase, common_known_imports
>>> db = TypesDatabase(known_imports=common_known_imports())
>>> db.query("Any")
>>> from docstub._analysis import TypeMatcher, common_known_imports
>>> db = TypeMatcher()
>>> db.match("Any")
('Any', <KnownImport 'from typing import Any'>)
"""

def __init__(
self,
*,
source_pkgs=None,
known_imports=None,
types=None,
prefixes=None,
aliases=None,
):
"""
Parameters
----------
source_pkgs : list[Path], optional
known_imports : dict[str, KnownImport], optional
If not provided, defaults to imports returned by
:func:`common_known_imports`.
types : dict[str, KnownImport]
prefixes : dict[str, KnownImport]
aliases : dict[str, str]
"""
if source_pkgs is None:
source_pkgs = []
if known_imports is None:
known_imports = common_known_imports()

self.current_source = None
self.source_pkgs = source_pkgs
self.types = types or common_known_imports()
self.prefixes = prefixes or {}
self.aliases = aliases or {}
self.successful_queries = 0
self.unknown_qualnames = []

self.known_imports = known_imports
self.current_module = None

self.stats = {
"successful_queries": 0,
"unknown_doctypes": [],
}

def query(self, search_name):
def match(self, search_name):
"""Search for a known annotation name.

Parameters
----------
search_name : str
current_module : Path, optional

Returns
-------
annotation_name : str | None
If it was found, the name of the annotation that matches the `known_import`.
known_import : KnownImport | None
If it was found, import information matching the `annotation_name`.
type_name : str | None
type_origin : KnownImport | None
"""
annotation_name = None
known_import = None
type_name = None
type_origin = None

if search_name.startswith("~."):
# Sphinx like matching with abbreviated name
Expand All @@ -481,63 +475,64 @@ def query(self, search_name):
regex = re.compile(pattern + "$")
# Might be slow, but works for now
matches = {
key: value
for key, value in self.known_imports.items()
if regex.match(key)
key: value for key, value in self.types.items() if regex.match(key)
}
if len(matches) > 1:
shortest_key = sorted(matches.keys(), key=lambda x: len(x))[0]
known_import = matches[shortest_key]
annotation_name = shortest_key
type_origin = matches[shortest_key]
type_name = shortest_key
logger.warning(
"%r in %s matches multiple types %r, using %r",
search_name,
self.current_source,
self.current_module or "<file not known>",
matches.keys(),
shortest_key,
)
elif len(matches) == 1:
annotation_name, known_import = matches.popitem()
type_name, type_origin = matches.popitem()
else:
search_name = search_name[2:]
logger.debug(
"couldn't match %r in %s", search_name, self.current_source
"couldn't match %r in %s",
search_name,
self.current_module or "<file not known>",
)

if known_import is None and self.current_source:
# Replace alias
search_name = self.aliases.get(search_name, search_name)

if type_origin is None and self.current_module:
# Try scope of current module
module_name = module_name_from_path(self.current_source)
module_name = module_name_from_path(self.current_module)
try_qualname = f"{module_name}.{search_name}"
known_import = self.known_imports.get(try_qualname)
if known_import:
annotation_name = search_name
type_origin = self.types.get(try_qualname)
if type_origin:
type_name = search_name

if type_origin is None and search_name in self.types:
type_name = search_name
type_origin = self.types[search_name]

if known_import is None:
if type_origin is None:
# Try a subset of the qualname (first 'a.b.c', then 'a.b' and 'a')
for partial_qualname in reversed(accumulate_qualname(search_name)):
known_import = self.known_imports.get(partial_qualname)
if known_import:
annotation_name = search_name
type_origin = self.prefixes.get(partial_qualname)
if type_origin:
type_name = search_name
break

if (
known_import is not None
and annotation_name is not None
and annotation_name != known_import.target
and not annotation_name.startswith(known_import.target)
type_origin is not None
and type_name is not None
and type_name != type_origin.target
and not type_name.startswith(type_origin.target)
):
# Ensure that the annotation matches the import target
annotation_name = annotation_name[
annotation_name.find(known_import.target) :
]
type_name = type_name[type_name.find(type_origin.target) :]

if annotation_name is not None:
self.stats["successful_queries"] += 1
if type_name is not None:
self.successful_queries += 1
else:
self.stats["unknown_doctypes"].append(search_name)
self.unknown_qualnames.append(search_name)

return annotation_name, known_import

def __repr__(self) -> str:
repr = f"{type(self).__name__}({self.source_pkgs})"
return repr
return type_name, type_origin
56 changes: 32 additions & 24 deletions src/docstub/_cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
from ._analysis import (
KnownImport,
TypeCollector,
TypesDatabase,
TypeMatcher,
common_known_imports,
)
from ._cache import FileCache
Expand Down Expand Up @@ -76,19 +76,18 @@ def _setup_logging(*, verbose):
)


def _build_import_map(config, root_path):
"""Build a map of known imports.
def _collect_types(root_path):
"""Collect types.

Parameters
----------
config : ~.Config
root_path : Path

Returns
-------
imports : dict[str, ~.KnownImport]
types : dict[str, ~.KnownImport]
"""
known_imports = common_known_imports()
types = common_known_imports()

collect_cached_types = FileCache(
func=TypeCollector.collect,
Expand All @@ -99,12 +98,10 @@ def _build_import_map(config, root_path):
if root_path.is_dir():
for source_path in walk_python_package(root_path):
logger.info("collecting types in %s", source_path)
known_imports_in_source = collect_cached_types(source_path)
known_imports.update(known_imports_in_source)

known_imports.update(KnownImport.many_from_config(config.known_imports))
types_in_source = collect_cached_types(source_path)
types.update(types_in_source)

return known_imports
return types


@contextmanager
Expand Down Expand Up @@ -195,15 +192,26 @@ def main(root_path, out_dir, config_path, group_errors, allow_errors, verbose):
)

config = _load_configuration(config_path)
known_imports = _build_import_map(config, root_path)

types = common_known_imports()
types |= _collect_types(root_path)
types |= {
type_name: KnownImport(import_path=module, import_name=type_name)
for type_name, module in config.types.items()
}

prefixes = {
prefix: (
KnownImport(import_name=module, import_alias=prefix)
if module != prefix
else KnownImport(import_name=prefix)
)
for prefix, module in config.type_prefixes.items()
}

reporter = GroupedErrorReporter() if group_errors else ErrorReporter()
types_db = TypesDatabase(
source_pkgs=[root_path.parent.resolve()], known_imports=known_imports
)
stub_transformer = Py2StubTransformer(
types_db=types_db, replace_doctypes=config.replace_doctypes, reporter=reporter
)
matcher = TypeMatcher(types=types, prefixes=prefixes, aliases=config.type_aliases)
stub_transformer = Py2StubTransformer(matcher=matcher, reporter=reporter)

if not out_dir:
if root_path.is_file():
Expand Down Expand Up @@ -246,22 +254,22 @@ def main(root_path, out_dir, config_path, group_errors, allow_errors, verbose):
reporter.print_grouped()

# Report basic statistics
successful_queries = types_db.stats["successful_queries"]
successful_queries = matcher.successful_queries
click.secho(f"{successful_queries} matched annotations", fg="green")

syntax_error_count = stub_transformer.transformer.stats["syntax_errors"]
if syntax_error_count:
click.secho(f"{syntax_error_count} syntax errors", fg="red")

unknown_doctypes = types_db.stats["unknown_doctypes"]
if unknown_doctypes:
click.secho(f"{len(unknown_doctypes)} unknown doctypes", fg="red")
counter = Counter(unknown_doctypes)
unknown_qualnames = matcher.unknown_qualnames
if unknown_qualnames:
click.secho(f"{len(unknown_qualnames)} unknown type names", fg="red")
counter = Counter(unknown_qualnames)
sorted_item_counts = sorted(counter.items(), key=lambda x: x[1], reverse=True)
for item, count in sorted_item_counts:
click.echo(f" {item} (x{count})")

total_errors = len(unknown_doctypes) + syntax_error_count
total_errors = len(unknown_qualnames) + syntax_error_count
total_msg = f"{total_errors} total errors"
if allow_errors:
total_msg = f"{total_msg} (allowed {allow_errors})"
Expand Down
Loading