Skip to content

Commit

Permalink
Allow task functions to be partialed. (#536)
Browse files Browse the repository at this point in the history
  • Loading branch information
tobiasraabe committed Dec 19, 2023
1 parent c69ab45 commit 524353b
Show file tree
Hide file tree
Showing 4 changed files with 46 additions and 3 deletions.
1 change: 1 addition & 0 deletions docs/source/changes.md
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ releases are available on [PyPI](https://pypi.org/project/pytask) and
- {pull}`525` enables pytask to work with remote files using universal_pathlib.
- {pull}`528` improves the codecov setup and coverage.
- {pull}`535` reenables and fixes tests with Jupyter.
- {pull}`536` allows partialed functions to be task functions.

## 0.4.4 - 2023-12-04

Expand Down
18 changes: 17 additions & 1 deletion src/_pytask/task_utils.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
"""Contains utilities related to the ``@pytask.mark.task`` decorator."""
from __future__ import annotations

import functools
import inspect
from collections import defaultdict
from types import BuiltinFunctionType
Expand Down Expand Up @@ -115,7 +116,7 @@ def wrapper(func: Callable[..., Any]) -> Callable[..., Any]:
path = get_file(unwrapped)

parsed_kwargs = {} if kwargs is None else kwargs
parsed_name = name if isinstance(name, str) else func.__name__
parsed_name = _parse_name(unwrapped, name)
parsed_after = _parse_after(after)

if hasattr(unwrapped, "pytask_meta"):
Expand Down Expand Up @@ -148,6 +149,21 @@ def wrapper(func: Callable[..., Any]) -> Callable[..., Any]:
return wrapper


def _parse_name(func: Callable[..., Any], name: str | None) -> str:
"""Parse name from task function."""
if name:
return name

if isinstance(func, functools.partial):
func = func.func

if hasattr(func, "__name__"):
return func.__name__

msg = "Cannot infer name for task function."
raise NotImplementedError(msg)


def _parse_after(
after: str | Callable[..., Any] | list[Callable[..., Any]] | None,
) -> str | list[Callable[..., Any]]:
Expand Down
5 changes: 3 additions & 2 deletions tests/test_task.py
Original file line number Diff line number Diff line change
Expand Up @@ -349,8 +349,9 @@ def func(content):

result = runner.invoke(cli, [tmp_path.as_posix()])

assert result.exit_code == ExitCode.COLLECTION_FAILED
assert "1 Collected errors and tasks" in result.output
assert result.exit_code == ExitCode.OK
assert "1 Succeeded" in result.output
assert tmp_path.joinpath("out.txt").read_text() == "hello"


@pytest.mark.end_to_end()
Expand Down
25 changes: 25 additions & 0 deletions tests/test_task_utils.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,12 @@
from __future__ import annotations

from contextlib import ExitStack as does_not_raise # noqa: N813
from functools import partial
from typing import NamedTuple

import pytest
from _pytask.task_utils import _arg_value_to_id_component
from _pytask.task_utils import _parse_name
from _pytask.task_utils import _parse_task_kwargs
from attrs import define

Expand Down Expand Up @@ -56,3 +58,26 @@ def test_parse_task_kwargs(kwargs, expectation, expected):
with expectation:
result = _parse_task_kwargs(kwargs)
assert result == expected


def task_func(x): # noqa: ARG001 # pragma: no cover
pass


@pytest.mark.unit()
@pytest.mark.parametrize(
("func", "name", "expectation", "expected"),
[
(task_func, None, does_not_raise(), "task_func"),
(task_func, "name", does_not_raise(), "name"),
(partial(task_func, x=1), None, does_not_raise(), "task_func"),
(partial(task_func, x=1), "name", does_not_raise(), "name"),
(lambda x: None, None, does_not_raise(), "<lambda>"), # noqa: ARG005
(partial(lambda x: None, x=1), None, does_not_raise(), "<lambda>"), # noqa: ARG005
(1, None, pytest.raises(NotImplementedError, match="Cannot"), None),
],
)
def test_parse_name(func, name, expectation, expected):
with expectation:
result = _parse_name(func, name)
assert result == expected

0 comments on commit 524353b

Please sign in to comment.