Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Use protocols instead of ABCs. #402

Merged
merged 8 commits into from
Jul 30, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 5 additions & 1 deletion docs/source/changes.md
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,12 @@ releases are available on [PyPI](https://pypi.org/project/pytask) and
- {pull}`395` refactors all occurrences of pybaum to {mod}`_pytask.tree_util`.
- {pull}`396` replaces pybaum with optree and adds paths to the name of
{class}`pytask.PythonNode`'s allowing for better hashing.
- {class}`397` adds support for {class}`typing.NamedTuple` and attrs classes in
- {pull}`397` adds support for {class}`typing.NamedTuple` and attrs classes in
`@pytask.mark.task(kwargs=...)`.
- {pull}`398` deprecates the decorators `@pytask.mark.depends_on` and
`@pytask.mark.produces`.
- {pull}`402` replaces ABCs with protocols allowing for more flexibility for users
implementing their own nodes.

## 0.3.2 - 2023-06-07

Expand Down
2 changes: 1 addition & 1 deletion docs/source/reference_guides/api.md
Original file line number Diff line number Diff line change
Expand Up @@ -246,7 +246,7 @@ from {class}`pytask.MetaNode`.
Then, different kinds of nodes can be implemented.

```{eval-rst}
.. autoclass:: pytask.FilePathNode
.. autoclass:: pytask.PathNode
:members:
```

Expand Down
14 changes: 7 additions & 7 deletions src/_pytask/collect.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,8 +21,8 @@
from _pytask.exceptions import CollectionError
from _pytask.mark_utils import has_mark
from _pytask.models import NodeInfo
from _pytask.nodes import FilePathNode
from _pytask.nodes import MetaNode
from _pytask.node_protocols import Node
from _pytask.nodes import PathNode
from _pytask.nodes import PythonNode
from _pytask.nodes import Task
from _pytask.outcomes import CollectionOutcome
Expand Down Expand Up @@ -95,7 +95,7 @@ def pytask_collect_file_protocol(
)
flat_reports = list(itertools.chain.from_iterable(new_reports))
except Exception: # noqa: BLE001
node = FilePathNode.from_path(path)
node = PathNode.from_path(path)
flat_reports = [
CollectionReport.from_exception(
outcome=CollectionOutcome.FAIL, node=node, exc_info=sys.exc_info()
Expand Down Expand Up @@ -204,8 +204,8 @@ def pytask_collect_task(


@hookimpl(trylast=True)
def pytask_collect_node(session: Session, path: Path, node_info: NodeInfo) -> MetaNode:
"""Collect a node of a task as a :class:`pytask.nodes.FilePathNode`.
def pytask_collect_node(session: Session, path: Path, node_info: NodeInfo) -> Node:
"""Collect a node of a task as a :class:`pytask.nodes.PathNode`.

Strings are assumed to be paths. This might be a strict assumption, but since this
hook is executed at last and possible errors will be shown, it seems reasonable and
Expand All @@ -223,7 +223,7 @@ def pytask_collect_node(session: Session, path: Path, node_info: NodeInfo) -> Me
node.name = node_info.arg_name + suffix
return node

if isinstance(node, MetaNode):
if isinstance(node, Node):
return node

if isinstance(node, Path):
Expand All @@ -243,7 +243,7 @@ def pytask_collect_node(session: Session, path: Path, node_info: NodeInfo) -> Me
if str(node) != str(case_sensitive_path):
raise ValueError(_TEMPLATE_ERROR.format(node, case_sensitive_path))

return FilePathNode.from_path(node)
return PathNode.from_path(node)

suffix = "-" + "-".join(map(str, node_info.path)) if node_info.path else ""
node_name = node_info.arg_name + suffix
Expand Down
25 changes: 14 additions & 11 deletions src/_pytask/collect_command.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
from _pytask.exceptions import ResolvingDependenciesError
from _pytask.mark import select_by_keyword
from _pytask.mark import select_by_mark
from _pytask.nodes import FilePathNode
from _pytask.node_protocols import PPathNode
from _pytask.outcomes import ExitCode
from _pytask.path import find_common_ancestor
from _pytask.path import relative_to
Expand Down Expand Up @@ -125,9 +125,7 @@ def _find_common_ancestor_of_all_nodes(
all_paths.append(task.path)
if show_nodes:
all_paths.extend(
x.path
for x in tree_leaves(task.depends_on)
if isinstance(x, FilePathNode)
x.path for x in tree_leaves(task.depends_on) if isinstance(x, PPathNode)
)
all_paths.extend(x.path for x in tree_leaves(task.produces))

Expand Down Expand Up @@ -202,24 +200,29 @@ def _print_collected_tasks(
)

if show_nodes:
file_path_nodes = list(tree_leaves(task.depends_on))
sorted_nodes = sorted(file_path_nodes, key=lambda x: x.name)
nodes = list(tree_leaves(task.depends_on))
sorted_nodes = sorted(nodes, key=lambda x: x.name)
for node in sorted_nodes:
if isinstance(node, FilePathNode):
reduced_node_name = relative_to(node.path, common_ancestor)
if isinstance(node, PPathNode):
if node.path.as_posix() in node.name:
reduced_node_name = str(
relative_to(node.path, common_ancestor)
)
else:
reduced_node_name = node.name
url_style = create_url_style_for_path(
node.path, editor_url_scheme
)
text = Text(str(reduced_node_name), style=url_style)
text = Text(reduced_node_name, style=url_style)
else:
text = node.name

task_branch.add(Text.assemble(FILE_ICON, "<Dependency ", text, ">"))

for node in sorted(tree_leaves(task.produces), key=lambda x: x.path):
reduced_node_name = relative_to(node.path, common_ancestor)
reduced_node_name = str(relative_to(node.path, common_ancestor))
url_style = create_url_style_for_path(node.path, editor_url_scheme)
text = Text(str(reduced_node_name), style=url_style)
text = Text(reduced_node_name, style=url_style)
task_branch.add(Text.assemble(FILE_ICON, "<Product ", text, ">"))

console.print(tree)
67 changes: 37 additions & 30 deletions src/_pytask/collect_utils.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
"""This module provides utility functions for :mod:`_pytask.collect`."""
from __future__ import annotations

import functools
import itertools
import uuid
import warnings
Expand All @@ -11,13 +12,13 @@
from typing import Iterable
from typing import TYPE_CHECKING

import attrs
from _pytask._inspect import get_annotations
from _pytask.exceptions import NodeNotCollectedError
from _pytask.mark_utils import has_mark
from _pytask.mark_utils import remove_marks
from _pytask.models import NodeInfo
from _pytask.nodes import MetaNode
from _pytask.node_protocols import Node
from _pytask.node_protocols import PPathNode
from _pytask.nodes import ProductType
from _pytask.nodes import PythonNode
from _pytask.shared import find_duplicates
Expand Down Expand Up @@ -80,7 +81,7 @@ def parse_nodes(
objects = _extract_nodes_from_function_markers(obj, parser)
nodes = _convert_objects_to_node_dictionary(objects, arg_name)
nodes = tree_map(
lambda x: _collect_decorator_nodes(
lambda x: _collect_decorator_node(
session, path, name, NodeInfo(arg_name, (), x)
),
nodes,
Expand Down Expand Up @@ -228,7 +229,7 @@ def _merge_dictionaries(list_of_dicts: list[dict[Any, Any]]) -> dict[Any, Any]:
"""


def parse_dependencies_from_task_function( # noqa: C901
def parse_dependencies_from_task_function(
session: Session, path: Path, name: str, obj: Any
) -> dict[str, Any]:
"""Parse dependencies from task function."""
Expand All @@ -250,7 +251,7 @@ def parse_dependencies_from_task_function( # noqa: C901
if "depends_on" in kwargs:
has_depends_on_argument = True
dependencies["depends_on"] = tree_map(
lambda x: _collect_decorator_nodes(
lambda x: _collect_decorator_node(
session, path, name, NodeInfo(arg_name="depends_on", path=(), value=x)
),
kwargs["depends_on"],
Expand All @@ -269,23 +270,17 @@ def parse_dependencies_from_task_function( # noqa: C901
if parameter_name == "depends_on":
continue

if parameter_name in parameters_with_node_annot:

def _evolve(x: Any) -> Any:
instance = parameters_with_node_annot[parameter_name] # noqa: B023
return attrs.evolve(instance, value=x) # type: ignore[misc]

else:

def _evolve(x: Any) -> Any:
return x
partialed_evolve = functools.partial(
_evolve_instance,
instance_from_annot=parameters_with_node_annot.get(parameter_name),
)

nodes = tree_map_with_path(
lambda p, x: _collect_dependencies(
lambda p, x: _collect_dependency(
session,
path,
name,
NodeInfo(parameter_name, p, _evolve(x)), # noqa: B023
NodeInfo(parameter_name, p, partialed_evolve(x)), # noqa: B023
),
value,
)
Expand All @@ -295,14 +290,14 @@ def _evolve(x: Any) -> Any:
are_all_nodes_python_nodes_without_hash = all(
isinstance(x, PythonNode) and not x.hash for x in tree_leaves(nodes)
)
if are_all_nodes_python_nodes_without_hash:
if not isinstance(nodes, Node) and are_all_nodes_python_nodes_without_hash:
dependencies[parameter_name] = PythonNode(value=value, name=parameter_name)
else:
dependencies[parameter_name] = nodes
return dependencies


def _find_args_with_node_annotation(func: Callable[..., Any]) -> dict[str, MetaNode]:
def _find_args_with_node_annotation(func: Callable[..., Any]) -> dict[str, Node]:
"""Find args with node annotations."""
annotations = get_annotations(func, eval_str=True)
metas = {
Expand All @@ -314,9 +309,7 @@ def _find_args_with_node_annotation(func: Callable[..., Any]) -> dict[str, MetaN
args_with_node_annotation = {}
for name, meta in metas.items():
annot = [
i
for i in meta
if not isinstance(i, ProductType) and isinstance(i, MetaNode)
i for i in meta if not isinstance(i, ProductType) and isinstance(i, Node)
]
if len(annot) >= 2: # noqa: PLR2004
raise ValueError(
Expand Down Expand Up @@ -380,6 +373,7 @@ def parse_products_from_task_function(
kwargs = {**signature_defaults, **task_kwargs}

parameters_with_product_annot = _find_args_with_product_annotation(obj)
parameters_with_node_annot = _find_args_with_node_annotation(obj)

# Parse products from task decorated with @task and that uses produces.
if "produces" in kwargs:
Expand All @@ -404,13 +398,17 @@ def parse_products_from_task_function(
has_annotation = True
for parameter_name in parameters_with_product_annot:
if parameter_name in kwargs:
# Use _collect_new_node to not collect strings.
partialed_evolve = functools.partial(
_evolve_instance,
instance_from_annot=parameters_with_node_annot.get(parameter_name),
)

collected_products = tree_map_with_path(
lambda p, x: _collect_product(
session,
path,
name,
NodeInfo(parameter_name, p, x), # noqa: B023
NodeInfo(parameter_name, p, partialed_evolve(x)), # noqa: B023
is_string_allowed=False,
),
kwargs[parameter_name],
Expand Down Expand Up @@ -456,9 +454,9 @@ def _find_args_with_product_annotation(func: Callable[..., Any]) -> list[str]:
"""


def _collect_decorator_nodes(
def _collect_decorator_node(
session: Session, path: Path, name: str, node_info: NodeInfo
) -> dict[str, MetaNode]:
) -> Node:
"""Collect nodes for a task.

Raises
Expand Down Expand Up @@ -495,9 +493,9 @@ def _collect_decorator_nodes(
return collected_node


def _collect_dependencies(
def _collect_dependency(
session: Session, path: Path, name: str, node_info: NodeInfo
) -> dict[str, MetaNode]:
) -> Node:
"""Collect nodes for a task.

Raises
Expand Down Expand Up @@ -525,7 +523,7 @@ def _collect_product(
task_name: str,
node_info: NodeInfo,
is_string_allowed: bool = False,
) -> dict[str, MetaNode]:
) -> Node:
"""Collect products for a task.

Defining products with strings is only allowed when using the decorator. Parameter
Expand All @@ -546,7 +544,7 @@ def _collect_product(
f"tuples, lists, and dictionaries. Here, {node} has type {type(node)}."
)
# The parameter defaults only support Path objects.
if not isinstance(node, Path) and not is_string_allowed:
if not isinstance(node, (Path, PPathNode)) and not is_string_allowed:
raise ValueError(
"If you declare products with 'Annotated[..., Product]', only values of "
"type 'pathlib.Path' optionally nested in tuples, lists, and "
Expand All @@ -566,3 +564,12 @@ def _collect_product(
)

return collected_node


def _evolve_instance(x: Any, instance_from_annot: Node | None) -> Any:
"""Evolve a value to a node if it is given by annotations."""
if not instance_from_annot:
return x

instance_from_annot.value = x
return instance_from_annot
11 changes: 6 additions & 5 deletions src/_pytask/dag.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,8 +21,9 @@
from _pytask.mark import Mark
from _pytask.mark_utils import get_marks
from _pytask.mark_utils import has_mark
from _pytask.nodes import FilePathNode
from _pytask.nodes import MetaNode
from _pytask.node_protocols import MetaNode
from _pytask.node_protocols import Node
from _pytask.node_protocols import PPathNode
from _pytask.nodes import Task
from _pytask.path import find_common_ancestor_of_nodes
from _pytask.report import DagReport
Expand Down Expand Up @@ -140,13 +141,13 @@ def pytask_dag_has_node_changed(node: MetaNode, task_name: str) -> bool:
if db_state is None:
return True

if isinstance(node, (FilePathNode, Task)):
if isinstance(node, (PPathNode, Task)):
# If the modification times match, the node has not been changed.
if node_state == db_state.modification_time:
return False

# If the modification time changed, quickly return for non-tasks.
if isinstance(node, FilePathNode):
if not isinstance(node, Task):
return True

# When modification times changed, we are still comparing the hash of the file
Expand Down Expand Up @@ -238,7 +239,7 @@ def _check_if_root_nodes_are_available(dag: nx.DiGraph) -> None:


def _check_if_tasks_are_skipped(
node: MetaNode, dag: nx.DiGraph, is_task_skipped: dict[str, bool]
node: Node, dag: nx.DiGraph, is_task_skipped: dict[str, bool]
) -> tuple[bool, dict[str, bool]]:
"""Check for a given node whether it is only used by skipped tasks."""
are_all_tasks_skipped = []
Expand Down
4 changes: 2 additions & 2 deletions src/_pytask/database_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import hashlib

from _pytask.dag_utils import node_and_neighbors
from _pytask.nodes import FilePathNode
from _pytask.node_protocols import PPathNode
from _pytask.nodes import Task
from _pytask.session import Session
from sqlalchemy import Column
Expand Down Expand Up @@ -80,7 +80,7 @@ def update_states_in_database(session: Session, task_name: str) -> None:
if isinstance(node, Task):
modification_time = node.state()
hash_ = hashlib.sha256(node.path.read_bytes()).hexdigest()
elif isinstance(node, FilePathNode):
elif isinstance(node, PPathNode):
modification_time = node.state()
hash_ = ""
else:
Expand Down
Loading