diff --git a/dvc/cli.py b/dvc/cli.py
index 0e4a22a25c..98521646a1 100644
--- a/dvc/cli.py
+++ b/dvc/cli.py
@@ -35,6 +35,7 @@
update,
version,
git_hook,
+ plot,
)
from .command.base import fix_subparsers
from .exceptions import DvcParserError
@@ -74,6 +75,7 @@
version,
update,
git_hook,
+ plot,
]
diff --git a/dvc/command/plot.py b/dvc/command/plot.py
new file mode 100644
index 0000000000..b507e73063
--- /dev/null
+++ b/dvc/command/plot.py
@@ -0,0 +1,243 @@
+import argparse
+import logging
+import os
+
+from dvc.command.base import append_doc_link, CmdBase, fix_subparsers
+from dvc.exceptions import DvcException
+from dvc.repo.plot.data import WORKSPACE_REVISION_NAME
+
+logger = logging.getLogger(__name__)
+
+
+class CmdPLot(CmdBase):
+ def _revisions(self):
+ raise NotImplementedError
+
+ def _result_file(self):
+ if self.args.file:
+ return self.args.file
+
+ extension = self._result_extension()
+ base = self._result_basename()
+
+ result_file = base + extension
+ return result_file
+
+ def _result_basename(self):
+ if self.args.datafile:
+ return self.args.datafile
+ return "plot"
+
+ def _result_extension(self):
+ if not self.args.no_html:
+ return ".html"
+ elif self.args.template:
+ return os.path.splitext(self.args.template)[-1]
+ return ".json"
+
+ def run(self):
+ fields = None
+ jsonpath = None
+ if self.args.select:
+ if self.args.select.startswith("$"):
+ jsonpath = self.args.select
+ else:
+ fields = set(self.args.select.split(","))
+ try:
+ plot_string = self.repo.plot(
+ datafile=self.args.datafile,
+ template=self.args.template,
+ revisions=self._revisions(),
+ fields=fields,
+ x_field=self.args.x,
+ y_field=self.args.y,
+ path=jsonpath,
+ embed=not self.args.no_html,
+ csv_header=not self.args.no_csv_header,
+ title=self.args.title,
+ x_title=self.args.xlab,
+ y_title=self.args.ylab,
+ )
+
+ if self.args.stdout:
+ logger.info(plot_string)
+ else:
+ result_path = self._result_file()
+ with open(result_path, "w") as fobj:
+ fobj.write(plot_string)
+
+ logger.info(
+ "file://{}".format(
+ os.path.join(self.repo.root_dir, result_path)
+ )
+ )
+
+ except DvcException:
+ logger.exception("")
+ return 1
+
+ return 0
+
+
+class CmdPlotShow(CmdPLot):
+ def _revisions(self):
+ return None
+
+
+class CmdPlotDiff(CmdPLot):
+ def _revisions(self):
+ revisions = self.args.revisions or []
+ if len(revisions) <= 1:
+ if len(revisions) == 0 and self.repo.scm.is_dirty():
+ revisions.append("HEAD")
+ revisions.append(WORKSPACE_REVISION_NAME)
+ return revisions
+
+
+def add_parser(subparsers, parent_parser):
+ PLOT_HELP = (
+ "Generating plots for continuous metrics stored in structured files "
+ "(JSON, CSV, TSV)."
+ )
+
+ plot_parser = subparsers.add_parser(
+ "plot",
+ parents=[parent_parser],
+ description=append_doc_link(PLOT_HELP, "plot"),
+ help=PLOT_HELP,
+ formatter_class=argparse.RawDescriptionHelpFormatter,
+ )
+ plot_subparsers = plot_parser.add_subparsers(
+ dest="cmd",
+ help="Use `dvc plot CMD --help` to display command-specific help.",
+ )
+
+ fix_subparsers(plot_subparsers)
+
+ SHOW_HELP = "Generate a plot image file from a continuous metrics file."
+ plot_show_parser = plot_subparsers.add_parser(
+ "show",
+ parents=[parent_parser],
+ description=append_doc_link(SHOW_HELP, "plot/show"),
+ help=SHOW_HELP,
+ formatter_class=argparse.RawDescriptionHelpFormatter,
+ )
+ plot_show_parser.add_argument(
+ "-t",
+ "--template",
+ nargs="?",
+ default=None,
+ help="File to be injected with data.",
+ )
+ plot_show_parser.add_argument(
+ "-f", "--file", default=None, help="Name of the generated file."
+ )
+ plot_show_parser.add_argument(
+ "-s",
+ "--select",
+ default=None,
+ help="Choose which field(s) or JSONPath to include in the plot.",
+ )
+ plot_show_parser.add_argument(
+ "-x", default=None, help="Field name for x axis."
+ )
+ plot_show_parser.add_argument(
+ "-y", default=None, help="Field name for y axis."
+ )
+ plot_show_parser.add_argument(
+ "--stdout",
+ action="store_true",
+ default=False,
+ help="Print plot specification to stdout.",
+ )
+ plot_show_parser.add_argument(
+ "--no-csv-header",
+ action="store_true",
+ default=False,
+ help="Required when CSV or TSV datafile does not have a header.",
+ )
+ plot_show_parser.add_argument(
+ "--no-html",
+ action="store_true",
+ default=False,
+ help="Do not wrap Vega plot JSON with HTML.",
+ )
+ plot_show_parser.add_argument("--title", default=None, help="Plot title.")
+ plot_show_parser.add_argument("--xlab", default=None, help="X axis title.")
+ plot_show_parser.add_argument("--ylab", default=None, help="Y axis title.")
+ plot_show_parser.add_argument(
+ "datafile",
+ nargs="?",
+ default=None,
+ help="Continuous metrics file to visualize.",
+ )
+ plot_show_parser.set_defaults(func=CmdPlotShow)
+
+ PLOT_DIFF_HELP = (
+ "Plot continuous metrics differences between commits in the DVC "
+ "repository, or between the last commit and the workspace."
+ )
+ plot_diff_parser = plot_subparsers.add_parser(
+ "diff",
+ parents=[parent_parser],
+ description=append_doc_link(PLOT_DIFF_HELP, "plot/diff"),
+ help=PLOT_DIFF_HELP,
+ formatter_class=argparse.RawDescriptionHelpFormatter,
+ )
+ plot_diff_parser.add_argument(
+ "-t",
+ "--template",
+ nargs="?",
+ default=None,
+ help="File to be injected with data.",
+ )
+ plot_diff_parser.add_argument(
+ "-d",
+ "--datafile",
+ nargs="?",
+ default=None,
+ help="Continuous metrics file to visualize.",
+ )
+ plot_diff_parser.add_argument(
+ "-f", "--file", default=None, help="Name of the generated file."
+ )
+ plot_diff_parser.add_argument(
+ "-s",
+ "--select",
+ default=None,
+ help="Choose which field(s) or JSONPath to include in the plot.",
+ )
+ plot_diff_parser.add_argument(
+ "-x", default=None, help="Field name for x axis."
+ )
+ plot_diff_parser.add_argument(
+ "-y", default=None, help="Field name for y axis."
+ )
+ plot_diff_parser.add_argument(
+ "--stdout",
+ action="store_true",
+ default=False,
+ help="Print plot specification to stdout.",
+ )
+ plot_diff_parser.add_argument(
+ "--no-csv-header",
+ action="store_true",
+ default=False,
+ help="Provided CSV ot TSV datafile does not have a header.",
+ )
+ plot_diff_parser.add_argument(
+ "--no-html",
+ action="store_true",
+ default=False,
+ help="Do not wrap Vega plot JSON with HTML.",
+ )
+ plot_diff_parser.add_argument("--title", default=None, help="Plot title.")
+ plot_diff_parser.add_argument("--xlab", default=None, help="X axis title.")
+ plot_diff_parser.add_argument("--ylab", default=None, help="Y axis title.")
+ plot_diff_parser.add_argument(
+ "revisions",
+ nargs="*",
+ default=None,
+ help="Git revisions to plot from",
+ )
+ plot_diff_parser.set_defaults(func=CmdPlotDiff)
diff --git a/dvc/repo/__init__.py b/dvc/repo/__init__.py
index 33431fa4f0..b9c38f575c 100644
--- a/dvc/repo/__init__.py
+++ b/dvc/repo/__init__.py
@@ -61,6 +61,7 @@ class Repo(object):
from dvc.repo.get import get
from dvc.repo.get_url import get_url
from dvc.repo.update import update
+ from dvc.repo.plot import plot
def __init__(self, root_dir=None):
from dvc.state import State
@@ -426,6 +427,12 @@ def stages(self):
"""
return self._collect_stages()
+ @cached_property
+ def plot_templates(self):
+ from dvc.repo.plot.template import PlotTemplates
+
+ return PlotTemplates(self.dvc_dir)
+
def _collect_stages(self):
from dvc.dvcfile import Dvcfile, is_valid_filename
diff --git a/dvc/repo/init.py b/dvc/repo/init.py
index 3238bb8e94..dda339bbd7 100644
--- a/dvc/repo/init.py
+++ b/dvc/repo/init.py
@@ -102,7 +102,7 @@ def init(root_dir=os.curdir, no_scm=False, force=False, subdir=False):
proj = Repo(root_dir)
- scm.add([config.files["repo"]])
+ scm.add([config.files["repo"], proj.plot_templates.templates_dir])
if scm.ignore_file:
scm.add([os.path.join(dvc_dir, scm.ignore_file)])
diff --git a/dvc/repo/plot/__init__.py b/dvc/repo/plot/__init__.py
new file mode 100644
index 0000000000..92a1376e31
--- /dev/null
+++ b/dvc/repo/plot/__init__.py
@@ -0,0 +1,164 @@
+import logging
+import os
+
+from funcy import first, last
+
+from dvc.exceptions import DvcException
+from dvc.repo.plot.data import PlotData
+from dvc.repo.plot.template import Template, NoDataForTemplateError
+from dvc.repo import locked
+
+logger = logging.getLogger(__name__)
+
+PAGE_HTML = """
+
+ DVC Plot
+
+
+
+
+
+ {divs}
+
+"""
+
+DIV_HTML = """
+"""
+
+
+class TooManyDataSourcesError(DvcException):
+ def __init__(self, datafile, template_datafiles):
+ super().__init__(
+ "Unable to infer which of possible data sources: '{}' "
+ "should be replaced with '{}'.".format(
+ ", ".join(template_datafiles), datafile
+ )
+ )
+
+
+class NoDataOrTemplateProvided(DvcException):
+ def __init__(self):
+ super().__init__("Datafile or template is not specified.")
+
+
+def _evaluate_templatepath(repo, template=None):
+ if not template:
+ return repo.plot_templates.default_template
+
+ if os.path.exists(template):
+ return template
+ return repo.plot_templates.get_template(template)
+
+
+@locked
+def fill_template(
+ repo,
+ datafile,
+ template_path,
+ revisions,
+ fields=None,
+ path=None,
+ csv_header=True,
+ x_field=None,
+ y_field=None,
+ **kwargs
+):
+ if x_field and fields:
+ fields.add(x_field)
+
+ if y_field and fields:
+ fields.add(y_field)
+
+ template_datafiles, x_anchor, y_anchor = _parse_template(
+ template_path, datafile
+ )
+ append_index = x_anchor and not x_field
+ if append_index:
+ x_field = PlotData.INDEX_FIELD
+
+ template_data = {}
+ for template_datafile in template_datafiles:
+ from dvc.repo.plot.data import _load_from_revisions
+
+ plot_datas = _load_from_revisions(repo, template_datafile, revisions)
+ tmp_data = []
+ for pd in plot_datas:
+ rev_data_points = pd.to_datapoints(
+ fields=fields,
+ path=path,
+ csv_header=csv_header,
+ append_index=append_index,
+ )
+
+ if y_anchor and not y_field:
+ y_field = _infer_y_field(rev_data_points, x_field)
+ tmp_data.extend(rev_data_points)
+
+ template_data[template_datafile] = tmp_data
+
+ if len(template_data) == 0:
+ raise NoDataForTemplateError(template_path)
+
+ return Template.fill(
+ template_path,
+ template_data,
+ priority_datafile=datafile,
+ x_field=x_field,
+ y_field=y_field,
+ **kwargs
+ )
+
+
+def _infer_y_field(rev_data_points, x_field):
+ all_fields = list(first(rev_data_points).keys())
+ all_fields.remove(PlotData.REVISION_FIELD)
+ if x_field and x_field in all_fields:
+ all_fields.remove(x_field)
+ y_field = last(all_fields)
+ return y_field
+
+
+def plot(
+ repo, datafile=None, template=None, revisions=None, embed=False, **kwargs
+):
+ if revisions is None:
+ from dvc.repo.plot.data import WORKSPACE_REVISION_NAME
+
+ revisions = [WORKSPACE_REVISION_NAME]
+
+ if not datafile and not template:
+ raise NoDataOrTemplateProvided()
+
+ template_path = _evaluate_templatepath(repo, template)
+
+ plot_content = fill_template(
+ repo, datafile, template_path, revisions, **kwargs
+ )
+
+ if embed:
+ div = DIV_HTML.format(id="plot", vega_json=plot_content)
+ plot_content = PAGE_HTML.format(divs=div)
+
+ return plot_content
+
+
+def _parse_template(template_path, priority_datafile):
+ with open(template_path, "r") as fobj:
+ tempalte_content = fobj.read()
+
+ template_datafiles = Template.parse_data_anchors(tempalte_content)
+ if priority_datafile:
+ if len(template_datafiles) > 1:
+ raise TooManyDataSourcesError(
+ priority_datafile, template_datafiles
+ )
+ template_datafiles = {priority_datafile}
+
+ return (
+ template_datafiles,
+ Template.X_ANCHOR in tempalte_content,
+ Template.Y_ANCHOR in tempalte_content,
+ )
diff --git a/dvc/repo/plot/data.py b/dvc/repo/plot/data.py
new file mode 100644
index 0000000000..31b53b2692
--- /dev/null
+++ b/dvc/repo/plot/data.py
@@ -0,0 +1,292 @@
+import csv
+import io
+import json
+import logging
+import os
+from collections import OrderedDict
+from copy import copy
+
+import yaml
+from funcy import first
+from yaml import SafeLoader
+
+from dvc.exceptions import DvcException, PathMissingError
+
+
+logger = logging.getLogger(__name__)
+
+WORKSPACE_REVISION_NAME = "workspace"
+
+
+class PlotMetricTypeError(DvcException):
+ def __init__(self, file):
+ super().__init__(
+ "'{}' - file type error\n"
+ "Only JSON, YAML, CSV and TSV formats are supported.".format(file)
+ )
+
+
+class PlotDataStructureError(DvcException):
+ def __init__(self):
+ super().__init__(
+ "Plot data extraction failed. Please see "
+ "documentation for supported data formats."
+ )
+
+
+class JsonParsingError(DvcException):
+ def __init__(self, file):
+ super().__init__(
+ "Failed to infer data structure from '{}'. Did you forget "
+ "to specify JSONpath?".format(file)
+ )
+
+
+class NoMetricOnRevisionError(DvcException):
+ def __init__(self, path, revision):
+ self.path = path
+ self.revision = revision
+ super().__init__(
+ "Could not find '{}' on revision '{}'".format(path, revision)
+ )
+
+
+class NoMetricInHistoryError(DvcException):
+ def __init__(self, path):
+ super().__init__("Could not find '{}'.".format(path))
+
+
+def plot_data(filename, revision, content):
+ _, extension = os.path.splitext(filename.lower())
+ if extension == ".json":
+ return JSONPlotData(filename, revision, content)
+ elif extension == ".csv":
+ return CSVPlotData(filename, revision, content)
+ elif extension == ".tsv":
+ return CSVPlotData(filename, revision, content, delimiter="\t")
+ elif extension == ".yaml":
+ return YAMLPlotData(filename, revision, content)
+ raise PlotMetricTypeError(filename)
+
+
+def _filter_fields(data_points, filename, revision, fields=None, **kwargs):
+ if not fields:
+ return data_points
+ assert isinstance(fields, set)
+
+ new_data = []
+ for data_point in data_points:
+ new_dp = copy(data_point)
+
+ keys = set(data_point.keys())
+ if keys & fields != fields:
+ raise DvcException(
+ "Could not find fields: '{}' for '{}' at '{}'.".format(
+ ", " "".join(fields), filename, revision
+ )
+ )
+
+ to_del = keys - fields
+ for key in to_del:
+ del new_dp[key]
+ new_data.append(new_dp)
+ return new_data
+
+
+def _apply_path(data, path=None, **kwargs):
+ if not path or not isinstance(data, dict):
+ return data
+
+ import jsonpath_ng
+
+ found = jsonpath_ng.parse(path).find(data)
+ first_datum = first(found)
+ if (
+ len(found) == 1
+ and isinstance(first_datum.value, list)
+ and isinstance(first(first_datum.value), dict)
+ ):
+ data_points = first_datum.value
+ elif len(first_datum.path.fields) == 1:
+ field_name = first(first_datum.path.fields)
+ data_points = [{field_name: datum.value} for datum in found]
+ else:
+ raise PlotDataStructureError()
+
+ if not isinstance(data_points, list) or not (
+ isinstance(first(data_points), dict)
+ ):
+ raise PlotDataStructureError()
+
+ return data_points
+
+
+def _lists(dictionary):
+ for _, value in dictionary.items():
+ if isinstance(value, dict):
+ yield from (_lists(value))
+ elif isinstance(value, list):
+ yield value
+
+
+def _find_data(data, fields=None, **kwargs):
+ if not isinstance(data, dict):
+ return data
+
+ if not fields:
+ # just look for first list of dicts
+ fields = set()
+
+ for l in _lists(data):
+ if all(isinstance(dp, dict) for dp in l):
+ if set(first(l).keys()) & fields == fields:
+ return l
+ raise PlotDataStructureError()
+
+
+def _append_index(data_points, append_index=False, **kwargs):
+ if not append_index:
+ return data_points
+
+ if PlotData.INDEX_FIELD in first(data_points).keys():
+ raise DvcException(
+ "Cannot append index. Field of same name ('{}') found in data. "
+ "Use `-x` to specify x axis field.".format(PlotData.INDEX_FIELD)
+ )
+
+ for index, data_point in enumerate(data_points):
+ data_point[PlotData.INDEX_FIELD] = index
+ return data_points
+
+
+def _append_revision(data_points, revision, **kwargs):
+ for data_point in data_points:
+ data_point[PlotData.REVISION_FIELD] = revision
+ return data_points
+
+
+class PlotData:
+ REVISION_FIELD = "rev"
+ INDEX_FIELD = "index"
+
+ def __init__(self, filename, revision, content, **kwargs):
+ self.filename = filename
+ self.revision = revision
+ self.content = content
+
+ def raw(self, **kwargs):
+ raise NotImplementedError
+
+ def _processors(self):
+ return [_filter_fields, _append_index, _append_revision]
+
+ def to_datapoints(self, **kwargs):
+ data = self.raw(**kwargs)
+
+ for data_proc in self._processors():
+ data = data_proc(
+ data, filename=self.filename, revision=self.revision, **kwargs
+ )
+ return data
+
+
+class JSONPlotData(PlotData):
+ def raw(self, **kwargs):
+ return json.loads(self.content, object_pairs_hook=OrderedDict)
+
+ def _processors(self):
+ parent_processors = super(JSONPlotData, self)._processors()
+ return [_apply_path, _find_data] + parent_processors
+
+
+class CSVPlotData(PlotData):
+ def __init__(self, filename, revision, content, delimiter=","):
+ super(CSVPlotData, self).__init__(filename, revision, content)
+ self.delimiter = delimiter
+
+ def raw(self, csv_header=True, **kwargs):
+ first_row = first(csv.reader(io.StringIO(self.content)))
+
+ if csv_header:
+ reader = csv.DictReader(
+ io.StringIO(self.content), delimiter=self.delimiter,
+ )
+ else:
+ reader = csv.DictReader(
+ io.StringIO(self.content),
+ delimiter=self.delimiter,
+ fieldnames=[str(i) for i in range(len(first_row))],
+ )
+
+ fieldnames = reader.fieldnames
+ data = [row for row in reader]
+
+ return [
+ OrderedDict([(field, data_point[field]) for field in fieldnames])
+ for data_point in data
+ ]
+
+
+class YAMLPlotData(PlotData):
+ def raw(self, **kwargs):
+ class OrderedLoader(SafeLoader):
+ pass
+
+ def construct_mapping(loader, node):
+ loader.flatten_mapping(node)
+ return OrderedDict(loader.construct_pairs(node))
+
+ OrderedLoader.add_constructor(
+ yaml.resolver.BaseResolver.DEFAULT_MAPPING_TAG, construct_mapping
+ )
+
+ return yaml.load(self.content, OrderedLoader)
+
+
+def _load_from_revision(repo, datafile, revision):
+ if revision is WORKSPACE_REVISION_NAME:
+
+ def open_datafile():
+ return repo.tree.open(datafile, "r")
+
+ else:
+
+ def open_datafile():
+ from dvc import api
+
+ return api.open(datafile, repo.root_dir, revision)
+
+ try:
+ with open_datafile() as fobj:
+ datafile_content = fobj.read()
+
+ except (FileNotFoundError, PathMissingError):
+ raise NoMetricOnRevisionError(datafile, revision)
+
+ return plot_data(datafile, revision, datafile_content)
+
+
+def _load_from_revisions(repo, datafile, revisions):
+ data = []
+ exceptions = []
+
+ for rev in revisions:
+ try:
+ data.append(_load_from_revision(repo, datafile, rev))
+ except NoMetricOnRevisionError as e:
+ exceptions.append(e)
+ except PlotMetricTypeError:
+ raise
+ except (yaml.error.YAMLError, json.decoder.JSONDecodeError, csv.Error):
+ logger.error("Failed to parse '{}' at '{}'.".format(datafile, rev))
+ raise
+
+ if not data and exceptions:
+ raise NoMetricInHistoryError(datafile)
+ else:
+ for e in exceptions:
+ logger.warning(
+ "File '{}' was not found at: '{}'. It will not be "
+ "plotted.".format(e.path, e.revision)
+ )
+ return data
diff --git a/dvc/repo/plot/template.py b/dvc/repo/plot/template.py
new file mode 100644
index 0000000000..2850a6220f
--- /dev/null
+++ b/dvc/repo/plot/template.py
@@ -0,0 +1,276 @@
+import json
+import logging
+import os
+import re
+
+from funcy import cached_property
+
+from dvc.exceptions import DvcException
+from dvc.utils.fs import makedirs
+
+
+logger = logging.getLogger(__name__)
+
+
+class TemplateNotFoundError(DvcException):
+ def __init__(self, path):
+ super().__init__("Template '{}' not found.".format(path))
+
+
+class NoDataForTemplateError(DvcException):
+ def __init__(self, template_path):
+ super().__init__(
+ "No data provided for '{}'.".format(os.path.relpath(template_path))
+ )
+
+
+class Template:
+ INDENT = 4
+ SEPARATORS = (",", ": ")
+ EXTENSION = ".json"
+ METRIC_DATA_ANCHOR = ""
+ X_ANCHOR = ""
+ Y_ANCHOR = ""
+ TITLE_ANCHOR = ""
+ X_TITLE_ANCHOR = ""
+ Y_TITLE_ANCHOR = ""
+
+ def __init__(self, templates_dir):
+ self.plot_templates_dir = templates_dir
+
+ def dump(self):
+ makedirs(self.plot_templates_dir, exist_ok=True)
+
+ with open(
+ os.path.join(
+ self.plot_templates_dir, self.TEMPLATE_NAME + self.EXTENSION
+ ),
+ "w",
+ ) as fobj:
+ json.dump(
+ self.DEFAULT_CONTENT,
+ fobj,
+ indent=self.INDENT,
+ separators=self.SEPARATORS,
+ )
+
+ @staticmethod
+ def get_data_anchor(template_content):
+ regex = re.compile('""]*>"')
+ return regex.findall(template_content)
+
+ @staticmethod
+ def parse_data_anchors(template_content):
+ data_files = {
+ Template.get_datafile(m)
+ for m in Template.get_data_anchor(template_content)
+ }
+ return {df for df in data_files if df}
+
+ @staticmethod
+ def get_datafile(anchor_string):
+ return (
+ anchor_string.replace("<", "")
+ .replace(">", "")
+ .replace('"', "")
+ .replace("DVC_METRIC_DATA", "")
+ .replace(",", "")
+ )
+
+ @staticmethod
+ def fill(
+ template_path,
+ data,
+ priority_datafile=None,
+ x_field=None,
+ y_field=None,
+ title=None,
+ x_title=None,
+ y_title=None,
+ ):
+ with open(template_path, "r") as fobj:
+ result_content = fobj.read()
+
+ result_content = Template._replace_data_anchors(
+ result_content, data, priority_datafile
+ )
+
+ result_content = Template._replace_metadata_anchors(
+ result_content, title, x_field, x_title, y_field, y_title
+ )
+
+ return result_content
+
+ @staticmethod
+ def _replace_metadata_anchors(
+ result_content, title, x_field, x_title, y_field, y_title
+ ):
+ if Template.TITLE_ANCHOR in result_content:
+ if title:
+ result_content = result_content.replace(
+ Template.TITLE_ANCHOR, title
+ )
+ else:
+ result_content = result_content.replace(
+ Template.TITLE_ANCHOR, ""
+ )
+ if Template.X_ANCHOR in result_content and x_field:
+ result_content = result_content.replace(Template.X_ANCHOR, x_field)
+ if Template.Y_ANCHOR in result_content and y_field:
+ result_content = result_content.replace(Template.Y_ANCHOR, y_field)
+ if Template.X_TITLE_ANCHOR in result_content:
+ if not x_title and x_field:
+ x_title = x_field
+ result_content = result_content.replace(
+ Template.X_TITLE_ANCHOR, x_title
+ )
+ if Template.Y_TITLE_ANCHOR in result_content:
+ if not y_title and y_field:
+ y_title = y_field
+ result_content = result_content.replace(
+ Template.Y_TITLE_ANCHOR, y_title
+ )
+ return result_content
+
+ @staticmethod
+ def _replace_data_anchors(result_content, data, priority_datafile):
+ for anchor in Template.get_data_anchor(result_content):
+ file = Template.get_datafile(anchor)
+
+ if not file or priority_datafile:
+ key = priority_datafile
+ else:
+ key = file
+
+ result_content = result_content.replace(
+ anchor,
+ json.dumps(
+ data[key],
+ indent=Template.INDENT,
+ separators=Template.SEPARATORS,
+ sort_keys=True,
+ ),
+ )
+ return result_content
+
+
+class DefaultLinearTemplate(Template):
+ TEMPLATE_NAME = "default"
+
+ DEFAULT_CONTENT = {
+ "$schema": "https://vega.github.io/schema/vega-lite/v4.json",
+ "data": {"values": Template.METRIC_DATA_ANCHOR},
+ "title": Template.TITLE_ANCHOR,
+ "mark": {"type": "line"},
+ "encoding": {
+ "x": {
+ "field": Template.X_ANCHOR,
+ "type": "quantitative",
+ "title": Template.X_TITLE_ANCHOR,
+ },
+ "y": {
+ "field": Template.Y_ANCHOR,
+ "type": "quantitative",
+ "title": Template.Y_TITLE_ANCHOR,
+ },
+ "color": {"field": "rev", "type": "nominal"},
+ },
+ }
+
+
+class DefaultConfusionTemplate(Template):
+ TEMPLATE_NAME = "confusion"
+ DEFAULT_CONTENT = {
+ "$schema": "https://vega.github.io/schema/vega-lite/v4.json",
+ "data": {"values": Template.METRIC_DATA_ANCHOR},
+ "title": Template.TITLE_ANCHOR,
+ "mark": "rect",
+ "encoding": {
+ "x": {
+ "field": Template.X_ANCHOR,
+ "type": "nominal",
+ "sort": "ascending",
+ "title": Template.X_TITLE_ANCHOR,
+ },
+ "y": {
+ "field": Template.Y_ANCHOR,
+ "type": "nominal",
+ "sort": "ascending",
+ "title": Template.Y_TITLE_ANCHOR,
+ },
+ "color": {"aggregate": "count", "type": "quantitative"},
+ "facet": {"field": "rev", "type": "nominal"},
+ },
+ }
+
+
+class DefaultScatterTemplate(Template):
+ TEMPLATE_NAME = "scatter"
+ DEFAULT_CONTENT = {
+ "$schema": "https://vega.github.io/schema/vega-lite/v4.json",
+ "data": {"values": Template.METRIC_DATA_ANCHOR},
+ "title": Template.TITLE_ANCHOR,
+ "mark": "point",
+ "encoding": {
+ "x": {
+ "field": Template.X_ANCHOR,
+ "type": "quantitative",
+ "title": Template.X_TITLE_ANCHOR,
+ },
+ "y": {
+ "field": Template.Y_ANCHOR,
+ "type": "quantitative",
+ "title": Template.Y_TITLE_ANCHOR,
+ },
+ "color": {"field": "rev", "type": "nominal"},
+ },
+ }
+
+
+class PlotTemplates:
+ TEMPLATES_DIR = "plot"
+ TEMPLATES = [
+ DefaultLinearTemplate,
+ DefaultConfusionTemplate,
+ DefaultScatterTemplate,
+ ]
+
+ @cached_property
+ def templates_dir(self):
+ return os.path.join(self.dvc_dir, self.TEMPLATES_DIR)
+
+ @cached_property
+ def default_template(self):
+ default_plot_path = os.path.join(self.templates_dir, "default.json")
+ if not os.path.exists(default_plot_path):
+ raise TemplateNotFoundError(os.path.relpath(default_plot_path))
+ return default_plot_path
+
+ def get_template(self, path):
+ t_path = os.path.join(self.templates_dir, path)
+ if os.path.exists(t_path):
+ return t_path
+
+ all_templates = [
+ os.path.join(root, file)
+ for root, _, files in os.walk(self.templates_dir)
+ for file in files
+ ]
+ matches = [
+ template
+ for template in all_templates
+ if os.path.splitext(template)[0] == t_path
+ ]
+ if matches:
+ assert len(matches) == 1
+ return matches[0]
+
+ raise TemplateNotFoundError(path)
+
+ def __init__(self, dvc_dir):
+ self.dvc_dir = dvc_dir
+
+ if not os.path.exists(self.templates_dir):
+ makedirs(self.templates_dir, exist_ok=True)
+ for t in self.TEMPLATES:
+ t(self.templates_dir).dump()
diff --git a/scripts/completion/dvc.bash b/scripts/completion/dvc.bash
index cde8130e62..78f966acde 100644
--- a/scripts/completion/dvc.bash
+++ b/scripts/completion/dvc.bash
@@ -5,7 +5,7 @@
# - https://stackoverflow.com/questions/12933362
_dvc_commands='add cache checkout commit config destroy diff fetch get-url get gc \
- import-url import init install lock list metrics move pipeline pull push \
+ import-url import init install lock list metrics move pipeline plot pull push \
remote remove repro root run status unlock unprotect update version'
_dvc_options='-h --help -V --version'
@@ -51,6 +51,9 @@ _dvc_pipeline='list show'
_dvc_pipeline_list=''
_dvc_pipeline_show='-c --commands -o --outs --ascii --dot --tree -l --locked'
_dvc_pipeline_show_COMPGEN=_dvc_compgen_DVCFiles
+_dvc_plot='show diff'
+_dvc_plot_show='-t --template -f --file -s --select -x -y --stdout --no-csv-header --no-html --title --xlab --ylab'
+_dvc_plot_diff='-t --template -d --datafile -f --file -s --select -x -y --stdout --no-csv-header --no-html --title --xlab --ylab'
_dvc_pull='-j --jobs -r --remote -a --all-branches -T --all-tags -f --force -d --with-deps -R --recursive'
_dvc_pull_COMPGEN=_dvc_compgen_DVCFiles
_dvc_push='-j --jobs -r --remote -a --all-branches -T --all-tags -d --with-deps -R --recursive'
diff --git a/setup.py b/setup.py
index a77f5b767d..e27b30843e 100644
--- a/setup.py
+++ b/setup.py
@@ -129,6 +129,7 @@ def run(self):
"mock-ssh-server>=0.6.0",
"moto==1.3.14.dev464",
"rangehttpserver==1.2.0",
+ "beautifulsoup4==4.4.0",
]
if (sys.version_info) >= (3, 6):
diff --git a/tests/func/test_plot.py b/tests/func/test_plot.py
new file mode 100644
index 0000000000..6be11320f8
--- /dev/null
+++ b/tests/func/test_plot.py
@@ -0,0 +1,504 @@
+import csv
+import json
+import logging
+import shutil
+from collections import OrderedDict
+
+import pytest
+import yaml
+from bs4 import BeautifulSoup
+from funcy import first
+
+from dvc.compat import fspath
+from dvc.repo.plot.data import (
+ NoMetricInHistoryError,
+ PlotMetricTypeError,
+ PlotData,
+)
+from dvc.repo.plot.template import (
+ TemplateNotFoundError,
+ NoDataForTemplateError,
+)
+from dvc.repo.plot import NoDataOrTemplateProvided
+
+
+def _remove_whitespace(value):
+ return value.replace(" ", "").replace("\n", "")
+
+
+def _run_with_metric(tmp_dir, metric_filename, commit=None, tag=None):
+ tmp_dir.dvc.run(metrics_no_cache=[metric_filename])
+ if hasattr(tmp_dir.dvc, "scm"):
+ tmp_dir.dvc.scm.add([metric_filename, metric_filename + ".dvc"])
+ if commit:
+ tmp_dir.dvc.scm.commit(commit)
+ if tag:
+ tmp_dir.dvc.scm.tag(tag)
+
+
+def _write_csv(metric, filename, header=True):
+ with open(filename, "w", newline="") as csvobj:
+ if header:
+ writer = csv.DictWriter(
+ csvobj, fieldnames=list(first(metric).keys())
+ )
+ writer.writeheader()
+ writer.writerows(metric)
+ else:
+ writer = csv.writer(csvobj)
+ for d in metric:
+ assert len(d) == 1
+ writer.writerow(list(d.values()))
+
+
+def _write_json(tmp_dir, metric, filename):
+ tmp_dir.gen(filename, json.dumps(metric, sort_keys=True))
+
+
+def test_plot_csv_one_column(tmp_dir, scm, dvc):
+ # no header
+ metric = [{"val": 2}, {"val": 3}]
+ _write_csv(metric, "metric.csv", header=False)
+ _run_with_metric(tmp_dir, metric_filename="metric.csv")
+
+ plot_string = dvc.plot(
+ "metric.csv",
+ csv_header=False,
+ x_title="x_title",
+ y_title="y_title",
+ title="mytitle",
+ )
+
+ plot_content = json.loads(plot_string)
+ assert plot_content["title"] == "mytitle"
+ assert plot_content["data"]["values"] == [
+ {"0": "2", PlotData.INDEX_FIELD: 0, "rev": "workspace"},
+ {"0": "3", PlotData.INDEX_FIELD: 1, "rev": "workspace"},
+ ]
+ assert plot_content["encoding"]["x"]["field"] == PlotData.INDEX_FIELD
+ assert plot_content["encoding"]["y"]["field"] == "0"
+ assert plot_content["encoding"]["x"]["title"] == "x_title"
+ assert plot_content["encoding"]["y"]["title"] == "y_title"
+
+
+def test_plot_csv_multiple_columns(tmp_dir, scm, dvc):
+ metric = [
+ OrderedDict([("first_val", 100), ("second_val", 100), ("val", 2)]),
+ OrderedDict([("first_val", 200), ("second_val", 300), ("val", 3)]),
+ ]
+ _write_csv(metric, "metric.csv")
+ _run_with_metric(tmp_dir, metric_filename="metric.csv")
+
+ plot_string = dvc.plot("metric.csv")
+
+ plot_content = json.loads(plot_string)
+ assert plot_content["data"]["values"] == [
+ {
+ "val": "2",
+ PlotData.INDEX_FIELD: 0,
+ "rev": "workspace",
+ "first_val": "100",
+ "second_val": "100",
+ },
+ {
+ "val": "3",
+ PlotData.INDEX_FIELD: 1,
+ "rev": "workspace",
+ "first_val": "200",
+ "second_val": "300",
+ },
+ ]
+ assert plot_content["encoding"]["x"]["field"] == PlotData.INDEX_FIELD
+ assert plot_content["encoding"]["y"]["field"] == "val"
+
+
+def test_plot_csv_choose_axes(tmp_dir, scm, dvc):
+ metric = [
+ OrderedDict([("first_val", 100), ("second_val", 100), ("val", 2)]),
+ OrderedDict([("first_val", 200), ("second_val", 300), ("val", 3)]),
+ ]
+ _write_csv(metric, "metric.csv")
+ _run_with_metric(tmp_dir, metric_filename="metric.csv")
+
+ plot_string = dvc.plot(
+ "metric.csv", x_field="first_val", y_field="second_val"
+ )
+
+ plot_content = json.loads(plot_string)
+ assert plot_content["data"]["values"] == [
+ {
+ "val": "2",
+ "rev": "workspace",
+ "first_val": "100",
+ "second_val": "100",
+ },
+ {
+ "val": "3",
+ "rev": "workspace",
+ "first_val": "200",
+ "second_val": "300",
+ },
+ ]
+ assert plot_content["encoding"]["x"]["field"] == "first_val"
+ assert plot_content["encoding"]["y"]["field"] == "second_val"
+
+
+def test_plot_json_single_val(tmp_dir, scm, dvc):
+ metric = [{"val": 2}, {"val": 3}]
+ _write_json(tmp_dir, metric, "metric.json")
+ _run_with_metric(tmp_dir, "metric.json", "first run")
+
+ plot_string = dvc.plot("metric.json")
+
+ plot_json = json.loads(plot_string)
+ assert plot_json["data"]["values"] == [
+ {"val": 2, PlotData.INDEX_FIELD: 0, "rev": "workspace"},
+ {"val": 3, PlotData.INDEX_FIELD: 1, "rev": "workspace"},
+ ]
+ assert plot_json["encoding"]["x"]["field"] == PlotData.INDEX_FIELD
+ assert plot_json["encoding"]["y"]["field"] == "val"
+
+
+def test_plot_json_multiple_val(tmp_dir, scm, dvc):
+ metric = [
+ {"first_val": 100, "val": 2},
+ {"first_val": 200, "val": 3},
+ ]
+ _write_json(tmp_dir, metric, "metric.json")
+ _run_with_metric(tmp_dir, "metric.json", "first run")
+
+ plot_string = dvc.plot("metric.json")
+
+ plot_content = json.loads(plot_string)
+ assert plot_content["data"]["values"] == [
+ {
+ "val": 2,
+ PlotData.INDEX_FIELD: 0,
+ "first_val": 100,
+ "rev": "workspace",
+ },
+ {
+ "val": 3,
+ PlotData.INDEX_FIELD: 1,
+ "first_val": 200,
+ "rev": "workspace",
+ },
+ ]
+ assert plot_content["encoding"]["x"]["field"] == PlotData.INDEX_FIELD
+ assert plot_content["encoding"]["y"]["field"] == "val"
+
+
+def test_plot_confusion(tmp_dir, dvc):
+ confusion_matrix = [
+ {"predicted": "B", "actual": "A"},
+ {"predicted": "A", "actual": "A"},
+ ]
+ _write_json(tmp_dir, confusion_matrix, "metric.json")
+ _run_with_metric(tmp_dir, "metric.json", "first run")
+
+ plot_string = dvc.plot(
+ datafile="metric.json",
+ template="confusion",
+ x_field="predicted",
+ y_field="actual",
+ )
+
+ plot_content = json.loads(plot_string)
+ assert plot_content["data"]["values"] == [
+ {"predicted": "B", "actual": "A", "rev": "workspace"},
+ {"predicted": "A", "actual": "A", "rev": "workspace"},
+ ]
+ assert plot_content["encoding"]["x"]["field"] == "predicted"
+ assert plot_content["encoding"]["y"]["field"] == "actual"
+
+
+def test_plot_multiple_revs_default(tmp_dir, scm, dvc):
+ metric_1 = [{"y": 2}, {"y": 3}]
+ _write_json(tmp_dir, metric_1, "metric.json")
+ _run_with_metric(tmp_dir, "metric.json", "init", "v1")
+
+ metric_2 = [{"y": 3}, {"y": 5}]
+ _write_json(tmp_dir, metric_2, "metric.json")
+ _run_with_metric(tmp_dir, "metric.json", "second", "v2")
+
+ metric_3 = [{"y": 5}, {"y": 6}]
+ _write_json(tmp_dir, metric_3, "metric.json")
+ _run_with_metric(tmp_dir, "metric.json", "third")
+
+ plot_string = dvc.plot(
+ "metric.json", fields={"y"}, revisions=["HEAD", "v2", "v1"],
+ )
+
+ plot_content = json.loads(plot_string)
+ assert plot_content["data"]["values"] == [
+ {"y": 5, PlotData.INDEX_FIELD: 0, "rev": "HEAD"},
+ {"y": 6, PlotData.INDEX_FIELD: 1, "rev": "HEAD"},
+ {"y": 3, PlotData.INDEX_FIELD: 0, "rev": "v2"},
+ {"y": 5, PlotData.INDEX_FIELD: 1, "rev": "v2"},
+ {"y": 2, PlotData.INDEX_FIELD: 0, "rev": "v1"},
+ {"y": 3, PlotData.INDEX_FIELD: 1, "rev": "v1"},
+ ]
+ assert plot_content["encoding"]["x"]["field"] == PlotData.INDEX_FIELD
+ assert plot_content["encoding"]["y"]["field"] == "y"
+
+
+def test_plot_multiple_revs(tmp_dir, scm, dvc):
+ shutil.copy(
+ fspath(tmp_dir / ".dvc" / "plot" / "default.json"), "template.json"
+ )
+
+ metric_1 = [{"y": 2}, {"y": 3}]
+ _write_json(tmp_dir, metric_1, "metric.json")
+ _run_with_metric(tmp_dir, "metric.json", "init", "v1")
+
+ metric_2 = [{"y": 3}, {"y": 5}]
+ _write_json(tmp_dir, metric_2, "metric.json")
+ _run_with_metric(tmp_dir, "metric.json", "second", "v2")
+
+ metric_3 = [{"y": 5}, {"y": 6}]
+ _write_json(tmp_dir, metric_3, "metric.json")
+ _run_with_metric(tmp_dir, "metric.json", "third")
+
+ plot_string = dvc.plot(
+ "metric.json",
+ template="template.json",
+ revisions=["HEAD", "v2", "v1"],
+ )
+
+ plot_content = json.loads(plot_string)
+ assert plot_content["data"]["values"] == [
+ {"y": 5, PlotData.INDEX_FIELD: 0, "rev": "HEAD"},
+ {"y": 6, PlotData.INDEX_FIELD: 1, "rev": "HEAD"},
+ {"y": 3, PlotData.INDEX_FIELD: 0, "rev": "v2"},
+ {"y": 5, PlotData.INDEX_FIELD: 1, "rev": "v2"},
+ {"y": 2, PlotData.INDEX_FIELD: 0, "rev": "v1"},
+ {"y": 3, PlotData.INDEX_FIELD: 1, "rev": "v1"},
+ ]
+ assert plot_content["encoding"]["x"]["field"] == PlotData.INDEX_FIELD
+ assert plot_content["encoding"]["y"]["field"] == "y"
+
+
+def test_plot_even_if_metric_missing(tmp_dir, scm, dvc, caplog):
+ tmp_dir.scm_gen("some_file", "content", commit="there is no metric")
+ scm.tag("v1")
+
+ metric = [{"y": 2}, {"y": 3}]
+ _write_json(tmp_dir, metric, "metric.json")
+ _run_with_metric(tmp_dir, "metric.json", "there is metric", "v2")
+
+ caplog.clear()
+ with caplog.at_level(logging.WARNING, "dvc"):
+ plot_string = dvc.plot("metric.json", revisions=["v1", "v2"])
+ assert (
+ "File 'metric.json' was not found at: 'v1'. "
+ "It will not be plotted." in caplog.text
+ )
+
+ plot_content = json.loads(plot_string)
+ assert plot_content["data"]["values"] == [
+ {"y": 2, PlotData.INDEX_FIELD: 0, "rev": "v2"},
+ {"y": 3, PlotData.INDEX_FIELD: 1, "rev": "v2"},
+ ]
+ assert plot_content["encoding"]["x"]["field"] == PlotData.INDEX_FIELD
+ assert plot_content["encoding"]["y"]["field"] == "y"
+
+
+def test_throw_on_no_metric_at_all(tmp_dir, scm, dvc, caplog):
+ tmp_dir.scm_gen("some_file", "content", commit="there is no metric")
+ scm.tag("v1")
+
+ tmp_dir.gen("some_file", "make repo dirty")
+
+ caplog.clear()
+ with pytest.raises(NoMetricInHistoryError) as error, caplog.at_level(
+ logging.WARNING, "dvc"
+ ):
+ dvc.plot("metric.json", revisions=["v1"])
+
+ # do not warn if none found
+ assert len(caplog.messages) == 0
+
+ assert str(error.value) == "Could not find 'metric.json'."
+
+
+@pytest.fixture()
+def custom_template(tmp_dir, dvc):
+ custom_template = tmp_dir / "custom_template.json"
+ shutil.copy(
+ fspath(tmp_dir / ".dvc" / "plot" / "default.json"),
+ fspath(custom_template),
+ )
+ return custom_template
+
+
+def test_custom_template(tmp_dir, scm, dvc, custom_template):
+ metric = [{"a": 1, "b": 2}, {"a": 2, "b": 3}]
+ _write_json(tmp_dir, metric, "metric.json")
+ _run_with_metric(tmp_dir, "metric.json", "init", "v1")
+
+ plot_string = dvc.plot(
+ "metric.json", fspath(custom_template), x_field="a", y_field="b"
+ )
+
+ plot_content = json.loads(plot_string)
+ assert plot_content["data"]["values"] == [
+ {"a": 1, "b": 2, "rev": "workspace"},
+ {"a": 2, "b": 3, "rev": "workspace"},
+ ]
+ assert plot_content["encoding"]["x"]["field"] == "a"
+ assert plot_content["encoding"]["y"]["field"] == "b"
+
+
+def _replace(path, src, dst):
+ path.write_text(path.read_text().replace(src, dst))
+
+
+def test_custom_template_with_specified_data(
+ tmp_dir, scm, dvc, custom_template
+):
+ _replace(
+ custom_template, "DVC_METRIC_DATA", "DVC_METRIC_DATA,metric.json",
+ )
+
+ metric = [{"a": 1, "b": 2}, {"a": 2, "b": 3}]
+ _write_json(tmp_dir, metric, "metric.json")
+ _run_with_metric(tmp_dir, "metric.json", "init", "v1")
+
+ plot_string = dvc.plot(
+ datafile=None,
+ template=fspath(custom_template),
+ x_field="a",
+ y_field="b",
+ )
+
+ plot_content = json.loads(plot_string)
+ assert plot_content["data"]["values"] == [
+ {"a": 1, "b": 2, "rev": "workspace"},
+ {"a": 2, "b": 3, "rev": "workspace"},
+ ]
+ assert plot_content["encoding"]["x"]["field"] == "a"
+ assert plot_content["encoding"]["y"]["field"] == "b"
+
+
+def test_plot_override_specified_data_source(tmp_dir, scm, dvc):
+ shutil.copy(
+ fspath(tmp_dir / ".dvc" / "plot" / "default.json"),
+ fspath(tmp_dir / "newtemplate.json"),
+ )
+ _replace(
+ tmp_dir / "newtemplate.json",
+ "DVC_METRIC_DATA",
+ "DVC_METRIC_DATA,metric.json",
+ )
+
+ metric = [{"a": 1, "b": 2}, {"a": 2, "b": 3}]
+ _write_json(tmp_dir, metric, "metric2.json")
+ _run_with_metric(tmp_dir, "metric2.json", "init", "v1")
+
+ plot_string = dvc.plot(
+ datafile="metric2.json", template="newtemplate.json", x_field="a"
+ )
+
+ plot_content = json.loads(plot_string)
+ assert plot_content["data"]["values"] == [
+ {"a": 1, "b": 2, "rev": "workspace"},
+ {"a": 2, "b": 3, "rev": "workspace"},
+ ]
+ assert plot_content["encoding"]["x"]["field"] == "a"
+ assert plot_content["encoding"]["y"]["field"] == "b"
+
+
+def test_should_raise_on_no_template_and_datafile(tmp_dir, dvc):
+ with pytest.raises(NoDataOrTemplateProvided):
+ dvc.plot()
+
+
+def test_should_raise_on_no_template(tmp_dir, dvc):
+ with pytest.raises(TemplateNotFoundError):
+ dvc.plot("metric.json", "non_existing_template.json")
+
+
+def test_plot_no_data(tmp_dir, dvc):
+ with pytest.raises(NoDataForTemplateError):
+ dvc.plot(template="default")
+
+
+def test_plot_wrong_metric_type(tmp_dir, scm, dvc):
+ tmp_dir.scm_gen("metric.txt", "content", commit="initial")
+ with pytest.raises(PlotMetricTypeError):
+ dvc.plot(datafile="metric.txt")
+
+
+def test_plot_choose_columns(tmp_dir, scm, dvc, custom_template):
+ metric = [{"a": 1, "b": 2, "c": 3}, {"a": 2, "b": 3, "c": 4}]
+ _write_json(tmp_dir, metric, "metric.json")
+ _run_with_metric(tmp_dir, "metric.json", "init", "v1")
+
+ plot_string = dvc.plot(
+ "metric.json",
+ fspath(custom_template),
+ fields={"b", "c"},
+ x_field="b",
+ y_field="c",
+ )
+
+ plot_content = json.loads(plot_string)
+ assert plot_content["data"]["values"] == [
+ {"b": 2, "c": 3, "rev": "workspace"},
+ {"b": 3, "c": 4, "rev": "workspace"},
+ ]
+ assert plot_content["encoding"]["x"]["field"] == "b"
+ assert plot_content["encoding"]["y"]["field"] == "c"
+
+
+def test_plot_default_choose_column(tmp_dir, scm, dvc):
+ metric = [{"a": 1, "b": 2, "c": 3}, {"a": 2, "b": 3, "c": 4}]
+ _write_json(tmp_dir, metric, "metric.json")
+ _run_with_metric(tmp_dir, "metric.json", "init", "v1")
+
+ plot_string = dvc.plot("metric.json", fields={"b"})
+
+ plot_content = json.loads(plot_string)
+ assert plot_content["data"]["values"] == [
+ {PlotData.INDEX_FIELD: 0, "b": 2, "rev": "workspace"},
+ {PlotData.INDEX_FIELD: 1, "b": 3, "rev": "workspace"},
+ ]
+ assert plot_content["encoding"]["x"]["field"] == PlotData.INDEX_FIELD
+ assert plot_content["encoding"]["y"]["field"] == "b"
+
+
+def test_plot_embed(tmp_dir, scm, dvc):
+ metric = [{"val": 2}, {"val": 3}]
+ _write_json(tmp_dir, metric, "metric.json")
+ _run_with_metric(tmp_dir, "metric.json", "first run")
+
+ plot_string = dvc.plot("metric.json", embed=True, y_field="val")
+
+ page_content = BeautifulSoup(plot_string)
+ data_dump = json.dumps(
+ [
+ {"val": 2, PlotData.INDEX_FIELD: 0, "rev": "workspace"},
+ {"val": 3, PlotData.INDEX_FIELD: 1, "rev": "workspace"},
+ ],
+ sort_keys=True,
+ )
+
+ assert _remove_whitespace(data_dump) in _remove_whitespace(
+ first(page_content.body.script.contents)
+ )
+
+
+def test_plot_yaml(tmp_dir, scm, dvc):
+ metric = [{"val": 2}, {"val": 3}]
+ with open("metric.yaml", "w") as fobj:
+ yaml.dump(metric, fobj)
+
+ _run_with_metric(tmp_dir, metric_filename="metric.yaml")
+
+ plot_string = dvc.plot("metric.yaml",)
+
+ plot_content = json.loads(plot_string)
+ assert plot_content["data"]["values"] == [
+ {"val": 2, PlotData.INDEX_FIELD: 0, "rev": "workspace"},
+ {"val": 3, PlotData.INDEX_FIELD: 1, "rev": "workspace"},
+ ]
diff --git a/tests/unit/command/test_plot.py b/tests/unit/command/test_plot.py
new file mode 100644
index 0000000000..c0a3b00bc9
--- /dev/null
+++ b/tests/unit/command/test_plot.py
@@ -0,0 +1,122 @@
+import pytest
+
+from dvc.cli import parse_args
+from dvc.command.plot import CmdPlotShow, CmdPlotDiff
+
+
+def test_metrics_diff(mocker):
+ cli_args = parse_args(
+ [
+ "plot",
+ "diff",
+ "--file",
+ "result.extension",
+ "-t",
+ "template",
+ "-d",
+ "datafile",
+ "--select",
+ "column1,column2",
+ "--no-html",
+ "--stdout",
+ "-x",
+ "x_field",
+ "-y",
+ "y_field",
+ "--title",
+ "my_title",
+ "--xlab",
+ "x_title",
+ "--ylab",
+ "y_title",
+ "HEAD",
+ "tag1",
+ "tag2",
+ ]
+ )
+ assert cli_args.func == CmdPlotDiff
+
+ cmd = cli_args.func(cli_args)
+
+ m = mocker.patch.object(cmd.repo, "plot", autospec=True)
+ mocker.patch("builtins.open")
+ mocker.patch("os.path.join")
+
+ assert cmd.run() == 0
+
+ m.assert_called_once_with(
+ datafile="datafile",
+ template="template",
+ revisions=["HEAD", "tag1", "tag2"],
+ fields={"column1", "column2"},
+ path=None,
+ embed=False,
+ x_field="x_field",
+ y_field="y_field",
+ csv_header=True,
+ title="my_title",
+ x_title="x_title",
+ y_title="y_title",
+ )
+
+
+def test_metrics_show(mocker):
+ cli_args = parse_args(
+ [
+ "plot",
+ "show",
+ "-f",
+ "result.extension",
+ "-t",
+ "template",
+ "-s",
+ "$.data",
+ "--no-html",
+ "--stdout",
+ "--no-csv-header",
+ "datafile",
+ ]
+ )
+ assert cli_args.func == CmdPlotShow
+
+ cmd = cli_args.func(cli_args)
+
+ m = mocker.patch.object(cmd.repo, "plot", autospec=True)
+ mocker.patch("builtins.open")
+ mocker.patch("os.path.join")
+
+ assert cmd.run() == 0
+
+ m.assert_called_once_with(
+ datafile="datafile",
+ template="template",
+ revisions=None,
+ fields=None,
+ path="$.data",
+ embed=False,
+ x_field=None,
+ y_field=None,
+ csv_header=False,
+ title=None,
+ x_title=None,
+ y_title=None,
+ )
+
+
+@pytest.mark.parametrize(
+ "arg_revisions,is_dirty,expected_revisions",
+ [
+ ([], False, ["workspace"]),
+ ([], True, ["HEAD", "workspace"]),
+ (["v1", "v2", "workspace"], False, ["v1", "v2", "workspace"]),
+ (["v1", "v2", "workspace"], True, ["v1", "v2", "workspace"]),
+ ],
+)
+def test_revisions(mocker, arg_revisions, is_dirty, expected_revisions):
+ args = mocker.MagicMock()
+
+ cmd = CmdPlotDiff(args)
+ mocker.patch.object(args, "revisions", arg_revisions)
+ mocker.patch.object(cmd.repo.scm, "is_dirty", return_value=is_dirty)
+
+ assert cmd._revisions() == expected_revisions
diff --git a/tests/unit/test_plot.py b/tests/unit/test_plot.py
new file mode 100644
index 0000000000..fcf2ce30b1
--- /dev/null
+++ b/tests/unit/test_plot.py
@@ -0,0 +1,46 @@
+import pytest
+
+from dvc.repo.plot.data import _apply_path, _lists, _find_data
+
+
+@pytest.mark.parametrize(
+ "path,expected_result",
+ [
+ ("$.some.path[*].a", [{"a": 1}, {"a": 4}]),
+ ("$.some.path", [{"a": 1, "b": 2, "c": 3}, {"a": 4, "b": 5, "c": 6}]),
+ ],
+)
+def test_parse_json(path, expected_result):
+ value = {
+ "some": {"path": [{"a": 1, "b": 2, "c": 3}, {"a": 4, "b": 5, "c": 6}]}
+ }
+
+ result = _apply_path(value, path=path)
+
+ assert result == expected_result
+
+
+@pytest.mark.parametrize(
+ "dictionary, expected_result",
+ [
+ ({}, []),
+ ({"x": ["a", "b", "c"]}, [["a", "b", "c"]]),
+ (
+ {"x": {"y": ["a", "b"]}, "z": {"w": ["c", "d"]}},
+ [["a", "b"], ["c", "d"]],
+ ),
+ ],
+)
+def test_finding_lists(dictionary, expected_result):
+ result = _lists(dictionary)
+
+ assert list(result) == expected_result
+
+
+@pytest.mark.parametrize("fields", [{"x"}, set()])
+def test_finding_data(fields):
+ data = {"a": {"b": [{"x": 2, "y": 3}, {"x": 1, "y": 5}]}}
+
+ result = _find_data(data, fields=fields)
+
+ assert result == [{"x": 2, "y": 3}, {"x": 1, "y": 5}]