diff --git a/docs/source/changes.md b/docs/source/changes.md index 18c6c6c6..7a535918 100644 --- a/docs/source/changes.md +++ b/docs/source/changes.md @@ -17,8 +17,12 @@ releases are available on [PyPI](https://pypi.org/project/pytask) and - {pull}`395` refactors all occurrences of pybaum to {mod}`_pytask.tree_util`. - {pull}`396` replaces pybaum with optree and adds paths to the name of {class}`pytask.PythonNode`'s allowing for better hashing. -- {class}`397` adds support for {class}`typing.NamedTuple` and attrs classes in +- {pull}`397` adds support for {class}`typing.NamedTuple` and attrs classes in `@pytask.mark.task(kwargs=...)`. +- {pull}`398` deprecates the decorators `@pytask.mark.depends_on` and + `@pytask.mark.produces`. +- {pull}`402` replaces ABCs with protocols allowing for more flexibility for users + implementing their own nodes. ## 0.3.2 - 2023-06-07 diff --git a/docs/source/reference_guides/api.md b/docs/source/reference_guides/api.md index 2227723d..bfd922e2 100644 --- a/docs/source/reference_guides/api.md +++ b/docs/source/reference_guides/api.md @@ -246,7 +246,7 @@ from {class}`pytask.MetaNode`. Then, different kinds of nodes can be implemented. ```{eval-rst} -.. autoclass:: pytask.FilePathNode +.. autoclass:: pytask.PathNode :members: ``` diff --git a/src/_pytask/collect.py b/src/_pytask/collect.py index d4d135c9..26db423f 100644 --- a/src/_pytask/collect.py +++ b/src/_pytask/collect.py @@ -21,8 +21,8 @@ from _pytask.exceptions import CollectionError from _pytask.mark_utils import has_mark from _pytask.models import NodeInfo -from _pytask.nodes import FilePathNode -from _pytask.nodes import MetaNode +from _pytask.node_protocols import Node +from _pytask.nodes import PathNode from _pytask.nodes import PythonNode from _pytask.nodes import Task from _pytask.outcomes import CollectionOutcome @@ -95,7 +95,7 @@ def pytask_collect_file_protocol( ) flat_reports = list(itertools.chain.from_iterable(new_reports)) except Exception: # noqa: BLE001 - node = FilePathNode.from_path(path) + node = PathNode.from_path(path) flat_reports = [ CollectionReport.from_exception( outcome=CollectionOutcome.FAIL, node=node, exc_info=sys.exc_info() @@ -204,8 +204,8 @@ def pytask_collect_task( @hookimpl(trylast=True) -def pytask_collect_node(session: Session, path: Path, node_info: NodeInfo) -> MetaNode: - """Collect a node of a task as a :class:`pytask.nodes.FilePathNode`. +def pytask_collect_node(session: Session, path: Path, node_info: NodeInfo) -> Node: + """Collect a node of a task as a :class:`pytask.nodes.PathNode`. Strings are assumed to be paths. This might be a strict assumption, but since this hook is executed at last and possible errors will be shown, it seems reasonable and @@ -223,7 +223,7 @@ def pytask_collect_node(session: Session, path: Path, node_info: NodeInfo) -> Me node.name = node_info.arg_name + suffix return node - if isinstance(node, MetaNode): + if isinstance(node, Node): return node if isinstance(node, Path): @@ -243,7 +243,7 @@ def pytask_collect_node(session: Session, path: Path, node_info: NodeInfo) -> Me if str(node) != str(case_sensitive_path): raise ValueError(_TEMPLATE_ERROR.format(node, case_sensitive_path)) - return FilePathNode.from_path(node) + return PathNode.from_path(node) suffix = "-" + "-".join(map(str, node_info.path)) if node_info.path else "" node_name = node_info.arg_name + suffix diff --git a/src/_pytask/collect_command.py b/src/_pytask/collect_command.py index e615a8ea..a8b3fcf7 100644 --- a/src/_pytask/collect_command.py +++ b/src/_pytask/collect_command.py @@ -20,7 +20,7 @@ from _pytask.exceptions import ResolvingDependenciesError from _pytask.mark import select_by_keyword from _pytask.mark import select_by_mark -from _pytask.nodes import FilePathNode +from _pytask.node_protocols import PPathNode from _pytask.outcomes import ExitCode from _pytask.path import find_common_ancestor from _pytask.path import relative_to @@ -125,9 +125,7 @@ def _find_common_ancestor_of_all_nodes( all_paths.append(task.path) if show_nodes: all_paths.extend( - x.path - for x in tree_leaves(task.depends_on) - if isinstance(x, FilePathNode) + x.path for x in tree_leaves(task.depends_on) if isinstance(x, PPathNode) ) all_paths.extend(x.path for x in tree_leaves(task.produces)) @@ -202,24 +200,29 @@ def _print_collected_tasks( ) if show_nodes: - file_path_nodes = list(tree_leaves(task.depends_on)) - sorted_nodes = sorted(file_path_nodes, key=lambda x: x.name) + nodes = list(tree_leaves(task.depends_on)) + sorted_nodes = sorted(nodes, key=lambda x: x.name) for node in sorted_nodes: - if isinstance(node, FilePathNode): - reduced_node_name = relative_to(node.path, common_ancestor) + if isinstance(node, PPathNode): + if node.path.as_posix() in node.name: + reduced_node_name = str( + relative_to(node.path, common_ancestor) + ) + else: + reduced_node_name = node.name url_style = create_url_style_for_path( node.path, editor_url_scheme ) - text = Text(str(reduced_node_name), style=url_style) + text = Text(reduced_node_name, style=url_style) else: text = node.name task_branch.add(Text.assemble(FILE_ICON, "")) for node in sorted(tree_leaves(task.produces), key=lambda x: x.path): - reduced_node_name = relative_to(node.path, common_ancestor) + reduced_node_name = str(relative_to(node.path, common_ancestor)) url_style = create_url_style_for_path(node.path, editor_url_scheme) - text = Text(str(reduced_node_name), style=url_style) + text = Text(reduced_node_name, style=url_style) task_branch.add(Text.assemble(FILE_ICON, "")) console.print(tree) diff --git a/src/_pytask/collect_utils.py b/src/_pytask/collect_utils.py index b333d2c6..a5e6e507 100644 --- a/src/_pytask/collect_utils.py +++ b/src/_pytask/collect_utils.py @@ -1,6 +1,7 @@ """This module provides utility functions for :mod:`_pytask.collect`.""" from __future__ import annotations +import functools import itertools import uuid import warnings @@ -11,13 +12,13 @@ from typing import Iterable from typing import TYPE_CHECKING -import attrs from _pytask._inspect import get_annotations from _pytask.exceptions import NodeNotCollectedError from _pytask.mark_utils import has_mark from _pytask.mark_utils import remove_marks from _pytask.models import NodeInfo -from _pytask.nodes import MetaNode +from _pytask.node_protocols import Node +from _pytask.node_protocols import PPathNode from _pytask.nodes import ProductType from _pytask.nodes import PythonNode from _pytask.shared import find_duplicates @@ -80,7 +81,7 @@ def parse_nodes( objects = _extract_nodes_from_function_markers(obj, parser) nodes = _convert_objects_to_node_dictionary(objects, arg_name) nodes = tree_map( - lambda x: _collect_decorator_nodes( + lambda x: _collect_decorator_node( session, path, name, NodeInfo(arg_name, (), x) ), nodes, @@ -228,7 +229,7 @@ def _merge_dictionaries(list_of_dicts: list[dict[Any, Any]]) -> dict[Any, Any]: """ -def parse_dependencies_from_task_function( # noqa: C901 +def parse_dependencies_from_task_function( session: Session, path: Path, name: str, obj: Any ) -> dict[str, Any]: """Parse dependencies from task function.""" @@ -250,7 +251,7 @@ def parse_dependencies_from_task_function( # noqa: C901 if "depends_on" in kwargs: has_depends_on_argument = True dependencies["depends_on"] = tree_map( - lambda x: _collect_decorator_nodes( + lambda x: _collect_decorator_node( session, path, name, NodeInfo(arg_name="depends_on", path=(), value=x) ), kwargs["depends_on"], @@ -269,23 +270,17 @@ def parse_dependencies_from_task_function( # noqa: C901 if parameter_name == "depends_on": continue - if parameter_name in parameters_with_node_annot: - - def _evolve(x: Any) -> Any: - instance = parameters_with_node_annot[parameter_name] # noqa: B023 - return attrs.evolve(instance, value=x) # type: ignore[misc] - - else: - - def _evolve(x: Any) -> Any: - return x + partialed_evolve = functools.partial( + _evolve_instance, + instance_from_annot=parameters_with_node_annot.get(parameter_name), + ) nodes = tree_map_with_path( - lambda p, x: _collect_dependencies( + lambda p, x: _collect_dependency( session, path, name, - NodeInfo(parameter_name, p, _evolve(x)), # noqa: B023 + NodeInfo(parameter_name, p, partialed_evolve(x)), # noqa: B023 ), value, ) @@ -295,14 +290,14 @@ def _evolve(x: Any) -> Any: are_all_nodes_python_nodes_without_hash = all( isinstance(x, PythonNode) and not x.hash for x in tree_leaves(nodes) ) - if are_all_nodes_python_nodes_without_hash: + if not isinstance(nodes, Node) and are_all_nodes_python_nodes_without_hash: dependencies[parameter_name] = PythonNode(value=value, name=parameter_name) else: dependencies[parameter_name] = nodes return dependencies -def _find_args_with_node_annotation(func: Callable[..., Any]) -> dict[str, MetaNode]: +def _find_args_with_node_annotation(func: Callable[..., Any]) -> dict[str, Node]: """Find args with node annotations.""" annotations = get_annotations(func, eval_str=True) metas = { @@ -314,9 +309,7 @@ def _find_args_with_node_annotation(func: Callable[..., Any]) -> dict[str, MetaN args_with_node_annotation = {} for name, meta in metas.items(): annot = [ - i - for i in meta - if not isinstance(i, ProductType) and isinstance(i, MetaNode) + i for i in meta if not isinstance(i, ProductType) and isinstance(i, Node) ] if len(annot) >= 2: # noqa: PLR2004 raise ValueError( @@ -380,6 +373,7 @@ def parse_products_from_task_function( kwargs = {**signature_defaults, **task_kwargs} parameters_with_product_annot = _find_args_with_product_annotation(obj) + parameters_with_node_annot = _find_args_with_node_annotation(obj) # Parse products from task decorated with @task and that uses produces. if "produces" in kwargs: @@ -404,13 +398,17 @@ def parse_products_from_task_function( has_annotation = True for parameter_name in parameters_with_product_annot: if parameter_name in kwargs: - # Use _collect_new_node to not collect strings. + partialed_evolve = functools.partial( + _evolve_instance, + instance_from_annot=parameters_with_node_annot.get(parameter_name), + ) + collected_products = tree_map_with_path( lambda p, x: _collect_product( session, path, name, - NodeInfo(parameter_name, p, x), # noqa: B023 + NodeInfo(parameter_name, p, partialed_evolve(x)), # noqa: B023 is_string_allowed=False, ), kwargs[parameter_name], @@ -456,9 +454,9 @@ def _find_args_with_product_annotation(func: Callable[..., Any]) -> list[str]: """ -def _collect_decorator_nodes( +def _collect_decorator_node( session: Session, path: Path, name: str, node_info: NodeInfo -) -> dict[str, MetaNode]: +) -> Node: """Collect nodes for a task. Raises @@ -495,9 +493,9 @@ def _collect_decorator_nodes( return collected_node -def _collect_dependencies( +def _collect_dependency( session: Session, path: Path, name: str, node_info: NodeInfo -) -> dict[str, MetaNode]: +) -> Node: """Collect nodes for a task. Raises @@ -525,7 +523,7 @@ def _collect_product( task_name: str, node_info: NodeInfo, is_string_allowed: bool = False, -) -> dict[str, MetaNode]: +) -> Node: """Collect products for a task. Defining products with strings is only allowed when using the decorator. Parameter @@ -546,7 +544,7 @@ def _collect_product( f"tuples, lists, and dictionaries. Here, {node} has type {type(node)}." ) # The parameter defaults only support Path objects. - if not isinstance(node, Path) and not is_string_allowed: + if not isinstance(node, (Path, PPathNode)) and not is_string_allowed: raise ValueError( "If you declare products with 'Annotated[..., Product]', only values of " "type 'pathlib.Path' optionally nested in tuples, lists, and " @@ -566,3 +564,12 @@ def _collect_product( ) return collected_node + + +def _evolve_instance(x: Any, instance_from_annot: Node | None) -> Any: + """Evolve a value to a node if it is given by annotations.""" + if not instance_from_annot: + return x + + instance_from_annot.value = x + return instance_from_annot diff --git a/src/_pytask/dag.py b/src/_pytask/dag.py index bee58b98..050aae21 100644 --- a/src/_pytask/dag.py +++ b/src/_pytask/dag.py @@ -21,8 +21,9 @@ from _pytask.mark import Mark from _pytask.mark_utils import get_marks from _pytask.mark_utils import has_mark -from _pytask.nodes import FilePathNode -from _pytask.nodes import MetaNode +from _pytask.node_protocols import MetaNode +from _pytask.node_protocols import Node +from _pytask.node_protocols import PPathNode from _pytask.nodes import Task from _pytask.path import find_common_ancestor_of_nodes from _pytask.report import DagReport @@ -140,13 +141,13 @@ def pytask_dag_has_node_changed(node: MetaNode, task_name: str) -> bool: if db_state is None: return True - if isinstance(node, (FilePathNode, Task)): + if isinstance(node, (PPathNode, Task)): # If the modification times match, the node has not been changed. if node_state == db_state.modification_time: return False # If the modification time changed, quickly return for non-tasks. - if isinstance(node, FilePathNode): + if not isinstance(node, Task): return True # When modification times changed, we are still comparing the hash of the file @@ -238,7 +239,7 @@ def _check_if_root_nodes_are_available(dag: nx.DiGraph) -> None: def _check_if_tasks_are_skipped( - node: MetaNode, dag: nx.DiGraph, is_task_skipped: dict[str, bool] + node: Node, dag: nx.DiGraph, is_task_skipped: dict[str, bool] ) -> tuple[bool, dict[str, bool]]: """Check for a given node whether it is only used by skipped tasks.""" are_all_tasks_skipped = [] diff --git a/src/_pytask/database_utils.py b/src/_pytask/database_utils.py index 05a09de1..19f5724b 100644 --- a/src/_pytask/database_utils.py +++ b/src/_pytask/database_utils.py @@ -4,7 +4,7 @@ import hashlib from _pytask.dag_utils import node_and_neighbors -from _pytask.nodes import FilePathNode +from _pytask.node_protocols import PPathNode from _pytask.nodes import Task from _pytask.session import Session from sqlalchemy import Column @@ -80,7 +80,7 @@ def update_states_in_database(session: Session, task_name: str) -> None: if isinstance(node, Task): modification_time = node.state() hash_ = hashlib.sha256(node.path.read_bytes()).hexdigest() - elif isinstance(node, FilePathNode): + elif isinstance(node, PPathNode): modification_time = node.state() hash_ = "" else: diff --git a/src/_pytask/execute.py b/src/_pytask/execute.py index bcc3fb41..a4d70311 100644 --- a/src/_pytask/execute.py +++ b/src/_pytask/execute.py @@ -21,7 +21,7 @@ from _pytask.exceptions import NodeNotFoundError from _pytask.mark import Mark from _pytask.mark_utils import has_mark -from _pytask.nodes import FilePathNode +from _pytask.node_protocols import PPathNode from _pytask.nodes import Task from _pytask.outcomes import count_outcomes from _pytask.outcomes import Exit @@ -129,7 +129,7 @@ def pytask_execute_task_setup(session: Session, task: Task) -> None: # method for the node classes. for product in session.dag.successors(task.name): node = session.dag.nodes[product]["node"] - if isinstance(node, FilePathNode): + if isinstance(node, PPathNode): node.path.parent.mkdir(parents=True, exist_ok=True) would_be_executed = has_mark(task, "would_be_executed") @@ -159,7 +159,7 @@ def pytask_execute_task(session: Session, task: Task) -> bool: @hookimpl def pytask_execute_task_teardown(session: Session, task: Task) -> None: - """Check if :class:`_pytask.nodes.FilePathNode` are produced by a task.""" + """Check if :class:`_pytask.nodes.PathNode` are produced by a task.""" missing_nodes = [] for product in session.dag.successors(task.name): node = session.dag.nodes[product]["node"] diff --git a/src/_pytask/hookspecs.py b/src/_pytask/hookspecs.py index 3eb56599..dfa83f92 100644 --- a/src/_pytask/hookspecs.py +++ b/src/_pytask/hookspecs.py @@ -14,11 +14,12 @@ import networkx import pluggy from _pytask.models import NodeInfo +from _pytask.node_protocols import MetaNode +from _pytask.node_protocols import Node if TYPE_CHECKING: from _pytask.session import Session - from _pytask.nodes import MetaNode from _pytask.nodes import Task from _pytask.outcomes import CollectionOutcome from _pytask.outcomes import TaskOutcome @@ -196,7 +197,7 @@ def pytask_collect_task_teardown(session: Session, task: Task) -> None: @hookspec(firstresult=True) def pytask_collect_node( session: Session, path: pathlib.Path, node_info: NodeInfo -) -> MetaNode | None: +) -> Node | None: """Collect a node which is a dependency or a product of a task.""" diff --git a/src/_pytask/node_protocols.py b/src/_pytask/node_protocols.py new file mode 100644 index 00000000..f3062cb4 --- /dev/null +++ b/src/_pytask/node_protocols.py @@ -0,0 +1,33 @@ +from __future__ import annotations + +from abc import abstractmethod +from pathlib import Path +from typing import Any +from typing import Protocol +from typing import runtime_checkable + + +@runtime_checkable +class MetaNode(Protocol): + """Protocol for an intersection between nodes and tasks.""" + + name: str | None + """The name of node that must be unique.""" + + @abstractmethod + def state(self) -> Any: + ... + + +@runtime_checkable +class Node(MetaNode, Protocol): + """Protocol for nodes.""" + + value: Any + + +@runtime_checkable +class PPathNode(Node, Protocol): + """Nodes with paths.""" + + path: Path diff --git a/src/_pytask/nodes.py b/src/_pytask/nodes.py index b41c906a..b8650310 100644 --- a/src/_pytask/nodes.py +++ b/src/_pytask/nodes.py @@ -3,13 +3,13 @@ import functools import hashlib -from abc import ABCMeta -from abc import abstractmethod from pathlib import Path from typing import Any from typing import Callable from typing import TYPE_CHECKING +from _pytask.node_protocols import MetaNode +from _pytask.node_protocols import Node from _pytask.tree_util import PyTree from attrs import define from attrs import field @@ -19,7 +19,7 @@ from _pytask.mark import Mark -__all__ = ["FilePathNode", "MetaNode", "Product", "Task"] +__all__ = ["PathNode", "Product", "Task"] @define(frozen=True) @@ -30,17 +30,6 @@ class ProductType: Product = ProductType() -class MetaNode(metaclass=ABCMeta): - """Meta class for nodes.""" - - name: str - """str: The name of node that must be unique.""" - - @abstractmethod - def state(self) -> Any: - ... - - @define(kw_only=True) class Task(MetaNode): """The class for tasks which are Python functions.""" @@ -55,9 +44,9 @@ class Task(MetaNode): """The name of the task.""" short_name: str | None = field(default=None, init=False) """The shortest uniquely identifiable name for task for display.""" - depends_on: PyTree[MetaNode] = field(factory=dict) + depends_on: PyTree[Node] = field(factory=dict) """A list of dependencies of task.""" - produces: PyTree[MetaNode] = field(factory=dict) + produces: PyTree[Node] = field(factory=dict) """A list of products of task.""" markers: list[Mark] = field(factory=list) """A list of markers attached to the task function.""" @@ -92,27 +81,41 @@ def add_report_section(self, when: str, key: str, content: str) -> None: @define(kw_only=True) -class FilePathNode(MetaNode): +class PathNode(Node): """The class for a node which is a path.""" name: str = "" """Name of the node which makes it identifiable in the DAG.""" - value: Path | None = None + _value: Path | None = None """Value passed to the decorator which can be requested inside the function.""" - path: Path | None = None - """Path to the FilePathNode.""" + + @property + def path(self) -> Path: + return self.value + + @property + def value(self) -> Path: + return self._value + + @value.setter + def value(self, value: Path) -> None: + if not isinstance(value, Path): + raise TypeError("'value' must be a 'pathlib.Path'.") + if not self.name: + self.name = value.as_posix() + self._value = value @classmethod @functools.lru_cache - def from_path(cls, path: Path) -> FilePathNode: + def from_path(cls, path: Path) -> PathNode: """Instantiate class from path to file. The `lru_cache` decorator ensures that the same object is not collected twice. """ if not path.is_absolute(): - raise ValueError("FilePathNode must be instantiated from absolute path.") - return cls(name=path.as_posix(), value=path, path=path) + raise ValueError("Node must be instantiated from absolute path.") + return cls(name=path.as_posix(), value=path) def state(self) -> str | None: """Calculate the state of the node. @@ -126,7 +129,7 @@ def state(self) -> str | None: @define(kw_only=True) -class PythonNode(MetaNode): +class PythonNode(Node): """The class for a node which is a Python object.""" name: str = "" @@ -153,4 +156,4 @@ def state(self) -> str | None: if isinstance(self.value, str): return str(hashlib.sha256(self.value.encode()).hexdigest()) return str(hash(self.value)) - return str(0) + return "0" diff --git a/src/_pytask/profile.py b/src/_pytask/profile.py index e0b29d3b..5c8a33e5 100644 --- a/src/_pytask/profile.py +++ b/src/_pytask/profile.py @@ -23,7 +23,7 @@ from _pytask.database_utils import DatabaseSession from _pytask.exceptions import CollectionError from _pytask.exceptions import ConfigurationError -from _pytask.nodes import FilePathNode +from _pytask.node_protocols import PPathNode from _pytask.nodes import Task from _pytask.outcomes import ExitCode from _pytask.outcomes import TaskOutcome @@ -228,7 +228,7 @@ def pytask_profile_add_info_on_task( sum_bytes = 0 for successor in successors: node = session.dag.nodes[successor]["node"] - if isinstance(node, FilePathNode): + if isinstance(node, PPathNode): with suppress(FileNotFoundError): sum_bytes += node.path.stat().st_size diff --git a/src/_pytask/report.py b/src/_pytask/report.py index 0bc0cf03..7df151fc 100644 --- a/src/_pytask/report.py +++ b/src/_pytask/report.py @@ -15,7 +15,7 @@ if TYPE_CHECKING: - from _pytask.nodes import MetaNode + from _pytask.node_protocols import MetaNode from _pytask.nodes import Task diff --git a/src/_pytask/shared.py b/src/_pytask/shared.py index 4747813d..6059bd1b 100644 --- a/src/_pytask/shared.py +++ b/src/_pytask/shared.py @@ -10,8 +10,8 @@ import click import networkx as nx from _pytask.console import format_task_id -from _pytask.nodes import FilePathNode -from _pytask.nodes import MetaNode +from _pytask.node_protocols import MetaNode +from _pytask.node_protocols import PPathNode from _pytask.nodes import Task from _pytask.path import find_closest_ancestor from _pytask.path import find_common_ancestor @@ -67,7 +67,7 @@ def reduce_node_name(node: MetaNode, paths: Sequence[str | Path]) -> str: path from one path in ``session.config["paths"]`` to the node. """ - if isinstance(node, (Task, FilePathNode)): + if isinstance(node, (PPathNode, Task)): ancestor = find_closest_ancestor(node.path, paths) if ancestor is None: try: @@ -75,10 +75,7 @@ def reduce_node_name(node: MetaNode, paths: Sequence[str | Path]) -> str: except ValueError: ancestor = node.path.parents[-1] - if isinstance(node, MetaNode): - name = relative_to(node.path, ancestor).as_posix() - else: - raise TypeError(f"Unknown node {node} with type {type(node)!r}.") + name = relative_to(node.path, ancestor).as_posix() return name return node.name diff --git a/src/pytask/__init__.py b/src/pytask/__init__.py index 1b5eca52..d626d4dd 100644 --- a/src/pytask/__init__.py +++ b/src/pytask/__init__.py @@ -36,8 +36,8 @@ from _pytask.mark_utils import set_marks from _pytask.models import CollectionMetadata from _pytask.models import NodeInfo -from _pytask.nodes import FilePathNode -from _pytask.nodes import MetaNode +from _pytask.node_protocols import MetaNode +from _pytask.nodes import PathNode from _pytask.nodes import Product from _pytask.nodes import PythonNode from _pytask.nodes import Task @@ -84,7 +84,7 @@ "ExecutionReport", "Exit", "ExitCode", - "FilePathNode", + "PathNode", "Mark", "MarkDecorator", "MarkGenerator", diff --git a/tests/test_collect.py b/tests/test_collect.py index db951b15..3e6c117b 100644 --- a/tests/test_collect.py +++ b/tests/test_collect.py @@ -422,3 +422,24 @@ def task_write_text(depends_on, produces): assert "FutureWarning" in result.output assert "Using strings to specify a dependency" in result.output assert "Using strings to specify a product" in result.output + + +@pytest.mark.end_to_end() +def test_setting_name_for_path_node_via_annotation(tmp_path): + source = """ + from pathlib import Path + from typing_extensions import Annotated + from pytask import Product, PathNode + from typing import Any + + def task_example( + path: Annotated[Path, Product, PathNode(name="product")] = Path("out.txt"), + ) -> None: + path.write_text("text") + """ + tmp_path.joinpath("task_module.py").write_text(textwrap.dedent(source)) + + session = main({"paths": [tmp_path]}) + assert session.exit_code == ExitCode.OK + product = session.tasks[0].produces["path"] + assert product.name == "product" diff --git a/tests/test_collect_command.py b/tests/test_collect_command.py index 6ab33011..4856db53 100644 --- a/tests/test_collect_command.py +++ b/tests/test_collect_command.py @@ -1,17 +1,17 @@ from __future__ import annotations import os +import pickle import textwrap from pathlib import Path import pytest from _pytask.collect_command import _find_common_ancestor_of_all_nodes from _pytask.collect_command import _print_collected_tasks -from _pytask.nodes import FilePathNode +from _pytask.nodes import PathNode from attrs import define from pytask import cli from pytask import ExitCode -from pytask import MetaNode from pytask import Task @@ -343,7 +343,7 @@ def task_example_2(): @define -class MetaNode(MetaNode): +class Node: path: Path def state(self): @@ -362,8 +362,8 @@ def test_print_collected_tasks_without_nodes(capsys): base_name="function", path=Path("task_path.py"), function=function, - depends_on={0: MetaNode("in.txt")}, - produces={0: MetaNode("out.txt")}, + depends_on={0: Node("in.txt")}, + produces={0: Node("out.txt")}, ) ] } @@ -386,15 +386,9 @@ def test_print_collected_tasks_with_nodes(capsys): path=Path("task_path.py"), function=function, depends_on={ - "depends_on": FilePathNode( - name="in.txt", value=Path("in.txt"), path=Path("in.txt") - ) - }, - produces={ - 0: FilePathNode( - name="out.txt", value=Path("out.txt"), path=Path("out.txt") - ) + "depends_on": PathNode(name="in.txt", value=Path("in.txt")) }, + produces={0: PathNode(name="out.txt", value=Path("out.txt"))}, ) ] } @@ -418,10 +412,10 @@ def test_find_common_ancestor_of_all_nodes(show_nodes, expected_add): path=Path.cwd() / "src" / "task_path.py", function=function, depends_on={ - "depends_on": FilePathNode.from_path(Path.cwd() / "src" / "in.txt") + "depends_on": PathNode.from_path(Path.cwd() / "src" / "in.txt") }, produces={ - 0: FilePathNode.from_path( + 0: PathNode.from_path( Path.cwd().joinpath("..", "bld", "out.txt").resolve() ) }, @@ -518,3 +512,92 @@ def task_example( assert "task_example>" in captured assert "" in result.output assert "Product" in captured + + +def test_node_protocol_for_custom_nodes(runner, tmp_path): + source = """ + from typing_extensions import Annotated + from pytask import Product + from attrs import define + from pathlib import Path + + @define + class CustomNode: + name: str + value: str + + def state(self): + return self.value + + + def task_example( + data = CustomNode("custom", "text"), + out: Annotated[Path, Product] = Path("out.txt"), + ) -> None: + out.write_text(data) + """ + tmp_path.joinpath("task_module.py").write_text(textwrap.dedent(source)) + + result = runner.invoke(cli, ["collect", "--nodes", tmp_path.as_posix()]) + assert result.exit_code == ExitCode.OK + assert "" in result.output + + +def test_node_protocol_for_custom_nodes_with_paths(runner, tmp_path): + source = """ + from typing_extensions import Annotated + from pytask import Product + from pathlib import Path + from attrs import define + import pickle + + @define + class PickleFile: + name: str + path: Path + + @property + def value(self): + with self.path.open("rb") as f: + out = pickle.load(f) + return out + + def state(self): + return str(self.path.stat().st_mtime) + + + _PATH = Path(__file__).parent.joinpath("in.pkl") + + def task_example( + data = PickleFile(_PATH.as_posix(), _PATH), + out: Annotated[Path, Product] = Path("out.txt"), + ) -> None: + out.write_text(data) + """ + tmp_path.joinpath("task_module.py").write_text(textwrap.dedent(source)) + tmp_path.joinpath("in.pkl").write_bytes(pickle.dumps("text")) + + result = runner.invoke(cli, ["collect", "--nodes", tmp_path.as_posix()]) + assert result.exit_code == ExitCode.OK + assert "in.pkl" in result.output + + +@pytest.mark.end_to_end() +def test_setting_name_for_python_node_via_annotation(runner, tmp_path): + source = """ + from pathlib import Path + from typing_extensions import Annotated + from pytask import Product, PythonNode + from typing import Any + + def task_example( + input: Annotated[str, PythonNode(name="node-name")] = "text", + path: Annotated[Path, Product] = Path("out.txt"), + ) -> None: + path.write_text(input) + """ + tmp_path.joinpath("task_module.py").write_text(textwrap.dedent(source)) + + result = runner.invoke(cli, ["collect", "--nodes", tmp_path.as_posix()]) + assert result.exit_code == ExitCode.OK + assert "node-name" in result.output diff --git a/tests/test_console.py b/tests/test_console.py index 92cb4fc9..bd0b5eca 100644 --- a/tests/test_console.py +++ b/tests/test_console.py @@ -127,14 +127,7 @@ def test_render_to_string(color_system, text, strip_styles, expected): None, Text( _THIS_FILE.as_posix() + "::task_a", - spans=[ - Span(0, len(_THIS_FILE.as_posix()) + 2, "dim"), - Span( - len(_THIS_FILE.as_posix()) + 2, - len(_THIS_FILE.as_posix()) + 2 + 6, - Style(), - ), - ], + spans=[Span(0, len(_THIS_FILE.as_posix()) + 2, "dim")], ), id="format full id", ), @@ -146,7 +139,7 @@ def test_render_to_string(color_system, text, strip_styles, expected): None, Text( "test_console.py::task_a", - spans=[Span(0, 17, "dim"), Span(17, 23, Style())], + spans=[Span(0, 17, "dim")], ), id="format short id", ), @@ -158,7 +151,7 @@ def test_render_to_string(color_system, text, strip_styles, expected): _THIS_FILE.parent, Text( "tests/test_console.py::task_a", - spans=[Span(0, 23, "dim"), Span(23, 29, Style())], + spans=[Span(0, 23, "dim")], ), id="format relative to id", ), diff --git a/tests/test_dag.py b/tests/test_dag.py index 017374eb..09259abc 100644 --- a/tests/test_dag.py +++ b/tests/test_dag.py @@ -3,7 +3,6 @@ import textwrap from contextlib import ExitStack as does_not_raise # noqa: N813 from pathlib import Path -from typing import Any import networkx as nx import pytest @@ -14,18 +13,14 @@ from attrs import define from pytask import cli from pytask import ExitCode -from pytask import FilePathNode +from pytask import PathNode from pytask import Task @define -class Node(FilePathNode): +class Node(PathNode): """See https://github.com/python-attrs/attrs/issues/293 for property hack.""" - name: str - value: Any - path: Path - def state(self): if "missing" in self.name: raise NodeNotFoundError diff --git a/tests/test_execute.py b/tests/test_execute.py index a86dc176..21b3909a 100644 --- a/tests/test_execute.py +++ b/tests/test_execute.py @@ -530,11 +530,11 @@ def test_error_with_multiple_different_dep_annotations(runner, tmp_path): source = """ from pathlib import Path from typing_extensions import Annotated - from pytask import Product, PythonNode, FilePathNode + from pytask import Product, PythonNode, PathNode from typing import Any def task_example( - dependency: Annotated[Any, PythonNode(), FilePathNode()] = "hello", + dependency: Annotated[Any, PythonNode(), PathNode()] = "hello", path: Annotated[Path, Product] = Path("out.txt") ) -> None: path.write_text(dependency) diff --git a/tests/test_node_protocols.py b/tests/test_node_protocols.py new file mode 100644 index 00000000..4cf1944c --- /dev/null +++ b/tests/test_node_protocols.py @@ -0,0 +1,75 @@ +from __future__ import annotations + +import pickle +import textwrap + +from pytask import cli +from pytask import ExitCode + + +def test_node_protocol_for_custom_nodes(runner, tmp_path): + source = """ + from typing_extensions import Annotated + from pytask import Product + from attrs import define + from pathlib import Path + + @define + class CustomNode: + name: str + value: str + + def state(self): + return self.value + + + def task_example( + data = CustomNode("custom", "text"), + out: Annotated[Path, Product] = Path("out.txt"), + ) -> None: + out.write_text(data) + """ + tmp_path.joinpath("task_module.py").write_text(textwrap.dedent(source)) + + result = runner.invoke(cli, [tmp_path.as_posix()]) + assert result.exit_code == ExitCode.OK + assert tmp_path.joinpath("out.txt").read_text() == "text" + + +def test_node_protocol_for_custom_nodes_with_paths(runner, tmp_path): + source = """ + from typing_extensions import Annotated + from pytask import Product + from pathlib import Path + from attrs import define + import pickle + + @define + class PickleFile: + name: str + path: Path + + @property + def value(self): + with self.path.open("rb") as f: + out = pickle.load(f) + return out + + def state(self): + return str(self.path.stat().st_mtime) + + + _PATH = Path(__file__).parent.joinpath("in.pkl") + + def task_example( + data = PickleFile(_PATH.as_posix(), _PATH), + out: Annotated[Path, Product] = Path("out.txt"), + ) -> None: + out.write_text(data) + """ + tmp_path.joinpath("task_module.py").write_text(textwrap.dedent(source)) + tmp_path.joinpath("in.pkl").write_bytes(pickle.dumps("text")) + + result = runner.invoke(cli, [tmp_path.as_posix()]) + assert result.exit_code == ExitCode.OK + assert tmp_path.joinpath("out.txt").read_text() == "text" diff --git a/tests/test_nodes.py b/tests/test_nodes.py index 9165a2f9..a58d7e50 100644 --- a/tests/test_nodes.py +++ b/tests/test_nodes.py @@ -5,7 +5,7 @@ import pytest from _pytask.shared import reduce_node_name -from pytask import FilePathNode +from pytask import PathNode _ROOT = Path.cwd() @@ -16,14 +16,14 @@ ("node", "paths", "expectation", "expected"), [ pytest.param( - FilePathNode.from_path(_ROOT.joinpath("src/module.py")), + PathNode.from_path(_ROOT.joinpath("src/module.py")), [_ROOT.joinpath("alternative_src")], does_not_raise(), "pytask/src/module.py", - id="Common path found for FilePathNode not in 'paths' and 'paths'", + id="Common path found for PathNode not in 'paths' and 'paths'", ), pytest.param( - FilePathNode.from_path(_ROOT.joinpath("top/src/module.py")), + PathNode.from_path(_ROOT.joinpath("top/src/module.py")), [_ROOT.joinpath("top/src")], does_not_raise(), "src/module.py",