diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index dac51200..21fdb08c 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -21,10 +21,7 @@ repos: hooks: - id: python-check-blanket-noqa - id: python-check-mock-methods - - id: python-no-eval - exclude: expression.py - id: python-no-log-warn - - id: python-use-type-annotations - id: text-unicode-replacement-char - repo: https://github.com/asottile/reorder-python-imports rev: v3.9.0 diff --git a/docs/rtd_environment.yml b/docs/rtd_environment.yml index 9bf67458..489116b6 100644 --- a/docs/rtd_environment.yml +++ b/docs/rtd_environment.yml @@ -30,6 +30,7 @@ dependencies: - rich - sqlalchemy >=1.4.36 - tomli >=1.0.0 + - typing_extensions - pip: - ../ diff --git a/environment.yml b/environment.yml index 1c58b9c5..586062e7 100644 --- a/environment.yml +++ b/environment.yml @@ -20,6 +20,7 @@ dependencies: - rich - sqlalchemy >=1.4.36 - tomli >=1.0.0 + - typing_extensions # Misc - black diff --git a/pyproject.toml b/pyproject.toml index d38fcd9a..dc6b12a1 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -83,6 +83,7 @@ convention = "numpy" [tool.pytest.ini_options] +addopts = ["--doctest-modules"] testpaths = ["src", "tests"] markers = [ "wip: Tests that are work-in-progress.", diff --git a/setup.cfg b/setup.cfg index 700c89d5..2e009af9 100644 --- a/setup.cfg +++ b/setup.cfg @@ -40,6 +40,7 @@ install_requires = rich sqlalchemy>=1.4.36 tomli>=1.0.0 + typing-extensions python_requires = >=3.8 include_package_data = True package_dir = diff --git a/src/_pytask/_inspect.py b/src/_pytask/_inspect.py new file mode 100644 index 00000000..84413c1f --- /dev/null +++ b/src/_pytask/_inspect.py @@ -0,0 +1,140 @@ +from __future__ import annotations + +import functools +import sys +import types +from typing import Any +from typing import Callable +from typing import Mapping + + +__all__ = ["get_annotations"] + + +if sys.version_info >= (3, 10): + from inspect import get_annotations +else: + + def get_annotations( # noqa: C901, PLR0912, PLR0915 + obj: Callable[..., object] | type[Any] | types.ModuleType, + *, + globals: Mapping[str, Any] | None = None, # noqa: A002 + locals: Mapping[str, Any] | None = None, # noqa: A002 + eval_str: bool = False, + ) -> dict[str, Any]: + """Compute the annotations dict for an object. + + obj may be a callable, class, or module. + Passing in an object of any other type raises TypeError. + + Returns a dict. get_annotations() returns a new dict every time + it's called; calling it twice on the same object will return two + different but equivalent dicts. + + This function handles several details for you: + + * If eval_str is true, values of type str will + be un-stringized using eval(). This is intended + for use with stringized annotations + ("from __future__ import annotations"). + * If obj doesn't have an annotations dict, returns an + empty dict. (Functions and methods always have an + annotations dict; classes, modules, and other types of + callables may not.) + * Ignores inherited annotations on classes. If a class + doesn't have its own annotations dict, returns an empty dict. + * All accesses to object members and dict values are done + using getattr() and dict.get() for safety. + * Always, always, always returns a freshly-created dict. + + eval_str controls whether or not values of type str are replaced + with the result of calling eval() on those values: + + * If eval_str is true, eval() is called on values of type str. + * If eval_str is false (the default), values of type str are unchanged. + + globals and locals are passed in to eval(); see the documentation + for eval() for more information. If either globals or locals is + None, this function may replace that value with a context-specific + default, contingent on type(obj): + + * If obj is a module, globals defaults to obj.__dict__. + * If obj is a class, globals defaults to + sys.modules[obj.__module__].__dict__ and locals + defaults to the obj class namespace. + * If obj is a callable, globals defaults to obj.__globals__, + although if obj is a wrapped function (using + functools.update_wrapper()) it is first unwrapped. + """ + if isinstance(obj, type): + # class + obj_dict = getattr(obj, "__dict__", None) + if obj_dict and hasattr(obj_dict, "get"): + ann = obj_dict.get("__annotations__", None) + if isinstance(ann, types.GetSetDescriptorType): + ann = None + else: + ann = None + + obj_globals = None + module_name = getattr(obj, "__module__", None) + if module_name: + module = sys.modules.get(module_name, None) + if module: + obj_globals = getattr(module, "__dict__", None) + obj_locals = dict(vars(obj)) + unwrap = obj + elif isinstance(obj, types.ModuleType): + # module + ann = getattr(obj, "__annotations__", None) + obj_globals = obj.__dict__ + obj_locals = None + unwrap = None + elif callable(obj): + # this includes types.Function, types.BuiltinFunctionType, + # types.BuiltinMethodType, functools.partial, functools.singledispatch, + # "class funclike" from Lib/test/test_inspect... on and on it goes. + ann = getattr(obj, "__annotations__", None) + obj_globals = getattr(obj, "__globals__", None) + obj_locals = None + unwrap = obj + else: + raise TypeError(f"{obj!r} is not a module, class, or callable.") + + if ann is None: + return {} + + if not isinstance(ann, dict): + raise ValueError(f"{obj!r}.__annotations__ is neither a dict nor None") + + if not ann: + return {} + + if not eval_str: + return dict(ann) + + if unwrap is not None: + while True: + if hasattr(unwrap, "__wrapped__"): + unwrap = unwrap.__wrapped__ + continue + if isinstance(unwrap, functools.partial): + unwrap = unwrap.func + continue + break + if hasattr(unwrap, "__globals__"): + obj_globals = unwrap.__globals__ + + if globals is None: + globals = obj_globals # noqa: A001 + if locals is None: + locals = obj_locals # noqa: A001 + + eval_func = eval + return_value = { + key: value + if not isinstance(value, str) + else eval_func(value, globals, locals) + for key, value in ann.items() + } + return return_value diff --git a/src/_pytask/collect.py b/src/_pytask/collect.py index 8403a9e8..fc65df98 100644 --- a/src/_pytask/collect.py +++ b/src/_pytask/collect.py @@ -15,7 +15,6 @@ from _pytask.collect_utils import parse_dependencies_from_task_function from _pytask.collect_utils import parse_nodes from _pytask.collect_utils import parse_products_from_task_function -from _pytask.collect_utils import produces from _pytask.config import hookimpl from _pytask.config import IS_FILE_SYSTEM_CASE_SENSITIVE from _pytask.console import console @@ -178,10 +177,7 @@ def pytask_collect_task( session, path, name, obj ) - if has_mark(obj, "produces"): - products = parse_nodes(session, path, name, obj, produces) - else: - products = parse_products_from_task_function(session, path, name, obj) + products = parse_products_from_task_function(session, path, name, obj) markers = obj.pytask_meta.markers if hasattr(obj, "pytask_meta") else [] diff --git a/src/_pytask/collect_utils.py b/src/_pytask/collect_utils.py index 5eb6cff9..0b179e8d 100644 --- a/src/_pytask/collect_utils.py +++ b/src/_pytask/collect_utils.py @@ -11,14 +11,19 @@ from typing import Iterable from typing import TYPE_CHECKING +from _pytask._inspect import get_annotations from _pytask.exceptions import NodeNotCollectedError +from _pytask.mark_utils import has_mark from _pytask.mark_utils import remove_marks +from _pytask.nodes import ProductType from _pytask.nodes import PythonNode from _pytask.shared import find_duplicates from _pytask.task_utils import parse_keyword_arguments_from_signature_defaults from attrs import define from attrs import field from pybaum.tree_util import tree_map +from typing_extensions import Annotated +from typing_extensions import get_origin if TYPE_CHECKING: @@ -211,8 +216,12 @@ def parse_dependencies_from_task_function( kwargs = {**signature_defaults, **task_kwargs} kwargs.pop("produces", None) + parameters_with_product_annot = _find_args_with_product_annotation(obj) + dependencies = {} for name, value in kwargs.items(): + if name in parameters_with_product_annot: + continue parsed_value = tree_map( lambda x: _collect_dependencies(session, path, name, x), value # noqa: B023 ) @@ -223,29 +232,108 @@ def parse_dependencies_from_task_function( return dependencies +_ERROR_MULTIPLE_PRODUCT_DEFINITIONS = ( + "The task uses multiple ways to define products. Products should be defined with " + "either\n\n- 'typing.Annotated[Path(...), Product]' (recommended)\n" + "- '@pytask.mark.task(kwargs={'produces': Path(...)})'\n" + "- as a default argument for 'produces': 'produces = Path(...)'\n" + "- '@pytask.mark.produces(Path(...))' (deprecated).\n\n" + "Read more about products in the documentation: https://tinyurl.com/yrezszr4." +) + + def parse_products_from_task_function( session: Session, path: Path, name: str, obj: Any ) -> dict[str, Any]: - """Parse dependencies from task function.""" + """Parse products from task function. + + Raises + ------ + NodeNotCollectedError + If multiple ways were used to specify products. + + """ + has_produces_decorator = False + has_task_decorator = False + has_signature_default = False + has_annotation = False + out = {} + + if has_mark(obj, "produces"): + has_produces_decorator = True + nodes = parse_nodes(session, path, name, obj, produces) + out = {"produces": nodes} + task_kwargs = obj.pytask_meta.kwargs if hasattr(obj, "pytask_meta") else {} if "produces" in task_kwargs: - return tree_map( + collected_products = tree_map( lambda x: _collect_product(session, path, name, x, is_string_allowed=True), task_kwargs["produces"], ) + out = {"produces": collected_products} parameters = inspect.signature(obj).parameters - if "produces" in parameters: + + if not has_mark(obj, "task") and "produces" in parameters: parameter = parameters["produces"] if parameter.default is not parameter.empty: + has_signature_default = True # Use _collect_new_node to not collect strings. - return tree_map( + collected_products = tree_map( lambda x: _collect_product( session, path, name, x, is_string_allowed=False ), parameter.default, ) - return {} + out = {"produces": collected_products} + + parameters_with_product_annot = _find_args_with_product_annotation(obj) + if parameters_with_product_annot: + has_annotation = True + for parameter_name in parameters_with_product_annot: + parameter = parameters[parameter_name] + if parameter.default is not parameter.empty: + # Use _collect_new_node to not collect strings. + collected_products = tree_map( + lambda x: _collect_product( + session, path, name, x, is_string_allowed=False + ), + parameter.default, + ) + out = {parameter_name: collected_products} + + if ( + sum( + ( + has_produces_decorator, + has_task_decorator, + has_signature_default, + has_annotation, + ) + ) + >= 2 # noqa: PLR2004 + ): + raise NodeNotCollectedError(_ERROR_MULTIPLE_PRODUCT_DEFINITIONS) + + return out + + +def _find_args_with_product_annotation(func: Callable[..., Any]) -> list[str]: + """Find args with product annotation.""" + annotations = get_annotations(func, eval_str=True) + metas = { + name: annotation.__metadata__ + for name, annotation in annotations.items() + if get_origin(annotation) is Annotated + } + + args_with_product_annot = [] + for name, meta in metas.items(): + has_product_annot = any(isinstance(i, ProductType) for i in meta) + if has_product_annot: + args_with_product_annot.append(name) + + return args_with_product_annot def _collect_old_dependencies( @@ -331,9 +419,10 @@ def _collect_product( # The parameter defaults only support Path objects. if not isinstance(node, Path) and not is_string_allowed: raise ValueError( - "If you use 'produces' as an argument of a task, it can only accept values " - "of type 'pathlib.Path' or the same value nested in " - f"tuples, lists, and dictionaries. Here, {node} has type {type(node)}." + "If you use 'produces' as a function argument of a task and pass values as " + "function defaults, it can only accept values of type 'pathlib.Path' or " + "the same value nested in tuples, lists, and dictionaries. Here, " + f"{node!r} has type {type(node)}." ) if isinstance(node, str): diff --git a/src/_pytask/execute.py b/src/_pytask/execute.py index 40e8e4b4..cce8a054 100644 --- a/src/_pytask/execute.py +++ b/src/_pytask/execute.py @@ -149,8 +149,9 @@ def pytask_execute_task(session: Session, task: Task) -> bool: for name, value in task.depends_on.items(): kwargs[name] = tree_map(lambda x: x.value, value) - if task.produces and "produces" in parameters: - kwargs["produces"] = tree_map(lambda x: x.value, task.produces) + for name, value in task.produces.items(): + if name in parameters: + kwargs[name] = tree_map(lambda x: x.value, value) task.execute(**kwargs) return True diff --git a/src/_pytask/nodes.py b/src/_pytask/nodes.py index 47f7b9f1..1af06904 100644 --- a/src/_pytask/nodes.py +++ b/src/_pytask/nodes.py @@ -17,7 +17,15 @@ from _pytask.mark import Mark -__all__ = ["FilePathNode", "MetaNode", "Task"] +__all__ = ["FilePathNode", "MetaNode", "Product", "Task"] + + +@define(frozen=True) +class ProductType: + """A class to mark products.""" + + +Product = ProductType() class MetaNode(metaclass=ABCMeta): diff --git a/src/_pytask/traceback.py b/src/_pytask/traceback.py index 38fdeb88..cc14c181 100644 --- a/src/_pytask/traceback.py +++ b/src/_pytask/traceback.py @@ -10,6 +10,7 @@ import _pytask import pluggy +import pybaum from rich.traceback import Traceback @@ -22,6 +23,7 @@ _PLUGGY_DIRECTORY = Path(pluggy.__file__).parent +_PYBAUM_DIRECTORY = Path(pybaum.__file__).parent _PYTASK_DIRECTORY = Path(_pytask.__file__).parent @@ -89,7 +91,10 @@ def _is_internal_or_hidden_traceback_frame( return True path = Path(frame.tb_frame.f_code.co_filename) - return any(root in path.parents for root in (_PLUGGY_DIRECTORY, _PYTASK_DIRECTORY)) + return any( + root in path.parents + for root in (_PLUGGY_DIRECTORY, _PYBAUM_DIRECTORY, _PYTASK_DIRECTORY) + ) def _filter_internal_traceback_frames( diff --git a/src/pytask/__init__.py b/src/pytask/__init__.py index f589e91a..d84ec0f1 100644 --- a/src/pytask/__init__.py +++ b/src/pytask/__init__.py @@ -37,6 +37,7 @@ from _pytask.models import CollectionMetadata from _pytask.nodes import FilePathNode from _pytask.nodes import MetaNode +from _pytask.nodes import Product from _pytask.nodes import Task from _pytask.outcomes import CollectionOutcome from _pytask.outcomes import count_outcomes @@ -89,6 +90,7 @@ "NodeNotCollectedError", "NodeNotFoundError", "Persisted", + "Product", "PytaskError", "ResolvingDependenciesError", "Runtime", diff --git a/tests/test_collect.py b/tests/test_collect.py index 95c2ef22..6728e16e 100644 --- a/tests/test_collect.py +++ b/tests/test_collect.py @@ -249,6 +249,7 @@ def test_find_shortest_uniquely_identifiable_names_for_tasks(tmp_path): assert result == expected +@pytest.mark.end_to_end() def test_collect_dependencies_from_args_if_depends_on_is_missing(tmp_path): source = """ from pathlib import Path @@ -306,6 +307,21 @@ def task_my_task(): assert outcome == CollectionOutcome.SUCCESS +@pytest.mark.end_to_end() +def test_collect_string_product_with_task_decorator(tmp_path): + source = """ + import pytask + + @pytask.mark.task + def task_write_text(produces="out.txt"): + produces.touch() + """ + tmp_path.joinpath("task_module.py").write_text(textwrap.dedent(source)) + session = main({"paths": tmp_path}) + assert session.exit_code == ExitCode.OK + assert tmp_path.joinpath("out.txt").exists() + + @pytest.mark.end_to_end() def test_collect_string_product_as_function_default_fails(tmp_path): source = """ @@ -319,3 +335,27 @@ def task_write_text(produces="out.txt"): report = session.collection_reports[0] assert report.outcome == CollectionOutcome.FAIL assert "If you use 'produces'" in str(report.exc_info[1]) + + +@pytest.mark.end_to_end() +def test_product_cannot_mix_different_product_types(tmp_path): + source = """ + import pytask + from typing_extensions import Annotated + from pytask import Product + from pathlib import Path + + @pytask.mark.produces("out_deco.txt") + def task_example( + path: Annotated[Path, Product], produces: Path = Path("out_sig.txt") + ): + ... + """ + tmp_path.joinpath("task_module.py").write_text(textwrap.dedent(source)) + session = main({"paths": tmp_path}) + + assert session.exit_code == ExitCode.COLLECTION_FAILED + assert len(session.tasks) == 0 + report = session.collection_reports[0] + assert report.outcome == CollectionOutcome.FAIL + assert "The task uses multiple ways" in str(report.exc_info[1]) diff --git a/tests/test_collect_utils.py b/tests/test_collect_utils.py index 89b995e0..e4dbc4c1 100644 --- a/tests/test_collect_utils.py +++ b/tests/test_collect_utils.py @@ -9,10 +9,13 @@ from _pytask.collect_utils import _convert_objects_to_node_dictionary from _pytask.collect_utils import _convert_to_dict from _pytask.collect_utils import _extract_nodes_from_function_markers +from _pytask.collect_utils import _find_args_with_product_annotation from _pytask.collect_utils import _merge_dictionaries from _pytask.collect_utils import _Placeholder from pytask import depends_on from pytask import produces +from pytask import Product +from typing_extensions import Annotated ERROR = "'@pytask.mark.depends_on' has nodes with the same name:" @@ -159,3 +162,12 @@ def task_example(): # pragma: no cover parser = depends_on if decorator.name == "depends_on" else produces with pytest.raises(TypeError): list(_extract_nodes_from_function_markers(task_example, parser)) + + +@pytest.mark.unit() +def test_find_args_with_product_annotation(): + def func(a: Annotated[int, Product], b: float, c, d: Annotated[int, float]): + return a, b, c, d + + result = _find_args_with_product_annotation(func) + assert result == ["a"] diff --git a/tests/test_execute.py b/tests/test_execute.py index 9e9e5faa..4c0f2a7c 100644 --- a/tests/test_execute.py +++ b/tests/test_execute.py @@ -8,6 +8,7 @@ from pathlib import Path import pytest +from _pytask.capture import _CaptureMethod from _pytask.exceptions import NodeNotFoundError from pytask import cli from pytask import ExitCode @@ -385,6 +386,7 @@ def task_example(): assert "Collected 1 task" in result.output +@pytest.mark.end_to_end() def test_task_executed_with_force_although_unchanged(tmp_path): tmp_path.joinpath("task_module.py").write_text("def task_example(): pass") session = main({"paths": tmp_path}) @@ -419,3 +421,46 @@ def test_task_is_not_reexecuted_when_modification_changed_file_not(runner, tmp_p result = runner.invoke(cli, [tmp_path.as_posix()]) assert result.exit_code == ExitCode.OK assert "1 Skipped" in result.output + + +@pytest.mark.end_to_end() +def test_task_with_product_annotation(tmp_path): + source = """ + from pathlib import Path + from typing_extensions import Annotated + from pytask import Product + + def task_example(path_to_file: Annotated[Path, Product] = Path("out.txt")) -> None: + path_to_file.touch() + """ + tmp_path.joinpath("task_module.py").write_text(textwrap.dedent(source)) + + session = main({"paths": tmp_path, "capture": _CaptureMethod.NO}) + + assert session.exit_code == ExitCode.OK + assert len(session.tasks) == 1 + task = session.tasks[0] + assert "path_to_file" in task.produces + + +@pytest.mark.end_to_end() +@pytest.mark.xfail(reason="Nested annotations are not parsed.", raises=AssertionError) +def test_task_with_nested_product_annotation(tmp_path): + source = """ + from pathlib import Path + from typing_extensions import Annotated + from pytask import Product + + def task_example( + paths_to_file: dict[str, Annotated[Path, Product]] = {"a": Path("out.txt")} + ) -> None: + pass + """ + tmp_path.joinpath("task_module.py").write_text(textwrap.dedent(source)) + + session = main({"paths": tmp_path, "capture": _CaptureMethod.NO}) + + assert session.exit_code == ExitCode.OK + assert len(session.tasks) == 1 + task = session.tasks[0] + assert "paths_to_file" in task.produces diff --git a/tests/test_path.py b/tests/test_path.py index 58737a47..e55be2f7 100644 --- a/tests/test_path.py +++ b/tests/test_path.py @@ -133,6 +133,7 @@ def simple_module(tmp_path: Path) -> Path: return fn +@pytest.mark.unit() def test_importmode_importlib(simple_module: Path, tmp_path: Path) -> None: """`importlib` mode does not change sys.path.""" module = import_path(simple_module, root=tmp_path) @@ -144,6 +145,7 @@ def test_importmode_importlib(simple_module: Path, tmp_path: Path) -> None: assert "_src.project" in sys.modules +@pytest.mark.unit() def test_importmode_twice_is_different_module( simple_module: Path, tmp_path: Path ) -> None: @@ -153,6 +155,7 @@ def test_importmode_twice_is_different_module( assert module1 is not module2 +@pytest.mark.unit() def test_no_meta_path_found( simple_module: Path, monkeypatch: pytest.MonkeyPatch, tmp_path: Path ) -> None: @@ -171,6 +174,7 @@ def test_no_meta_path_found( import_path(simple_module, root=tmp_path) +@pytest.mark.unit() def test_importmode_importlib_with_dataclass(tmp_path: Path) -> None: """ Ensure that importlib mode works with a module containing dataclasses (#373, @@ -197,6 +201,7 @@ class Data: assert data.__module__ == "_src.project.task_dataclass" +@pytest.mark.unit() def test_importmode_importlib_with_pickle(tmp_path: Path) -> None: """Ensure that importlib mode works with pickle (#373, pytest#7859).""" fn = tmp_path.joinpath("_src/project/task_pickle.py") @@ -222,6 +227,7 @@ def round_trip(): assert action() == 42 +@pytest.mark.unit() def test_importmode_importlib_with_pickle_separate_modules(tmp_path: Path) -> None: """ Ensure that importlib mode works can load pickles that look similar but are @@ -275,6 +281,7 @@ def round_trip(obj): assert Data2.__module__ == "_src.m2.project.task" +@pytest.mark.unit() def test_module_name_from_path(tmp_path: Path) -> None: result = _module_name_from_path(tmp_path / "src/project/task_foo.py", tmp_path) assert result == "src.project.task_foo" @@ -284,6 +291,7 @@ def test_module_name_from_path(tmp_path: Path) -> None: assert result == "home.foo.task_foo" +@pytest.mark.unit() def test_insert_missing_modules( monkeypatch: pytest.MonkeyPatch, tmp_path: Path ) -> None: diff --git a/tests/test_pybaum.py b/tests/test_pybaum.py index 66ea24e6..b6889fbd 100644 --- a/tests/test_pybaum.py +++ b/tests/test_pybaum.py @@ -50,8 +50,7 @@ def task_example(): 2: {0: tmp_path / "list_out.txt"}, 3: {"a": tmp_path / "dict_out.txt", "b": {"c": tmp_path / "dict_out_2.txt"}}, } - if decorator_name == "depends_on": - expected = {"depends_on": expected} + expected = {decorator_name: expected} assert products == expected diff --git a/tests/test_task.py b/tests/test_task.py index bf4c16a3..ccbadfae 100644 --- a/tests/test_task.py +++ b/tests/test_task.py @@ -440,6 +440,7 @@ def task_func(i=i): assert isinstance(session.collection_reports[0].exc_info[1], ValueError) +@pytest.mark.end_to_end() def test_task_receives_unknown_kwarg(runner, tmp_path): source = """ import pytask