diff --git a/dvc/cli.py b/dvc/cli.py index 17d43a974b..6a8f857faa 100644 --- a/dvc/cli.py +++ b/dvc/cli.py @@ -35,6 +35,7 @@ update, version, git_hook, + plot, ) from .command.base import fix_subparsers from .exceptions import DvcParserError @@ -74,6 +75,7 @@ version, update, git_hook, + plot, ] diff --git a/dvc/command/plot.py b/dvc/command/plot.py new file mode 100644 index 0000000000..d1de4a340c --- /dev/null +++ b/dvc/command/plot.py @@ -0,0 +1,43 @@ +import argparse +import logging +import os + +from dvc.command.base import append_doc_link, CmdBase +from dvc.utils import format_link + +logger = logging.getLogger(__name__) + + +class CmdPlot(CmdBase): + def run(self): + path = self.repo.plot(self.args.target, template=self.args.template,) + logger.info( + "Your can see your plot by opening {} in your " + "browser!".format( + format_link( + "file://{}".format(os.path.join(self.repo.root_dir, path)) + ) + ) + ) + return 0 + + +def add_parser(subparsers, parent_parser): + PLOT_HELP = "Visualize target metric file using {}.".format( + format_link("https://vega.github.io") + ) + + plot_parser = subparsers.add_parser( + "plot", + parents=[parent_parser], + description=append_doc_link(PLOT_HELP, "plot"), + help=PLOT_HELP, + formatter_class=argparse.RawDescriptionHelpFormatter, + ) + plot_parser.add_argument( + "--template", nargs="?", help="Template file to choose." + ) + plot_parser.add_argument( + "target", nargs="?", help="Metric files to visualize." + ) + plot_parser.set_defaults(func=CmdPlot) diff --git a/dvc/plot.py b/dvc/plot.py new file mode 100644 index 0000000000..f6103d2492 --- /dev/null +++ b/dvc/plot.py @@ -0,0 +1,116 @@ +import json +import logging +import os + +from funcy import cached_property + +from dvc.exceptions import DvcException +from dvc.utils.fs import makedirs + + +logger = logging.getLogger(__name__) + + +class Template: + INDENT = 4 + SEPARATORS = (",", ": ") + + def __init__(self, templates_dir): + self.plot_templates_dir = templates_dir + + def dump(self): + import json + + makedirs(self.plot_templates_dir, exist_ok=True) + + if not os.path.exists(self.plot_templates_dir): + makedirs(self.plot_templates_dir) + + with open( + os.path.join(self.plot_templates_dir, self.TEMPLATE_NAME), "w+" + ) as fd: + json.dump( + self.DEFAULT_CONTENT, + fd, + indent=self.INDENT, + separators=self.SEPARATORS, + ) + + def load_template(self, path): + try: + with open(path, "r") as fd: + return json.load(fd) + except FileNotFoundError: + try: + with open( + os.path.join(self.plot_templates_dir, path), "r" + ) as fd: + return json.load(fd) + except FileNotFoundError: + raise DvcException("Not in repo nor in defaults") + + def fill(self, template_path, data, data_src=""): + assert isinstance(data, list) + assert all({"x", "y", "revision"} == set(d.keys()) for d in data) + + update_dict = {"data": {"values": data}, "title": data_src} + + vega_spec = self.load_template(template_path) + vega_spec.update(update_dict) + return vega_spec + + +class DefaultLinearTemplate(Template): + TEMPLATE_NAME = "default.json" + + DEFAULT_CONTENT = { + "$schema": "https://vega.github.io/schema/vega-lite/v4.json", + "data": {"values": []}, + "mark": {"type": "line"}, + "encoding": { + "x": {"field": "x", "type": "quantitative"}, + "y": {"field": "y", "type": "quantitative"}, + "color": {"field": "revision", "type": "nominal"}, + }, + } + + +class DefaultConfusionTemplate(Template): + TEMPLATE_NAME = "cf.json" + DEFAULT_CONTENT = { + "$schema": "https://vega.github.io/schema/vega-lite/v4.json", + "data": {"values": []}, + "mark": "rect", + "encoding": { + "x": { + "field": "x", + "type": "nominal", + "sort": "ascending", + "title": "Predicted value", + }, + "y": { + "field": "y", + "type": "nominal", + "sort": "ascending", + "title": "Actual value", + }, + "color": {"aggregate": "count", "type": "quantitative"}, + }, + } + + +class PlotTemplates: + TEMPLATES_DIR = "plot" + TEMPLATES = [DefaultLinearTemplate, DefaultConfusionTemplate] + + @cached_property + def templates_dir(self): + return os.path.join(self.dvc_dir, self.TEMPLATES_DIR) + + def __init__(self, dvc_dir): + self.dvc_dir = dvc_dir + + if not os.path.exists(self.templates_dir): + makedirs(self.templates_dir, exist_ok=True) + for t in self.TEMPLATES: + t(self.templates_dir).dump() diff --git a/dvc/repo/__init__.py b/dvc/repo/__init__.py index e7f2b4dfd3..30395fbc70 100644 --- a/dvc/repo/__init__.py +++ b/dvc/repo/__init__.py @@ -15,6 +15,7 @@ OutputNotFoundError, ) from dvc.path_info import PathInfo +from dvc.plot import PlotTemplates from dvc.remote.base import RemoteActionNotImplemented from dvc.utils.fs import path_isin from .graph import check_acyclic, get_pipeline, get_pipelines @@ -59,6 +60,7 @@ class Repo(object): from dvc.repo.get import get from dvc.repo.get_url import get_url from dvc.repo.update import update + from dvc.repo.plot import plot def __init__(self, root_dir=None): from dvc.state import State @@ -106,6 +108,8 @@ def __init__(self, root_dir=None): self._ignore() + self.plot_templates = PlotTemplates(self.dvc_dir) + @property def tree(self): return self._tree diff --git a/dvc/repo/init.py b/dvc/repo/init.py index 3238bb8e94..dda339bbd7 100644 --- a/dvc/repo/init.py +++ b/dvc/repo/init.py @@ -102,7 +102,7 @@ def init(root_dir=os.curdir, no_scm=False, force=False, subdir=False): proj = Repo(root_dir) - scm.add([config.files["repo"]]) + scm.add([config.files["repo"], proj.plot_templates.templates_dir]) if scm.ignore_file: scm.add([os.path.join(dvc_dir, scm.ignore_file)]) diff --git a/dvc/repo/plot.py b/dvc/repo/plot.py new file mode 100644 index 0000000000..4f9fa6e38d --- /dev/null +++ b/dvc/repo/plot.py @@ -0,0 +1,100 @@ +import json +import logging +import random +import re +import string + +from dvc.exceptions import DvcException +from dvc.plot import Template +from dvc.repo import locked + +logger = logging.getLogger(__name__) + +PAGE_HTML = """ + + dvc plot + + + + + + {divs} + +""" + +DIV_HTML = """
+""" + + +def _save_plot_html(divs, path): + page = PAGE_HTML.format(divs="\n".join(divs)) + with open(path, "w") as fobj: + fobj.write(page) + + +def _prepare_div(vega_dict): + id = "".join(random.sample(string.ascii_lowercase, 8)) + return DIV_HTML.format( + id=str(id), + vega_json=json.dumps(vega_dict, indent=4, separators=(",", ": ")), + ) + + +def _load_data(tree, target, revision="current workspace"): + with tree.open(target, "r") as fobj: + data = json.load(fobj) + for d in data: + d["revision"] = revision + return data + + +def _parse_plots(path): + with open(path, "r") as fobj: + content = fobj.read() + + plot_regex = re.compile("") + + plots = list(plot_regex.findall(content)) + return False, plots + + +def _parse_plot_str(plot_str): + content = plot_str.replace("<", "") + content = content.replace(">", "") + args = content.split("::")[1:] + if len(args) == 2: + return args + elif len(args) == 1: + return args[0], "default.json" + raise DvcException("Error parsing") + + +def to_div(repo, plot_str): + datafile, templatefile = _parse_plot_str(plot_str) + + data = _load_data(repo.tree, datafile) + vega_plot_json = Template(repo.plot_templates.templates_dir).fill( + templatefile, data, datafile + ) + return _prepare_div(vega_plot_json) + + +@locked +def plot(repo, template_file, revisions=None): + if revisions is None: + revisions = [] + + is_html, plot_strings = _parse_plots(template_file) + m = {plot_str: to_div(repo, plot_str) for plot_str in plot_strings} + + result = template_file.replace(".dvct", ".html") + if not is_html: + _save_plot_html( + [m[p] for p in plot_strings], result, + ) + return result + else: + raise NotImplementedError diff --git a/setup.py b/setup.py index 3b57acc4cd..82d382bf04 100644 --- a/setup.py +++ b/setup.py @@ -130,6 +130,7 @@ def run(self): "mock-ssh-server>=0.6.0", "moto==1.3.14.dev464", "rangehttpserver==1.2.0", + "beautifulsoup4==4.4.0", ] if (sys.version_info) >= (3, 6): diff --git a/tests/func/test_plot.py b/tests/func/test_plot.py new file mode 100644 index 0000000000..8ab02ecac8 --- /dev/null +++ b/tests/func/test_plot.py @@ -0,0 +1,93 @@ +import json + +from bs4 import BeautifulSoup +from funcy import first + + +def _run_with_metric(tmp_dir, dvc, metric, metric_filename, commit=None): + tmp_dir.gen({metric_filename: json.dumps(metric)}) + dvc.run(metrics_no_cache=[metric_filename]) + if hasattr(dvc, "scm"): + dvc.scm.add([metric_filename, metric_filename + ".dvc"]) + if commit: + dvc.scm.commit(commit) + + +# TODO +def test_plot_in_html_file(tmp_dir): + pass + + +def test_plot_in_no_html(tmp_dir, scm, dvc): + metric = [{"x": 1, "y": 2}, {"x": 2, "y": 3}] + _run_with_metric(tmp_dir, dvc, metric, "metric.json", "first run") + + template_content = "" + (tmp_dir / "template.dvct").write_text(template_content) + + result = dvc.plot("template.dvct") + + page_content = BeautifulSoup((tmp_dir / result).read_text()) + assert json.dumps( + { + "$schema": "https://vega.github.io/schema/vega-lite/v4.json", + "data": { + "values": [ + {"x": 1, "y": 2, "revision": "current workspace"}, + {"x": 2, "y": 3, "revision": "current workspace"}, + ] + }, + "mark": {"type": "line"}, + "encoding": { + "x": {"field": "x", "type": "quantitative"}, + "y": {"field": "y", "type": "quantitative"}, + "color": {"field": "revision", "type": "nominal"}, + }, + "title": "metric.json", + }, + indent=4, + separators=(",", ": "), + ) in first(page_content.body.script.contents) + + +def test_plot_confusion(tmp_dir, dvc): + confusion_matrix = [{"x": "B", "y": "A"}, {"x": "A", "y": "A"}] + _run_with_metric( + tmp_dir, dvc, confusion_matrix, "metric.json", "first run" + ) + template_content = "" + (tmp_dir / "template.dvct").write_text(template_content) + + result = dvc.plot("template.dvct") + + page_content = BeautifulSoup((tmp_dir / result).read_text()) + assert json.dumps( + { + "$schema": "https://vega.github.io/schema/vega-lite/v4.json", + "data": { + "values": [ + {"x": "B", "y": "A", "revision": "current workspace"}, + {"x": "A", "y": "A", "revision": "current workspace"}, + ] + }, + "mark": "rect", + "encoding": { + "x": { + "field": "x", + "type": "nominal", + "sort": "ascending", + "title": "Predicted value", + }, + "y": { + "field": "y", + "type": "nominal", + "sort": "ascending", + "title": "Actual value", + }, + "color": {"aggregate": "count", "type": "quantitative"}, + }, + "title": "metric.json", + }, + indent=4, + separators=(",", ": "), + ) in first(page_content.body.script.contents)