Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
65 changes: 61 additions & 4 deletions dvc/commands/plots.py
Original file line number Diff line number Diff line change
Expand Up @@ -140,6 +140,37 @@ def run(self):
return 0


class CmdPlotsTemplates(CmdBase):
TEMPLATES_CHOICES = [
"simple",
"linear",
"confusion",
"confusion_normalized",
"scatter",
"smooth",
]

def run(self):
import os

try:
out = (
os.path.join(os.getcwd(), self.args.out)
if self.args.out
else self.repo.plots.templates.templates_dir
)

targets = [self.args.target] if self.args.target else None
self.repo.plots.templates.init(output=out, targets=targets)
templates_path = os.path.relpath(out, os.getcwd())
ui.write(f"Templates have been written into '{templates_path}'.")

return 0
except DvcException:
logger.exception("")
return 1


def add_parser(subparsers, parent_parser):
PLOTS_HELP = (
"Commands to visualize and compare plot metrics in structured files "
Expand Down Expand Up @@ -176,7 +207,8 @@ def add_parser(subparsers, parent_parser):
"Shows all plots by default.",
).complete = completion.FILE
_add_props_arguments(plots_show_parser)
_add_output_arguments(plots_show_parser)
_add_output_argument(plots_show_parser)
_add_ui_arguments(plots_show_parser)
plots_show_parser.set_defaults(func=CmdPlotsShow)

PLOTS_DIFF_HELP = (
Expand Down Expand Up @@ -211,7 +243,8 @@ def add_parser(subparsers, parent_parser):
"revisions", nargs="*", default=None, help="Git commits to plot from"
)
_add_props_arguments(plots_diff_parser)
_add_output_arguments(plots_diff_parser)
_add_output_argument(plots_diff_parser)
_add_ui_arguments(plots_diff_parser)
plots_diff_parser.set_defaults(func=CmdPlotsDiff)

PLOTS_MODIFY_HELP = (
Expand All @@ -237,6 +270,27 @@ def add_parser(subparsers, parent_parser):
)
plots_modify_parser.set_defaults(func=CmdPlotsModify)

TEMPLATES_HELP = (
"Write built-in plots templates to a directory "
"(.dvc/plots by default)."
)
plots_templates_parser = plots_subparsers.add_parser(
"templates",
parents=[parent_parser],
description=append_doc_link(TEMPLATES_HELP, "plots/templates"),
help=TEMPLATES_HELP,
formatter_class=argparse.RawDescriptionHelpFormatter,
)
plots_templates_parser.add_argument(
"target",
default=None,
nargs="?",
choices=CmdPlotsTemplates.TEMPLATES_CHOICES,
help="Template to write. Writes all templates by default.",
)
_add_output_argument(plots_templates_parser, typ="templates")
plots_templates_parser.set_defaults(func=CmdPlotsTemplates)


def _add_props_arguments(parser):
parser.add_argument(
Expand Down Expand Up @@ -276,14 +330,17 @@ def _add_props_arguments(parser):
)


def _add_output_arguments(parser):
def _add_output_argument(parser, typ="plots"):
parser.add_argument(
"-o",
"--out",
default=None,
help="Destination path to save plots to",
help=f"Directory to save {typ} to.",
metavar="<path>",
).complete = completion.DIR


def _add_ui_arguments(parser):
parser.add_argument(
"--show-vega",
action="store_true",
Expand Down
3 changes: 0 additions & 3 deletions dvc/repo/init.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,13 +75,10 @@ def init(root_dir=os.curdir, no_scm=False, force=False, subdir=False):

proj = Repo(root_dir)

proj.plots.templates.init()

with proj.scm_context(autostage=True) as context:
files = [
config.files["repo"],
dvcignore,
proj.plots.templates.templates_dir,
]
ignore_file = context.scm.ignore_file
if ignore_file:
Expand Down
50 changes: 41 additions & 9 deletions dvc/repo/plots/template.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import json
import os
from typing import Any, Dict, Optional
from typing import Any, Dict, List, Optional

from funcy import cached_property

Expand All @@ -19,6 +19,15 @@ def __init__(self, field_name):
)


class TemplateContentDoesNotMatch(DvcException):
def __init__(self, template_name: str, path: str):
super().__init__(
f"Template '{path}' already exists "
f"and its content is different than '{template_name}' content. "
"Remove it manually if you want to recreate it."
)


class Template:
INDENT = 4
SEPARATORS = (",", ": ")
Expand Down Expand Up @@ -533,17 +542,40 @@ def get_template(self, template_name: str):
def __init__(self, dvc_dir):
self.dvc_dir = dvc_dir

def init(self):
def init(
self, output: Optional[str] = None, targets: Optional[List] = None
):
from dvc.utils.fs import makedirs

makedirs(self.templates_dir, exist_ok=True)
for t in self.TEMPLATES:
self._dump(t())
output = output or self.templates_dir

def _dump(self, template):
path = os.path.join(self.templates_dir, template.filename)
with open(path, "w", encoding="utf-8") as fd:
fd.write(template.content)
makedirs(output, exist_ok=True)

if targets:
templates = [
template
for template in self.TEMPLATES
if template.DEFAULT_NAME in targets
]
else:
templates = self.TEMPLATES

for template in templates:
self._dump(template(), output)

def _dump(self, template: Template, output: str):
path = os.path.join(output, template.filename)

if os.path.exists(path):
with open(path, "r", encoding="utf-8") as fd:
content = fd.read()
if content != template.content:
raise TemplateContentDoesNotMatch(
template.DEFAULT_NAME or "", path
)
else:
with open(path, "w", encoding="utf-8") as fd:
fd.write(template.content)

def load(self, template_name=None):
if not template_name:
Expand Down
38 changes: 37 additions & 1 deletion tests/unit/command/test_plots.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,10 @@
from pathlib import Path

import pytest
from funcy import pluck_attr

from dvc.cli import parse_args
from dvc.commands.plots import CmdPlotsDiff, CmdPlotsShow
from dvc.commands.plots import CmdPlotsDiff, CmdPlotsShow, CmdPlotsTemplates


@pytest.fixture
Expand Down Expand Up @@ -342,3 +343,38 @@ def test_show_json_requires_out(dvc, mocker, capsys):
)
cmd = cli_args.func(cli_args)
assert cmd.run() == 0


@pytest.mark.parametrize("target", (("t1"), (None)))
def test_plots_templates(tmp_dir, dvc, mocker, capsys, target):
assert not os.path.exists(dvc.plots.templates.templates_dir)
mocker.patch(
"dvc.commands.plots.CmdPlotsTemplates.TEMPLATES_CHOICES",
["t1", "t2"],
)

arguments = ["plots", "templates", "--out", "output"]
if target:
arguments += [target]

cli_args = parse_args(arguments)
assert cli_args.func == CmdPlotsTemplates

init_mock = mocker.patch("dvc.repo.plots.template.PlotTemplates.init")
cmd = cli_args.func(cli_args)

assert cmd.run() == 0
out, _ = capsys.readouterr()

init_mock.assert_called_once_with(
output=os.path.abspath("output"), targets=[target] if target else None
)
assert "Templates have been written into 'output'." in out


def test_plots_templates_choices(tmp_dir, dvc):
from dvc.repo.plots.template import PlotTemplates

assert CmdPlotsTemplates.TEMPLATES_CHOICES == list(
pluck_attr("DEFAULT_NAME", PlotTemplates.TEMPLATES)
)
Comment on lines +375 to +380
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Had to add the names of the templates manually to avoid importing plots on every cmd use. This test has been created to make sure plots module and templates command are in sync.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah, nice idea!

40 changes: 34 additions & 6 deletions tests/unit/repo/plots/test_templates.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,13 @@

import pytest

from dvc.repo.plots.template import TemplateNotFoundError
from dvc.repo.plots.template import (
LinearTemplate,
PlotTemplates,
ScatterTemplate,
TemplateContentDoesNotMatch,
TemplateNotFoundError,
)


def test_raise_on_no_template(tmp_dir, dvc):
Expand Down Expand Up @@ -35,11 +41,33 @@ def test_load_template(tmp_dir, dvc, template_path, target_name):


def test_load_default_template(tmp_dir, dvc):
assert dvc.plots.templates.load(None).content == LinearTemplate().content


@pytest.mark.parametrize("output", ("output", None))
@pytest.mark.parametrize(
"targets,expected_templates",
(
([None, PlotTemplates.TEMPLATES]),
(["linear", "scatter"], [ScatterTemplate, LinearTemplate]),
),
)
def test_init(tmp_dir, dvc, output, targets, expected_templates):
output = output or dvc.plots.templates.templates_dir
dvc.plots.templates.init(output, targets)

assert set(os.listdir(output)) == {
cls.DEFAULT_NAME + ".json" for cls in expected_templates
}


def test_raise_on_init_modified(tmp_dir, dvc):
dvc.plots.templates.init(output=None, targets=["linear"])

with open(
os.path.join(dvc.plots.templates.templates_dir, "linear.json"),
"r",
encoding="utf-8",
tmp_dir / ".dvc" / "plots" / "linear.json", "a", encoding="utf-8"
) as fd:
content = fd.read()
fd.write("modification")

assert dvc.plots.templates.load(None).content == content
with pytest.raises(TemplateContentDoesNotMatch):
dvc.plots.templates.init(output=None, targets=["linear"])