diff --git a/dvc/cli.py b/dvc/cli.py index f9333aaf0b..b919762e41 100644 --- a/dvc/cli.py +++ b/dvc/cli.py @@ -823,6 +823,10 @@ def parse_args(argv=None): action='store_true', default=False, help='Output DAG as ASCII.') + pipeline_show_parser.add_argument( + '--dot', + help='Write DAG in .dot format.' + ) pipeline_show_parser.add_argument( 'targets', nargs='*', diff --git a/dvc/command/pipeline.py b/dvc/command/pipeline.py index 280b8be940..ac682b2828 100644 --- a/dvc/command/pipeline.py +++ b/dvc/command/pipeline.py @@ -24,7 +24,7 @@ def _show(self, target, commands, outs): else: self.project.logger.info(n) - def _show_ascii(self, target, commands, outs): + def __build_graph(self, target, commands, outs): import networkx from dvc.stage import Stage @@ -51,9 +51,6 @@ def _show_ascii(self, target, commands, outs): else: nodes.append(stage.relpath) - if len(nodes) == 0: - return - edges = [] for e in G.edges(): from_stage = stages[e[0]] @@ -70,9 +67,28 @@ def _show_ascii(self, target, commands, outs): else: edges.append((from_stage.relpath, to_stage.relpath)) + return nodes, edges + + def _show_ascii(self, target, commands, outs): + nodes, edges = self.__build_graph(target, commands, outs) + + if not nodes: + return + d = Dagascii(nodes, edges) d.draw() + def __write_dot(self, target, commands, outs, filename): + import networkx + from networkx.drawing.nx_pydot import write_dot + + _, edges = self.__build_graph(target, commands, outs) + edges = [edge[::-1] for edge in edges] + + simple_g = networkx.DiGraph() + simple_g.add_edges_from(edges) + write_dot(simple_g, filename) + def run(self, unlock=False): for target in self.args.targets: try: @@ -80,6 +96,11 @@ def run(self, unlock=False): self._show_ascii(target, self.args.commands, self.args.outs) + elif self.args.dot: + self.__write_dot(target, + self.args.commands, + self.args.outs, + self.args.dot) else: self._show(target, self.args.commands, diff --git a/requirements.txt b/requirements.txt index 638b4f3634..59a8f335ab 100644 --- a/requirements.txt +++ b/requirements.txt @@ -23,3 +23,4 @@ wheel>=0.31.1 futures>=3.2.0; python_version == "2.7" grandalf==0.6 asciicanvas==0.0.3 +pydot>=1.2.4 diff --git a/tests/test_pipeline.py b/tests/test_pipeline.py index 8a6646b031..ec1d829c8c 100644 --- a/tests/test_pipeline.py +++ b/tests/test_pipeline.py @@ -2,12 +2,14 @@ from tests.basic_env import TestDvc from tests.test_repro import TestRepro, TestReproChangedDeepData +import os class TestPipelineShowSingle(TestDvc): def setUp(self): super(TestPipelineShowSingle, self).setUp() self.stage = 'foo.dvc' + self.dotFile = 'graph.dot' ret = main(['add', self.FOO]) self.assertEqual(ret, 0) @@ -27,6 +29,11 @@ def test_ascii(self): ret = main(['pipeline', 'show', '--ascii', self.stage]) self.assertEqual(ret, 0) + def test_dot(self): + ret = main(['pipeline', 'show', '--dot', self.dotFile, self.stage]) + self.assertEqual(ret, 0) + self.assertTrue(os.path.isfile(self.dotFile)) + def test_ascii_commands(self): ret = main(['pipeline', 'show', '--ascii', self.stage, '--commands']) self.assertEqual(ret, 0) @@ -35,6 +42,16 @@ def test_ascii_outs(self): ret = main(['pipeline', 'show', '--ascii', self.stage, '--outs']) self.assertEqual(ret, 0) + def test_dot_commands(self): + ret = main(['pipeline', 'show', '--dot', self.dotFile, self.stage, '--commands']) + self.assertEqual(ret, 0) + self.assertTrue(os.path.isfile(self.dotFile)) + + def test_dot_outs(self): + ret = main(['pipeline', 'show', '--dot', self.dotFile, self.stage, '--outs']) + self.assertEqual(ret, 0) + self.assertTrue(os.path.isfile(self.dotFile)) + def test_not_dvc_file(self): ret = main(['pipeline', 'show', self.FOO]) self.assertNotEqual(ret, 0) @@ -45,6 +62,10 @@ def test_non_existing(self): class TestPipelineShow(TestRepro): + def setUp(self): + super(TestPipelineShow, self).setUp() + self.dotFile = 'graph.dot' + def test(self): ret = main(['pipeline', 'show', self.file1_stage]) self.assertEqual(ret, 0) @@ -61,6 +82,11 @@ def test_ascii(self): ret = main(['pipeline', 'show', '--ascii', self.file1_stage]) self.assertEqual(ret, 0) + def test_dot(self): + ret = main(['pipeline', 'show', '--dot', self.dotFile, self.file1_stage]) + self.assertEqual(ret, 0) + self.assertTrue(os.path.isfile(self.dotFile)) + def test_ascii_commands(self): ret = main(['pipeline', 'show', '--ascii', self.file1_stage, '--commands']) self.assertEqual(ret, 0) @@ -69,6 +95,16 @@ def test_ascii_outs(self): ret = main(['pipeline', 'show', '--ascii', self.file1_stage, '--outs']) self.assertEqual(ret, 0) + def test_dot_commands(self): + ret = main(['pipeline', 'show', '--dot', self.dotFile, self.file1_stage, '--commands']) + self.assertEqual(ret, 0) + self.assertTrue(os.path.isfile(self.dotFile)) + + def test_dot_outs(self): + ret = main(['pipeline', 'show', '--dot', self.dotFile, self.file1_stage, '--outs']) + self.assertEqual(ret, 0) + self.assertTrue(os.path.isfile(self.dotFile)) + def test_not_dvc_file(self): ret = main(['pipeline', 'show', self.file1]) self.assertNotEqual(ret, 0) @@ -79,6 +115,10 @@ def test_non_existing(self): class TestPipelineShowDeep(TestReproChangedDeepData): + def setUp(self): + super(TestPipelineShowDeep, self).setUp() + self.dotFile = 'graph.dot' + def test(self): ret = main(['pipeline', 'show', self.file1_stage]) self.assertEqual(ret, 0) @@ -95,6 +135,11 @@ def test_ascii(self): ret = main(['pipeline', 'show', '--ascii', self.file1_stage]) self.assertEqual(ret, 0) + def test_dot(self): + ret = main(['pipeline', 'show', '--dot', self.dotFile, self.file1_stage]) + self.assertEqual(ret, 0) + self.assertTrue(os.path.isfile(self.dotFile)) + def test_ascii_commands(self): ret = main(['pipeline', 'show', '--ascii', self.file1_stage, '--commands']) self.assertEqual(ret, 0) @@ -103,6 +148,16 @@ def test_ascii_outs(self): ret = main(['pipeline', 'show', '--ascii', self.file1_stage, '--outs']) self.assertEqual(ret, 0) + def test_dot_commands(self): + ret = main(['pipeline', 'show', '--dot', self.dotFile, self.file1_stage, '--commands']) + self.assertEqual(ret, 0) + self.assertTrue(os.path.isfile(self.dotFile)) + + def test_dot_outs(self): + ret = main(['pipeline', 'show', '--dot', self.dotFile, self.file1_stage, '--outs']) + self.assertEqual(ret, 0) + self.assertTrue(os.path.isfile(self.dotFile)) + def test_not_dvc_file(self): ret = main(['pipeline', 'show', self.file1]) self.assertNotEqual(ret, 0)