Skip to content

Commit

Permalink
Merge 3f6d04d into 580f415
Browse files Browse the repository at this point in the history
  • Loading branch information
tobiasraabe committed Nov 14, 2023
2 parents 580f415 + 3f6d04d commit d266a68
Show file tree
Hide file tree
Showing 14 changed files with 135 additions and 173 deletions.
5 changes: 4 additions & 1 deletion docs/source/changes.md
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,11 @@ releases are available on [PyPI](https://pypi.org/project/pytask) and
when a product annotation is used with the argument name `produces`. And, allow
`produces` to intake any node.
- {pull}`490` refactors and better tests parsing of dependencies.
- {pull}`496` makes pytask even lazier. Now, when a task produces a node whose hash
remains the same, the consecutive tasks are not executed. It remained from when pytask
relied on timestamps.

## 0.4.2 - 2023-11-8
## 0.4.2 - 2023-11-08

- {pull}`449` simplifies the code building the plugin manager.
- {pull}`451` improves `collect_command.py` and renames `graph.py` to `dag_command.py`.
Expand Down
1 change: 0 additions & 1 deletion docs/source/reference_guides/hookspecs.md
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,6 @@ your plugin.
```{eval-rst}
.. autofunction:: pytask_dag
.. autofunction:: pytask_dag_create_dag
.. autofunction:: pytask_dag_select_execution_dag
.. autofunction:: pytask_dag_log
```
Expand Down
10 changes: 9 additions & 1 deletion src/_pytask/collect.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
from _pytask.console import is_jupyter
from _pytask.exceptions import CollectionError
from _pytask.mark import MarkGenerator
from _pytask.mark_utils import get_all_marks
from _pytask.mark_utils import has_mark
from _pytask.node_protocols import PNode
from _pytask.node_protocols import PPathNode
Expand Down Expand Up @@ -246,6 +247,13 @@ def pytask_collect_task(
"""
if (name.startswith("task_") or has_mark(obj, "task")) and is_task_function(obj):
if has_mark(obj, "try_first") and has_mark(obj, "try_last"):
msg = (

Check warning on line 251 in src/_pytask/collect.py

View check run for this annotation

Codecov / codecov/patch

src/_pytask/collect.py#L251

Added line #L251 was not covered by tests
"The task cannot have mixed priorities. Do not apply "
"'@pytask.mark.try_first' and '@pytask.mark.try_last' at the same time."
)
raise ValueError(msg)

Check warning on line 255 in src/_pytask/collect.py

View check run for this annotation

Codecov / codecov/patch

src/_pytask/collect.py#L255

Added line #L255 was not covered by tests

path_nodes = Path.cwd() if path is None else path.parent
dependencies = parse_dependencies_from_task_function(
session, path, name, path_nodes, obj
Expand All @@ -254,7 +262,7 @@ def pytask_collect_task(
session, path, name, path_nodes, obj
)

markers = obj.pytask_meta.markers if hasattr(obj, "pytask_meta") else []
markers = get_all_marks(obj)

# Get the underlying function to avoid having different states of the function,
# e.g. due to pytask_meta, in different layers of the wrapping.
Expand Down
61 changes: 0 additions & 61 deletions src/_pytask/dag.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,13 +14,7 @@
from _pytask.console import format_task_name
from _pytask.console import render_to_string
from _pytask.console import TASK_ICON
from _pytask.dag_utils import node_and_neighbors
from _pytask.dag_utils import task_and_descending_tasks
from _pytask.dag_utils import TopologicalSorter
from _pytask.database_utils import DatabaseSession
from _pytask.database_utils import State
from _pytask.exceptions import ResolvingDependenciesError
from _pytask.mark import Mark
from _pytask.node_protocols import PNode
from _pytask.node_protocols import PTask
from _pytask.nodes import PythonNode
Expand All @@ -31,7 +25,6 @@
from rich.tree import Tree

if TYPE_CHECKING:
from _pytask.node_protocols import MetaNode
from pathlib import Path
from _pytask.session import Session

Expand All @@ -44,7 +37,6 @@ def pytask_dag(session: Session) -> bool | None:
session=session, tasks=session.tasks
)
session.hook.pytask_dag_modify_dag(session=session, dag=session.dag)
session.hook.pytask_dag_select_execution_dag(session=session, dag=session.dag)

except Exception: # noqa: BLE001
report = DagReport.from_exception(sys.exc_info())
Expand Down Expand Up @@ -101,59 +93,6 @@ def _add_product(dag: nx.DiGraph, task: PTask, node: PNode) -> None:
return dag


@hookimpl
def pytask_dag_select_execution_dag(session: Session, dag: nx.DiGraph) -> None:
"""Select the tasks which need to be executed."""
scheduler = TopologicalSorter.from_dag(dag)
visited_nodes: set[str] = set()

while scheduler.is_active():
task_signature = scheduler.get_ready()[0]
if task_signature not in visited_nodes:
task = dag.nodes[task_signature]["task"]
have_changed = _have_task_or_neighbors_changed(session, dag, task)
if have_changed:
visited_nodes.update(task_and_descending_tasks(task_signature, dag))
else:
dag.nodes[task_signature]["task"].markers.append(
Mark("skip_unchanged", (), {})
)
scheduler.done(task_signature)


def _have_task_or_neighbors_changed(
session: Session, dag: nx.DiGraph, task: PTask
) -> bool:
"""Indicate whether dependencies or products of a task have changed."""
return any(
session.hook.pytask_dag_has_node_changed(
session=session,
dag=dag,
task=task,
node=dag.nodes[node_name].get("task") or dag.nodes[node_name].get("node"),
)
for node_name in node_and_neighbors(dag, task.signature)
)


@hookimpl(trylast=True)
def pytask_dag_has_node_changed(task: PTask, node: MetaNode) -> bool:
"""Indicate whether a single dependency or product has changed."""
# If node does not exist, we receive None.
node_state = node.state()
if node_state is None:
return True

with DatabaseSession() as session:
db_state = session.get(State, (task.signature, node.signature))

# If the node is not in the database.
if db_state is None:
return True

return node_state != db_state.hash_


def _check_if_dag_has_cycles(dag: nx.DiGraph) -> None:
"""Check if DAG has cycles."""
try:
Expand Down
27 changes: 4 additions & 23 deletions src/_pytask/dag_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,9 +7,6 @@
from typing import TYPE_CHECKING

import networkx as nx
from _pytask.console import format_strings_as_flat_tree
from _pytask.console import format_task_name
from _pytask.console import TASK_ICON
from _pytask.mark_utils import has_mark
from attrs import define
from attrs import field
Expand Down Expand Up @@ -54,8 +51,11 @@ def node_and_neighbors(dag: nx.DiGraph, node: str) -> Iterable[str]:
We cannot use ``dag.neighbors`` as it only considers successors as neighbors in a
DAG.
The task node needs to be yield in the middle so that first predecessors are checked
and then the rest of the nodes.
"""
return itertools.chain([node], dag.predecessors(node), dag.successors(node))
return itertools.chain(dag.predecessors(node), [node], dag.successors(node))


@define
Expand Down Expand Up @@ -166,25 +166,6 @@ def _extract_priorities_from_tasks(tasks: list[PTask]) -> dict[str, int]:
}
for task in tasks
}
tasks_w_mixed_priorities = [
name for name, p in priorities.items() if p["try_first"] and p["try_last"]
]

if tasks_w_mixed_priorities:
name_to_task = {task.signature: task for task in tasks}
reduced_names = []
for name in tasks_w_mixed_priorities:
reduced_name = format_task_name(name_to_task[name], "no_link")
reduced_names.append(reduced_name.plain)

text = format_strings_as_flat_tree(
reduced_names, "Tasks with mixed priorities", TASK_ICON
)
msg = (
f"'try_first' and 'try_last' cannot be applied on the same task. See the "
f"following tasks for errors:\n\n{text}"
)
raise ValueError(msg)

# Recode to numeric values for sorting.
numeric_mapping = {(True, False): 1, (False, False): 0, (False, True): -1}
Expand Down
19 changes: 19 additions & 0 deletions src/_pytask/database_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,8 @@
from sqlalchemy.orm import sessionmaker

if TYPE_CHECKING:
from _pytask.node_protocols import MetaNode
from _pytask.node_protocols import PTask
from _pytask.session import Session


Expand Down Expand Up @@ -62,3 +64,20 @@ def update_states_in_database(session: Session, task_signature: str) -> None:
node = session.dag.nodes[name].get("task") or session.dag.nodes[name]["node"]
hash_ = node.state()
_create_or_update_state(task_signature, node.signature, hash_)


def has_node_changed(task: PTask, node: MetaNode) -> bool:
"""Indicate whether a single dependency or product has changed."""
# If node does not exist, we receive None.
node_state = node.state()
if node_state is None:
return True

Check warning on line 74 in src/_pytask/database_utils.py

View check run for this annotation

Codecov / codecov/patch

src/_pytask/database_utils.py#L74

Added line #L74 was not covered by tests

with DatabaseSession() as session:
db_state = session.get(State, (task.signature, node.signature))

# If the node is not in the database.
if db_state is None:
return True

return node_state != db_state.hash_
49 changes: 33 additions & 16 deletions src/_pytask/execute.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,9 @@
from _pytask.console import format_strings_as_flat_tree
from _pytask.console import unify_styles
from _pytask.dag_utils import descending_tasks
from _pytask.dag_utils import node_and_neighbors
from _pytask.dag_utils import TopologicalSorter
from _pytask.database_utils import has_node_changed
from _pytask.database_utils import update_states_in_database
from _pytask.exceptions import ExecutionError
from _pytask.exceptions import NodeLoadError
Expand All @@ -28,6 +30,7 @@
from _pytask.node_protocols import PTask
from _pytask.outcomes import count_outcomes
from _pytask.outcomes import Exit
from _pytask.outcomes import SkippedUnchanged
from _pytask.outcomes import TaskOutcome
from _pytask.outcomes import WouldBeExecuted
from _pytask.reports import ExecutionReport
Expand Down Expand Up @@ -124,28 +127,42 @@ def pytask_execute_task_setup(session: Session, task: PTask) -> None:
2. Create the directory where the product will be placed.
"""
for dependency in session.dag.predecessors(task.signature):
node = session.dag.nodes[dependency]["node"]
if not node.state():
msg = f"{task.name!r} requires missing node {node.name!r}."
if IS_FILE_SYSTEM_CASE_SENSITIVE:
msg += (
"\n\n(Hint: Your file-system is case-sensitive. Check the paths' "
"capitalization carefully.)"
)
raise NodeNotFoundError(msg)
if has_mark(task, "would_be_executed"):
raise WouldBeExecuted

Check warning on line 131 in src/_pytask/execute.py

View check run for this annotation

Codecov / codecov/patch

src/_pytask/execute.py#L131

Added line #L131 was not covered by tests

dag = session.dag

needs_to_be_executed = session.config["force"]
if not needs_to_be_executed:
predecessors = set(dag.predecessors(task.signature)) | {task.signature}
for node_signature in node_and_neighbors(dag, task.signature):
node = dag.nodes[node_signature].get("task") or dag.nodes[
node_signature
].get("node")
if node_signature in predecessors and not node.state():
msg = f"{task.name!r} requires missing node {node.name!r}."
if IS_FILE_SYSTEM_CASE_SENSITIVE:
msg += (

Check warning on line 145 in src/_pytask/execute.py

View check run for this annotation

Codecov / codecov/patch

src/_pytask/execute.py#L143-L145

Added lines #L143 - L145 were not covered by tests
"\n\n(Hint: Your file-system is case-sensitive. Check the "
"paths' capitalization carefully.)"
)
raise NodeNotFoundError(msg)

Check warning on line 149 in src/_pytask/execute.py

View check run for this annotation

Codecov / codecov/patch

src/_pytask/execute.py#L149

Added line #L149 was not covered by tests

has_changed = has_node_changed(task=task, node=node)
if has_changed:
needs_to_be_executed = True
break

if not needs_to_be_executed:
raise SkippedUnchanged

# Create directory for product if it does not exist. Maybe this should be a `setup`
# method for the node classes.
for product in session.dag.successors(task.signature):
node = session.dag.nodes[product]["node"]
for product in dag.successors(task.signature):
node = dag.nodes[product]["node"]
if isinstance(node, PPathNode):
node.path.parent.mkdir(parents=True, exist_ok=True)

would_be_executed = has_mark(task, "would_be_executed")
if would_be_executed:
raise WouldBeExecuted


def _safe_load(node: PNode, task: PTask, is_product: bool) -> Any:
try:
Expand Down
23 changes: 0 additions & 23 deletions src/_pytask/hookspecs.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,6 @@


if TYPE_CHECKING:
from _pytask.node_protocols import MetaNode
from _pytask.models import NodeInfo
from _pytask.node_protocols import PNode
import click
Expand Down Expand Up @@ -245,28 +244,6 @@ def pytask_dag_modify_dag(session: Session, dag: nx.DiGraph) -> None:
"""


@hookspec
def pytask_dag_select_execution_dag(session: Session, dag: nx.DiGraph) -> None:
"""Select the subgraph which needs to be executed.
This hook determines which of the tasks have to be re-run because something has
changed.
"""


@hookspec(firstresult=True)
def pytask_dag_has_node_changed(
session: Session, dag: nx.DiGraph, task: PTask, node: MetaNode
) -> None:
"""Select the subgraph which needs to be executed.
This hook determines which of the tasks have to be re-run because something has
changed.
"""


@hookspec
def pytask_dag_log(session: Session, report: DagReport) -> None:
"""Log errors during resolving dependencies."""
Expand Down
3 changes: 2 additions & 1 deletion src/_pytask/mark/structures.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from typing import Iterable
from typing import Mapping

from _pytask.mark_utils import get_all_marks
from _pytask.models import CollectionMetadata
from _pytask.typing import is_task_function
from attrs import define
Expand Down Expand Up @@ -122,7 +123,7 @@ def __call__(self, *args: Any, **kwargs: Any) -> MarkDecorator:

def get_unpacked_marks(obj: Callable[..., Any]) -> list[Mark]:
"""Obtain the unpacked marks that are stored on an object."""
mark_list = obj.pytask_meta.markers if hasattr(obj, "pytask_meta") else []
mark_list = get_all_marks(obj)
return normalize_mark_list(mark_list)


Expand Down
12 changes: 11 additions & 1 deletion src/_pytask/persist.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@

from _pytask.config import hookimpl
from _pytask.dag_utils import node_and_neighbors
from _pytask.database_utils import has_node_changed
from _pytask.database_utils import update_states_in_database
from _pytask.mark_utils import has_mark
from _pytask.outcomes import Persisted
Expand Down Expand Up @@ -46,7 +47,16 @@ def pytask_execute_task_setup(session: Session, task: PTask) -> None:
)

if all_nodes_exist:
raise Persisted
any_node_changed = any(

Check warning on line 50 in src/_pytask/persist.py

View check run for this annotation

Codecov / codecov/patch

src/_pytask/persist.py#L50

Added line #L50 was not covered by tests
has_node_changed(
task=task,
node=session.dag.nodes[name].get("task")
or session.dag.nodes[name]["node"],
)
for name in node_and_neighbors(session.dag, task.signature)
)
if any_node_changed:
raise Persisted

Check warning on line 59 in src/_pytask/persist.py

View check run for this annotation

Codecov / codecov/patch

src/_pytask/persist.py#L58-L59

Added lines #L58 - L59 were not covered by tests


@hookimpl
Expand Down
18 changes: 18 additions & 0 deletions tests/test_collect.py
Original file line number Diff line number Diff line change
Expand Up @@ -661,3 +661,21 @@ def task_example() -> Annotated[int, 1]: ...
result = runner.invoke(cli, [tmp_path.as_posix()])
assert result.exit_code == ExitCode.COLLECTION_FAILED
assert "The return annotation of the task" in result.output


@pytest.mark.end_to_end()
def test_scheduling_w_mixed_priorities(runner, tmp_path):
source = """

Check warning on line 668 in tests/test_collect.py

View check run for this annotation

Codecov / codecov/patch

tests/test_collect.py#L668

Added line #L668 was not covered by tests
import pytask
@pytask.mark.try_last
@pytask.mark.try_first
def task_mixed(): pass
"""
tmp_path.joinpath("task_module.py").write_text(textwrap.dedent(source))

Check warning on line 675 in tests/test_collect.py

View check run for this annotation

Codecov / codecov/patch

tests/test_collect.py#L675

Added line #L675 was not covered by tests

result = runner.invoke(cli, [tmp_path.as_posix()])

Check warning on line 677 in tests/test_collect.py

View check run for this annotation

Codecov / codecov/patch

tests/test_collect.py#L677

Added line #L677 was not covered by tests

assert result.exit_code == ExitCode.COLLECTION_FAILED
assert "Could not collect" in result.output
assert "The task cannot have" in result.output

Check warning on line 681 in tests/test_collect.py

View check run for this annotation

Codecov / codecov/patch

tests/test_collect.py#L679-L681

Added lines #L679 - L681 were not covered by tests

0 comments on commit d266a68

Please sign in to comment.