/
dag_command.py
251 lines (199 loc) · 8.44 KB
/
dag_command.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
"""Contains the command and code for drawing the DAG."""
from __future__ import annotations
import enum
import sys
from pathlib import Path
from typing import Any
import click
import networkx as nx
from _pytask.click import ColoredCommand
from _pytask.click import EnumChoice
from _pytask.compat import check_for_optional_program
from _pytask.compat import import_optional_dependency
from _pytask.config_utils import find_project_root_and_config
from _pytask.config_utils import read_config
from _pytask.console import console
from _pytask.exceptions import CollectionError
from _pytask.exceptions import ConfigurationError
from _pytask.exceptions import ResolvingDependenciesError
from _pytask.outcomes import ExitCode
from _pytask.pluginmanager import get_plugin_manager
from _pytask.pluginmanager import hookimpl
from _pytask.pluginmanager import storage
from _pytask.session import Session
from _pytask.shared import parse_paths
from _pytask.shared import reduce_names_of_multiple_nodes
from _pytask.shared import to_list
from _pytask.traceback import remove_internal_traceback_frames_from_exc_info
from rich.text import Text
from rich.traceback import Traceback
class _RankDirection(enum.Enum):
TB = "TB"
LR = "LR"
BT = "BT"
RL = "RL"
@hookimpl(tryfirst=True)
def pytask_extend_command_line_interface(cli: click.Group) -> None:
"""Extend the command line interface."""
cli.add_command(dag)
_HELP_TEXT_LAYOUT: str = (
"The layout determines the structure of the graph. Here you find an overview of "
"all available layouts: https://graphviz.org/docs/layouts."
)
_HELP_TEXT_OUTPUT: str = (
"The output path of the visualization. The format is inferred from the file "
"extension."
)
_HELP_TEXT_RANK_DIRECTION: str = (
"The direction of the directed graph. It can be ordered from top to bottom, TB, "
"left to right, LR, bottom to top, BT, or right to left, RL."
)
@click.command(cls=ColoredCommand)
@click.option("-l", "--layout", type=str, default="dot", help=_HELP_TEXT_LAYOUT)
@click.option(
"-o",
"--output-path",
type=click.Path(file_okay=True, dir_okay=False, path_type=Path, resolve_path=True),
default="dag.pdf",
help=_HELP_TEXT_OUTPUT,
)
@click.option(
"-r",
"--rank-direction",
type=EnumChoice(_RankDirection),
help=_HELP_TEXT_RANK_DIRECTION,
default=_RankDirection.TB,
)
def dag(**raw_config: Any) -> int:
"""Create a visualization of the project's directed acyclic graph."""
try:
pm = storage.get()
config = pm.hook.pytask_configure(pm=pm, raw_config=raw_config)
session = Session.from_config(config)
except (ConfigurationError, Exception): # pragma: no cover
console.print_exception()
session = Session(exit_code=ExitCode.CONFIGURATION_FAILED)
else:
try:
session.hook.pytask_log_session_header(session=session)
import_optional_dependency("pygraphviz")
check_for_optional_program(
session.config["layout"],
extra="The layout program is part of the graphviz package which you "
"can install with conda.",
)
session.hook.pytask_collect(session=session)
session.hook.pytask_dag(session=session)
dag = _refine_dag(session)
_write_graph(dag, session.config["output_path"], session.config["layout"])
except CollectionError: # pragma: no cover
session.exit_code = ExitCode.COLLECTION_FAILED
except ResolvingDependenciesError: # pragma: no cover
session.exit_code = ExitCode.DAG_FAILED
except Exception: # noqa: BLE001
session.exit_code = ExitCode.FAILED
exc_info = remove_internal_traceback_frames_from_exc_info(sys.exc_info())
console.print()
console.print(Traceback.from_exception(*exc_info))
console.rule(style="failed")
session.hook.pytask_unconfigure(session=session)
sys.exit(session.exit_code)
def build_dag(raw_config: dict[str, Any]) -> nx.DiGraph:
"""Build the DAG.
This function is the programmatic interface to ``pytask dag`` and returns a
preprocessed :class:`pygraphviz.AGraph` which makes plotting easier than with
matplotlib.
To change the style of the graph, it might be easier to convert the graph back to
networkx, set attributes, and convert back to pygraphviz.
Parameters
----------
raw_config : Dict[str, Any]
The configuration usually received from the CLI. For example, use ``{"paths":
"example-directory/"}`` to collect tasks from a directory.
Returns
-------
pygraphviz.AGraph
A preprocessed graph which can be customized and exported.
"""
try:
pm = get_plugin_manager()
storage.store(pm)
# If someone called the programmatic interface, we need to do some parsing.
if "command" not in raw_config:
raw_config["command"] = "dag"
# Add defaults from cli.
from _pytask.cli import DEFAULTS_FROM_CLI
raw_config = {**DEFAULTS_FROM_CLI, **raw_config}
raw_config["paths"] = parse_paths(raw_config.get("paths"))
if raw_config["config"] is not None:
raw_config["config"] = Path(raw_config["config"]).resolve()
raw_config["root"] = raw_config["config"].parent
else:
if raw_config["paths"] is None:
raw_config["paths"] = (Path.cwd(),)
raw_config["paths"] = parse_paths(raw_config["paths"])
(
raw_config["root"],
raw_config["config"],
) = find_project_root_and_config(raw_config["paths"])
if raw_config["config"] is not None:
config_from_file = read_config(raw_config["config"])
if "paths" in config_from_file:
paths = config_from_file["paths"]
paths = [
raw_config["config"].parent.joinpath(path).resolve()
for path in to_list(paths)
]
config_from_file["paths"] = paths
raw_config = {**raw_config, **config_from_file}
config = pm.hook.pytask_configure(pm=pm, raw_config=raw_config)
session = Session.from_config(config)
except (ConfigurationError, Exception): # pragma: no cover
console.print_exception()
session = Session(exit_code=ExitCode.CONFIGURATION_FAILED)
else:
try:
session.hook.pytask_log_session_header(session=session)
import_optional_dependency("pygraphviz")
check_for_optional_program(
session.config["layout"],
extra="The layout program is part of the graphviz package that you "
"can install with conda.",
)
session.hook.pytask_collect(session=session)
session.hook.pytask_dag(session=session)
session.hook.pytask_unconfigure(session=session)
dag = _refine_dag(session)
except Exception:
raise
else:
return dag
def _refine_dag(session: Session) -> nx.DiGraph:
"""Refine the dag for plotting."""
dag = _shorten_node_labels(session.dag, session.config["paths"])
dag = _clean_dag(dag)
dag = _style_dag(dag)
dag.graph["graph"] = {"rankdir": session.config["rank_direction"].name}
return dag
def _shorten_node_labels(dag: nx.DiGraph, paths: list[Path]) -> nx.DiGraph:
"""Shorten the node labels in the graph for a better experience."""
node_names = dag.nodes
short_names = reduce_names_of_multiple_nodes(node_names, dag, paths)
short_names = [i.plain if isinstance(i, Text) else i for i in short_names] # type: ignore[attr-defined]
old_to_new = dict(zip(node_names, short_names))
return nx.relabel_nodes(dag, old_to_new)
def _clean_dag(dag: nx.DiGraph) -> nx.DiGraph:
"""Clean the DAG."""
for node in dag.nodes:
dag.nodes[node].clear()
return dag
def _style_dag(dag: nx.DiGraph) -> nx.DiGraph:
"""Style the DAG."""
shapes = {name: "hexagon" if "::task_" in name else "box" for name in dag.nodes}
nx.set_node_attributes(dag, shapes, "shape")
return dag
def _write_graph(dag: nx.DiGraph, path: Path, layout: str) -> None:
"""Write the graph to disk."""
path.parent.mkdir(exist_ok=True, parents=True)
graph = nx.nx_agraph.to_agraph(dag)
graph.draw(path, prog=layout)