Skip to content

Commit

Permalink
Fix error messages and categorize tests.
Browse files Browse the repository at this point in the history
  • Loading branch information
tobiasraabe committed Jul 8, 2023
1 parent 53378a0 commit 70c55be
Show file tree
Hide file tree
Showing 9 changed files with 115 additions and 16 deletions.
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,7 @@ convention = "numpy"


[tool.pytest.ini_options]
addopts = ["--doctest-modules"]
testpaths = ["src", "tests"]
markers = [
"wip: Tests that are work-in-progress.",
Expand Down
7 changes: 1 addition & 6 deletions src/_pytask/collect.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@
from _pytask.collect_utils import parse_dependencies_from_task_function
from _pytask.collect_utils import parse_nodes
from _pytask.collect_utils import parse_products_from_task_function
from _pytask.collect_utils import produces
from _pytask.config import hookimpl
from _pytask.config import IS_FILE_SYSTEM_CASE_SENSITIVE
from _pytask.console import console
Expand Down Expand Up @@ -178,11 +177,7 @@ def pytask_collect_task(
session, path, name, obj
)

if has_mark(obj, "produces"):
nodes = parse_nodes(session, path, name, obj, produces)
products = {"produces": nodes}
else:
products = parse_products_from_task_function(session, path, name, obj)
products = parse_products_from_task_function(session, path, name, obj)

markers = obj.pytask_meta.markers if hasattr(obj, "pytask_meta") else []

Expand Down
65 changes: 56 additions & 9 deletions src/_pytask/collect_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@

from _pytask._inspect import get_annotations
from _pytask.exceptions import NodeNotCollectedError
from _pytask.mark_utils import has_mark
from _pytask.mark_utils import remove_marks
from _pytask.nodes import ProductType
from _pytask.nodes import PythonNode
Expand Down Expand Up @@ -231,33 +232,64 @@ def parse_dependencies_from_task_function(
return dependencies


_ERROR_MULTIPLE_PRODUCT_DEFINITIONS = (
"The task uses multiple ways to define products. Products should be defined with "
"either\n\n- 'typing.Annotated[Path(...), Product]' (recommended)\n"
"- '@pytask.mark.task(kwargs={'produces': Path(...)})'\n"
"- as a default argument for 'produces': 'produces = Path(...)'\n"
"- '@pytask.mark.produces(Path(...))' (deprecated).\n\n"
"Read more about products in the documentation: https://tinyurl.com/yrezszr4."
)


def parse_products_from_task_function(
session: Session, path: Path, name: str, obj: Any
) -> dict[str, Any]:
"""Parse products from task function."""
"""Parse products from task function.
Raises
------
NodeNotCollectedError
If multiple ways were used to specify products.
"""
has_produces_decorator = False
has_task_decorator = False
has_signature_default = False
has_annotation = False
out = {}

if has_mark(obj, "produces"):
has_produces_decorator = True
nodes = parse_nodes(session, path, name, obj, produces)
out = {"produces": nodes}

task_kwargs = obj.pytask_meta.kwargs if hasattr(obj, "pytask_meta") else {}
if "produces" in task_kwargs:
collected_products = tree_map(
lambda x: _collect_product(session, path, name, x, is_string_allowed=True),
task_kwargs["produces"],
)
return {"produces": collected_products}
out = {"produces": collected_products}

parameters = inspect.signature(obj).parameters
if "produces" in parameters:

if not has_mark(obj, "task") and "produces" in parameters:
parameter = parameters["produces"]
if parameter.default is not parameter.empty:
has_signature_default = True
# Use _collect_new_node to not collect strings.
collected_products = tree_map(
lambda x: _collect_product(
session, path, name, x, is_string_allowed=False
),
parameter.default,
)
return {"produces": collected_products}
out = {"produces": collected_products}

parameters_with_product_annot = _find_args_with_product_annotation(obj)
if parameters_with_product_annot:
has_annotation = True
for parameter_name in parameters_with_product_annot:
parameter = parameters[parameter_name]
if parameter.default is not parameter.empty:
Expand All @@ -268,8 +300,22 @@ def parse_products_from_task_function(
),
parameter.default,
)
return {parameter_name: collected_products}
return {}
out = {parameter_name: collected_products}

if (
sum(
(
has_produces_decorator,
has_task_decorator,
has_signature_default,
has_annotation,
)
)
>= 2 # noqa: PLR2004
):
raise NodeNotCollectedError(_ERROR_MULTIPLE_PRODUCT_DEFINITIONS)

return out


def _find_args_with_product_annotation(func: Callable[..., Any]) -> list[str]:
Expand Down Expand Up @@ -373,9 +419,10 @@ def _collect_product(
# The parameter defaults only support Path objects.
if not isinstance(node, Path) and not is_string_allowed:
raise ValueError(
"If you use 'produces' as an argument of a task, it can only accept values "
"of type 'pathlib.Path' or the same value nested in "
f"tuples, lists, and dictionaries. Here, {node} has type {type(node)}."
"If you use 'produces' as a function argument of a task and pass values as "
"function defaults, it can only accept values of type 'pathlib.Path' or "
"the same value nested in tuples, lists, and dictionaries. Here, "
f"{node!r} has type {type(node)}."
)

if isinstance(node, str):
Expand Down
7 changes: 6 additions & 1 deletion src/_pytask/traceback.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@

import _pytask
import pluggy
import pybaum
from rich.traceback import Traceback


Expand All @@ -22,6 +23,7 @@


_PLUGGY_DIRECTORY = Path(pluggy.__file__).parent
_PYBAUM_DIRECTORY = Path(pybaum.__file__).parent
_PYTASK_DIRECTORY = Path(_pytask.__file__).parent


Expand Down Expand Up @@ -89,7 +91,10 @@ def _is_internal_or_hidden_traceback_frame(
return True

path = Path(frame.tb_frame.f_code.co_filename)
return any(root in path.parents for root in (_PLUGGY_DIRECTORY, _PYTASK_DIRECTORY))
return any(
root in path.parents
for root in (_PLUGGY_DIRECTORY, _PYBAUM_DIRECTORY, _PYTASK_DIRECTORY)
)


def _filter_internal_traceback_frames(
Expand Down
40 changes: 40 additions & 0 deletions tests/test_collect.py
Original file line number Diff line number Diff line change
Expand Up @@ -249,6 +249,7 @@ def test_find_shortest_uniquely_identifiable_names_for_tasks(tmp_path):
assert result == expected


@pytest.mark.end_to_end()
def test_collect_dependencies_from_args_if_depends_on_is_missing(tmp_path):
source = """
from pathlib import Path
Expand Down Expand Up @@ -306,6 +307,21 @@ def task_my_task():
assert outcome == CollectionOutcome.SUCCESS


@pytest.mark.end_to_end()
def test_collect_string_product_with_task_decorator(tmp_path):
source = """
import pytask
@pytask.mark.task
def task_write_text(produces="out.txt"):
produces.touch()
"""
tmp_path.joinpath("task_module.py").write_text(textwrap.dedent(source))
session = main({"paths": tmp_path})
assert session.exit_code == ExitCode.OK
assert tmp_path.joinpath("out.txt").exists()


@pytest.mark.end_to_end()
def test_collect_string_product_as_function_default_fails(tmp_path):
source = """
Expand All @@ -319,3 +335,27 @@ def task_write_text(produces="out.txt"):
report = session.collection_reports[0]
assert report.outcome == CollectionOutcome.FAIL
assert "If you use 'produces'" in str(report.exc_info[1])


@pytest.mark.end_to_end()
def test_product_cannot_mix_different_product_types(tmp_path):
source = """
import pytask
from typing_extensions import Annotated
from pytask import Product
from pathlib import Path
@pytask.mark.produces("out_deco.txt")
def task_example(
path: Annotated[Path, Product], produces: Path = Path("out_sig.txt")
):
...
"""
tmp_path.joinpath("task_module.py").write_text(textwrap.dedent(source))
session = main({"paths": tmp_path})

assert session.exit_code == ExitCode.COLLECTION_FAILED
assert len(session.tasks) == 0
report = session.collection_reports[0]
assert report.outcome == CollectionOutcome.FAIL
assert "The task uses multiple ways" in str(report.exc_info[1])
1 change: 1 addition & 0 deletions tests/test_collect_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -164,6 +164,7 @@ def task_example(): # pragma: no cover
list(_extract_nodes_from_function_markers(task_example, parser))


@pytest.mark.unit()
def test_find_args_with_product_annotation():
def func(a: Annotated[int, Product], b: float, c, d: Annotated[int, float]):
return a, b, c, d
Expand Down
1 change: 1 addition & 0 deletions tests/test_execute.py
Original file line number Diff line number Diff line change
Expand Up @@ -386,6 +386,7 @@ def task_example():
assert "Collected 1 task" in result.output


@pytest.mark.end_to_end()
def test_task_executed_with_force_although_unchanged(tmp_path):
tmp_path.joinpath("task_module.py").write_text("def task_example(): pass")
session = main({"paths": tmp_path})
Expand Down
8 changes: 8 additions & 0 deletions tests/test_path.py
Original file line number Diff line number Diff line change
Expand Up @@ -133,6 +133,7 @@ def simple_module(tmp_path: Path) -> Path:
return fn


@pytest.mark.unit()
def test_importmode_importlib(simple_module: Path, tmp_path: Path) -> None:
"""`importlib` mode does not change sys.path."""
module = import_path(simple_module, root=tmp_path)
Expand All @@ -144,6 +145,7 @@ def test_importmode_importlib(simple_module: Path, tmp_path: Path) -> None:
assert "_src.project" in sys.modules


@pytest.mark.unit()
def test_importmode_twice_is_different_module(
simple_module: Path, tmp_path: Path
) -> None:
Expand All @@ -153,6 +155,7 @@ def test_importmode_twice_is_different_module(
assert module1 is not module2


@pytest.mark.unit()
def test_no_meta_path_found(
simple_module: Path, monkeypatch: pytest.MonkeyPatch, tmp_path: Path
) -> None:
Expand All @@ -171,6 +174,7 @@ def test_no_meta_path_found(
import_path(simple_module, root=tmp_path)


@pytest.mark.unit()
def test_importmode_importlib_with_dataclass(tmp_path: Path) -> None:
"""
Ensure that importlib mode works with a module containing dataclasses (#373,
Expand All @@ -197,6 +201,7 @@ class Data:
assert data.__module__ == "_src.project.task_dataclass"


@pytest.mark.unit()
def test_importmode_importlib_with_pickle(tmp_path: Path) -> None:
"""Ensure that importlib mode works with pickle (#373, pytest#7859)."""
fn = tmp_path.joinpath("_src/project/task_pickle.py")
Expand All @@ -222,6 +227,7 @@ def round_trip():
assert action() == 42


@pytest.mark.unit()
def test_importmode_importlib_with_pickle_separate_modules(tmp_path: Path) -> None:
"""
Ensure that importlib mode works can load pickles that look similar but are
Expand Down Expand Up @@ -275,6 +281,7 @@ def round_trip(obj):
assert Data2.__module__ == "_src.m2.project.task"


@pytest.mark.unit()
def test_module_name_from_path(tmp_path: Path) -> None:
result = _module_name_from_path(tmp_path / "src/project/task_foo.py", tmp_path)
assert result == "src.project.task_foo"
Expand All @@ -284,6 +291,7 @@ def test_module_name_from_path(tmp_path: Path) -> None:
assert result == "home.foo.task_foo"


@pytest.mark.unit()
def test_insert_missing_modules(
monkeypatch: pytest.MonkeyPatch, tmp_path: Path
) -> None:
Expand Down
1 change: 1 addition & 0 deletions tests/test_task.py
Original file line number Diff line number Diff line change
Expand Up @@ -440,6 +440,7 @@ def task_func(i=i):
assert isinstance(session.collection_reports[0].exc_info[1], ValueError)


@pytest.mark.end_to_end()
def test_task_receives_unknown_kwarg(runner, tmp_path):
source = """
import pytask
Expand Down

0 comments on commit 70c55be

Please sign in to comment.