Skip to content

Commit

Permalink
Add protocols for tasks. (#412)
Browse files Browse the repository at this point in the history
  • Loading branch information
tobiasraabe committed Sep 4, 2023
1 parent 4bd60bf commit 802ceca
Show file tree
Hide file tree
Showing 27 changed files with 224 additions and 230 deletions.
1 change: 1 addition & 0 deletions docs/source/changes.md
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ releases are available on [PyPI](https://pypi.org/project/pytask) and
- {pull}`408` removes `.value` from `Node` protocol.
- {pull}`409` make `.from_annot` an optional feature of nodes.
- {pull}`410` allows to pass functions to `PythonNode(hash=...)`.
- {pull}`412` adds protocols for tasks.

## 0.3.2 - 2023-06-07

Expand Down
16 changes: 9 additions & 7 deletions src/_pytask/capture.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@
from _pytask.click import EnumChoice
from _pytask.config import hookimpl
from _pytask.enums import ShowCapture
from _pytask.nodes import Task
from _pytask.node_protocols import PTask


class _CaptureMethod(enum.Enum):
Expand Down Expand Up @@ -706,7 +706,7 @@ def read(self) -> CaptureResult[str]:
# Helper context managers

@contextlib.contextmanager
def task_capture(self, when: str, task: Task) -> Generator[None, None, None]:
def task_capture(self, when: str, task: PTask) -> Generator[None, None, None]:
"""Pipe captured stdout and stderr into report sections."""
self.resume()

Expand All @@ -716,25 +716,27 @@ def task_capture(self, when: str, task: Task) -> Generator[None, None, None]:
self.suspend(in_=False)

out, err = self.read()
task.add_report_section(when, "stdout", out)
task.add_report_section(when, "stderr", err)
if out:
task.report_sections.append((when, "stdout", out))
if err:
task.report_sections.append((when, "stderr", err))

# Hooks

@hookimpl(hookwrapper=True)
def pytask_execute_task_setup(self, task: Task) -> Generator[None, None, None]:
def pytask_execute_task_setup(self, task: PTask) -> Generator[None, None, None]:
"""Capture output during setup."""
with self.task_capture("setup", task):
yield

@hookimpl(hookwrapper=True)
def pytask_execute_task(self, task: Task) -> Generator[None, None, None]:
def pytask_execute_task(self, task: PTask) -> Generator[None, None, None]:
"""Capture output during execution."""
with self.task_capture("call", task):
yield

@hookimpl(hookwrapper=True)
def pytask_execute_task_teardown(self, task: Task) -> Generator[None, None, None]:
def pytask_execute_task_teardown(self, task: PTask) -> Generator[None, None, None]:
"""Capture output during teardown."""
with self.task_capture("teardown", task):
yield
Expand Down
11 changes: 7 additions & 4 deletions src/_pytask/clean.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,9 @@
from _pytask.git import get_all_files
from _pytask.git import get_root
from _pytask.git import is_git_installed
from _pytask.nodes import Task
from _pytask.node_protocols import PPathNode
from _pytask.node_protocols import PTask
from _pytask.node_protocols import PTaskWithPath
from _pytask.outcomes import ExitCode
from _pytask.path import find_common_ancestor
from _pytask.path import relative_to
Expand Down Expand Up @@ -214,12 +216,13 @@ def _collect_all_paths_known_to_pytask(session: Session) -> set[Path]:
return known_paths


def _yield_paths_from_task(task: Task) -> Generator[Path, None, None]:
def _yield_paths_from_task(task: PTask) -> Generator[Path, None, None]:
"""Yield all paths attached to a task."""
yield task.path
if isinstance(task, PTaskWithPath):
yield task.path
for attribute in ("depends_on", "produces"):
for node in tree_leaves(getattr(task, attribute)):
if hasattr(node, "path") and isinstance(node.path, Path):
if isinstance(node, PPathNode):
yield node.path


Expand Down
21 changes: 11 additions & 10 deletions src/_pytask/collect.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,11 +17,12 @@
from _pytask.config import IS_FILE_SYSTEM_CASE_SENSITIVE
from _pytask.console import console
from _pytask.console import create_summary_panel
from _pytask.console import format_task_id
from _pytask.console import format_task_name
from _pytask.exceptions import CollectionError
from _pytask.mark_utils import has_mark
from _pytask.models import NodeInfo
from _pytask.node_protocols import Node
from _pytask.node_protocols import PTask
from _pytask.nodes import PathNode
from _pytask.nodes import PythonNode
from _pytask.nodes import Task
Expand Down Expand Up @@ -270,16 +271,16 @@ def _not_ignored_paths(


@hookimpl(trylast=True)
def pytask_collect_modify_tasks(tasks: list[Task]) -> None:
def pytask_collect_modify_tasks(tasks: list[PTask]) -> None:
"""Given all tasks, assign a short uniquely identifiable name to each task."""
id_to_short_id = _find_shortest_uniquely_identifiable_name_for_tasks(tasks)
for task in tasks:
short_id = id_to_short_id[task.name]
task.short_name = short_id
if task.name in id_to_short_id and isinstance(task, Task):
task.display_name = id_to_short_id[task.name]


def _find_shortest_uniquely_identifiable_name_for_tasks(
tasks: list[Task],
tasks: list[PTask],
) -> dict[str, str]:
"""Find the shortest uniquely identifiable name for tasks.
Expand All @@ -291,7 +292,7 @@ def _find_shortest_uniquely_identifiable_name_for_tasks(
id_to_short_id = {}

# Make attempt to add up to twenty parts of the path to ensure uniqueness.
id_to_task = {task.name: task for task in tasks}
id_to_task = {task.name: task for task in tasks if isinstance(task, Task)}
for n_parts in range(1, 20):
dupl_id_to_short_id = {
id_: "/".join(task.path.parts[-n_parts:]) + "::" + task.base_name
Expand All @@ -313,7 +314,7 @@ def _find_shortest_uniquely_identifiable_name_for_tasks(

@hookimpl
def pytask_collect_log(
session: Session, reports: list[CollectionReport], tasks: list[Task]
session: Session, reports: list[CollectionReport], tasks: list[PTask]
) -> None:
"""Log collection."""
session.collection_end = time.time()
Expand All @@ -334,9 +335,9 @@ def pytask_collect_log(
if report.node is None:
header = "Error"
else:
if isinstance(report.node, Task):
short_name = format_task_id(
report.node, editor_url_scheme="no_link", short_name=True
if isinstance(report.node, PTask):
short_name = format_task_name(
report.node, editor_url_scheme="no_link"
)
else:
short_name = reduce_node_name(report.node, session.config["paths"])
Expand Down
29 changes: 14 additions & 15 deletions src/_pytask/collect_command.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
from __future__ import annotations

import sys
from collections import defaultdict
from pathlib import Path
from typing import Any
from typing import TYPE_CHECKING
Expand All @@ -12,7 +13,7 @@
from _pytask.console import console
from _pytask.console import create_url_style_for_path
from _pytask.console import FILE_ICON
from _pytask.console import format_task_id
from _pytask.console import format_task_name
from _pytask.console import PYTHON_ICON
from _pytask.console import TASK_ICON
from _pytask.exceptions import CollectionError
Expand All @@ -21,6 +22,8 @@
from _pytask.mark import select_by_keyword
from _pytask.mark import select_by_mark
from _pytask.node_protocols import PPathNode
from _pytask.node_protocols import PTask
from _pytask.node_protocols import PTaskWithPath
from _pytask.outcomes import ExitCode
from _pytask.path import find_common_ancestor
from _pytask.path import relative_to
Expand All @@ -33,7 +36,6 @@

if TYPE_CHECKING:
from typing import NoReturn
from _pytask.nodes import Task


@hookimpl(tryfirst=True)
Expand Down Expand Up @@ -76,11 +78,12 @@ def collect(**raw_config: Any | None) -> NoReturn:
session.hook.pytask_dag(session=session)

tasks = _select_tasks_by_expressions_and_marker(session)
task_with_path = [t for t in tasks if isinstance(t, PTaskWithPath)]

common_ancestor = _find_common_ancestor_of_all_nodes(
tasks, session.config["paths"], session.config["nodes"]
task_with_path, session.config["paths"], session.config["nodes"]
)
dictionary = _organize_tasks(tasks)
dictionary = _organize_tasks(task_with_path)
if dictionary:
_print_collected_tasks(
dictionary,
Expand All @@ -106,7 +109,7 @@ def collect(**raw_config: Any | None) -> NoReturn:
sys.exit(session.exit_code)


def _select_tasks_by_expressions_and_marker(session: Session) -> list[Task]:
def _select_tasks_by_expressions_and_marker(session: Session) -> list[PTask]:
"""Select tasks by expressions and marker."""
all_tasks = {task.name for task in session.tasks}
remaining_by_mark = select_by_mark(session, session.dag) or all_tasks
Expand All @@ -117,7 +120,7 @@ def _select_tasks_by_expressions_and_marker(session: Session) -> list[Task]:


def _find_common_ancestor_of_all_nodes(
tasks: list[Task], paths: list[Path], show_nodes: bool
tasks: list[PTaskWithPath], paths: list[Path], show_nodes: bool
) -> Path:
"""Find common ancestor from all nodes and passed paths."""
all_paths = []
Expand All @@ -136,16 +139,15 @@ def _find_common_ancestor_of_all_nodes(
return common_ancestor


def _organize_tasks(tasks: list[Task]) -> dict[Path, list[Task]]:
def _organize_tasks(tasks: list[PTaskWithPath]) -> dict[Path, list[PTaskWithPath]]:
"""Organize tasks in a dictionary.
The dictionary has file names as keys and then a dictionary with task names and
below a dictionary with dependencies and targets.
"""
dictionary: dict[Path, list[Task]] = {}
dictionary: dict[Path, list[PTaskWithPath]] = defaultdict(list)
for task in tasks:
dictionary[task.path] = dictionary.get(task.path, [])
dictionary[task.path].append(task)

sorted_dict = {}
Expand All @@ -156,7 +158,7 @@ def _organize_tasks(tasks: list[Task]) -> dict[Path, list[Task]]:


def _print_collected_tasks(
dictionary: dict[Path, list[Task]],
dictionary: dict[Path, list[PTaskWithPath]],
show_nodes: bool,
editor_url_scheme: str,
common_ancestor: Path,
Expand Down Expand Up @@ -191,11 +193,8 @@ def _print_collected_tasks(
)

for task in tasks:
reduced_task_name = format_task_id(
task,
editor_url_scheme=editor_url_scheme,
short_name=True,
relative_to=common_ancestor,
reduced_task_name = format_task_name(
task, editor_url_scheme=editor_url_scheme
)
task_branch = module_branch.add(
Text.assemble(TASK_ICON, "<Function ", reduced_task_name, ">"),
Expand Down
34 changes: 13 additions & 21 deletions src/_pytask/console.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,8 @@
from typing import TYPE_CHECKING

import rich
from _pytask.path import relative_to as relative_to_
from _pytask.node_protocols import PTask
from _pytask.nodes import Task
from rich.console import Console
from rich.padding import Padding
from rich.panel import Panel
Expand All @@ -26,7 +27,6 @@


if TYPE_CHECKING:
from _pytask.nodes import Task
from _pytask.outcomes import CollectionOutcome
from _pytask.outcomes import TaskOutcome

Expand All @@ -36,7 +36,7 @@
"create_url_style_for_task",
"create_url_style_for_path",
"console",
"format_task_id",
"format_task_name",
"format_strings_as_flat_tree",
"render_to_string",
"unify_styles",
Expand Down Expand Up @@ -143,37 +143,29 @@ def render_to_string(
return rendered


def format_task_id(
task: Task,
editor_url_scheme: str,
short_name: bool = False,
relative_to: Path | None = None,
) -> Text:
def format_task_name(task: PTask, editor_url_scheme: str) -> Text:
"""Format a task id."""
if short_name:
path, task_name = task.short_name.split("::")
elif relative_to:
path = relative_to_(task.path, relative_to).as_posix()
task_name = task.base_name
else:
path, task_name = task.name.split("::")

if task.function is None:
url_style = Style()
else:
url_style = create_url_style_for_task(task.function, editor_url_scheme)

task_id = Text.assemble(
Text(path + "::", style="dim"), Text(task_name, style=url_style)
)
if isinstance(task, Task):
path, task_name = task.display_name.split("::")
task_id = Text.assemble(
Text(path + "::", style="dim"), Text(task_name, style=url_style)
)
else:
name = getattr(task, "display_name", task.name)
task_id = Text(name, style=url_style)
return task_id


def format_strings_as_flat_tree(strings: Iterable[str], title: str, icon: str) -> str:
"""Format list of strings as flat tree."""
tree = Tree(title)
for name in strings:
tree.add(icon + name)
tree.add(Text.assemble(icon, name))
text = render_to_string(tree, console=console)
return text

Expand Down
14 changes: 8 additions & 6 deletions src/_pytask/dag.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
"""This module contains code related to resolving dependencies."""
from __future__ import annotations

import hashlib
import itertools
import sys

Expand All @@ -24,7 +25,8 @@
from _pytask.node_protocols import MetaNode
from _pytask.node_protocols import Node
from _pytask.node_protocols import PPathNode
from _pytask.nodes import Task
from _pytask.node_protocols import PTask
from _pytask.node_protocols import PTaskWithPath
from _pytask.path import find_common_ancestor_of_nodes
from _pytask.report import DagReport
from _pytask.session import Session
Expand Down Expand Up @@ -66,7 +68,7 @@ def pytask_dag(session: Session) -> bool | None:


@hookimpl
def pytask_dag_create_dag(tasks: list[Task]) -> nx.DiGraph:
def pytask_dag_create_dag(tasks: list[PTask]) -> nx.DiGraph:
"""Create the DAG from tasks, dependencies and products."""
dag = nx.DiGraph()

Expand Down Expand Up @@ -112,7 +114,7 @@ def pytask_dag_validate_dag(dag: nx.DiGraph) -> None:


def _have_task_or_neighbors_changed(
session: Session, dag: nx.DiGraph, task: Task
session: Session, dag: nx.DiGraph, task: PTask
) -> bool:
"""Indicate whether dependencies or products of a task have changed."""
return any(
Expand Down Expand Up @@ -141,18 +143,18 @@ def pytask_dag_has_node_changed(node: MetaNode, task_name: str) -> bool:
if db_state is None:
return True

if isinstance(node, (PPathNode, Task)):
if isinstance(node, (PPathNode, PTaskWithPath)):
# If the modification times match, the node has not been changed.
if node_state == db_state.modification_time:
return False

# If the modification time changed, quickly return for non-tasks.
if not isinstance(node, Task):
if not isinstance(node, PTaskWithPath):
return True

# When modification times changed, we are still comparing the hash of the file
# to avoid unnecessary and expensive reexecutions of tasks.
hash_ = node.state(hash=True)
hash_ = hashlib.sha256(node.path.read_bytes()).hexdigest()
return hash_ != db_state.hash_

return node_state != db_state.hash_
Expand Down
Loading

0 comments on commit 802ceca

Please sign in to comment.