diff --git a/.github/scripts/assert-unchanged.sh b/.github/scripts/assert-unchanged.sh index 4a2168d..d14e20e 100755 --- a/.github/scripts/assert-unchanged.sh +++ b/.github/scripts/assert-unchanged.sh @@ -15,8 +15,10 @@ UNTRACKED=$(git ls-files --others --exclude-standard "$CHECK_DIR") echo "$UNTRACKED" | xargs -I _ git --no-pager diff /dev/null _ || true # Display changes in tracked files and capture non-zero exit code if so -git diff --exit-code HEAD "$CHECK_DIR" || true +set +e +git diff --exit-code HEAD "$CHECK_DIR" GIT_DIFF_HEAD_EXIT_CODE=$? +set -e # Display changes in tracked files and capture exit status if [ $GIT_DIFF_HEAD_EXIT_CODE -ne 0 ] || [ -n "$UNTRACKED" ]; then diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 4ebec38..895d16d 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -89,7 +89,7 @@ jobs: examples/example_pkg .github/scripts/assert-unchanged.sh examples/ - - name: Check docstub-stubs + - name: Check docstub-stubs (single process) # Check that stubs for docstub are up-to-date by regenerating them # with docstub and looking for differences. run: | @@ -97,6 +97,13 @@ jobs: python -m docstub run -v src/docstub -o src/docstub-stubs .github/scripts/assert-unchanged.sh src/docstub-stubs/ + - name: Check docstub-stubs (multiprocess) + # Repeat test with multiprocessing enabled + run: | + rm -rf src/docstub-stubs + python -m docstub run -v src/docstub -o src/docstub-stubs --workers 2 + .github/scripts/assert-unchanged.sh src/docstub-stubs/ + - name: Check with mypy.stubtest run: | python -m mypy.stubtest \ diff --git a/REMINDERS.md b/REMINDERS.md new file mode 100644 index 0000000..b904406 --- /dev/null +++ b/REMINDERS.md @@ -0,0 +1,10 @@ +# Reminders + +## With Python >=3.13 + +Remove vendored `glob_translate` in `docstub._vendored.stdlib`. + + +## With Python >=3.14 + +Remove vendored `ProcessPoolExecutor` in `docstub._vendored.stdlib`. diff --git a/docs/command_line.md b/docs/command_line.md index 1e31c34..0b49668 100644 --- a/docs/command_line.md +++ b/docs/command_line.md @@ -49,10 +49,6 @@ Options: Set output directory explicitly. Stubs will be directly written into that directory while preserving the directory structure under PACKAGE_PATH. Otherwise, stubs are generated inplace. - --config PATH - Set one or more configuration file(s) explicitly. Otherwise, it will - look for a `pyproject.toml` or `docstub.toml` in the current - directory. --ignore GLOB Ignore files matching this glob-style pattern. Can be used multiple times. @@ -67,8 +63,15 @@ Options: -W, --fail-on-warning Return non-zero exit code when a warning is raised. Will add to --allow-errors. + --workers INT + Experimental: Process files in parallel with the desired number of + workers. By default, no multiprocessing is used. [default: 1] --no-cache Ignore pre-existing cache and don't create a new one. + --config PATH + Set one or more configuration file(s) explicitly. Otherwise, it will + look for a `pyproject.toml` or `docstub.toml` in the current + directory. -v, --verbose Print more details. Use once to show information messages. Use -vv to print debug messages. diff --git a/examples/example_pkg-stubs/_basic.pyi b/examples/example_pkg-stubs/_basic.pyi index 8fb3e9b..500d3d6 100644 --- a/examples/example_pkg-stubs/_basic.pyi +++ b/examples/example_pkg-stubs/_basic.pyi @@ -17,7 +17,7 @@ __all__ = [ "func_empty", ] -def func_empty(a1: Incomplete, a2: Incomplete, a3) -> None: ... +def func_empty(a1: Incomplete, a2: Incomplete, a3: Incomplete) -> None: ... def func_contains( a1: list[float], a2: dict[str, Union[int, str]], diff --git a/pyproject.toml b/pyproject.toml index f5ea050..243d8c9 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -124,6 +124,9 @@ xfail_strict = true filterwarnings = ["error"] log_cli_level = "info" testpaths = ["src", "tests"] +markers = [ + "slow: marks tests as slow (deselect with `-m 'not slow'`)", +] [tool.coverage] diff --git a/src/docstub-stubs/_cli.pyi b/src/docstub-stubs/_cli.pyi index cd9690a..ed6077b 100644 --- a/src/docstub-stubs/_cli.pyi +++ b/src/docstub-stubs/_cli.pyi @@ -16,6 +16,7 @@ from _typeshed import Incomplete from ._analysis import PyImport, TypeCollector, TypeMatcher, common_known_types from ._cache import CACHE_DIR_NAME, FileCache, validate_cache from ._cli_help import HelpFormatter +from ._concurrency import LoggingProcessExecutor, guess_concurrency_params from ._config import Config from ._path_utils import ( STUB_HEADER_COMMENT, @@ -25,6 +26,7 @@ from ._path_utils import ( ) from ._report import setup_logging from ._stubs import Py2StubTransformer, try_format_stub +from ._utils import update_with_add_values from ._version import __version__ logger: logging.Logger @@ -45,6 +47,9 @@ click.Context.formatter_class = HelpFormatter @click.group() def cli() -> None: ... def _add_verbosity_options(func: Callable) -> Callable: ... +def _transform_to_stub( + task: tuple[Path, Path, Py2StubTransformer], +) -> dict[str, int | list[str]]: ... @cli.command() def run( *, @@ -55,6 +60,7 @@ def run( group_errors: bool, allow_errors: int, fail_on_warning: bool, + desired_worker_count: int, no_cache: bool, verbose: int, quiet: int, diff --git a/src/docstub-stubs/_concurrency.pyi b/src/docstub-stubs/_concurrency.pyi new file mode 100644 index 0000000..572b649 --- /dev/null +++ b/src/docstub-stubs/_concurrency.pyi @@ -0,0 +1,48 @@ +# File generated with docstub + +import logging +import logging.handlers +import math +import multiprocessing +import os +from collections.abc import Callable, Iterator +from concurrent.futures import Executor +from dataclasses import dataclass +from types import TracebackType +from typing import Any + +from ._vendored.stdlib import ProcessPoolExecutor + +logger: logging.Logger + +class MockPoolExecutor(Executor): + def map[T]( + self, fn: Callable[..., T], *iterables: Any, **__: Any + ) -> Iterator[T]: ... + +@dataclass(kw_only=True) +class LoggingProcessExecutor: + + max_workers: int | None = ... + logging_handlers: tuple[logging.Handler, ...] = ... + initializer: Callable | None = ... + initargs: tuple | None = ... + + @staticmethod + def _initialize_worker( + queue: multiprocessing.Queue, + worker_log_level: int, + initializer: Callable, + initargs: tuple[Any], + ) -> None: ... + def __enter__(self) -> ProcessPoolExecutor | MockPoolExecutor: ... + def __exit__( + self, + exc_type: type[BaseException] | None, + exc_val: BaseException | None, + exc_tb: TracebackType, + ) -> bool: ... + +def guess_concurrency_params( + *, task_count: int, desired_worker_count: int | None = ... +) -> tuple[int, int]: ... diff --git a/src/docstub-stubs/_report.pyi b/src/docstub-stubs/_report.pyi index 7d2a29e..1515c66 100644 --- a/src/docstub-stubs/_report.pyi +++ b/src/docstub-stubs/_report.pyi @@ -47,6 +47,9 @@ class ContextReporter: def error( self, short: str, *args: Any, details: str | None = ..., **log_kw: Any ) -> None: ... + def critical( + self, short: str, *args: Any, details: str | None = ..., **log_kw: Any + ) -> None: ... def __post_init__(self) -> None: ... @staticmethod def underline(line: str, *, char: str = ...) -> str: ... @@ -62,9 +65,17 @@ class ReportHandler(logging.StreamHandler): self, stream: TextIO | None = ..., group_errors: bool = ... ) -> None: ... def format(self, record: logging.LogRecord) -> str: ... - def emit(self, record: logging.LogRecord) -> None: ... + def handle(self, record: logging.LogRecord) -> bool: ... def emit_grouped(self) -> None: ... +class LogCounter(logging.NullHandler): + critical_count: int + error_count: int + warning_count: int + + def __init__(self) -> None: ... + def handle(self, record: logging.LogRecord) -> bool: ... + def setup_logging( *, verbosity: Literal[-2, -1, 0, 1, 2, 3], group_errors: bool -) -> ReportHandler: ... +) -> tuple[ReportHandler, LogCounter]: ... diff --git a/src/docstub-stubs/_stubs.pyi b/src/docstub-stubs/_stubs.pyi index a6521f6..87c2554 100644 --- a/src/docstub-stubs/_stubs.pyi +++ b/src/docstub-stubs/_stubs.pyi @@ -19,7 +19,7 @@ from ._docstrings import ( FallbackAnnotation, ) from ._report import ContextReporter -from ._utils import module_name_from_path +from ._utils import module_name_from_path, update_with_add_values logger: logging.Logger @@ -73,6 +73,9 @@ class Py2StubTransformer(cst.CSTTransformer): @property def is_inside_function_def(self) -> bool: ... def python_to_stub(self, source: str, *, module_path: Path | None = ...) -> str: ... + def collect_stats( + self, *, reset_after: bool = ... + ) -> dict[str, int | list[str]]: ... def visit_ClassDef(self, node: cst.ClassDef) -> Literal[True]: ... def leave_ClassDef( self, original_node: cst.ClassDef, updated_node: cst.ClassDef diff --git a/src/docstub-stubs/_utils.pyi b/src/docstub-stubs/_utils.pyi index 8b8b4bd..c29d6a8 100644 --- a/src/docstub-stubs/_utils.pyi +++ b/src/docstub-stubs/_utils.pyi @@ -2,7 +2,7 @@ import itertools import re -from collections.abc import Callable +from collections.abc import Callable, Hashable, Mapping, Sequence from functools import lru_cache, wraps from pathlib import Path from zlib import crc32 @@ -12,6 +12,9 @@ def escape_qualname(name: str) -> str: ... def _resolve_path_before_caching(func: Callable) -> Callable: ... def module_name_from_path(path: Path) -> str: ... def pyfile_checksum(path: Path) -> str: ... +def update_with_add_values( + *mappings: Mapping[Hashable, int | Sequence], out: dict | None = ... +) -> dict: ... class DocstubError(Exception): pass diff --git a/src/docstub-stubs/_vendored/stdlib.pyi b/src/docstub-stubs/_vendored/stdlib.pyi index b4ff650..d07831d 100644 --- a/src/docstub-stubs/_vendored/stdlib.pyi +++ b/src/docstub-stubs/_vendored/stdlib.pyi @@ -3,6 +3,7 @@ import os import re from collections.abc import Sequence +from concurrent.futures import ProcessPoolExecutor as _ProcessPoolExecutor def _fnmatch_translate(pat: str, STAR: str, QUESTION_MARK: str) -> str: ... def glob_translate( @@ -12,3 +13,17 @@ def glob_translate( include_hidden: bool = ..., seps: Sequence[str] | None = ..., ) -> str: ... + +if not hasattr(_ProcessPoolExecutor, "terminate_workers"): + _TERMINATE: str + _KILL: str + + _SHUTDOWN_CALLBACK_OPERATION: set[str] + + class ProcessPoolExecutor(_ProcessPoolExecutor): + def _force_shutdown(self, operation: str) -> None: ... + def terminate_workers(self) -> None: ... + def kill_workers(self) -> None: ... + +else: + ProcessPoolExecutor: _ProcessPoolExecutor # type: ignore[no-redef] diff --git a/src/docstub/_analysis.py b/src/docstub/_analysis.py index d5bffe4..32a6daf 100644 --- a/src/docstub/_analysis.py +++ b/src/docstub/_analysis.py @@ -500,12 +500,14 @@ def __init__( type_prefixes : dict[str, PyImport] type_nicknames : dict[str, str] """ - self.types = common_known_types() | (types or {}) self.type_prefixes = type_prefixes or {} self.type_nicknames = type_nicknames or {} - self.successful_queries = 0 - self.unknown_qualnames = [] + + self.stats = { + "matched_type_names": 0, + "unknown_type_names": [], + } self.current_file = None @@ -621,8 +623,8 @@ def match(self, search): type_name = type_name[type_name.find(py_import.target) :] if type_name is not None: - self.successful_queries += 1 + self.stats["matched_type_names"] += 1 else: - self.unknown_qualnames.append(search) + self.stats["unknown_type_names"].append(search) return type_name, py_import diff --git a/src/docstub/_cli.py b/src/docstub/_cli.py index 0c64e3a..c741f8d 100644 --- a/src/docstub/_cli.py +++ b/src/docstub/_cli.py @@ -16,6 +16,7 @@ ) from ._cache import CACHE_DIR_NAME, FileCache, validate_cache from ._cli_help import HelpFormatter +from ._concurrency import LoggingProcessExecutor, guess_concurrency_params from ._config import Config from ._path_utils import ( STUB_HEADER_COMMENT, @@ -25,6 +26,7 @@ ) from ._report import setup_logging from ._stubs import Py2StubTransformer, try_format_stub +from ._utils import update_with_add_values from ._version import __version__ logger: logging.Logger = logging.getLogger(__name__) @@ -197,7 +199,7 @@ def log_execution_time(): try: yield except KeyboardInterrupt: - logger.critical("Interrupt!") + logger.critical("Interrupted!") finally: stop = time.time() total_seconds = stop - start @@ -255,6 +257,50 @@ def _add_verbosity_options(func): return func +def _transform_to_stub(task): + """Transform a Python file into a stub file. + + Parameters + ---------- + task : tuple[Path, Path, Py2StubTransformer] + The `source_path` for which to create a stub file at `stub_path` with + the given transformer. + + Returns + ------- + stats : dict of {str: int or list[str]} + Statistics about the transformation. + """ + source_path, stub_path, stub_transformer = task + + if source_path.suffix.lower() == ".pyi": + logger.debug("Using existing stub file %s", source_path) + with source_path.open() as fo: + stub_content = fo.read() + else: + with source_path.open() as fo: + py_content = fo.read() + logger.debug("Transforming %s", source_path) + try: + stub_content = stub_transformer.python_to_stub( + py_content, module_path=source_path + ) + stub_content = f"{STUB_HEADER_COMMENT}\n\n{stub_content}" + stub_content = try_format_stub(stub_content) + except Exception: + logger.exception("Failed creating stub for %s", source_path) + return None + + stub_path.parent.mkdir(parents=True, exist_ok=True) + with stub_path.open("w") as fo: + logger.info("Wrote %s", stub_path) + fo.write(stub_content) + + stats = stub_transformer.collect_stats() + + return stats + + # Preserve click.command below to keep type checker happy # docstub: off @cli.command() @@ -270,16 +316,6 @@ def _add_verbosity_options(func): "structure under PACKAGE_PATH. " "Otherwise, stubs are generated inplace.", ) -@click.option( - "--config", - "config_paths", - type=click.Path(exists=True, dir_okay=False), - metavar="PATH", - multiple=True, - help="Set one or more configuration file(s) explicitly. " - "Otherwise, it will look for a `pyproject.toml` or `docstub.toml` in the " - "current directory.", -) @click.option( "--ignore", type=str, @@ -312,11 +348,31 @@ def _add_verbosity_options(func): help="Return non-zero exit code when a warning is raised. " "Will add to --allow-errors.", ) +@click.option( + "--workers", + "desired_worker_count", + type=int, + default=1, + metavar="INT", + help="Experimental: Process files in parallel with the desired number of workers. " + "By default, no multiprocessing is used.", + show_default=True, +) @click.option( "--no-cache", is_flag=True, help="Ignore pre-existing cache and don't create a new one.", ) +@click.option( + "--config", + "config_paths", + type=click.Path(exists=True, dir_okay=False), + metavar="PATH", + multiple=True, + help="Set one or more configuration file(s) explicitly. " + "Otherwise, it will look for a `pyproject.toml` or `docstub.toml` in the " + "current directory.", +) @_add_verbosity_options @click.help_option("-h", "--help") @log_execution_time() @@ -329,6 +385,7 @@ def run( group_errors, allow_errors, fail_on_warning, + desired_worker_count, no_cache, verbose, quiet, @@ -349,6 +406,7 @@ def run( group_errors : bool allow_errors : int fail_on_warning : bool + desired_worker_count : int no_cache : bool verbose : int quiet : int @@ -357,7 +415,9 @@ def run( # Setup ------------------------------------------------------------------- verbosity = _calc_verbosity(verbose=verbose, quiet=quiet) - error_handler = setup_logging(verbosity=verbosity, group_errors=group_errors) + output_handler, error_counter = setup_logging( + verbosity=verbosity, group_errors=group_errors + ) root_path = Path(root_path) if root_path.is_file(): @@ -409,30 +469,32 @@ def run( # Stub generation --------------------------------------------------------- - for source_path, stub_path in walk_source_and_targets( - root_path, out_dir, ignore=config.ignore_files - ): - if source_path.suffix.lower() == ".pyi": - logger.debug("Using existing stub file %s", source_path) - with source_path.open() as fo: - stub_content = fo.read() - else: - with source_path.open() as fo: - py_content = fo.read() - logger.debug("Transforming %s", source_path) - try: - stub_content = stub_transformer.python_to_stub( - py_content, module_path=source_path - ) - stub_content = f"{STUB_HEADER_COMMENT}\n\n{stub_content}" - stub_content = try_format_stub(stub_content) - except Exception: - logger.exception("Failed creating stub for %s", source_path) - continue - stub_path.parent.mkdir(parents=True, exist_ok=True) - with stub_path.open("w") as fo: - logger.info("Wrote %s", stub_path) - fo.write(stub_content) + task_files = walk_source_and_targets(root_path, out_dir, ignore=config.ignore_files) + + # We must pass the `stub_transformer` to each worker, but we want to copy + # only once per worker. Testing suggests, that using a large enough + # `chunksize` of `>= len(task_count) / jobs` for `ProcessPoolExecutor.map`, + # ensures that. + # Using an `initializer` that assigns the transformer as a global variable + # per worker seems like the more robust solution, but naive timing suggests + # it's actually slower (> 1s on skimage). + task_args = [(*files, stub_transformer) for files in task_files] + task_count = len(task_args) + + worker_count, chunk_size = guess_concurrency_params( + task_count=task_count, desired_worker_count=desired_worker_count + ) + + logger.info("Using %i worker(s) to write %i stubs", worker_count, task_count) + logger.debug("Using chunk size of %i", chunk_size) + with LoggingProcessExecutor( + max_workers=worker_count, + logging_handlers=(output_handler, error_counter), + ) as executor: + stats_per_task = executor.map( + _transform_to_stub, task_args, chunksize=chunk_size + ) + stats = update_with_add_values(*stats_per_task) py_typed_out = out_dir / "py.typed" if not py_typed_out.exists(): @@ -442,30 +504,28 @@ def run( # Reporting -------------------------------------------------------------- if group_errors: - error_handler.emit_grouped() - assert error_handler.group_errors is True - error_handler.group_errors = False + output_handler.emit_grouped() + assert output_handler.group_errors is True + output_handler.group_errors = False # Report basic statistics - successful_queries = matcher.successful_queries - transformed_doctypes = stub_transformer.transformer.stats["transformed"] - syntax_error_count = stub_transformer.transformer.stats["syntax_errors"] - unknown_type_names = matcher.unknown_qualnames - total_warnings = error_handler.warning_count - total_errors = error_handler.error_count - - logger.info("Recognized type names: %i", successful_queries) - logger.info("Transformed doctypes: %i", transformed_doctypes) + total_warnings = error_counter.warning_count + total_errors = error_counter.error_count + + logger.info("Recognized type names: %i", stats["matched_type_names"]) + logger.info("Transformed doctypes: %i", stats["transformed_doctypes"]) if total_warnings: logger.warning("Warnings: %i", total_warnings) - if syntax_error_count: - logger.warning("Syntax errors: %i", syntax_error_count) - if unknown_type_names: + if stats["doctype_syntax_errors"]: + assert total_errors + logger.warning("Syntax errors: %i", stats["doctype_syntax_errors"]) + if stats["unknown_type_names"]: + assert total_errors logger.warning( "Unknown type names: %i (locations: %i)", - len(set(unknown_type_names)), - len(unknown_type_names), - extra={"details": _format_unknown_names(unknown_type_names)}, + len(set(stats["unknown_type_names"])), + len(stats["unknown_type_names"]), + extra={"details": _format_unknown_names(stats["unknown_type_names"])}, ) if total_errors: logger.error("Total errors: %i", total_errors) diff --git a/src/docstub/_concurrency.py b/src/docstub/_concurrency.py new file mode 100644 index 0000000..5d20ca3 --- /dev/null +++ b/src/docstub/_concurrency.py @@ -0,0 +1,228 @@ +"""Tools for parallel processing.""" + +import logging +import logging.handlers +import math +import multiprocessing +import os +from collections.abc import Callable, Iterator +from concurrent.futures import Executor +from dataclasses import dataclass + +from ._vendored.stdlib import ProcessPoolExecutor + +logger: logging.Logger = logging.getLogger(__name__) + + +class MockPoolExecutor(Executor): + """Mock executor that does not spawn a thread, interpreter, or process. + + Only implements the used part of the API defined by + :class:`concurrent.futures.Executor`. + """ + + def map[T](self, fn: Callable[..., T], *iterables, **__) -> Iterator[T]: + """Returns an iterator equivalent to map(fn, iter). + + Same behavior as :ref:`concurrent.futures.Executor.map` though any + parameters besides `fn` and `iterables` are ignored. + + Parameters + ---------- + *iterables : Any + **__ : Any + """ + return map(fn, *iterables) + + +@dataclass(kw_only=True) +class LoggingProcessExecutor: + """Wrapper around `ProcessPoolExecutor` that forwards logging from workers. + + Parameters + ---------- + max_workers, initializer, initargs: + Refer to the documentation :class:`concurrent.futures.ProcessPoolExecutor`. + logging_handlers: + Handlers, to which logging records of the worker processes will be + forwarded too. Worker processes will use the minimal log level of + the given workers. + + Examples + -------- + >>> with LoggingProcessExecutor() as pool: # doctest: +SKIP + ... # use `pool.submit` or `pool.map` ... + """ + + max_workers: int | None = None + logging_handlers: tuple[logging.Handler, ...] = () + initializer: Callable | None = None + initargs: tuple | None = () + + @staticmethod + def _initialize_worker(queue, worker_log_level, initializer, initargs): + """Initialize logging in workers. + + Parameters + ---------- + queue : multiprocessing.Queue + worker_log_level : int + initializer : Callable + initargs : tuple of Any + """ + queue_handler = logging.handlers.QueueHandler(queue) + queue_handler.setLevel(worker_log_level) + + # Could buffering with MemoryHandler improve performance here? + # memory_handler = logging.handlers.MemoryHandler( + # capacity=100, flushLevel=logging.CRITICAL, target=queue_handler + # ) + + root_logger = logging.getLogger() + root_logger.addHandler(queue_handler) + root_logger.setLevel(worker_log_level) + if initializer: + initializer(*initargs) + logger.debug("Initialized worker") + + def __enter__(self) -> ProcessPoolExecutor | MockPoolExecutor: + if self.max_workers == 1: + logger.debug("Not using concurrency (workers=%i)", self.max_workers) + return MockPoolExecutor() + + # This sets the logging level of worker processes. Use the minimal level + # of all handlers here, so that appropriate records are passed on + worker_log_level = min(*[h.level for h in self.logging_handlers]) + + # Sets method by which the worker processes are created, anything besides + # "spawn" is apparently "broken" on Windows & macOS + mp_context = multiprocessing.get_context("spawn") + + # A queue, used to pass logging records from worker processes to the + # current and main one. Naive testing suggests that + # `multiprocessing.Queue` is faster than `multiprocessing.Manager.Queue` + self._queue = mp_context.Queue() + + # The actual pool manager that is wrapped here and returned by this + # context manager + self._pool = ProcessPoolExecutor( + max_workers=self.max_workers, + mp_context=mp_context, + initializer=self._initialize_worker, + initargs=( + self._queue, + worker_log_level, + self.initializer, + self.initargs, + ), + ) + + # Forwards logging records from the queue to the given logging handlers + self._listener = logging.handlers.QueueListener( + self._queue, *self.logging_handlers, respect_handler_level=True + ) + logger.debug("Starting queue listener") + self._listener.start() + + return self._pool.__enter__() + + def __exit__(self, exc_type, exc_val, exc_tb): + """ + Parameters + ---------- + exc_type : type[BaseException] or None + exc_val : BaseException or None + exc_tb : TracebackType + + Returns + ------- + suppress_exception : bool + """ + if self.max_workers == 1: + return False + + if exc_type and issubclass(exc_type, KeyboardInterrupt): + # We want to exit immediately when user interrupts, even if we lose + # log records that way. + logger.debug("Terminating workers") + self._pool.terminate_workers() + + # Ensure that the queue doesn't block (not sure if necessary) + logger.debug("Calling `cancel_join_thread` on queue") + self._queue.cancel_join_thread() + + else: + # We want to wait for any log record to reach the listener + logger.debug("Shutting down pool") + self._pool.shutdown(wait=True, cancel_futures=True) + + logger.debug("Stopping queue listener") + self._listener.stop() + + if not self._queue.empty(): + logger.error("Expected logging queue to be empty, it is not!") + + logger.debug("Closing queue and joining its thread") + self._queue.close() + self._queue.join_thread() + + self._queue = None + self._pool = None + self._listener = None + + logger.debug("Exiting executor pool") + return False + + +def guess_concurrency_params(*, task_count, desired_worker_count=None): + """Estimate how tasks should be distributed to how many workers. + + Parameters + ---------- + task_count : int + The number of task that need to be processed. + desired_worker_count : int, optional + If not set, the number of workers is estimated. Set this explicitly + to force a number of workers. Passing `-1` will also trigger estimation. + + Returns + ------- + worker_count : int + The number of workers that should be used. + chunk_size : int + The chunk size that should be used to split the tasks among the workers. + + Examples + -------- + >>> worker_count, chunk_size = guess_concurrency_params( + ... task_count=9, desired_worker_count=None + ... ) + >>> (worker_count, chunk_size) + (1, 9) + """ + worker_count = desired_worker_count + + # `process_cpu_count` was added in Python 3.13 onwards + cpu_count = getattr(os, "process_cpu_count", os.cpu_count)() + + if worker_count is None or worker_count == -1: + # These crude heuristics were only ever "measured" on one computer + worker_count = cpu_count + # Clip to `worker_count <= task_count // 3` + worker_count = min(worker_count, task_count // 3) + # For a low number of files it may not be worth spinning up any workers + if task_count < 10: + worker_count = 1 + + # Clip to [1, cpu_count] + worker_count = max(1, min(cpu_count, worker_count)) + + # Chunking prevents unnecessary pickling of objects that are shared between + # each task more than once. When `worker_count * chunk_size` is slightly + # larger than `task_count`, each worker process only ever receives one chunk + chunk_size = task_count / worker_count + chunk_size = math.ceil(chunk_size) + + assert isinstance(worker_count, int) + assert isinstance(chunk_size, int) + return worker_count, chunk_size diff --git a/src/docstub/_docstrings.py b/src/docstub/_docstrings.py index d4a11a4..245c3f0 100644 --- a/src/docstub/_docstrings.py +++ b/src/docstub/_docstrings.py @@ -273,8 +273,8 @@ def __init__(self, *, matcher=None, **kwargs): super().__init__(**kwargs) self.stats = { - "syntax_errors": 0, - "transformed": 0, + "doctype_syntax_errors": 0, + "transformed_doctypes": 0, } def doctype_to_annotation(self, doctype, *, reporter=None): @@ -303,14 +303,14 @@ def doctype_to_annotation(self, doctype, *, reporter=None): annotation = Annotation( value=value, imports=frozenset(self._collected_imports) ) - self.stats["transformed"] += 1 + self.stats["transformed_doctypes"] += 1 return annotation, self._unknown_qualnames except ( lark.exceptions.LexError, lark.exceptions.ParseError, QualnameIsKeyword, ): - self.stats["syntax_errors"] += 1 + self.stats["doctype_syntax_errors"] += 1 raise finally: self._reporter = None diff --git a/src/docstub/_report.py b/src/docstub/_report.py index 9d3239e..86957c2 100644 --- a/src/docstub/_report.py +++ b/src/docstub/_report.py @@ -174,6 +174,23 @@ def error(self, short, *args, details=None, **log_kw): short, *args, log_level=logging.ERROR, details=details, **log_kw ) + def critical(self, short, *args, details=None, **log_kw): + """Log a critical error with context of the relevant source. + + Parameters + ---------- + short : str + A short summarizing report that shouldn't wrap over multiple lines. + *args : Any + Optional formatting arguments for `short`. + details : str, optional + An optional multiline report with more details. + **log_kw : Any + """ + return self.report( + short, *args, log_level=logging.CRITICAL, details=details, **log_kw + ) + def __post_init__(self): if self.path is not None and not isinstance(self.path, Path): msg = f"expected `path` to be of type `Path`, got {type(self.path)!r}" @@ -229,9 +246,6 @@ def __init__(self, stream=None, group_errors=False): self.group_errors = group_errors self._records = [] - self.error_count = 0 - self.warning_count = 0 - self.strip_ansi = should_strip_ansi(self.stream) def format(self, record): @@ -292,22 +306,22 @@ def format(self, record): return msg - def emit(self, record): + def handle(self, record): """Handle a log record. Parameters ---------- record : logging.LogRecord - """ - if record.levelno >= logging.ERROR: - self.error_count += 1 - elif record.levelno == logging.WARNING: - self.warning_count += 1 + Returns + ------- + out : bool + """ if self.group_errors and logging.WARNING <= record.levelno <= logging.ERROR: self._records.append(record) else: - super().emit(record) + self.emit(record) + return True def emit_grouped(self): """Emit all saved log records in groups. @@ -339,6 +353,42 @@ def emit_grouped(self): self._records = [] +class LogCounter(logging.NullHandler): + """Logging handler that counts warnings, errors and critical records. + + Attributes + ---------- + critical_count : int + error_count : int + warning_count : int + """ + + def __init__(self): + super().__init__() + self.critical_count = 0 + self.error_count = 0 + self.warning_count = 0 + + def handle(self, record): + """Count the log record if is a warning or more severe. + + Parameters + ---------- + record : logging.LogRecord + + Returns + ------- + out : bool + """ + if record.levelno >= logging.CRITICAL: + self.critical_count += 1 + elif record.levelno >= logging.ERROR: + self.error_count += 1 + elif record.levelno >= logging.WARNING: + self.warning_count += 1 + return True + + def setup_logging(*, verbosity, group_errors): """Setup logging to stderr for docstub's main process. @@ -349,10 +399,11 @@ def setup_logging(*, verbosity, group_errors): Returns ------- - handler : ReportHandler + output_handler : ReportHandler + log_counter : LogCounter """ _VERBOSITY_LEVEL = { - -2: logging.CRITICAL + 1, # never print anything + -2: logging.CRITICAL, -1: logging.ERROR, 0: logging.WARNING, 1: logging.INFO, @@ -360,6 +411,9 @@ def setup_logging(*, verbosity, group_errors): 3: logging.DEBUG, } + output_level = _VERBOSITY_LEVEL[verbosity] + report_level = min(logging.WARNING, output_level) + format_ = "%(message)s" if verbosity >= 3: debug_info = ( @@ -373,18 +427,21 @@ def setup_logging(*, verbosity, group_errors): debug_info = indent(",\n".join(debug_info), prefix=" ") format_ = f"{format_}\n [\n{debug_info}\n ]" - formatter = logging.Formatter(format_) - handler = ReportHandler(group_errors=group_errors) - handler.setLevel(_VERBOSITY_LEVEL[verbosity]) - handler.setFormatter(formatter) + reporter = ReportHandler(group_errors=group_errors) + reporter.setLevel(_VERBOSITY_LEVEL[verbosity]) + + log_counter = LogCounter() + log_counter.setLevel(report_level) # Only allow logging by docstub itself - handler.addFilter(logging.Filter("docstub")) + reporter.addFilter(logging.Filter("docstub")) + log_counter.addFilter(logging.Filter("docstub")) logging.basicConfig( - level=_VERBOSITY_LEVEL[verbosity], - handlers=[handler], + format=format_, + level=report_level, + handlers=[reporter, log_counter], ) logging.captureWarnings(True) - return handler + return reporter, log_counter diff --git a/src/docstub/_stubs.py b/src/docstub/_stubs.py index dc50d65..862e40e 100644 --- a/src/docstub/_stubs.py +++ b/src/docstub/_stubs.py @@ -17,7 +17,7 @@ from ._analysis import PyImport from ._docstrings import DocstringAnnotations, DoctypeTransformer, FallbackAnnotation from ._report import ContextReporter -from ._utils import module_name_from_path +from ._utils import module_name_from_path, update_with_add_values logger: logging.Logger = logging.getLogger(__name__) @@ -385,6 +385,26 @@ def python_to_stub(self, source, *, module_path=None): self._required_imports = None self.current_source = None + def collect_stats(self, *, reset_after=True): + """Return statistics from processing files. + + Parameters + ---------- + reset_after : bool, optional + Whether to reset counters and statistics after returning. + + Returns + ------- + stats : dict of {str: int or list[str]} + """ + collected = [self.transformer.stats, self.transformer.matcher.stats] + merged = update_with_add_values(*collected) + if reset_after is True: + for stats in collected: + for key in stats: + stats[key] = type(stats[key])() + return merged + def visit_ClassDef(self, node): """Collect pytypes from class docstring and add scope to stack. diff --git a/src/docstub/_utils.py b/src/docstub/_utils.py index a10c91e..297d881 100644 --- a/src/docstub/_utils.py +++ b/src/docstub/_utils.py @@ -159,5 +159,42 @@ def pyfile_checksum(path): return key +def update_with_add_values(*mappings, out=None): + """Merge mappings while adding together their values. + + Parameters + ---------- + mappings : Mapping[Hashable, int or Sequence] + out : dict, optional + + Returns + ------- + out : dict, optional + + Examples + -------- + >>> stats_1 = {"errors": 2, "warnings": 0, "unknown": ["string", "integer"]} + >>> stats_2 = {"unknown": ["func"], "errors": 1} + >>> update_with_add_values(stats_1, stats_2) + {'errors': 3, 'warnings': 0, 'unknown': ['string', 'integer', 'func']} + + >>> _ = update_with_add_values(stats_1, out=stats_2) + >>> stats_2 + {'unknown': ['func', 'string', 'integer'], 'errors': 3, 'warnings': 0} + + >>> update_with_add_values({"lines": (1, 33)}, {"lines": (42,)}) + {'lines': (1, 33, 42)} + """ + if out is None: + out = {} + for m in mappings: + for key, value in m.items(): + if hasattr(value, "__add__"): + out[key] = out.setdefault(key, type(value)()) + value + else: + raise TypeError(f"Don't know how to 'add' {value!r}") + return out + + class DocstubError(Exception): """An error raised by docstub.""" diff --git a/src/docstub/_vendored/stdlib.py b/src/docstub/_vendored/stdlib.py index 62b09cf..7d4e6c9 100644 --- a/src/docstub/_vendored/stdlib.py +++ b/src/docstub/_vendored/stdlib.py @@ -3,9 +3,15 @@ # # See LICENSE.txt for the full license text +"""Vendored snippets from Python's standard library. + +These are not available yet in all supported Python versions. +""" + import os import re from collections.abc import Sequence +from concurrent.futures import ProcessPoolExecutor as _ProcessPoolExecutor # Vendored `fnmatch._translate` from Python 3.13.4 because it isn't available in @@ -146,3 +152,80 @@ def glob_translate( results.append(any_sep) res = "".join(results) return rf"(?s:{res})\Z" + + +# Vendored `ProcessPoolExecutor.terminate_workers` from Python 3.14 because +# it isn't available in earlier Python versions. Copied from +# https://github.com/python/cpython/blob/02604314ba3e97cc1918520e9ef5c0c4a6e7fe47/Lib/concurrent/futures/process.py#L878-L939 +if not hasattr(_ProcessPoolExecutor, "terminate_workers"): + _TERMINATE: str = "terminate" + _KILL: str = "kill" + + _SHUTDOWN_CALLBACK_OPERATION: set[str] = {_TERMINATE, _KILL} + + class ProcessPoolExecutor(_ProcessPoolExecutor): + def _force_shutdown(self, operation: str) -> None: + """Attempts to terminate or kill the executor's workers based off the + given operation. Iterates through all of the current processes and + performs the relevant task if the process is still alive. + + After terminating workers, the pool will be in a broken state + and no longer usable (for instance, new tasks should not be + submitted). + """ + if operation not in _SHUTDOWN_CALLBACK_OPERATION: + raise ValueError(f"Unsupported operation: {operation!r}") + + processes = {} + if self._processes: + processes = self._processes.copy() + + # shutdown will invalidate ._processes, so we copy it right before + # calling. If we waited here, we would deadlock if a process decides not + # to exit. + self.shutdown(wait=False, cancel_futures=True) + + if not processes: + return + + for proc in processes.values(): + try: + if not proc.is_alive(): + continue + except ValueError: + # The process is already exited/closed out. + continue + + try: + if operation == _TERMINATE: + proc.terminate() + elif operation == _KILL: + proc.kill() + except ProcessLookupError: + # The process just ended before our signal + continue + + def terminate_workers(self) -> None: + """Attempts to terminate the executor's workers. + Iterates through all of the current worker processes and terminates + each one that is still alive. + + After terminating workers, the pool will be in a broken state + and no longer usable (for instance, new tasks should not be + submitted). + """ + return self._force_shutdown(operation=_TERMINATE) + + def kill_workers(self) -> None: + """Attempts to kill the executor's workers. + Iterates through all of the current worker processes and kills + each one that is still alive. + + After killing workers, the pool will be in a broken state + and no longer usable (for instance, new tasks should not be + submitted). + """ + return self._force_shutdown(operation=_KILL) + +else: + ProcessPoolExecutor: _ProcessPoolExecutor = _ProcessPoolExecutor # type: ignore[no-redef] diff --git a/tests/test_cli.py b/tests/test_cli.py index d0f751a..8d74ac9 100644 --- a/tests/test_cli.py +++ b/tests/test_cli.py @@ -1,7 +1,10 @@ """Test command line interface.""" import logging +import subprocess +import sys from pathlib import Path +from textwrap import dedent import pytest from click.testing import CliRunner @@ -54,6 +57,83 @@ def test_no_cache(self, tmp_path_cwd, caplog): # Check that at least one collected file was logged as "(cached)" assert "cached" not in "\n".join(caplog.messages) + @pytest.mark.slow + @pytest.mark.parametrize("workers", [1, 2]) + def test_fail_on_warning(self, workers, tmp_path_cwd): + source_with_warning = dedent( + ''' + def foo(x: str): + """ + Parameters + ---------- + x : int + """ + ''' + ) + package = tmp_path_cwd / "src/sample_package" + package.mkdir(parents=True) + init_py = package / "__init__.py" + with init_py.open("x") as io: + io.write(source_with_warning) + + result = subprocess.run( + [ + sys.executable, + "-m", + "docstub", + "run", + "--quiet", + "--quiet", + "--fail-on-warning", + "--workers", + str(workers), + str(package), + ], + check=False, + capture_output=True, + text=True, + ) + assert result.returncode == 1 + + @pytest.mark.slow + @pytest.mark.parametrize("workers", [1, 2]) + def test_no_output_exit_code(self, workers, tmp_path_cwd): + faulty_source = dedent( + ''' + def foo(x): + """ + Parameters + ---------- + x : doctype with syntax error + """ + ''' + ) + package = tmp_path_cwd / "src/sample_package" + package.mkdir(parents=True) + init_py = package / "__init__.py" + with init_py.open("x") as io: + io.write(faulty_source) + + result = subprocess.run( + [ + sys.executable, + "-m", + "docstub", + "run", + "--quiet", + "--quiet", + "--workers", + str(workers), + str(package), + ], + check=False, + capture_output=True, + text=True, + ) + assert result.stdout == "" + assert result.stderr == "" + assert result.returncode == 1 + class Test_clean: @pytest.mark.parametrize("verbosity", [["-v"], ["--verbose"], []]) diff --git a/tests/test_concurrency.py b/tests/test_concurrency.py new file mode 100644 index 0000000..e4b65dc --- /dev/null +++ b/tests/test_concurrency.py @@ -0,0 +1,26 @@ +import math +import os + +import pytest + +from docstub._concurrency import guess_concurrency_params + + +class Test_guess_concurrency_params: + @pytest.mark.parametrize("task_count", list(range(9))) + @pytest.mark.parametrize("cpu_count", [1, 8]) + def test_default_below_cutoff(self, task_count, cpu_count, monkeypatch): + monkeypatch.setattr(os, "cpu_count", lambda: cpu_count) + monkeypatch.setattr(os, "process_cpu_count", lambda: cpu_count, raising=False) + worker_count, chunk_size = guess_concurrency_params(task_count=task_count) + assert worker_count == 1 + assert chunk_size == task_count + + @pytest.mark.parametrize("task_count", [10, 15, 50, 100, 1000]) + @pytest.mark.parametrize("cpu_count", [1, 8, 16]) + def test_default(self, task_count, cpu_count, monkeypatch): + monkeypatch.setattr(os, "cpu_count", lambda: cpu_count) + monkeypatch.setattr(os, "process_cpu_count", lambda: cpu_count, raising=False) + worker_count, chunk_size = guess_concurrency_params(task_count=task_count) + assert worker_count == min(cpu_count, task_count // 3) + assert chunk_size == math.ceil(task_count / worker_count)