Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,12 @@ chronological order. Releases follow [semantic versioning](https://semver.org/)
releases are available on [PyPI](https://pypi.org/project/pytask) and
[Anaconda.org](https://anaconda.org/conda-forge/pytask).

## Unreleased

- [#889](https://github.com/pytask-dev/pytask/pull/889) improves typing for tree
operations by wrapping optree's pytree utilities with pytask-specific signatures
and requiring optree 0.16.0 or newer.

## 0.6.0 - 2026-05-01

- [#875](https://github.com/pytask-dev/pytask/pull/875) improves the documentation
Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ dependencies = [
"click>=8.1.8,!=8.2.0",
"click-default-group>=1.2.4",
"msgspec>=0.18.6",
"optree>=0.9.0",
"optree>=0.16.0",
"packaging>=23.0.0",
"pluggy>=1.3.0",
"rich>=13.8.0",
Expand Down
4 changes: 2 additions & 2 deletions src/_pytask/collect_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,11 +12,11 @@
from _pytask._inspect import get_annotations
from _pytask.exceptions import NodeNotCollectedError
from _pytask.models import NodeInfo
from _pytask.node_protocols import NodeTree
from _pytask.node_protocols import PNode
from _pytask.node_protocols import PProvisionalNode
from _pytask.nodes import PythonNode
from _pytask.task_utils import parse_keyword_arguments_from_signature_defaults
from _pytask.tree_util import PyTree
from _pytask.tree_util import tree_leaves
from _pytask.tree_util import tree_map_with_path
from _pytask.typing import ProductType
Expand Down Expand Up @@ -254,7 +254,7 @@ def _collect_nodes_and_provisional_nodes( # noqa: PLR0913
task_path: Path | None,
parameter_name: str,
value: Any,
) -> PyTree[PProvisionalNode | PNode]:
) -> NodeTree:
return tree_map_with_path(
lambda p, x: collection_func(
session,
Expand Down
103 changes: 62 additions & 41 deletions src/_pytask/node_protocols.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,18 +3,29 @@
from typing import TYPE_CHECKING
from typing import Any
from typing import Protocol
from typing import TypeAlias
from typing import runtime_checkable

from _pytask.tree_util import PyTree

if TYPE_CHECKING:
from collections.abc import Callable
from pathlib import Path

from _pytask.mark import Mark
from _pytask.tree_util import PyTree
from _pytask.typing import NodePath


__all__ = ["PNode", "PPathNode", "PProvisionalNode", "PTask", "PTaskWithPath"]
__all__ = [
"NodeTree",
"PNode",
"PPathNode",
"PProvisionalNode",
"PTask",
"PTaskWithPath",
"TaskIO",
"TaskNode",
]


@runtime_checkable
Expand Down Expand Up @@ -64,45 +75,6 @@ class PPathNode(PNode, Protocol):
path: NodePath


@runtime_checkable
class PTask(Protocol):
"""Protocol for nodes."""

name: str
depends_on: dict[str, PyTree[PNode | PProvisionalNode]]
produces: dict[str, PyTree[PNode | PProvisionalNode]]
function: Callable[..., Any]
markers: list[Mark]
report_sections: list[tuple[str, str, str]]
attributes: dict[Any, Any]

@property
def signature(self) -> str:
"""Return the signature of the node."""

def state(self) -> str | None:
"""Return the state of the node.

The state can be something like a hash or a last modified timestamp. If the node
does not exist, you can also return ``None``.

"""

def execute(self, **kwargs: Any) -> Any:
"""Return the value of the node that will be injected into the task."""


@runtime_checkable
class PTaskWithPath(PTask, Protocol):
"""Tasks with paths.

Tasks with paths receive special handling when it comes to printing their names.

"""

path: Path


@runtime_checkable
class PProvisionalNode(Protocol):
"""A protocol for provisional nodes.
Expand Down Expand Up @@ -141,3 +113,52 @@ def load(self, is_product: bool = False) -> Any: # pragma: no cover

def collect(self) -> list[Any]:
"""Collect the objects that are defined by the provisional nodes."""


TaskNode: TypeAlias = PNode | PProvisionalNode
"""A concrete or provisional pytask node."""

NodeTree: TypeAlias = PyTree[TaskNode]
"""A pytask tree whose leaves are concrete or provisional nodes."""

TaskIO: TypeAlias = dict[str, NodeTree]
"""The top-level task argument mapping for dependencies and products."""


@runtime_checkable
class PTask(Protocol):
"""Protocol for nodes."""

name: str
depends_on: TaskIO
produces: TaskIO
function: Callable[..., Any]
markers: list[Mark]
report_sections: list[tuple[str, str, str]]
attributes: dict[Any, Any]

@property
def signature(self) -> str:
"""Return the signature of the node."""

def state(self) -> str | None:
"""Return the state of the node.

The state can be something like a hash or a last modified timestamp. If the node
does not exist, you can also return ``None``.

"""

def execute(self, **kwargs: Any) -> Any:
"""Return the value of the node that will be injected into the task."""


@runtime_checkable
class PTaskWithPath(PTask, Protocol):
"""Tasks with paths.

Tasks with paths receive special handling when it comes to printing their names.

"""

path: Path
14 changes: 5 additions & 9 deletions src/_pytask/nodes.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
from _pytask.node_protocols import PProvisionalNode
from _pytask.node_protocols import PTask
from _pytask.node_protocols import PTaskWithPath
from _pytask.node_protocols import TaskIO
from _pytask.path import hash_path
from _pytask.typing import NoDefault
from _pytask.typing import NodePath
Expand All @@ -34,7 +35,6 @@

from _pytask.mark import Mark
from _pytask.models import NodeInfo
from _pytask.tree_util import PyTree


__all__ = [
Expand Down Expand Up @@ -77,10 +77,8 @@ class TaskWithoutPath(PTask):

name: str
function: Callable[..., Any]
depends_on: dict[str, PyTree[PNode | PProvisionalNode]] = field(
default_factory=dict
)
produces: dict[str, PyTree[PNode | PProvisionalNode]] = field(default_factory=dict)
depends_on: TaskIO = field(default_factory=dict)
produces: TaskIO = field(default_factory=dict)
markers: list[Mark] = field(default_factory=list)
report_sections: list[tuple[str, str, str]] = field(default_factory=list)
attributes: dict[Any, Any] = field(default_factory=dict)
Expand Down Expand Up @@ -133,10 +131,8 @@ class Task(PTaskWithPath):
path: Path
function: Callable[..., Any]
name: str = field(default="", init=False)
depends_on: dict[str, PyTree[PNode | PProvisionalNode]] = field(
default_factory=dict
)
produces: dict[str, PyTree[PNode | PProvisionalNode]] = field(default_factory=dict)
depends_on: TaskIO = field(default_factory=dict)
produces: TaskIO = field(default_factory=dict)
markers: list[Mark] = field(default_factory=list)
report_sections: list[tuple[str, str, str]] = field(default_factory=list)
attributes: dict[Any, Any] = field(default_factory=dict)
Expand Down
5 changes: 2 additions & 3 deletions src/_pytask/provisional_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,13 +10,12 @@
from _pytask.collect_utils import collect_dependency
from _pytask.dag import create_dag_from_session
from _pytask.models import NodeInfo
from _pytask.node_protocols import PNode
from _pytask.node_protocols import NodeTree
from _pytask.node_protocols import PProvisionalNode
from _pytask.node_protocols import PTask
from _pytask.node_protocols import PTaskWithPath
from _pytask.nodes import Task
from _pytask.reports import ExecutionReport
from _pytask.tree_util import PyTree
from _pytask.tree_util import tree_map_with_path
from _pytask.typing import is_task_generator

Expand All @@ -29,7 +28,7 @@

def collect_provisional_nodes(
session: Session, task: PTask, node: Any, path: tuple[Any, ...]
) -> PyTree[PNode | PProvisionalNode]:
) -> NodeTree:
"""Collect provisional nodes.
1. Call the [`pytask.PProvisionalNode.collect`][] to receive the raw nodes.
Expand Down
Loading
Loading