From 4afd87b8a2ed2ec2e2caa0b251dcd3e29ae744b8 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Lars=20Gr=C3=BCter?= Date: Sat, 1 Nov 2025 08:55:30 +0100 Subject: [PATCH 01/22] Draft multiprocessing with PoolExecutor --- src/docstub-stubs/_cli.pyi | 5 + src/docstub-stubs/_concurrency.pyi | 44 +++++ src/docstub-stubs/_vendored/stdlib.pyi | 15 ++ src/docstub/_cli.py | 110 ++++++++---- src/docstub/_concurrency.py | 222 +++++++++++++++++++++++++ src/docstub/_vendored/stdlib.py | 83 +++++++++ 6 files changed, 451 insertions(+), 28 deletions(-) create mode 100644 src/docstub-stubs/_concurrency.pyi create mode 100644 src/docstub/_concurrency.py diff --git a/src/docstub-stubs/_cli.pyi b/src/docstub-stubs/_cli.pyi index cd9690a..8acea45 100644 --- a/src/docstub-stubs/_cli.pyi +++ b/src/docstub-stubs/_cli.pyi @@ -16,6 +16,7 @@ from _typeshed import Incomplete from ._analysis import PyImport, TypeCollector, TypeMatcher, common_known_types from ._cache import CACHE_DIR_NAME, FileCache, validate_cache from ._cli_help import HelpFormatter +from ._concurrency import LoggingProcessExecutor, guess_concurrency_params from ._config import Config from ._path_utils import ( STUB_HEADER_COMMENT, @@ -45,6 +46,9 @@ click.Context.formatter_class = HelpFormatter @click.group() def cli() -> None: ... def _add_verbosity_options(func: Callable) -> Callable: ... +def _transform_to_stub( + source_path: Path, stub_path: Path, stub_transformer: Py2StubTransformer +) -> dict[str, int | list[str]]: ... @cli.command() def run( *, @@ -55,6 +59,7 @@ def run( group_errors: bool, allow_errors: int, fail_on_warning: bool, + jobs: int | None, no_cache: bool, verbose: int, quiet: int, diff --git a/src/docstub-stubs/_concurrency.pyi b/src/docstub-stubs/_concurrency.pyi new file mode 100644 index 0000000..ff14ee7 --- /dev/null +++ b/src/docstub-stubs/_concurrency.pyi @@ -0,0 +1,44 @@ +# File generated with docstub + +import logging +import logging.handlers +import math +import multiprocessing +import os +from collections.abc import Callable, Iterable +from concurrent.futures import Executor +from dataclasses import dataclass +from multiprocessing import Queue +from types import TracebackType +from typing import Any + +from ._vendored.stdlib import ProcessPoolExecutor + +logger: logging.Logger + +class MockPoolExecutor(Executor): + def map(self, fn: Callable, *iterables: Any, **__: Any) -> Iterable: ... + +@dataclass(kw_only=True) +class LoggingProcessExecutor: + + max_workers: int | None = ... + logging_handlers: tuple[logging.Handler, ...] = ... + initializer: Callable | None = ... + initargs: tuple | None = ... + + @staticmethod + def _initialize_worker( + queue: Queue, worker_log_level: int, initializer: Callable, initargs: tuple[Any] + ) -> None: ... + def __enter__(self) -> ProcessPoolExecutor | MockPoolExecutor: ... + def __exit__( + self, + exc_type: type[BaseException] | None, + exc_val: BaseException | None, + exc_tb: TracebackType, + ) -> bool: ... + +def guess_concurrency_params( + *, task_count: int, worker_count: int | None = ... +) -> tuple[int, int]: ... diff --git a/src/docstub-stubs/_vendored/stdlib.pyi b/src/docstub-stubs/_vendored/stdlib.pyi index b4ff650..d589113 100644 --- a/src/docstub-stubs/_vendored/stdlib.pyi +++ b/src/docstub-stubs/_vendored/stdlib.pyi @@ -3,6 +3,7 @@ import os import re from collections.abc import Sequence +from concurrent.futures import ProcessPoolExecutor as _ProcessPoolExecutor def _fnmatch_translate(pat: str, STAR: str, QUESTION_MARK: str) -> str: ... def glob_translate( @@ -12,3 +13,17 @@ def glob_translate( include_hidden: bool = ..., seps: Sequence[str] | None = ..., ) -> str: ... + +if not hasattr(_ProcessPoolExecutor, "terminate_workers"): + _TERMINATE: str + _KILL: str + + _SHUTDOWN_CALLBACK_OPERATION: set[str] + + class ProcessPoolExecutor(_ProcessPoolExecutor): + def _force_shutdown(self, operation: str) -> None: ... + def terminate_workers(self) -> None: ... + def kill_workers(self) -> None: ... + +else: + ProcessPoolExecutor: _ProcessPoolExecutor diff --git a/src/docstub/_cli.py b/src/docstub/_cli.py index 0c64e3a..432afaf 100644 --- a/src/docstub/_cli.py +++ b/src/docstub/_cli.py @@ -16,6 +16,7 @@ ) from ._cache import CACHE_DIR_NAME, FileCache, validate_cache from ._cli_help import HelpFormatter +from ._concurrency import LoggingProcessExecutor, guess_concurrency_params from ._config import Config from ._path_utils import ( STUB_HEADER_COMMENT, @@ -255,6 +256,46 @@ def _add_verbosity_options(func): return func +def _transform_to_stub(source_path, stub_path, stub_transformer): + """Transform a Python file into a stub file. + + Parameters + ---------- + source_path : Path + stub_path : Path + stub_transformer : Py2StubTransformer + + Returns + ------- + stats : dict of {str: int or list[str]} + """ + if source_path.suffix.lower() == ".pyi": + logger.debug("Using existing stub file %s", source_path) + with source_path.open() as fo: + stub_content = fo.read() + else: + with source_path.open() as fo: + py_content = fo.read() + logger.debug("Transforming %s", source_path) + try: + stub_content = stub_transformer.python_to_stub( + py_content, module_path=source_path + ) + stub_content = f"{STUB_HEADER_COMMENT}\n\n{stub_content}" + stub_content = try_format_stub(stub_content) + except Exception: + logger.exception("Failed creating stub for %s", source_path) + return None + + stub_path.parent.mkdir(parents=True, exist_ok=True) + with stub_path.open("w") as fo: + logger.info("Wrote %s", stub_path) + fo.write(stub_content) + + stats = stub_transformer.collect_stats() + return stats + + # Preserve click.command below to keep type checker happy # docstub: off @cli.command() @@ -312,6 +353,13 @@ def _add_verbosity_options(func): help="Return non-zero exit code when a warning is raised. " "Will add to --allow-errors.", ) +@click.option( + "--jobs", + type=click.IntRange(min=1), + metavar="INT", + help="Set the number of jobs to use in parallel. By default docstub will " + "attempt to choose an appropriate number.", +) @click.option( "--no-cache", is_flag=True, @@ -329,6 +377,7 @@ def run( group_errors, allow_errors, fail_on_warning, + jobs, no_cache, verbose, quiet, @@ -349,6 +398,7 @@ def run( group_errors : bool allow_errors : int fail_on_warning : bool + jobs : int | None no_cache : bool verbose : int quiet : int @@ -357,7 +407,9 @@ def run( # Setup ------------------------------------------------------------------- verbosity = _calc_verbosity(verbose=verbose, quiet=quiet) - error_handler = setup_logging(verbosity=verbosity, group_errors=group_errors) + output_handler, error_counter = setup_logging( + verbosity=verbosity, group_errors=group_errors + ) root_path = Path(root_path) if root_path.is_file(): @@ -409,30 +461,32 @@ def run( # Stub generation --------------------------------------------------------- - for source_path, stub_path in walk_source_and_targets( - root_path, out_dir, ignore=config.ignore_files - ): - if source_path.suffix.lower() == ".pyi": - logger.debug("Using existing stub file %s", source_path) - with source_path.open() as fo: - stub_content = fo.read() - else: - with source_path.open() as fo: - py_content = fo.read() - logger.debug("Transforming %s", source_path) - try: - stub_content = stub_transformer.python_to_stub( - py_content, module_path=source_path - ) - stub_content = f"{STUB_HEADER_COMMENT}\n\n{stub_content}" - stub_content = try_format_stub(stub_content) - except Exception: - logger.exception("Failed creating stub for %s", source_path) - continue - stub_path.parent.mkdir(parents=True, exist_ok=True) - with stub_path.open("w") as fo: - logger.info("Wrote %s", stub_path) - fo.write(stub_content) + tasks = walk_source_and_targets(root_path, out_dir, ignore=config.ignore_files) + + # We must pass the `stub_transformer` to each worker, but we want to copy + # only once per worker. Testing suggests, that using a large enough + # `chunksize` of `>= len(tasks) / jobs` for `ProcessPoolExecutor.map`, + # ensures that. + # Using an `initializer` that assigns the transformer as a global variable + # per worker seems like the more robust solution, but naive timing suggests + # it's actually slower (> 1s on skimage). + tasks = [(*task, stub_transformer) for task in tasks] + + worker_count, chunk_size = guess_concurrency_params( + task_count=len(tasks), worker_count=jobs + ) + logger.info("Using %i parallel jobs to write %i stubs", worker_count, len(tasks)) + logger.debug("Using chunk size of %i", chunk_size) + with LoggingProcessExecutor( + max_workers=worker_count, + logging_handlers=(output_handler, error_counter), + ) as executor: + # Doesn't block + stats = executor.map( + _transform_to_stub, *zip(*tasks, strict=False), chunksize=chunk_size + ) + # Iterate results (which blocks) until all tasks have been processed + stats = update_with_add_values(*stats) py_typed_out = out_dir / "py.typed" if not py_typed_out.exists(): @@ -442,9 +496,9 @@ def run( # Reporting -------------------------------------------------------------- if group_errors: - error_handler.emit_grouped() - assert error_handler.group_errors is True - error_handler.group_errors = False + output_handler.emit_grouped() + assert output_handler.group_errors is True + output_handler.group_errors = False # Report basic statistics successful_queries = matcher.successful_queries diff --git a/src/docstub/_concurrency.py b/src/docstub/_concurrency.py new file mode 100644 index 0000000..bede652 --- /dev/null +++ b/src/docstub/_concurrency.py @@ -0,0 +1,222 @@ +"""Tools for parallel processing.""" + +import logging +import logging.handlers +import math +import multiprocessing +import os +from collections.abc import Callable +from concurrent.futures import Executor +from dataclasses import dataclass +from multiprocessing import Queue + +from ._vendored.stdlib import ProcessPoolExecutor + +logger: logging.Logger = logging.getLogger(__name__) + + +class MockPoolExecutor(Executor): + """Mock executor that does not spawn a thread, interpreter, or process. + + Only implements the used part of the API defined by + :class:`concurrent.futures.Executor`. + """ + + def map(self, fn, *iterables, **__): + """Returns an iterator equivalent to map(fn, iter). + + Same behavior as :ref:`concurrent.futures.Executor.map` though any + parameters besides `fn` and `iterables` are ignored. + + Parameters + ---------- + fn : Callable + *iterables : Any + **__ : Any + + Returns + ------- + results : Iterable + """ + tasks = zip(*iterables, strict=False) + for task in tasks: + yield fn(*task) + + +@dataclass(kw_only=True) +class LoggingProcessExecutor: + """Wrapper around `ProcessPoolExecutor` that forwards logging from workers. + + Parameters + ---------- + max_workers, initializer, initargs: + Refer to the documentation :class:`concurrent.futures.ProcessPoolExecutor`. + logging_handlers: + Handlers, to which logging records of the worker processes will be + forwarded too. Worker processes will use the minimal log level of + the given workers. + + Examples + -------- + >>> with LoggingProcessExecutor() as pool: # doctest: +SKIP + ... # use `pool.submit` or `pool.map` ... + """ + + max_workers: int | None = None + logging_handlers: tuple[logging.Handler, ...] = () + initializer: Callable | None = None + initargs: tuple | None = () + + @staticmethod + def _initialize_worker(queue, worker_log_level, initializer, initargs): + """Initialize logging in workers. + + Parameters + ---------- + queue : Queue + worker_log_level : int + initializer : Callable + initargs : tuple of Any + """ + queue_handler = logging.handlers.QueueHandler(queue) + queue_handler.setLevel(worker_log_level) + + # Could buffering with MemoryHandler improve performance here? + # memory_handler = logging.handlers.MemoryHandler( + # capacity=100, flushLevel=logging.CRITICAL, target=queue_handler + # ) + + root_logger = logging.getLogger() + root_logger.addHandler(queue_handler) + root_logger.setLevel(worker_log_level) + if initializer: + initializer(*initargs) + logger.debug("Initialized worker") + + def __enter__(self) -> ProcessPoolExecutor | MockPoolExecutor: + if self.max_workers == 1: + logger.debug("Not using concurrency (workers=%i)", self.max_workers) + return MockPoolExecutor() + + # This sets the logging level of worker processes. Use the minimal level + # of all handlers here, so that appropriate records are passed on + worker_log_level = min(*[h.level for h in self.logging_handlers]) + + # Sets method by which the worker processes are created, anything besides + # "spawn" is apparently "broken" on Windows & macOS + mp_context = multiprocessing.get_context("spawn") + + # A queue, used to pass logging records from worker processes to the + # current and main one. Naive testing suggests that + # `multiprocessing.Queue` is faster than `multiprocessing.Manager.Queue` + self._queue = Queue() + + # The actual pool manager that is wrapped here and returned by this + # context manager + self._pool = ProcessPoolExecutor( + max_workers=self.max_workers, + mp_context=mp_context, + initializer=self._initialize_worker, + initargs=( + self._queue, + worker_log_level, + self.initializer, + self.initargs, + ), + ) + + # Forwards logging records from the queue to the given logging handlers + self._listener = logging.handlers.QueueListener( + self._queue, *self.logging_handlers, respect_handler_level=True + ) + logger.debug("Starting queue listener") + self._listener.start() + + return self._pool.__enter__() + + def __exit__(self, exc_type, exc_val, exc_tb): + """ + + Parameters + ---------- + exc_type : type[BaseException] or None + exc_val : BaseException or None + exc_tb : TracebackType + + Returns + ------- + suppress_exception : bool + """ + if self.max_workers == 1: + return False + + if exc_type and issubclass(exc_type, KeyboardInterrupt): + # We want to exit immediately when user requests a KeyboardInterrupt + # even if we may lose logging records that way + logger.debug("Terminating workers") + self._pool.terminate_workers() + # Calling `self._listener.stop()` here seems to block forever?! + else: + # Graceful shutdown + logger.debug("Shutting down pool") + self._pool.__exit__(exc_type, exc_val, exc_tb) + logger.debug("Stopping queue listener") + self._listener.stop() + + logger.debug("Closing queue and joining its thread") + self._queue.close() + self._queue.join_thread() + + logger.debug("Exiting executor pool") + + return False + + +def guess_concurrency_params(*, task_count, worker_count=None): + """Estimate how tasks should be distributed to how many workers. + + Parameters + ---------- + task_count : int + The number of task that need to be processed. + worker_count : int, optional + If not set, the number of workers is estimated. Set this explicitly + to force a number of workers. + + Returns + ------- + worker_count : int + The number of workers that should be used. + chunk_size : int + The chunk size that should be used to split the tasks among the workers. + + Examples + -------- + >>> _, chunk_size = guess_concurrency_params(task_count=10, worker_count=8) + >>> chunk_size + 2 + """ + # `process_cpu_count` was added in Python 3.13 onwards + cpu_count = getattr(os, "process_cpu_count", os.cpu_count)() + + if worker_count is None: + # These crude heuristics were only ever "measured" on one computer + worker_count = cpu_count + # Clip to `worker_count <= task_count // 3` + worker_count = min(worker_count, task_count // 3) + # For a low number of files it may not be worth spinning up any workers + if task_count < 10: + worker_count = 1 + + # Clip to [1, cpu_count] + worker_count = max(1, min(cpu_count, worker_count)) + + # Chunking prevents unnecessary pickling of objects that are shared between + # each task more than once. When `worker_count * chunk_size` is slightly + # larger than `task_count`, each worker process only ever receives one chunk + chunk_size = task_count / worker_count + chunk_size = math.ceil(chunk_size) + + assert isinstance(worker_count, int) + assert isinstance(chunk_size, int) + return worker_count, chunk_size diff --git a/src/docstub/_vendored/stdlib.py b/src/docstub/_vendored/stdlib.py index 62b09cf..4f4abc3 100644 --- a/src/docstub/_vendored/stdlib.py +++ b/src/docstub/_vendored/stdlib.py @@ -3,9 +3,15 @@ # # See LICENSE.txt for the full license text +"""Vendored snippets from Python's standard library. + +These are not available yet in all supported Python versions. +""" + import os import re from collections.abc import Sequence +from concurrent.futures import ProcessPoolExecutor as _ProcessPoolExecutor # Vendored `fnmatch._translate` from Python 3.13.4 because it isn't available in @@ -146,3 +152,80 @@ def glob_translate( results.append(any_sep) res = "".join(results) return rf"(?s:{res})\Z" + + +# Vendored `ProcessPoolExecutor.terminate_workers` from Python 3.14 because +# it isn't available in earlier Python versions. Copied from +# https://github.com/python/cpython/blob/02604314ba3e97cc1918520e9ef5c0c4a6e7fe47/Lib/concurrent/futures/process.py#L878-L939 +if not hasattr(_ProcessPoolExecutor, "terminate_workers"): + _TERMINATE: str = "terminate" + _KILL: str = "kill" + + _SHUTDOWN_CALLBACK_OPERATION: set[str] = {_TERMINATE, _KILL} + + class ProcessPoolExecutor(_ProcessPoolExecutor): + def _force_shutdown(self, operation: str) -> None: + """Attempts to terminate or kill the executor's workers based off the + given operation. Iterates through all of the current processes and + performs the relevant task if the process is still alive. + + After terminating workers, the pool will be in a broken state + and no longer usable (for instance, new tasks should not be + submitted). + """ + if operation not in _SHUTDOWN_CALLBACK_OPERATION: + raise ValueError(f"Unsupported operation: {operation!r}") + + processes = {} + if self._processes: + processes = self._processes.copy() + + # shutdown will invalidate ._processes, so we copy it right before + # calling. If we waited here, we would deadlock if a process decides not + # to exit. + self.shutdown(wait=False, cancel_futures=True) + + if not processes: + return + + for proc in processes.values(): + try: + if not proc.is_alive(): + continue + except ValueError: + # The process is already exited/closed out. + continue + + try: + if operation == _TERMINATE: + proc.terminate() + elif operation == _KILL: + proc.kill() + except ProcessLookupError: + # The process just ended before our signal + continue + + def terminate_workers(self) -> None: + """Attempts to terminate the executor's workers. + Iterates through all of the current worker processes and terminates + each one that is still alive. + + After terminating workers, the pool will be in a broken state + and no longer usable (for instance, new tasks should not be + submitted). + """ + return self._force_shutdown(operation=_TERMINATE) + + def kill_workers(self) -> None: + """Attempts to kill the executor's workers. + Iterates through all of the current worker processes and kills + each one that is still alive. + + After killing workers, the pool will be in a broken state + and no longer usable (for instance, new tasks should not be + submitted). + """ + return self._force_shutdown(operation=_KILL) + +else: + ProcessPoolExecutor: _ProcessPoolExecutor = _ProcessPoolExecutor From fbc017201e702a0c1380b3c2c941aa437aa73ae6 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Lars=20Gr=C3=BCter?= Date: Sat, 1 Nov 2025 08:58:18 +0100 Subject: [PATCH 02/22] Count errors in dedicated handler Splitting these two responsibilities (logging to console and counting) simplifies both in context of multiprocessing. --- src/docstub-stubs/_report.pyi | 12 +++++- src/docstub/_cli.py | 4 +- src/docstub/_report.py | 70 +++++++++++++++++++++++++---------- 3 files changed, 62 insertions(+), 24 deletions(-) diff --git a/src/docstub-stubs/_report.pyi b/src/docstub-stubs/_report.pyi index 7d2a29e..809d7e3 100644 --- a/src/docstub-stubs/_report.pyi +++ b/src/docstub-stubs/_report.pyi @@ -62,9 +62,17 @@ class ReportHandler(logging.StreamHandler): self, stream: TextIO | None = ..., group_errors: bool = ... ) -> None: ... def format(self, record: logging.LogRecord) -> str: ... - def emit(self, record: logging.LogRecord) -> None: ... + def handle(self, record: logging.LogRecord) -> None: ... def emit_grouped(self) -> None: ... +class LogCounter(logging.NullHandler): + critical_count: int + error_count: int + warning_count: int + + def __init__(self) -> None: ... + def handle(self, record: logging.Record) -> None: ... + def setup_logging( *, verbosity: Literal[-2, -1, 0, 1, 2, 3], group_errors: bool -) -> ReportHandler: ... +) -> tuple[ReportHandler, LogCounter]: ... diff --git a/src/docstub/_cli.py b/src/docstub/_cli.py index 432afaf..ccd7593 100644 --- a/src/docstub/_cli.py +++ b/src/docstub/_cli.py @@ -505,8 +505,8 @@ def run( transformed_doctypes = stub_transformer.transformer.stats["transformed"] syntax_error_count = stub_transformer.transformer.stats["syntax_errors"] unknown_type_names = matcher.unknown_qualnames - total_warnings = error_handler.warning_count - total_errors = error_handler.error_count + total_warnings = error_counter.warning_count + total_errors = error_counter.error_count logger.info("Recognized type names: %i", successful_queries) logger.info("Transformed doctypes: %i", transformed_doctypes) diff --git a/src/docstub/_report.py b/src/docstub/_report.py index 9d3239e..ba4a095 100644 --- a/src/docstub/_report.py +++ b/src/docstub/_report.py @@ -229,9 +229,6 @@ def __init__(self, stream=None, group_errors=False): self.group_errors = group_errors self._records = [] - self.error_count = 0 - self.warning_count = 0 - self.strip_ansi = should_strip_ansi(self.stream) def format(self, record): @@ -292,22 +289,17 @@ def format(self, record): return msg - def emit(self, record): + def handle(self, record): """Handle a log record. Parameters ---------- record : logging.LogRecord """ - if record.levelno >= logging.ERROR: - self.error_count += 1 - elif record.levelno == logging.WARNING: - self.warning_count += 1 - if self.group_errors and logging.WARNING <= record.levelno <= logging.ERROR: self._records.append(record) else: - super().emit(record) + self.emit(record) def emit_grouped(self): """Emit all saved log records in groups. @@ -339,6 +331,37 @@ def emit_grouped(self): self._records = [] +class LogCounter(logging.NullHandler): + """Logging handler that counts warnings, errors and critical records. + + Attributes + ---------- + critical_count : int + error_count : int + warning_count : int + """ + + def __init__(self): + super().__init__() + self.critical_count = 0 + self.error_count = 0 + self.warning_count = 0 + + def handle(self, record): + """Count the log record if is a warning or more severe. + + Parameters + ---------- + record : logging.Record + """ + if record.levelno >= logging.CRITICAL: + self.critical_count += 1 + elif record.levelno >= logging.ERROR: + self.error_count += 1 + elif record.levelno >= logging.WARNING: + self.warning_count += 1 + + def setup_logging(*, verbosity, group_errors): """Setup logging to stderr for docstub's main process. @@ -349,10 +372,11 @@ def setup_logging(*, verbosity, group_errors): Returns ------- - handler : ReportHandler + output_handler : ReportHandler + log_counter : LogCounter """ _VERBOSITY_LEVEL = { - -2: logging.CRITICAL + 1, # never print anything + -2: logging.CRITICAL, -1: logging.ERROR, 0: logging.WARNING, 1: logging.INFO, @@ -360,6 +384,9 @@ def setup_logging(*, verbosity, group_errors): 3: logging.DEBUG, } + output_level = _VERBOSITY_LEVEL[verbosity] + report_level = min(logging.WARNING, output_level) + format_ = "%(message)s" if verbosity >= 3: debug_info = ( @@ -373,18 +400,21 @@ def setup_logging(*, verbosity, group_errors): debug_info = indent(",\n".join(debug_info), prefix=" ") format_ = f"{format_}\n [\n{debug_info}\n ]" - formatter = logging.Formatter(format_) - handler = ReportHandler(group_errors=group_errors) - handler.setLevel(_VERBOSITY_LEVEL[verbosity]) - handler.setFormatter(formatter) + reporter = ReportHandler(group_errors=group_errors) + reporter.setLevel(_VERBOSITY_LEVEL[verbosity]) + + log_counter = LogCounter() + log_counter.setLevel(report_level) # Only allow logging by docstub itself - handler.addFilter(logging.Filter("docstub")) + reporter.addFilter(logging.Filter("docstub")) + log_counter.addFilter(logging.Filter("docstub")) logging.basicConfig( - level=_VERBOSITY_LEVEL[verbosity], - handlers=[handler], + format=format_, + level=report_level, + handlers=[reporter, log_counter], ) logging.captureWarnings(True) - return handler + return reporter, log_counter From 9228bb0196c9fd84cf8f9fbbb97aed069a4ec322 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Lars=20Gr=C3=BCter?= Date: Sat, 1 Nov 2025 08:59:36 +0100 Subject: [PATCH 03/22] Update how stats are propagated back to the main process --- src/docstub-stubs/_cli.pyi | 1 + src/docstub-stubs/_stubs.pyi | 5 ++++- src/docstub-stubs/_utils.pyi | 5 ++++- src/docstub/_analysis.py | 12 +++++++----- src/docstub/_cli.py | 25 ++++++++++++------------ src/docstub/_docstrings.py | 8 ++++---- src/docstub/_stubs.py | 22 ++++++++++++++++++++- src/docstub/_utils.py | 37 ++++++++++++++++++++++++++++++++++++ 8 files changed, 90 insertions(+), 25 deletions(-) diff --git a/src/docstub-stubs/_cli.pyi b/src/docstub-stubs/_cli.pyi index 8acea45..05173f8 100644 --- a/src/docstub-stubs/_cli.pyi +++ b/src/docstub-stubs/_cli.pyi @@ -26,6 +26,7 @@ from ._path_utils import ( ) from ._report import setup_logging from ._stubs import Py2StubTransformer, try_format_stub +from ._utils import update_with_add_values from ._version import __version__ logger: logging.Logger diff --git a/src/docstub-stubs/_stubs.pyi b/src/docstub-stubs/_stubs.pyi index a6521f6..87c2554 100644 --- a/src/docstub-stubs/_stubs.pyi +++ b/src/docstub-stubs/_stubs.pyi @@ -19,7 +19,7 @@ from ._docstrings import ( FallbackAnnotation, ) from ._report import ContextReporter -from ._utils import module_name_from_path +from ._utils import module_name_from_path, update_with_add_values logger: logging.Logger @@ -73,6 +73,9 @@ class Py2StubTransformer(cst.CSTTransformer): @property def is_inside_function_def(self) -> bool: ... def python_to_stub(self, source: str, *, module_path: Path | None = ...) -> str: ... + def collect_stats( + self, *, reset_after: bool = ... + ) -> dict[str, int | list[str]]: ... def visit_ClassDef(self, node: cst.ClassDef) -> Literal[True]: ... def leave_ClassDef( self, original_node: cst.ClassDef, updated_node: cst.ClassDef diff --git a/src/docstub-stubs/_utils.pyi b/src/docstub-stubs/_utils.pyi index 8b8b4bd..c29d6a8 100644 --- a/src/docstub-stubs/_utils.pyi +++ b/src/docstub-stubs/_utils.pyi @@ -2,7 +2,7 @@ import itertools import re -from collections.abc import Callable +from collections.abc import Callable, Hashable, Mapping, Sequence from functools import lru_cache, wraps from pathlib import Path from zlib import crc32 @@ -12,6 +12,9 @@ def escape_qualname(name: str) -> str: ... def _resolve_path_before_caching(func: Callable) -> Callable: ... def module_name_from_path(path: Path) -> str: ... def pyfile_checksum(path: Path) -> str: ... +def update_with_add_values( + *mappings: Mapping[Hashable, int | Sequence], out: dict | None = ... +) -> dict: ... class DocstubError(Exception): pass diff --git a/src/docstub/_analysis.py b/src/docstub/_analysis.py index d5bffe4..32a6daf 100644 --- a/src/docstub/_analysis.py +++ b/src/docstub/_analysis.py @@ -500,12 +500,14 @@ def __init__( type_prefixes : dict[str, PyImport] type_nicknames : dict[str, str] """ - self.types = common_known_types() | (types or {}) self.type_prefixes = type_prefixes or {} self.type_nicknames = type_nicknames or {} - self.successful_queries = 0 - self.unknown_qualnames = [] + + self.stats = { + "matched_type_names": 0, + "unknown_type_names": [], + } self.current_file = None @@ -621,8 +623,8 @@ def match(self, search): type_name = type_name[type_name.find(py_import.target) :] if type_name is not None: - self.successful_queries += 1 + self.stats["matched_type_names"] += 1 else: - self.unknown_qualnames.append(search) + self.stats["unknown_type_names"].append(search) return type_name, py_import diff --git a/src/docstub/_cli.py b/src/docstub/_cli.py index ccd7593..4faed86 100644 --- a/src/docstub/_cli.py +++ b/src/docstub/_cli.py @@ -26,6 +26,7 @@ ) from ._report import setup_logging from ._stubs import Py2StubTransformer, try_format_stub +from ._utils import update_with_add_values from ._version import __version__ logger: logging.Logger = logging.getLogger(__name__) @@ -198,7 +199,7 @@ def log_execution_time(): try: yield except KeyboardInterrupt: - logger.critical("Interrupt!") + logger.critical("Interrupted!") finally: stop = time.time() total_seconds = stop - start @@ -501,25 +502,23 @@ def run( output_handler.group_errors = False # Report basic statistics - successful_queries = matcher.successful_queries - transformed_doctypes = stub_transformer.transformer.stats["transformed"] - syntax_error_count = stub_transformer.transformer.stats["syntax_errors"] - unknown_type_names = matcher.unknown_qualnames total_warnings = error_counter.warning_count total_errors = error_counter.error_count - logger.info("Recognized type names: %i", successful_queries) - logger.info("Transformed doctypes: %i", transformed_doctypes) + logger.info("Recognized type names: %i", stats["matched_type_names"]) + logger.info("Transformed doctypes: %i", stats["transformed_doctypes"]) if total_warnings: logger.warning("Warnings: %i", total_warnings) - if syntax_error_count: - logger.warning("Syntax errors: %i", syntax_error_count) - if unknown_type_names: + if stats["doctype_syntax_errors"]: + assert total_errors + logger.warning("Syntax errors: %i", stats["doctype_syntax_errors"]) + if stats["unknown_type_names"]: + assert total_errors logger.warning( "Unknown type names: %i (locations: %i)", - len(set(unknown_type_names)), - len(unknown_type_names), - extra={"details": _format_unknown_names(unknown_type_names)}, + len(set(stats["unknown_type_names"])), + len(stats["unknown_type_names"]), + extra={"details": _format_unknown_names(stats["unknown_type_names"])}, ) if total_errors: logger.error("Total errors: %i", total_errors) diff --git a/src/docstub/_docstrings.py b/src/docstub/_docstrings.py index d4a11a4..245c3f0 100644 --- a/src/docstub/_docstrings.py +++ b/src/docstub/_docstrings.py @@ -273,8 +273,8 @@ def __init__(self, *, matcher=None, **kwargs): super().__init__(**kwargs) self.stats = { - "syntax_errors": 0, - "transformed": 0, + "doctype_syntax_errors": 0, + "transformed_doctypes": 0, } def doctype_to_annotation(self, doctype, *, reporter=None): @@ -303,14 +303,14 @@ def doctype_to_annotation(self, doctype, *, reporter=None): annotation = Annotation( value=value, imports=frozenset(self._collected_imports) ) - self.stats["transformed"] += 1 + self.stats["transformed_doctypes"] += 1 return annotation, self._unknown_qualnames except ( lark.exceptions.LexError, lark.exceptions.ParseError, QualnameIsKeyword, ): - self.stats["syntax_errors"] += 1 + self.stats["doctype_syntax_errors"] += 1 raise finally: self._reporter = None diff --git a/src/docstub/_stubs.py b/src/docstub/_stubs.py index dc50d65..862e40e 100644 --- a/src/docstub/_stubs.py +++ b/src/docstub/_stubs.py @@ -17,7 +17,7 @@ from ._analysis import PyImport from ._docstrings import DocstringAnnotations, DoctypeTransformer, FallbackAnnotation from ._report import ContextReporter -from ._utils import module_name_from_path +from ._utils import module_name_from_path, update_with_add_values logger: logging.Logger = logging.getLogger(__name__) @@ -385,6 +385,26 @@ def python_to_stub(self, source, *, module_path=None): self._required_imports = None self.current_source = None + def collect_stats(self, *, reset_after=True): + """Return statistics from processing files. + + Parameters + ---------- + reset_after : bool, optional + Whether to reset counters and statistics after returning. + + Returns + ------- + stats : dict of {str: int or list[str]} + """ + collected = [self.transformer.stats, self.transformer.matcher.stats] + merged = update_with_add_values(*collected) + if reset_after is True: + for stats in collected: + for key in stats: + stats[key] = type(stats[key])() + return merged + def visit_ClassDef(self, node): """Collect pytypes from class docstring and add scope to stack. diff --git a/src/docstub/_utils.py b/src/docstub/_utils.py index a10c91e..297d881 100644 --- a/src/docstub/_utils.py +++ b/src/docstub/_utils.py @@ -159,5 +159,42 @@ def pyfile_checksum(path): return key +def update_with_add_values(*mappings, out=None): + """Merge mappings while adding together their values. + + Parameters + ---------- + mappings : Mapping[Hashable, int or Sequence] + out : dict, optional + + Returns + ------- + out : dict, optional + + Examples + -------- + >>> stats_1 = {"errors": 2, "warnings": 0, "unknown": ["string", "integer"]} + >>> stats_2 = {"unknown": ["func"], "errors": 1} + >>> update_with_add_values(stats_1, stats_2) + {'errors': 3, 'warnings': 0, 'unknown': ['string', 'integer', 'func']} + + >>> _ = update_with_add_values(stats_1, out=stats_2) + >>> stats_2 + {'unknown': ['func', 'string', 'integer'], 'errors': 3, 'warnings': 0} + + >>> update_with_add_values({"lines": (1, 33)}, {"lines": (42,)}) + {'lines': (1, 33, 42)} + """ + if out is None: + out = {} + for m in mappings: + for key, value in m.items(): + if hasattr(value, "__add__"): + out[key] = out.setdefault(key, type(value)()) + value + else: + raise TypeError(f"Don't know how to 'add' {value!r}") + return out + + class DocstubError(Exception): """An error raised by docstub.""" From 46c032ad8f0ada541ac0daca6f0da1c4c4c05fe2 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Lars=20Gr=C3=BCter?= Date: Sat, 1 Nov 2025 09:01:15 +0100 Subject: [PATCH 04/22] Add critical method to ContextReporter --- src/docstub-stubs/_report.pyi | 3 +++ src/docstub/_report.py | 17 +++++++++++++++++ 2 files changed, 20 insertions(+) diff --git a/src/docstub-stubs/_report.pyi b/src/docstub-stubs/_report.pyi index 809d7e3..266b165 100644 --- a/src/docstub-stubs/_report.pyi +++ b/src/docstub-stubs/_report.pyi @@ -47,6 +47,9 @@ class ContextReporter: def error( self, short: str, *args: Any, details: str | None = ..., **log_kw: Any ) -> None: ... + def critical( + self, short: str, *args: Any, details: str | None = ..., **log_kw: Any + ) -> None: ... def __post_init__(self) -> None: ... @staticmethod def underline(line: str, *, char: str = ...) -> str: ... diff --git a/src/docstub/_report.py b/src/docstub/_report.py index ba4a095..fd0592c 100644 --- a/src/docstub/_report.py +++ b/src/docstub/_report.py @@ -174,6 +174,23 @@ def error(self, short, *args, details=None, **log_kw): short, *args, log_level=logging.ERROR, details=details, **log_kw ) + def critical(self, short, *args, details=None, **log_kw): + """Log a critical error with context of the relevant source. + + Parameters + ---------- + short : str + A short summarizing report that shouldn't wrap over multiple lines. + *args : Any + Optional formatting arguments for `short`. + details : str, optional + An optional multiline report with more details. + **log_kw : Any + """ + return self.report( + short, *args, log_level=logging.CRITICAL, details=details, **log_kw + ) + def __post_init__(self): if self.path is not None and not isinstance(self.path, Path): msg = f"expected `path` to be of type `Path`, got {type(self.path)!r}" From 9c3858f66e105f66dacc2e54500dc91acca1d01f Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Lars=20Gr=C3=BCter?= Date: Sat, 1 Nov 2025 09:20:43 +0100 Subject: [PATCH 05/22] Use "--workers" instead of "--jobs" The former is what scikit-image agreed on. --- docs/command_line.md | 3 +++ src/docstub/_cli.py | 13 +++++++------ 2 files changed, 10 insertions(+), 6 deletions(-) diff --git a/docs/command_line.md b/docs/command_line.md index 1e31c34..e5719e6 100644 --- a/docs/command_line.md +++ b/docs/command_line.md @@ -67,6 +67,9 @@ Options: -W, --fail-on-warning Return non-zero exit code when a warning is raised. Will add to --allow-errors. + --workers INT + Set the number of workers to process files in parallel. By default + docstub will attempt to choose an appropriate number. [x>=1] --no-cache Ignore pre-existing cache and don't create a new one. -v, --verbose diff --git a/src/docstub/_cli.py b/src/docstub/_cli.py index 4faed86..815f2e9 100644 --- a/src/docstub/_cli.py +++ b/src/docstub/_cli.py @@ -355,11 +355,12 @@ def _transform_to_stub(source_path, stub_path, stub_transformer): "Will add to --allow-errors.", ) @click.option( - "--jobs", + "--workers", + "desired_worker_count", type=click.IntRange(min=1), metavar="INT", - help="Set the number of jobs to use in parallel. By default docstub will " - "attempt to choose an appropriate number.", + help="Set the number of workers to process files in parallel. " + "By default docstub will attempt to choose an appropriate number.", ) @click.option( "--no-cache", @@ -378,7 +379,7 @@ def run( group_errors, allow_errors, fail_on_warning, - jobs, + desired_worker_count, no_cache, verbose, quiet, @@ -399,7 +400,7 @@ def run( group_errors : bool allow_errors : int fail_on_warning : bool - jobs : int | None + desired_worker_count : int | None no_cache : bool verbose : int quiet : int @@ -474,7 +475,7 @@ def run( tasks = [(*task, stub_transformer) for task in tasks] worker_count, chunk_size = guess_concurrency_params( - task_count=len(tasks), worker_count=jobs + task_count=len(tasks), worker_count=desired_worker_count ) logger.info("Using %i parallel jobs to write %i stubs", worker_count, len(tasks)) logger.debug("Using chunk size of %i", chunk_size) From aa194291c594a21f98ff3749146c71fefeb5f985 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Lars=20Gr=C3=BCter?= Date: Sat, 1 Nov 2025 09:21:28 +0100 Subject: [PATCH 06/22] Test error and warning propagation with -qq Mark these as "slow" (because they are) --- pyproject.toml | 3 ++ tests/test_cli.py | 80 +++++++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 83 insertions(+) diff --git a/pyproject.toml b/pyproject.toml index f5ea050..243d8c9 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -124,6 +124,9 @@ xfail_strict = true filterwarnings = ["error"] log_cli_level = "info" testpaths = ["src", "tests"] +markers = [ + "slow: marks tests as slow (deselect with `-m 'not slow'`)", +] [tool.coverage] diff --git a/tests/test_cli.py b/tests/test_cli.py index d0f751a..8d74ac9 100644 --- a/tests/test_cli.py +++ b/tests/test_cli.py @@ -1,7 +1,10 @@ """Test command line interface.""" import logging +import subprocess +import sys from pathlib import Path +from textwrap import dedent import pytest from click.testing import CliRunner @@ -54,6 +57,83 @@ def test_no_cache(self, tmp_path_cwd, caplog): # Check that at least one collected file was logged as "(cached)" assert "cached" not in "\n".join(caplog.messages) + @pytest.mark.slow + @pytest.mark.parametrize("workers", [1, 2]) + def test_fail_on_warning(self, workers, tmp_path_cwd): + source_with_warning = dedent( + ''' + def foo(x: str): + """ + Parameters + ---------- + x : int + """ + ''' + ) + package = tmp_path_cwd / "src/sample_package" + package.mkdir(parents=True) + init_py = package / "__init__.py" + with init_py.open("x") as io: + io.write(source_with_warning) + + result = subprocess.run( + [ + sys.executable, + "-m", + "docstub", + "run", + "--quiet", + "--quiet", + "--fail-on-warning", + "--workers", + str(workers), + str(package), + ], + check=False, + capture_output=True, + text=True, + ) + assert result.returncode == 1 + + @pytest.mark.slow + @pytest.mark.parametrize("workers", [1, 2]) + def test_no_output_exit_code(self, workers, tmp_path_cwd): + faulty_source = dedent( + ''' + def foo(x): + """ + Parameters + ---------- + x : doctype with syntax error + """ + ''' + ) + package = tmp_path_cwd / "src/sample_package" + package.mkdir(parents=True) + init_py = package / "__init__.py" + with init_py.open("x") as io: + io.write(faulty_source) + + result = subprocess.run( + [ + sys.executable, + "-m", + "docstub", + "run", + "--quiet", + "--quiet", + "--workers", + str(workers), + str(package), + ], + check=False, + capture_output=True, + text=True, + ) + assert result.stdout == "" + assert result.stderr == "" + assert result.returncode == 1 + class Test_clean: @pytest.mark.parametrize("verbosity", [["-v"], ["--verbose"], []]) From f23745c3e93ac624d4a27200f5389db438ae32d6 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Lars=20Gr=C3=BCter?= Date: Sat, 1 Nov 2025 09:21:51 +0100 Subject: [PATCH 07/22] Add reminders to remove vendored stdlib code eventually --- REMINDERS.md | 10 ++++++++++ 1 file changed, 10 insertions(+) create mode 100644 REMINDERS.md diff --git a/REMINDERS.md b/REMINDERS.md new file mode 100644 index 0000000..b904406 --- /dev/null +++ b/REMINDERS.md @@ -0,0 +1,10 @@ +# Reminders + +## With Python >=3.13 + +Remove vendored `glob_translate` in `docstub._vendored.stdlib`. + + +## With Python >=3.14 + +Remove vendored `ProcessPoolExecutor` in `docstub._vendored.stdlib`. From 6448d46368fe8d03f04e2d7e9fb8c30f37dba344 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Lars=20Gr=C3=BCter?= Date: Sat, 1 Nov 2025 09:29:49 +0100 Subject: [PATCH 08/22] Make docstest independent of available CPUs --- src/docstub/_concurrency.py | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/src/docstub/_concurrency.py b/src/docstub/_concurrency.py index bede652..a2d891b 100644 --- a/src/docstub/_concurrency.py +++ b/src/docstub/_concurrency.py @@ -192,9 +192,11 @@ def guess_concurrency_params(*, task_count, worker_count=None): Examples -------- - >>> _, chunk_size = guess_concurrency_params(task_count=10, worker_count=8) - >>> chunk_size - 2 + >>> worker_count, chunk_size = guess_concurrency_params( + ... task_count=9, worker_count=None + ... ) + >>> (worker_count, chunk_size) + (1, 9) """ # `process_cpu_count` was added in Python 3.13 onwards cpu_count = getattr(os, "process_cpu_count", os.cpu_count)() From 217848d4fc93b25cec6d90e980b3b4e5fcf1b6d3 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Lars=20Gr=C3=BCter?= Date: Sat, 1 Nov 2025 09:44:44 +0100 Subject: [PATCH 09/22] Add tests for `guess_concurrency_params` --- tests/test_concurrency.py | 26 ++++++++++++++++++++++++++ 1 file changed, 26 insertions(+) create mode 100644 tests/test_concurrency.py diff --git a/tests/test_concurrency.py b/tests/test_concurrency.py new file mode 100644 index 0000000..dbc1cb0 --- /dev/null +++ b/tests/test_concurrency.py @@ -0,0 +1,26 @@ +import math +import os + +import pytest + +from docstub._concurrency import guess_concurrency_params + + +class Test_guess_concurrency_params: + @pytest.mark.parametrize("task_count", list(range(9))) + @pytest.mark.parametrize("cpu_count", [1, 8]) + def test_default_below_cutoff(self, task_count, cpu_count, monkeypatch): + monkeypatch.setattr(os, "cpu_count", lambda: cpu_count) + monkeypatch.setattr(os, "process_cpu_count", lambda: cpu_count) + worker_count, chunk_size = guess_concurrency_params(task_count=task_count) + assert worker_count == 1 + assert chunk_size == task_count + + @pytest.mark.parametrize("task_count", [10, 15, 50, 100, 1000]) + @pytest.mark.parametrize("cpu_count", [1, 8, 16]) + def test_default(self, task_count, cpu_count, monkeypatch): + monkeypatch.setattr(os, "cpu_count", lambda: cpu_count) + monkeypatch.setattr(os, "process_cpu_count", lambda: cpu_count) + worker_count, chunk_size = guess_concurrency_params(task_count=task_count) + assert worker_count == min(cpu_count, task_count // 3) + assert chunk_size == math.ceil(task_count / worker_count) From 623da9efc152710bf0d4de487fc9b6f6303c2b37 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Lars=20Gr=C3=BCter?= Date: Sat, 1 Nov 2025 09:53:58 +0100 Subject: [PATCH 10/22] Fix swallowing exit code for diffs --- .github/scripts/assert-unchanged.sh | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/.github/scripts/assert-unchanged.sh b/.github/scripts/assert-unchanged.sh index 4a2168d..d14e20e 100755 --- a/.github/scripts/assert-unchanged.sh +++ b/.github/scripts/assert-unchanged.sh @@ -15,8 +15,10 @@ UNTRACKED=$(git ls-files --others --exclude-standard "$CHECK_DIR") echo "$UNTRACKED" | xargs -I _ git --no-pager diff /dev/null _ || true # Display changes in tracked files and capture non-zero exit code if so -git diff --exit-code HEAD "$CHECK_DIR" || true +set +e +git diff --exit-code HEAD "$CHECK_DIR" GIT_DIFF_HEAD_EXIT_CODE=$? +set -e # Display changes in tracked files and capture exit status if [ $GIT_DIFF_HEAD_EXIT_CODE -ne 0 ] || [ -n "$UNTRACKED" ]; then From 3e3f8cebf371d90a2224fbf377d9f9cdfacdc196 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Lars=20Gr=C3=BCter?= Date: Sun, 2 Nov 2025 17:56:51 +0100 Subject: [PATCH 11/22] Improve exit function of LoggingProcessExecutor --- src/docstub/_concurrency.py | 29 ++++++++++++++++++++--------- 1 file changed, 20 insertions(+), 9 deletions(-) diff --git a/src/docstub/_concurrency.py b/src/docstub/_concurrency.py index a2d891b..476103c 100644 --- a/src/docstub/_concurrency.py +++ b/src/docstub/_concurrency.py @@ -151,24 +151,35 @@ def __exit__(self, exc_type, exc_val, exc_tb): return False if exc_type and issubclass(exc_type, KeyboardInterrupt): - # We want to exit immediately when user requests a KeyboardInterrupt - # even if we may lose logging records that way + # We want to exit immediately when user interrupts, even if we lose + # log records that way. logger.debug("Terminating workers") self._pool.terminate_workers() - # Calling `self._listener.stop()` here seems to block forever?! + + # Ensure that the queue doesn't block (not sure if necessary) + logger.debug("Calling `cancel_join_thread` on queue") + self._queue.cancel_join_thread() + else: - # Graceful shutdown + # We want to wait for any log record to reach the listener logger.debug("Shutting down pool") - self._pool.__exit__(exc_type, exc_val, exc_tb) + self._pool.shutdown(wait=True, cancel_futures=True) + logger.debug("Stopping queue listener") self._listener.stop() - logger.debug("Closing queue and joining its thread") - self._queue.close() - self._queue.join_thread() + if not self._queue.empty(): + logger.error("Expected logging queue to be empty, it is not!") - logger.debug("Exiting executor pool") + logger.debug("Closing queue and joining its thread") + self._queue.close() + self._queue.join_thread() + + self._queue = None + self._pool = None + self._listener = None + logger.debug("Exiting executor pool") return False From 00570788a2e84144e6bc96c11b7412cfe9db3122 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Lars=20Gr=C3=BCter?= Date: Sun, 2 Nov 2025 17:57:27 +0100 Subject: [PATCH 12/22] Fix stubtest errors --- src/docstub-stubs/_cli.pyi | 4 ++-- src/docstub-stubs/_concurrency.pyi | 6 +++-- src/docstub-stubs/_report.pyi | 4 ++-- src/docstub-stubs/_vendored/stdlib.pyi | 2 +- src/docstub/_cli.py | 32 +++++++++++++++----------- src/docstub/_concurrency.py | 14 +++-------- src/docstub/_report.py | 12 +++++++++- src/docstub/_vendored/stdlib.py | 2 +- 8 files changed, 42 insertions(+), 34 deletions(-) diff --git a/src/docstub-stubs/_cli.pyi b/src/docstub-stubs/_cli.pyi index 05173f8..ef9ecfc 100644 --- a/src/docstub-stubs/_cli.pyi +++ b/src/docstub-stubs/_cli.pyi @@ -48,7 +48,7 @@ click.Context.formatter_class = HelpFormatter def cli() -> None: ... def _add_verbosity_options(func: Callable) -> Callable: ... def _transform_to_stub( - source_path: Path, stub_path: Path, stub_transformer: Py2StubTransformer + task: tuple[Path, Path, Py2StubTransformer], ) -> dict[str, int | list[str]]: ... @cli.command() def run( @@ -60,7 +60,7 @@ def run( group_errors: bool, allow_errors: int, fail_on_warning: bool, - jobs: int | None, + desired_worker_count: int | None, no_cache: bool, verbose: int, quiet: int, diff --git a/src/docstub-stubs/_concurrency.pyi b/src/docstub-stubs/_concurrency.pyi index ff14ee7..70fc6c8 100644 --- a/src/docstub-stubs/_concurrency.pyi +++ b/src/docstub-stubs/_concurrency.pyi @@ -5,7 +5,7 @@ import logging.handlers import math import multiprocessing import os -from collections.abc import Callable, Iterable +from collections.abc import Callable, Iterator from concurrent.futures import Executor from dataclasses import dataclass from multiprocessing import Queue @@ -17,7 +17,9 @@ from ._vendored.stdlib import ProcessPoolExecutor logger: logging.Logger class MockPoolExecutor(Executor): - def map(self, fn: Callable, *iterables: Any, **__: Any) -> Iterable: ... + def map[_T]( + self, fn: Callable[..., _T], *iterables: Any, **__: Any + ) -> Iterator[_T]: ... @dataclass(kw_only=True) class LoggingProcessExecutor: diff --git a/src/docstub-stubs/_report.pyi b/src/docstub-stubs/_report.pyi index 266b165..1515c66 100644 --- a/src/docstub-stubs/_report.pyi +++ b/src/docstub-stubs/_report.pyi @@ -65,7 +65,7 @@ class ReportHandler(logging.StreamHandler): self, stream: TextIO | None = ..., group_errors: bool = ... ) -> None: ... def format(self, record: logging.LogRecord) -> str: ... - def handle(self, record: logging.LogRecord) -> None: ... + def handle(self, record: logging.LogRecord) -> bool: ... def emit_grouped(self) -> None: ... class LogCounter(logging.NullHandler): @@ -74,7 +74,7 @@ class LogCounter(logging.NullHandler): warning_count: int def __init__(self) -> None: ... - def handle(self, record: logging.Record) -> None: ... + def handle(self, record: logging.LogRecord) -> bool: ... def setup_logging( *, verbosity: Literal[-2, -1, 0, 1, 2, 3], group_errors: bool diff --git a/src/docstub-stubs/_vendored/stdlib.pyi b/src/docstub-stubs/_vendored/stdlib.pyi index d589113..d07831d 100644 --- a/src/docstub-stubs/_vendored/stdlib.pyi +++ b/src/docstub-stubs/_vendored/stdlib.pyi @@ -26,4 +26,4 @@ if not hasattr(_ProcessPoolExecutor, "terminate_workers"): def kill_workers(self) -> None: ... else: - ProcessPoolExecutor: _ProcessPoolExecutor + ProcessPoolExecutor: _ProcessPoolExecutor # type: ignore[no-redef] diff --git a/src/docstub/_cli.py b/src/docstub/_cli.py index 815f2e9..c66f707 100644 --- a/src/docstub/_cli.py +++ b/src/docstub/_cli.py @@ -257,19 +257,22 @@ def _add_verbosity_options(func): return func -def _transform_to_stub(source_path, stub_path, stub_transformer): +def _transform_to_stub(task): """Transform a Python file into a stub file. Parameters ---------- - source_path : Path - stub_path : Path - stub_transformer : Py2StubTransformer + task : tuple[Path, Path, Py2StubTransformer] + The `source_path` for which to create a stub file at `stub_path` with + the given transformer. Returns ------- stats : dict of {str: int or list[str]} + Statistics about the transformation. """ + source_path, stub_path, stub_transformer = task + if source_path.suffix.lower() == ".pyi": logger.debug("Using existing stub file %s", source_path) with source_path.open() as fo: @@ -294,6 +297,7 @@ def _transform_to_stub(source_path, stub_path, stub_transformer): fo.write(stub_content) stats = stub_transformer.collect_stats() + return stats @@ -463,32 +467,32 @@ def run( # Stub generation --------------------------------------------------------- - tasks = walk_source_and_targets(root_path, out_dir, ignore=config.ignore_files) + task_files = walk_source_and_targets(root_path, out_dir, ignore=config.ignore_files) # We must pass the `stub_transformer` to each worker, but we want to copy # only once per worker. Testing suggests, that using a large enough - # `chunksize` of `>= len(tasks) / jobs` for `ProcessPoolExecutor.map`, + # `chunksize` of `>= len(task_count) / jobs` for `ProcessPoolExecutor.map`, # ensures that. # Using an `initializer` that assigns the transformer as a global variable # per worker seems like the more robust solution, but naive timing suggests # it's actually slower (> 1s on skimage). - tasks = [(*task, stub_transformer) for task in tasks] + task_args = [(*files, stub_transformer) for files in task_files] + task_count = len(task_args) worker_count, chunk_size = guess_concurrency_params( - task_count=len(tasks), worker_count=desired_worker_count + task_count=task_count, worker_count=desired_worker_count ) - logger.info("Using %i parallel jobs to write %i stubs", worker_count, len(tasks)) + + logger.info("Using %i parallel jobs to write %i stubs", worker_count, task_count) logger.debug("Using chunk size of %i", chunk_size) with LoggingProcessExecutor( max_workers=worker_count, logging_handlers=(output_handler, error_counter), ) as executor: - # Doesn't block - stats = executor.map( - _transform_to_stub, *zip(*tasks, strict=False), chunksize=chunk_size + stats_per_task = executor.map( + _transform_to_stub, task_args, chunksize=chunk_size ) - # Iterate results (which blocks) until all tasks have been processed - stats = update_with_add_values(*stats) + stats = update_with_add_values(*stats_per_task) py_typed_out = out_dir / "py.typed" if not py_typed_out.exists(): diff --git a/src/docstub/_concurrency.py b/src/docstub/_concurrency.py index 476103c..f7d3c62 100644 --- a/src/docstub/_concurrency.py +++ b/src/docstub/_concurrency.py @@ -5,7 +5,7 @@ import math import multiprocessing import os -from collections.abc import Callable +from collections.abc import Callable, Iterator from concurrent.futures import Executor from dataclasses import dataclass from multiprocessing import Queue @@ -22,7 +22,7 @@ class MockPoolExecutor(Executor): :class:`concurrent.futures.Executor`. """ - def map(self, fn, *iterables, **__): + def map[T](self, fn: Callable[..., T], *iterables, **__) -> Iterator[T]: """Returns an iterator equivalent to map(fn, iter). Same behavior as :ref:`concurrent.futures.Executor.map` though any @@ -30,17 +30,10 @@ def map(self, fn, *iterables, **__): Parameters ---------- - fn : Callable *iterables : Any **__ : Any - - Returns - ------- - results : Iterable """ - tasks = zip(*iterables, strict=False) - for task in tasks: - yield fn(*task) + return map(fn, *iterables) @dataclass(kw_only=True) @@ -136,7 +129,6 @@ def __enter__(self) -> ProcessPoolExecutor | MockPoolExecutor: def __exit__(self, exc_type, exc_val, exc_tb): """ - Parameters ---------- exc_type : type[BaseException] or None diff --git a/src/docstub/_report.py b/src/docstub/_report.py index fd0592c..86957c2 100644 --- a/src/docstub/_report.py +++ b/src/docstub/_report.py @@ -312,11 +312,16 @@ def handle(self, record): Parameters ---------- record : logging.LogRecord + + Returns + ------- + out : bool """ if self.group_errors and logging.WARNING <= record.levelno <= logging.ERROR: self._records.append(record) else: self.emit(record) + return True def emit_grouped(self): """Emit all saved log records in groups. @@ -369,7 +374,11 @@ def handle(self, record): Parameters ---------- - record : logging.Record + record : logging.LogRecord + + Returns + ------- + out : bool """ if record.levelno >= logging.CRITICAL: self.critical_count += 1 @@ -377,6 +386,7 @@ def handle(self, record): self.error_count += 1 elif record.levelno >= logging.WARNING: self.warning_count += 1 + return True def setup_logging(*, verbosity, group_errors): diff --git a/src/docstub/_vendored/stdlib.py b/src/docstub/_vendored/stdlib.py index 4f4abc3..7d4e6c9 100644 --- a/src/docstub/_vendored/stdlib.py +++ b/src/docstub/_vendored/stdlib.py @@ -228,4 +228,4 @@ def kill_workers(self) -> None: return self._force_shutdown(operation=_KILL) else: - ProcessPoolExecutor: _ProcessPoolExecutor = _ProcessPoolExecutor + ProcessPoolExecutor: _ProcessPoolExecutor = _ProcessPoolExecutor # type: ignore[no-redef] From d24c161ba252783a119c9e56571f8aff522a7b6e Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Lars=20Gr=C3=BCter?= Date: Sun, 2 Nov 2025 18:08:46 +0100 Subject: [PATCH 13/22] Update example_pkg-stubs --- examples/example_pkg-stubs/_basic.pyi | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/examples/example_pkg-stubs/_basic.pyi b/examples/example_pkg-stubs/_basic.pyi index 8fb3e9b..500d3d6 100644 --- a/examples/example_pkg-stubs/_basic.pyi +++ b/examples/example_pkg-stubs/_basic.pyi @@ -17,7 +17,7 @@ __all__ = [ "func_empty", ] -def func_empty(a1: Incomplete, a2: Incomplete, a3) -> None: ... +def func_empty(a1: Incomplete, a2: Incomplete, a3: Incomplete) -> None: ... def func_contains( a1: list[float], a2: dict[str, Union[int, str]], From 7ca23db65f42b62e8633d7a2926f2c23201fa145 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Lars=20Gr=C3=BCter?= Date: Sun, 2 Nov 2025 18:11:42 +0100 Subject: [PATCH 14/22] Sync _concurrency.pyi --- src/docstub-stubs/_concurrency.pyi | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/docstub-stubs/_concurrency.pyi b/src/docstub-stubs/_concurrency.pyi index 70fc6c8..ff4392b 100644 --- a/src/docstub-stubs/_concurrency.pyi +++ b/src/docstub-stubs/_concurrency.pyi @@ -17,9 +17,9 @@ from ._vendored.stdlib import ProcessPoolExecutor logger: logging.Logger class MockPoolExecutor(Executor): - def map[_T]( - self, fn: Callable[..., _T], *iterables: Any, **__: Any - ) -> Iterator[_T]: ... + def map[T]( + self, fn: Callable[..., T], *iterables: Any, **__: Any + ) -> Iterator[T]: ... @dataclass(kw_only=True) class LoggingProcessExecutor: From 0c94caf9747c74c57ff35d6128667655b7c38548 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Lars=20Gr=C3=BCter?= Date: Sun, 2 Nov 2025 22:33:54 +0100 Subject: [PATCH 15/22] Mark --worker option as experimental and default to single process This is probably safest for now. I don't want to break peoples setup the new version of docstub suddenly hangs or fails. This way they users have to opt in explicitly and will now what might have caused it. --- docs/command_line.md | 12 +++++------ src/docstub-stubs/_cli.pyi | 2 +- src/docstub-stubs/_concurrency.pyi | 2 +- src/docstub/_cli.py | 32 ++++++++++++++++-------------- src/docstub/_concurrency.py | 14 +++++++------ 5 files changed, 33 insertions(+), 29 deletions(-) diff --git a/docs/command_line.md b/docs/command_line.md index e5719e6..0b49668 100644 --- a/docs/command_line.md +++ b/docs/command_line.md @@ -49,10 +49,6 @@ Options: Set output directory explicitly. Stubs will be directly written into that directory while preserving the directory structure under PACKAGE_PATH. Otherwise, stubs are generated inplace. - --config PATH - Set one or more configuration file(s) explicitly. Otherwise, it will - look for a `pyproject.toml` or `docstub.toml` in the current - directory. --ignore GLOB Ignore files matching this glob-style pattern. Can be used multiple times. @@ -68,10 +64,14 @@ Options: Return non-zero exit code when a warning is raised. Will add to --allow-errors. --workers INT - Set the number of workers to process files in parallel. By default - docstub will attempt to choose an appropriate number. [x>=1] + Experimental: Process files in parallel with the desired number of + workers. By default, no multiprocessing is used. [default: 1] --no-cache Ignore pre-existing cache and don't create a new one. + --config PATH + Set one or more configuration file(s) explicitly. Otherwise, it will + look for a `pyproject.toml` or `docstub.toml` in the current + directory. -v, --verbose Print more details. Use once to show information messages. Use -vv to print debug messages. diff --git a/src/docstub-stubs/_cli.pyi b/src/docstub-stubs/_cli.pyi index ef9ecfc..ed6077b 100644 --- a/src/docstub-stubs/_cli.pyi +++ b/src/docstub-stubs/_cli.pyi @@ -60,7 +60,7 @@ def run( group_errors: bool, allow_errors: int, fail_on_warning: bool, - desired_worker_count: int | None, + desired_worker_count: int, no_cache: bool, verbose: int, quiet: int, diff --git a/src/docstub-stubs/_concurrency.pyi b/src/docstub-stubs/_concurrency.pyi index ff4392b..8881929 100644 --- a/src/docstub-stubs/_concurrency.pyi +++ b/src/docstub-stubs/_concurrency.pyi @@ -42,5 +42,5 @@ class LoggingProcessExecutor: ) -> bool: ... def guess_concurrency_params( - *, task_count: int, worker_count: int | None = ... + *, task_count: int, desired_worker_count: int | None = ... ) -> tuple[int, int]: ... diff --git a/src/docstub/_cli.py b/src/docstub/_cli.py index c66f707..ce1dda0 100644 --- a/src/docstub/_cli.py +++ b/src/docstub/_cli.py @@ -316,16 +316,6 @@ def _transform_to_stub(task): "structure under PACKAGE_PATH. " "Otherwise, stubs are generated inplace.", ) -@click.option( - "--config", - "config_paths", - type=click.Path(exists=True, dir_okay=False), - metavar="PATH", - multiple=True, - help="Set one or more configuration file(s) explicitly. " - "Otherwise, it will look for a `pyproject.toml` or `docstub.toml` in the " - "current directory.", -) @click.option( "--ignore", type=str, @@ -361,16 +351,28 @@ def _transform_to_stub(task): @click.option( "--workers", "desired_worker_count", - type=click.IntRange(min=1), + type=int, + default=1, metavar="INT", - help="Set the number of workers to process files in parallel. " - "By default docstub will attempt to choose an appropriate number.", + help="Experimental: Process files in parallel with the desired number of workers. " + "By default, no multiprocessing is used.", + show_default=True, ) @click.option( "--no-cache", is_flag=True, help="Ignore pre-existing cache and don't create a new one.", ) +@click.option( + "--config", + "config_paths", + type=click.Path(exists=True, dir_okay=False), + metavar="PATH", + multiple=True, + help="Set one or more configuration file(s) explicitly. " + "Otherwise, it will look for a `pyproject.toml` or `docstub.toml` in the " + "current directory.", +) @_add_verbosity_options @click.help_option("-h", "--help") @log_execution_time() @@ -404,7 +406,7 @@ def run( group_errors : bool allow_errors : int fail_on_warning : bool - desired_worker_count : int | None + desired_worker_count : int no_cache : bool verbose : int quiet : int @@ -480,7 +482,7 @@ def run( task_count = len(task_args) worker_count, chunk_size = guess_concurrency_params( - task_count=task_count, worker_count=desired_worker_count + task_count=task_count, desired_worker_count=desired_worker_count ) logger.info("Using %i parallel jobs to write %i stubs", worker_count, task_count) diff --git a/src/docstub/_concurrency.py b/src/docstub/_concurrency.py index f7d3c62..dcc5ac1 100644 --- a/src/docstub/_concurrency.py +++ b/src/docstub/_concurrency.py @@ -175,16 +175,16 @@ def __exit__(self, exc_type, exc_val, exc_tb): return False -def guess_concurrency_params(*, task_count, worker_count=None): +def guess_concurrency_params(*, task_count, desired_worker_count=None): """Estimate how tasks should be distributed to how many workers. Parameters ---------- task_count : int The number of task that need to be processed. - worker_count : int, optional + desired_worker_count : int, optional If not set, the number of workers is estimated. Set this explicitly - to force a number of workers. + to force a number of workers. Passing `-1` will also trigger estimation. Returns ------- @@ -196,15 +196,17 @@ def guess_concurrency_params(*, task_count, worker_count=None): Examples -------- >>> worker_count, chunk_size = guess_concurrency_params( - ... task_count=9, worker_count=None + ... task_count=9, desired_worker_count=None ... ) - >>> (worker_count, chunk_size) + >>> (desired_worker_count, chunk_size) (1, 9) """ + worker_count = desired_worker_count + # `process_cpu_count` was added in Python 3.13 onwards cpu_count = getattr(os, "process_cpu_count", os.cpu_count)() - if worker_count is None: + if worker_count is None or worker_count == -1: # These crude heuristics were only ever "measured" on one computer worker_count = cpu_count # Clip to `worker_count <= task_count // 3` From 33843a791342fb3b82c7801c04f3f5a5a591a77d Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Lars=20Gr=C3=BCter?= Date: Sun, 2 Nov 2025 22:36:38 +0100 Subject: [PATCH 16/22] Fix typo in docstest caused by accidental refactor --- src/docstub/_concurrency.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/docstub/_concurrency.py b/src/docstub/_concurrency.py index dcc5ac1..2189c59 100644 --- a/src/docstub/_concurrency.py +++ b/src/docstub/_concurrency.py @@ -198,7 +198,7 @@ def guess_concurrency_params(*, task_count, desired_worker_count=None): >>> worker_count, chunk_size = guess_concurrency_params( ... task_count=9, desired_worker_count=None ... ) - >>> (desired_worker_count, chunk_size) + >>> (worker_count, chunk_size) (1, 9) """ worker_count = desired_worker_count From 8812b88b471868d32e784cd4ea489e1d2213619a Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Lars=20Gr=C3=BCter?= Date: Sun, 2 Nov 2025 22:40:03 +0100 Subject: [PATCH 17/22] Test multiprocessing in CI --- .github/workflows/ci.yml | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 4ebec38..88f51ef 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -91,12 +91,17 @@ jobs: - name: Check docstub-stubs # Check that stubs for docstub are up-to-date by regenerating them - # with docstub and looking for differences. + # with docstub and looking for differences. Repeat test with + # multiprocessing enabled run: | rm -rf src/docstub-stubs python -m docstub run -v src/docstub -o src/docstub-stubs .github/scripts/assert-unchanged.sh src/docstub-stubs/ + rm -rf src/docstub-stubs + python -m docstub run -v src/docstub -o src/docstub-stubs --workers 2 + .github/scripts/assert-unchanged.sh src/docstub-stubs/ + - name: Check with mypy.stubtest run: | python -m mypy.stubtest \ From 48c3b6f2bad9ac397c95e7287e6ee2399ad8dfd1 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Lars=20Gr=C3=BCter?= Date: Sun, 2 Nov 2025 22:43:25 +0100 Subject: [PATCH 18/22] Factor CI multiprocessing step into its own step --- .github/workflows/ci.yml | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 88f51ef..895d16d 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -89,15 +89,17 @@ jobs: examples/example_pkg .github/scripts/assert-unchanged.sh examples/ - - name: Check docstub-stubs + - name: Check docstub-stubs (single process) # Check that stubs for docstub are up-to-date by regenerating them - # with docstub and looking for differences. Repeat test with - # multiprocessing enabled + # with docstub and looking for differences. run: | rm -rf src/docstub-stubs python -m docstub run -v src/docstub -o src/docstub-stubs .github/scripts/assert-unchanged.sh src/docstub-stubs/ + - name: Check docstub-stubs (multiprocess) + # Repeat test with multiprocessing enabled + run: | rm -rf src/docstub-stubs python -m docstub run -v src/docstub -o src/docstub-stubs --workers 2 .github/scripts/assert-unchanged.sh src/docstub-stubs/ From 7d4666225240a7aba20f69f0ecd4f80d410a47f2 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Lars=20Gr=C3=BCter?= Date: Sun, 2 Nov 2025 22:45:12 +0100 Subject: [PATCH 19/22] Fix test setup for monkeypatching process_cpu_count --- tests/test_concurrency.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/test_concurrency.py b/tests/test_concurrency.py index dbc1cb0..e4b65dc 100644 --- a/tests/test_concurrency.py +++ b/tests/test_concurrency.py @@ -11,7 +11,7 @@ class Test_guess_concurrency_params: @pytest.mark.parametrize("cpu_count", [1, 8]) def test_default_below_cutoff(self, task_count, cpu_count, monkeypatch): monkeypatch.setattr(os, "cpu_count", lambda: cpu_count) - monkeypatch.setattr(os, "process_cpu_count", lambda: cpu_count) + monkeypatch.setattr(os, "process_cpu_count", lambda: cpu_count, raising=False) worker_count, chunk_size = guess_concurrency_params(task_count=task_count) assert worker_count == 1 assert chunk_size == task_count @@ -20,7 +20,7 @@ def test_default_below_cutoff(self, task_count, cpu_count, monkeypatch): @pytest.mark.parametrize("cpu_count", [1, 8, 16]) def test_default(self, task_count, cpu_count, monkeypatch): monkeypatch.setattr(os, "cpu_count", lambda: cpu_count) - monkeypatch.setattr(os, "process_cpu_count", lambda: cpu_count) + monkeypatch.setattr(os, "process_cpu_count", lambda: cpu_count, raising=False) worker_count, chunk_size = guess_concurrency_params(task_count=task_count) assert worker_count == min(cpu_count, task_count // 3) assert chunk_size == math.ceil(task_count / worker_count) From fc5b2e914737755defe000b4dd0f67c678b577eb Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Lars=20Gr=C3=BCter?= Date: Sun, 2 Nov 2025 23:17:00 +0100 Subject: [PATCH 20/22] Skip subprocess tests on Linux 3.12-3.13 --- tests/test_cli.py | 21 +++++++++++++++++++++ 1 file changed, 21 insertions(+) diff --git a/tests/test_cli.py b/tests/test_cli.py index 8d74ac9..92541e0 100644 --- a/tests/test_cli.py +++ b/tests/test_cli.py @@ -1,6 +1,7 @@ """Test command line interface.""" import logging +import multiprocessing import subprocess import sys from pathlib import Path @@ -15,6 +16,24 @@ PROJECT_ROOT = Path(__file__).parent.parent +# Trying to run docstub via subprocess fails on Linux pre 3.14 with the +# following: +# +# RuntimeError: A SemLock created in a fork context is being shared with a +# process in a spawn context. This is not supported. Please use the same +# context to create multiprocessing objects and Process. +# +# I think this is because docstub defaults to the more robust method "spawn" +# to create processes. This seems incompatible with the default "fork" method +# on Linux on Python 3.12 & 3.13. Python 3.14 switched to "forkserver" which +# seems fine. I didn't manage to figure out a good way around this other than +# to skip. +skip_if_process_start_defaults_fork = pytest.mark.skipif( + multiprocessing.get_start_method() == "fork", + reason="incompatible default OS process start method", +) + + class Test_run: def test_no_cache(self, tmp_path_cwd, caplog): caplog.set_level(logging.INFO) @@ -57,6 +76,7 @@ def test_no_cache(self, tmp_path_cwd, caplog): # Check that at least one collected file was logged as "(cached)" assert "cached" not in "\n".join(caplog.messages) + @skip_if_process_start_defaults_fork @pytest.mark.slow @pytest.mark.parametrize("workers", [1, 2]) def test_fail_on_warning(self, workers, tmp_path_cwd): @@ -95,6 +115,7 @@ def foo(x: str): ) assert result.returncode == 1 + @skip_if_process_start_defaults_fork @pytest.mark.slow @pytest.mark.parametrize("workers", [1, 2]) def test_no_output_exit_code(self, workers, tmp_path_cwd): From 2bb21fa0e66efa56747ebecfdc6578fdc6e00d1d Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Lars=20Gr=C3=BCter?= Date: Sun, 2 Nov 2025 23:35:42 +0100 Subject: [PATCH 21/22] Start logging queue is in same multiprocessing context as pool This fixes the following: RuntimeError: A SemLock created in a fork context is being shared with a process in a spawn context. This is not supported. Please use the same context to create multiprocessing objects and Process. --- src/docstub/_concurrency.py | 3 +-- tests/test_cli.py | 21 --------------------- 2 files changed, 1 insertion(+), 23 deletions(-) diff --git a/src/docstub/_concurrency.py b/src/docstub/_concurrency.py index 2189c59..d60e92f 100644 --- a/src/docstub/_concurrency.py +++ b/src/docstub/_concurrency.py @@ -8,7 +8,6 @@ from collections.abc import Callable, Iterator from concurrent.futures import Executor from dataclasses import dataclass -from multiprocessing import Queue from ._vendored.stdlib import ProcessPoolExecutor @@ -102,7 +101,7 @@ def __enter__(self) -> ProcessPoolExecutor | MockPoolExecutor: # A queue, used to pass logging records from worker processes to the # current and main one. Naive testing suggests that # `multiprocessing.Queue` is faster than `multiprocessing.Manager.Queue` - self._queue = Queue() + self._queue = mp_context.Queue() # The actual pool manager that is wrapped here and returned by this # context manager diff --git a/tests/test_cli.py b/tests/test_cli.py index 92541e0..8d74ac9 100644 --- a/tests/test_cli.py +++ b/tests/test_cli.py @@ -1,7 +1,6 @@ """Test command line interface.""" import logging -import multiprocessing import subprocess import sys from pathlib import Path @@ -16,24 +15,6 @@ PROJECT_ROOT = Path(__file__).parent.parent -# Trying to run docstub via subprocess fails on Linux pre 3.14 with the -# following: -# -# RuntimeError: A SemLock created in a fork context is being shared with a -# process in a spawn context. This is not supported. Please use the same -# context to create multiprocessing objects and Process. -# -# I think this is because docstub defaults to the more robust method "spawn" -# to create processes. This seems incompatible with the default "fork" method -# on Linux on Python 3.12 & 3.13. Python 3.14 switched to "forkserver" which -# seems fine. I didn't manage to figure out a good way around this other than -# to skip. -skip_if_process_start_defaults_fork = pytest.mark.skipif( - multiprocessing.get_start_method() == "fork", - reason="incompatible default OS process start method", -) - - class Test_run: def test_no_cache(self, tmp_path_cwd, caplog): caplog.set_level(logging.INFO) @@ -76,7 +57,6 @@ def test_no_cache(self, tmp_path_cwd, caplog): # Check that at least one collected file was logged as "(cached)" assert "cached" not in "\n".join(caplog.messages) - @skip_if_process_start_defaults_fork @pytest.mark.slow @pytest.mark.parametrize("workers", [1, 2]) def test_fail_on_warning(self, workers, tmp_path_cwd): @@ -115,7 +95,6 @@ def foo(x: str): ) assert result.returncode == 1 - @skip_if_process_start_defaults_fork @pytest.mark.slow @pytest.mark.parametrize("workers", [1, 2]) def test_no_output_exit_code(self, workers, tmp_path_cwd): From 4c0daff087bebe16c98dc79f310522f9925c2dfd Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Lars=20Gr=C3=BCter?= Date: Sun, 2 Nov 2025 23:41:34 +0100 Subject: [PATCH 22/22] Fix unkown "Queue" in doctype --- src/docstub-stubs/_concurrency.pyi | 6 ++++-- src/docstub/_cli.py | 2 +- src/docstub/_concurrency.py | 2 +- 3 files changed, 6 insertions(+), 4 deletions(-) diff --git a/src/docstub-stubs/_concurrency.pyi b/src/docstub-stubs/_concurrency.pyi index 8881929..572b649 100644 --- a/src/docstub-stubs/_concurrency.pyi +++ b/src/docstub-stubs/_concurrency.pyi @@ -8,7 +8,6 @@ import os from collections.abc import Callable, Iterator from concurrent.futures import Executor from dataclasses import dataclass -from multiprocessing import Queue from types import TracebackType from typing import Any @@ -31,7 +30,10 @@ class LoggingProcessExecutor: @staticmethod def _initialize_worker( - queue: Queue, worker_log_level: int, initializer: Callable, initargs: tuple[Any] + queue: multiprocessing.Queue, + worker_log_level: int, + initializer: Callable, + initargs: tuple[Any], ) -> None: ... def __enter__(self) -> ProcessPoolExecutor | MockPoolExecutor: ... def __exit__( diff --git a/src/docstub/_cli.py b/src/docstub/_cli.py index ce1dda0..c741f8d 100644 --- a/src/docstub/_cli.py +++ b/src/docstub/_cli.py @@ -485,7 +485,7 @@ def run( task_count=task_count, desired_worker_count=desired_worker_count ) - logger.info("Using %i parallel jobs to write %i stubs", worker_count, task_count) + logger.info("Using %i worker(s) to write %i stubs", worker_count, task_count) logger.debug("Using chunk size of %i", chunk_size) with LoggingProcessExecutor( max_workers=worker_count, diff --git a/src/docstub/_concurrency.py b/src/docstub/_concurrency.py index d60e92f..5d20ca3 100644 --- a/src/docstub/_concurrency.py +++ b/src/docstub/_concurrency.py @@ -65,7 +65,7 @@ def _initialize_worker(queue, worker_log_level, initializer, initargs): Parameters ---------- - queue : Queue + queue : multiprocessing.Queue worker_log_level : int initializer : Callable initargs : tuple of Any