| 
2 | 2 | 
 
  | 
3 | 3 | from __future__ import annotations  | 
4 | 4 | 
 
  | 
5 |  | -import enum  | 
6 | 5 | import os  | 
7 | 6 | from typing import Any  | 
8 | 7 | 
 
  | 
9 | 8 | from pytask import hookimpl  | 
10 | 9 | 
 
  | 
 | 10 | +from pytask_parallel import execute  | 
 | 11 | +from pytask_parallel import processes  | 
 | 12 | +from pytask_parallel import threads  | 
11 | 13 | from pytask_parallel.backends import ParallelBackend  | 
12 | 14 | 
 
  | 
13 | 15 | 
 
  | 
14 | 16 | @hookimpl  | 
15 | 17 | def pytask_parse_config(config: dict[str, Any]) -> None:  | 
16 | 18 |     """Parse the configuration."""  | 
 | 19 | +    __tracebackhide__ = True  | 
 | 20 | + | 
17 | 21 |     if config["n_workers"] == "auto":  | 
18 | 22 |         config["n_workers"] = max(os.cpu_count() - 1, 1)  | 
19 | 23 | 
 
  | 
20 |  | -    if (  | 
21 |  | -        isinstance(config["parallel_backend"], str)  | 
22 |  | -        and config["parallel_backend"] in ParallelBackend._value2member_map_  # noqa: SLF001  | 
23 |  | -    ):  | 
 | 24 | +    try:  | 
24 | 25 |         config["parallel_backend"] = ParallelBackend(config["parallel_backend"])  | 
25 |  | -    elif (  | 
26 |  | -        isinstance(config["parallel_backend"], enum.Enum)  | 
27 |  | -        and config["parallel_backend"] in ParallelBackend  | 
28 |  | -    ):  | 
29 |  | -        pass  | 
30 |  | -    else:  | 
31 |  | -        msg = f"Invalid value for 'parallel_backend'. Got {config['parallel_backend']}."  | 
32 |  | -        raise ValueError(msg)  | 
 | 26 | +    except ValueError:  | 
 | 27 | +        msg = (  | 
 | 28 | +            f"Invalid value for 'parallel_backend'. Got {config['parallel_backend']}. "  | 
 | 29 | +            f"Choose one of {', '.join([e.value for e in ParallelBackend])}."  | 
 | 30 | +        )  | 
 | 31 | +        raise ValueError(msg) from None  | 
33 | 32 | 
 
  | 
34 | 33 |     config["delay"] = 0.1  | 
35 | 34 | 
 
  | 
36 | 35 | 
 
  | 
37 | 36 | @hookimpl  | 
38 | 37 | def pytask_post_parse(config: dict[str, Any]) -> None:  | 
39 |  | -    """Disable parallelization if debugging is enabled."""  | 
 | 38 | +    """Register the parallel backend if debugging is not enabled."""  | 
40 | 39 |     if config["pdb"] or config["trace"] or config["dry_run"]:  | 
41 | 40 |         config["n_workers"] = 1  | 
 | 41 | + | 
 | 42 | +    if config["n_workers"] > 1:  | 
 | 43 | +        config["pm"].register(execute)  | 
 | 44 | +        if config["parallel_backend"] == ParallelBackend.THREADS:  | 
 | 45 | +            config["pm"].register(threads)  | 
 | 46 | +        else:  | 
 | 47 | +            config["pm"].register(processes)  | 
0 commit comments