From 70c55bec8e97736e54fb64876f2180382ba79694 Mon Sep 17 00:00:00 2001 From: Tobias Raabe Date: Sat, 8 Jul 2023 21:39:51 +0200 Subject: [PATCH] Fix error messages and categorize tests. --- pyproject.toml | 1 + src/_pytask/collect.py | 7 +--- src/_pytask/collect_utils.py | 65 +++++++++++++++++++++++++++++++----- src/_pytask/traceback.py | 7 +++- tests/test_collect.py | 40 ++++++++++++++++++++++ tests/test_collect_utils.py | 1 + tests/test_execute.py | 1 + tests/test_path.py | 8 +++++ tests/test_task.py | 1 + 9 files changed, 115 insertions(+), 16 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index d38fcd9a..dc6b12a1 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -83,6 +83,7 @@ convention = "numpy" [tool.pytest.ini_options] +addopts = ["--doctest-modules"] testpaths = ["src", "tests"] markers = [ "wip: Tests that are work-in-progress.", diff --git a/src/_pytask/collect.py b/src/_pytask/collect.py index a5197e5c..fc65df98 100644 --- a/src/_pytask/collect.py +++ b/src/_pytask/collect.py @@ -15,7 +15,6 @@ from _pytask.collect_utils import parse_dependencies_from_task_function from _pytask.collect_utils import parse_nodes from _pytask.collect_utils import parse_products_from_task_function -from _pytask.collect_utils import produces from _pytask.config import hookimpl from _pytask.config import IS_FILE_SYSTEM_CASE_SENSITIVE from _pytask.console import console @@ -178,11 +177,7 @@ def pytask_collect_task( session, path, name, obj ) - if has_mark(obj, "produces"): - nodes = parse_nodes(session, path, name, obj, produces) - products = {"produces": nodes} - else: - products = parse_products_from_task_function(session, path, name, obj) + products = parse_products_from_task_function(session, path, name, obj) markers = obj.pytask_meta.markers if hasattr(obj, "pytask_meta") else [] diff --git a/src/_pytask/collect_utils.py b/src/_pytask/collect_utils.py index c7feb1c1..0b179e8d 100644 --- a/src/_pytask/collect_utils.py +++ b/src/_pytask/collect_utils.py @@ -13,6 +13,7 @@ from _pytask._inspect import get_annotations from _pytask.exceptions import NodeNotCollectedError +from _pytask.mark_utils import has_mark from _pytask.mark_utils import remove_marks from _pytask.nodes import ProductType from _pytask.nodes import PythonNode @@ -231,22 +232,52 @@ def parse_dependencies_from_task_function( return dependencies +_ERROR_MULTIPLE_PRODUCT_DEFINITIONS = ( + "The task uses multiple ways to define products. Products should be defined with " + "either\n\n- 'typing.Annotated[Path(...), Product]' (recommended)\n" + "- '@pytask.mark.task(kwargs={'produces': Path(...)})'\n" + "- as a default argument for 'produces': 'produces = Path(...)'\n" + "- '@pytask.mark.produces(Path(...))' (deprecated).\n\n" + "Read more about products in the documentation: https://tinyurl.com/yrezszr4." +) + + def parse_products_from_task_function( session: Session, path: Path, name: str, obj: Any ) -> dict[str, Any]: - """Parse products from task function.""" + """Parse products from task function. + + Raises + ------ + NodeNotCollectedError + If multiple ways were used to specify products. + + """ + has_produces_decorator = False + has_task_decorator = False + has_signature_default = False + has_annotation = False + out = {} + + if has_mark(obj, "produces"): + has_produces_decorator = True + nodes = parse_nodes(session, path, name, obj, produces) + out = {"produces": nodes} + task_kwargs = obj.pytask_meta.kwargs if hasattr(obj, "pytask_meta") else {} if "produces" in task_kwargs: collected_products = tree_map( lambda x: _collect_product(session, path, name, x, is_string_allowed=True), task_kwargs["produces"], ) - return {"produces": collected_products} + out = {"produces": collected_products} parameters = inspect.signature(obj).parameters - if "produces" in parameters: + + if not has_mark(obj, "task") and "produces" in parameters: parameter = parameters["produces"] if parameter.default is not parameter.empty: + has_signature_default = True # Use _collect_new_node to not collect strings. collected_products = tree_map( lambda x: _collect_product( @@ -254,10 +285,11 @@ def parse_products_from_task_function( ), parameter.default, ) - return {"produces": collected_products} + out = {"produces": collected_products} parameters_with_product_annot = _find_args_with_product_annotation(obj) if parameters_with_product_annot: + has_annotation = True for parameter_name in parameters_with_product_annot: parameter = parameters[parameter_name] if parameter.default is not parameter.empty: @@ -268,8 +300,22 @@ def parse_products_from_task_function( ), parameter.default, ) - return {parameter_name: collected_products} - return {} + out = {parameter_name: collected_products} + + if ( + sum( + ( + has_produces_decorator, + has_task_decorator, + has_signature_default, + has_annotation, + ) + ) + >= 2 # noqa: PLR2004 + ): + raise NodeNotCollectedError(_ERROR_MULTIPLE_PRODUCT_DEFINITIONS) + + return out def _find_args_with_product_annotation(func: Callable[..., Any]) -> list[str]: @@ -373,9 +419,10 @@ def _collect_product( # The parameter defaults only support Path objects. if not isinstance(node, Path) and not is_string_allowed: raise ValueError( - "If you use 'produces' as an argument of a task, it can only accept values " - "of type 'pathlib.Path' or the same value nested in " - f"tuples, lists, and dictionaries. Here, {node} has type {type(node)}." + "If you use 'produces' as a function argument of a task and pass values as " + "function defaults, it can only accept values of type 'pathlib.Path' or " + "the same value nested in tuples, lists, and dictionaries. Here, " + f"{node!r} has type {type(node)}." ) if isinstance(node, str): diff --git a/src/_pytask/traceback.py b/src/_pytask/traceback.py index 38fdeb88..cc14c181 100644 --- a/src/_pytask/traceback.py +++ b/src/_pytask/traceback.py @@ -10,6 +10,7 @@ import _pytask import pluggy +import pybaum from rich.traceback import Traceback @@ -22,6 +23,7 @@ _PLUGGY_DIRECTORY = Path(pluggy.__file__).parent +_PYBAUM_DIRECTORY = Path(pybaum.__file__).parent _PYTASK_DIRECTORY = Path(_pytask.__file__).parent @@ -89,7 +91,10 @@ def _is_internal_or_hidden_traceback_frame( return True path = Path(frame.tb_frame.f_code.co_filename) - return any(root in path.parents for root in (_PLUGGY_DIRECTORY, _PYTASK_DIRECTORY)) + return any( + root in path.parents + for root in (_PLUGGY_DIRECTORY, _PYBAUM_DIRECTORY, _PYTASK_DIRECTORY) + ) def _filter_internal_traceback_frames( diff --git a/tests/test_collect.py b/tests/test_collect.py index 95c2ef22..6728e16e 100644 --- a/tests/test_collect.py +++ b/tests/test_collect.py @@ -249,6 +249,7 @@ def test_find_shortest_uniquely_identifiable_names_for_tasks(tmp_path): assert result == expected +@pytest.mark.end_to_end() def test_collect_dependencies_from_args_if_depends_on_is_missing(tmp_path): source = """ from pathlib import Path @@ -306,6 +307,21 @@ def task_my_task(): assert outcome == CollectionOutcome.SUCCESS +@pytest.mark.end_to_end() +def test_collect_string_product_with_task_decorator(tmp_path): + source = """ + import pytask + + @pytask.mark.task + def task_write_text(produces="out.txt"): + produces.touch() + """ + tmp_path.joinpath("task_module.py").write_text(textwrap.dedent(source)) + session = main({"paths": tmp_path}) + assert session.exit_code == ExitCode.OK + assert tmp_path.joinpath("out.txt").exists() + + @pytest.mark.end_to_end() def test_collect_string_product_as_function_default_fails(tmp_path): source = """ @@ -319,3 +335,27 @@ def task_write_text(produces="out.txt"): report = session.collection_reports[0] assert report.outcome == CollectionOutcome.FAIL assert "If you use 'produces'" in str(report.exc_info[1]) + + +@pytest.mark.end_to_end() +def test_product_cannot_mix_different_product_types(tmp_path): + source = """ + import pytask + from typing_extensions import Annotated + from pytask import Product + from pathlib import Path + + @pytask.mark.produces("out_deco.txt") + def task_example( + path: Annotated[Path, Product], produces: Path = Path("out_sig.txt") + ): + ... + """ + tmp_path.joinpath("task_module.py").write_text(textwrap.dedent(source)) + session = main({"paths": tmp_path}) + + assert session.exit_code == ExitCode.COLLECTION_FAILED + assert len(session.tasks) == 0 + report = session.collection_reports[0] + assert report.outcome == CollectionOutcome.FAIL + assert "The task uses multiple ways" in str(report.exc_info[1]) diff --git a/tests/test_collect_utils.py b/tests/test_collect_utils.py index d2c0af8d..e4dbc4c1 100644 --- a/tests/test_collect_utils.py +++ b/tests/test_collect_utils.py @@ -164,6 +164,7 @@ def task_example(): # pragma: no cover list(_extract_nodes_from_function_markers(task_example, parser)) +@pytest.mark.unit() def test_find_args_with_product_annotation(): def func(a: Annotated[int, Product], b: float, c, d: Annotated[int, float]): return a, b, c, d diff --git a/tests/test_execute.py b/tests/test_execute.py index 2f8b89c7..4c0f2a7c 100644 --- a/tests/test_execute.py +++ b/tests/test_execute.py @@ -386,6 +386,7 @@ def task_example(): assert "Collected 1 task" in result.output +@pytest.mark.end_to_end() def test_task_executed_with_force_although_unchanged(tmp_path): tmp_path.joinpath("task_module.py").write_text("def task_example(): pass") session = main({"paths": tmp_path}) diff --git a/tests/test_path.py b/tests/test_path.py index 58737a47..e55be2f7 100644 --- a/tests/test_path.py +++ b/tests/test_path.py @@ -133,6 +133,7 @@ def simple_module(tmp_path: Path) -> Path: return fn +@pytest.mark.unit() def test_importmode_importlib(simple_module: Path, tmp_path: Path) -> None: """`importlib` mode does not change sys.path.""" module = import_path(simple_module, root=tmp_path) @@ -144,6 +145,7 @@ def test_importmode_importlib(simple_module: Path, tmp_path: Path) -> None: assert "_src.project" in sys.modules +@pytest.mark.unit() def test_importmode_twice_is_different_module( simple_module: Path, tmp_path: Path ) -> None: @@ -153,6 +155,7 @@ def test_importmode_twice_is_different_module( assert module1 is not module2 +@pytest.mark.unit() def test_no_meta_path_found( simple_module: Path, monkeypatch: pytest.MonkeyPatch, tmp_path: Path ) -> None: @@ -171,6 +174,7 @@ def test_no_meta_path_found( import_path(simple_module, root=tmp_path) +@pytest.mark.unit() def test_importmode_importlib_with_dataclass(tmp_path: Path) -> None: """ Ensure that importlib mode works with a module containing dataclasses (#373, @@ -197,6 +201,7 @@ class Data: assert data.__module__ == "_src.project.task_dataclass" +@pytest.mark.unit() def test_importmode_importlib_with_pickle(tmp_path: Path) -> None: """Ensure that importlib mode works with pickle (#373, pytest#7859).""" fn = tmp_path.joinpath("_src/project/task_pickle.py") @@ -222,6 +227,7 @@ def round_trip(): assert action() == 42 +@pytest.mark.unit() def test_importmode_importlib_with_pickle_separate_modules(tmp_path: Path) -> None: """ Ensure that importlib mode works can load pickles that look similar but are @@ -275,6 +281,7 @@ def round_trip(obj): assert Data2.__module__ == "_src.m2.project.task" +@pytest.mark.unit() def test_module_name_from_path(tmp_path: Path) -> None: result = _module_name_from_path(tmp_path / "src/project/task_foo.py", tmp_path) assert result == "src.project.task_foo" @@ -284,6 +291,7 @@ def test_module_name_from_path(tmp_path: Path) -> None: assert result == "home.foo.task_foo" +@pytest.mark.unit() def test_insert_missing_modules( monkeypatch: pytest.MonkeyPatch, tmp_path: Path ) -> None: diff --git a/tests/test_task.py b/tests/test_task.py index bf4c16a3..ccbadfae 100644 --- a/tests/test_task.py +++ b/tests/test_task.py @@ -440,6 +440,7 @@ def task_func(i=i): assert isinstance(session.collection_reports[0].exc_info[1], ValueError) +@pytest.mark.end_to_end() def test_task_receives_unknown_kwarg(runner, tmp_path): source = """ import pytask