From ed074f9c83979a6447ba1e42f61ca30f0833f907 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Pawe=C5=82=20Redzy=C5=84ski?= Date: Tue, 7 Dec 2021 11:30:18 +0100 Subject: [PATCH] plots: make templates generation explicit Fixes: #6389 --- dvc/commands/plots.py | 65 +++++++++++++++++++++++-- dvc/repo/init.py | 3 -- dvc/repo/plots/template.py | 50 +++++++++++++++---- tests/unit/command/test_plots.py | 38 ++++++++++++++- tests/unit/repo/plots/test_templates.py | 40 ++++++++++++--- 5 files changed, 173 insertions(+), 23 deletions(-) diff --git a/dvc/commands/plots.py b/dvc/commands/plots.py index d73dd31f21..13aa9f5c33 100644 --- a/dvc/commands/plots.py +++ b/dvc/commands/plots.py @@ -140,6 +140,37 @@ def run(self): return 0 +class CmdPlotsTemplates(CmdBase): + TEMPLATES_CHOICES = [ + "simple", + "linear", + "confusion", + "confusion_normalized", + "scatter", + "smooth", + ] + + def run(self): + import os + + try: + out = ( + os.path.join(os.getcwd(), self.args.out) + if self.args.out + else self.repo.plots.templates.templates_dir + ) + + targets = [self.args.target] if self.args.target else None + self.repo.plots.templates.init(output=out, targets=targets) + templates_path = os.path.relpath(out, os.getcwd()) + ui.write(f"Templates have been written into '{templates_path}'.") + + return 0 + except DvcException: + logger.exception("") + return 1 + + def add_parser(subparsers, parent_parser): PLOTS_HELP = ( "Commands to visualize and compare plot metrics in structured files " @@ -176,7 +207,8 @@ def add_parser(subparsers, parent_parser): "Shows all plots by default.", ).complete = completion.FILE _add_props_arguments(plots_show_parser) - _add_output_arguments(plots_show_parser) + _add_output_argument(plots_show_parser) + _add_ui_arguments(plots_show_parser) plots_show_parser.set_defaults(func=CmdPlotsShow) PLOTS_DIFF_HELP = ( @@ -211,7 +243,8 @@ def add_parser(subparsers, parent_parser): "revisions", nargs="*", default=None, help="Git commits to plot from" ) _add_props_arguments(plots_diff_parser) - _add_output_arguments(plots_diff_parser) + _add_output_argument(plots_diff_parser) + _add_ui_arguments(plots_diff_parser) plots_diff_parser.set_defaults(func=CmdPlotsDiff) PLOTS_MODIFY_HELP = ( @@ -237,6 +270,27 @@ def add_parser(subparsers, parent_parser): ) plots_modify_parser.set_defaults(func=CmdPlotsModify) + TEMPLATES_HELP = ( + "Write built-in plots templates to a directory " + "(.dvc/plots by default)." + ) + plots_templates_parser = plots_subparsers.add_parser( + "templates", + parents=[parent_parser], + description=append_doc_link(TEMPLATES_HELP, "plots/templates"), + help=TEMPLATES_HELP, + formatter_class=argparse.RawDescriptionHelpFormatter, + ) + plots_templates_parser.add_argument( + "target", + default=None, + nargs="?", + choices=CmdPlotsTemplates.TEMPLATES_CHOICES, + help="Template to write. Writes all templates by default.", + ) + _add_output_argument(plots_templates_parser, typ="templates") + plots_templates_parser.set_defaults(func=CmdPlotsTemplates) + def _add_props_arguments(parser): parser.add_argument( @@ -276,14 +330,17 @@ def _add_props_arguments(parser): ) -def _add_output_arguments(parser): +def _add_output_argument(parser, typ="plots"): parser.add_argument( "-o", "--out", default=None, - help="Destination path to save plots to", + help=f"Directory to save {typ} to.", metavar="", ).complete = completion.DIR + + +def _add_ui_arguments(parser): parser.add_argument( "--show-vega", action="store_true", diff --git a/dvc/repo/init.py b/dvc/repo/init.py index 1bed86a460..cb7bc624ab 100644 --- a/dvc/repo/init.py +++ b/dvc/repo/init.py @@ -75,13 +75,10 @@ def init(root_dir=os.curdir, no_scm=False, force=False, subdir=False): proj = Repo(root_dir) - proj.plots.templates.init() - with proj.scm_context(autostage=True) as context: files = [ config.files["repo"], dvcignore, - proj.plots.templates.templates_dir, ] ignore_file = context.scm.ignore_file if ignore_file: diff --git a/dvc/repo/plots/template.py b/dvc/repo/plots/template.py index 69b94e43f1..e8b8509685 100644 --- a/dvc/repo/plots/template.py +++ b/dvc/repo/plots/template.py @@ -1,6 +1,6 @@ import json import os -from typing import Any, Dict, Optional +from typing import Any, Dict, List, Optional from funcy import cached_property @@ -19,6 +19,15 @@ def __init__(self, field_name): ) +class TemplateContentDoesNotMatch(DvcException): + def __init__(self, template_name: str, path: str): + super().__init__( + f"Template '{path}' already exists " + f"and its content is different than '{template_name}' content. " + "Remove it manually if you want to recreate it." + ) + + class Template: INDENT = 4 SEPARATORS = (",", ": ") @@ -533,17 +542,40 @@ def get_template(self, template_name: str): def __init__(self, dvc_dir): self.dvc_dir = dvc_dir - def init(self): + def init( + self, output: Optional[str] = None, targets: Optional[List] = None + ): from dvc.utils.fs import makedirs - makedirs(self.templates_dir, exist_ok=True) - for t in self.TEMPLATES: - self._dump(t()) + output = output or self.templates_dir - def _dump(self, template): - path = os.path.join(self.templates_dir, template.filename) - with open(path, "w", encoding="utf-8") as fd: - fd.write(template.content) + makedirs(output, exist_ok=True) + + if targets: + templates = [ + template + for template in self.TEMPLATES + if template.DEFAULT_NAME in targets + ] + else: + templates = self.TEMPLATES + + for template in templates: + self._dump(template(), output) + + def _dump(self, template: Template, output: str): + path = os.path.join(output, template.filename) + + if os.path.exists(path): + with open(path, "r", encoding="utf-8") as fd: + content = fd.read() + if content != template.content: + raise TemplateContentDoesNotMatch( + template.DEFAULT_NAME or "", path + ) + else: + with open(path, "w", encoding="utf-8") as fd: + fd.write(template.content) def load(self, template_name=None): if not template_name: diff --git a/tests/unit/command/test_plots.py b/tests/unit/command/test_plots.py index a9454329b4..6b855b184f 100644 --- a/tests/unit/command/test_plots.py +++ b/tests/unit/command/test_plots.py @@ -4,9 +4,10 @@ from pathlib import Path import pytest +from funcy import pluck_attr from dvc.cli import parse_args -from dvc.commands.plots import CmdPlotsDiff, CmdPlotsShow +from dvc.commands.plots import CmdPlotsDiff, CmdPlotsShow, CmdPlotsTemplates @pytest.fixture @@ -342,3 +343,38 @@ def test_show_json_requires_out(dvc, mocker, capsys): ) cmd = cli_args.func(cli_args) assert cmd.run() == 0 + + +@pytest.mark.parametrize("target", (("t1"), (None))) +def test_plots_templates(tmp_dir, dvc, mocker, capsys, target): + assert not os.path.exists(dvc.plots.templates.templates_dir) + mocker.patch( + "dvc.commands.plots.CmdPlotsTemplates.TEMPLATES_CHOICES", + ["t1", "t2"], + ) + + arguments = ["plots", "templates", "--out", "output"] + if target: + arguments += [target] + + cli_args = parse_args(arguments) + assert cli_args.func == CmdPlotsTemplates + + init_mock = mocker.patch("dvc.repo.plots.template.PlotTemplates.init") + cmd = cli_args.func(cli_args) + + assert cmd.run() == 0 + out, _ = capsys.readouterr() + + init_mock.assert_called_once_with( + output=os.path.abspath("output"), targets=[target] if target else None + ) + assert "Templates have been written into 'output'." in out + + +def test_plots_templates_choices(tmp_dir, dvc): + from dvc.repo.plots.template import PlotTemplates + + assert CmdPlotsTemplates.TEMPLATES_CHOICES == list( + pluck_attr("DEFAULT_NAME", PlotTemplates.TEMPLATES) + ) diff --git a/tests/unit/repo/plots/test_templates.py b/tests/unit/repo/plots/test_templates.py index 6a54dfa688..7bc9ffba40 100644 --- a/tests/unit/repo/plots/test_templates.py +++ b/tests/unit/repo/plots/test_templates.py @@ -2,7 +2,13 @@ import pytest -from dvc.repo.plots.template import TemplateNotFoundError +from dvc.repo.plots.template import ( + LinearTemplate, + PlotTemplates, + ScatterTemplate, + TemplateContentDoesNotMatch, + TemplateNotFoundError, +) def test_raise_on_no_template(tmp_dir, dvc): @@ -35,11 +41,33 @@ def test_load_template(tmp_dir, dvc, template_path, target_name): def test_load_default_template(tmp_dir, dvc): + assert dvc.plots.templates.load(None).content == LinearTemplate().content + + +@pytest.mark.parametrize("output", ("output", None)) +@pytest.mark.parametrize( + "targets,expected_templates", + ( + ([None, PlotTemplates.TEMPLATES]), + (["linear", "scatter"], [ScatterTemplate, LinearTemplate]), + ), +) +def test_init(tmp_dir, dvc, output, targets, expected_templates): + output = output or dvc.plots.templates.templates_dir + dvc.plots.templates.init(output, targets) + + assert set(os.listdir(output)) == { + cls.DEFAULT_NAME + ".json" for cls in expected_templates + } + + +def test_raise_on_init_modified(tmp_dir, dvc): + dvc.plots.templates.init(output=None, targets=["linear"]) + with open( - os.path.join(dvc.plots.templates.templates_dir, "linear.json"), - "r", - encoding="utf-8", + tmp_dir / ".dvc" / "plots" / "linear.json", "a", encoding="utf-8" ) as fd: - content = fd.read() + fd.write("modification") - assert dvc.plots.templates.load(None).content == content + with pytest.raises(TemplateContentDoesNotMatch): + dvc.plots.templates.init(output=None, targets=["linear"])