diff --git a/dvc/cli.py b/dvc/cli.py index 0e4a22a25c..98521646a1 100644 --- a/dvc/cli.py +++ b/dvc/cli.py @@ -35,6 +35,7 @@ update, version, git_hook, + plot, ) from .command.base import fix_subparsers from .exceptions import DvcParserError @@ -74,6 +75,7 @@ version, update, git_hook, + plot, ] diff --git a/dvc/command/plot.py b/dvc/command/plot.py new file mode 100644 index 0000000000..b507e73063 --- /dev/null +++ b/dvc/command/plot.py @@ -0,0 +1,243 @@ +import argparse +import logging +import os + +from dvc.command.base import append_doc_link, CmdBase, fix_subparsers +from dvc.exceptions import DvcException +from dvc.repo.plot.data import WORKSPACE_REVISION_NAME + +logger = logging.getLogger(__name__) + + +class CmdPLot(CmdBase): + def _revisions(self): + raise NotImplementedError + + def _result_file(self): + if self.args.file: + return self.args.file + + extension = self._result_extension() + base = self._result_basename() + + result_file = base + extension + return result_file + + def _result_basename(self): + if self.args.datafile: + return self.args.datafile + return "plot" + + def _result_extension(self): + if not self.args.no_html: + return ".html" + elif self.args.template: + return os.path.splitext(self.args.template)[-1] + return ".json" + + def run(self): + fields = None + jsonpath = None + if self.args.select: + if self.args.select.startswith("$"): + jsonpath = self.args.select + else: + fields = set(self.args.select.split(",")) + try: + plot_string = self.repo.plot( + datafile=self.args.datafile, + template=self.args.template, + revisions=self._revisions(), + fields=fields, + x_field=self.args.x, + y_field=self.args.y, + path=jsonpath, + embed=not self.args.no_html, + csv_header=not self.args.no_csv_header, + title=self.args.title, + x_title=self.args.xlab, + y_title=self.args.ylab, + ) + + if self.args.stdout: + logger.info(plot_string) + else: + result_path = self._result_file() + with open(result_path, "w") as fobj: + fobj.write(plot_string) + + logger.info( + "file://{}".format( + os.path.join(self.repo.root_dir, result_path) + ) + ) + + except DvcException: + logger.exception("") + return 1 + + return 0 + + +class CmdPlotShow(CmdPLot): + def _revisions(self): + return None + + +class CmdPlotDiff(CmdPLot): + def _revisions(self): + revisions = self.args.revisions or [] + if len(revisions) <= 1: + if len(revisions) == 0 and self.repo.scm.is_dirty(): + revisions.append("HEAD") + revisions.append(WORKSPACE_REVISION_NAME) + return revisions + + +def add_parser(subparsers, parent_parser): + PLOT_HELP = ( + "Generating plots for continuous metrics stored in structured files " + "(JSON, CSV, TSV)." + ) + + plot_parser = subparsers.add_parser( + "plot", + parents=[parent_parser], + description=append_doc_link(PLOT_HELP, "plot"), + help=PLOT_HELP, + formatter_class=argparse.RawDescriptionHelpFormatter, + ) + plot_subparsers = plot_parser.add_subparsers( + dest="cmd", + help="Use `dvc plot CMD --help` to display command-specific help.", + ) + + fix_subparsers(plot_subparsers) + + SHOW_HELP = "Generate a plot image file from a continuous metrics file." + plot_show_parser = plot_subparsers.add_parser( + "show", + parents=[parent_parser], + description=append_doc_link(SHOW_HELP, "plot/show"), + help=SHOW_HELP, + formatter_class=argparse.RawDescriptionHelpFormatter, + ) + plot_show_parser.add_argument( + "-t", + "--template", + nargs="?", + default=None, + help="File to be injected with data.", + ) + plot_show_parser.add_argument( + "-f", "--file", default=None, help="Name of the generated file." + ) + plot_show_parser.add_argument( + "-s", + "--select", + default=None, + help="Choose which field(s) or JSONPath to include in the plot.", + ) + plot_show_parser.add_argument( + "-x", default=None, help="Field name for x axis." + ) + plot_show_parser.add_argument( + "-y", default=None, help="Field name for y axis." + ) + plot_show_parser.add_argument( + "--stdout", + action="store_true", + default=False, + help="Print plot specification to stdout.", + ) + plot_show_parser.add_argument( + "--no-csv-header", + action="store_true", + default=False, + help="Required when CSV or TSV datafile does not have a header.", + ) + plot_show_parser.add_argument( + "--no-html", + action="store_true", + default=False, + help="Do not wrap Vega plot JSON with HTML.", + ) + plot_show_parser.add_argument("--title", default=None, help="Plot title.") + plot_show_parser.add_argument("--xlab", default=None, help="X axis title.") + plot_show_parser.add_argument("--ylab", default=None, help="Y axis title.") + plot_show_parser.add_argument( + "datafile", + nargs="?", + default=None, + help="Continuous metrics file to visualize.", + ) + plot_show_parser.set_defaults(func=CmdPlotShow) + + PLOT_DIFF_HELP = ( + "Plot continuous metrics differences between commits in the DVC " + "repository, or between the last commit and the workspace." + ) + plot_diff_parser = plot_subparsers.add_parser( + "diff", + parents=[parent_parser], + description=append_doc_link(PLOT_DIFF_HELP, "plot/diff"), + help=PLOT_DIFF_HELP, + formatter_class=argparse.RawDescriptionHelpFormatter, + ) + plot_diff_parser.add_argument( + "-t", + "--template", + nargs="?", + default=None, + help="File to be injected with data.", + ) + plot_diff_parser.add_argument( + "-d", + "--datafile", + nargs="?", + default=None, + help="Continuous metrics file to visualize.", + ) + plot_diff_parser.add_argument( + "-f", "--file", default=None, help="Name of the generated file." + ) + plot_diff_parser.add_argument( + "-s", + "--select", + default=None, + help="Choose which field(s) or JSONPath to include in the plot.", + ) + plot_diff_parser.add_argument( + "-x", default=None, help="Field name for x axis." + ) + plot_diff_parser.add_argument( + "-y", default=None, help="Field name for y axis." + ) + plot_diff_parser.add_argument( + "--stdout", + action="store_true", + default=False, + help="Print plot specification to stdout.", + ) + plot_diff_parser.add_argument( + "--no-csv-header", + action="store_true", + default=False, + help="Provided CSV ot TSV datafile does not have a header.", + ) + plot_diff_parser.add_argument( + "--no-html", + action="store_true", + default=False, + help="Do not wrap Vega plot JSON with HTML.", + ) + plot_diff_parser.add_argument("--title", default=None, help="Plot title.") + plot_diff_parser.add_argument("--xlab", default=None, help="X axis title.") + plot_diff_parser.add_argument("--ylab", default=None, help="Y axis title.") + plot_diff_parser.add_argument( + "revisions", + nargs="*", + default=None, + help="Git revisions to plot from", + ) + plot_diff_parser.set_defaults(func=CmdPlotDiff) diff --git a/dvc/repo/__init__.py b/dvc/repo/__init__.py index 33431fa4f0..b9c38f575c 100644 --- a/dvc/repo/__init__.py +++ b/dvc/repo/__init__.py @@ -61,6 +61,7 @@ class Repo(object): from dvc.repo.get import get from dvc.repo.get_url import get_url from dvc.repo.update import update + from dvc.repo.plot import plot def __init__(self, root_dir=None): from dvc.state import State @@ -426,6 +427,12 @@ def stages(self): """ return self._collect_stages() + @cached_property + def plot_templates(self): + from dvc.repo.plot.template import PlotTemplates + + return PlotTemplates(self.dvc_dir) + def _collect_stages(self): from dvc.dvcfile import Dvcfile, is_valid_filename diff --git a/dvc/repo/init.py b/dvc/repo/init.py index 3238bb8e94..dda339bbd7 100644 --- a/dvc/repo/init.py +++ b/dvc/repo/init.py @@ -102,7 +102,7 @@ def init(root_dir=os.curdir, no_scm=False, force=False, subdir=False): proj = Repo(root_dir) - scm.add([config.files["repo"]]) + scm.add([config.files["repo"], proj.plot_templates.templates_dir]) if scm.ignore_file: scm.add([os.path.join(dvc_dir, scm.ignore_file)]) diff --git a/dvc/repo/plot/__init__.py b/dvc/repo/plot/__init__.py new file mode 100644 index 0000000000..92a1376e31 --- /dev/null +++ b/dvc/repo/plot/__init__.py @@ -0,0 +1,164 @@ +import logging +import os + +from funcy import first, last + +from dvc.exceptions import DvcException +from dvc.repo.plot.data import PlotData +from dvc.repo.plot.template import Template, NoDataForTemplateError +from dvc.repo import locked + +logger = logging.getLogger(__name__) + +PAGE_HTML = """ + + DVC Plot + + + + + + {divs} + +""" + +DIV_HTML = """
+""" + + +class TooManyDataSourcesError(DvcException): + def __init__(self, datafile, template_datafiles): + super().__init__( + "Unable to infer which of possible data sources: '{}' " + "should be replaced with '{}'.".format( + ", ".join(template_datafiles), datafile + ) + ) + + +class NoDataOrTemplateProvided(DvcException): + def __init__(self): + super().__init__("Datafile or template is not specified.") + + +def _evaluate_templatepath(repo, template=None): + if not template: + return repo.plot_templates.default_template + + if os.path.exists(template): + return template + return repo.plot_templates.get_template(template) + + +@locked +def fill_template( + repo, + datafile, + template_path, + revisions, + fields=None, + path=None, + csv_header=True, + x_field=None, + y_field=None, + **kwargs +): + if x_field and fields: + fields.add(x_field) + + if y_field and fields: + fields.add(y_field) + + template_datafiles, x_anchor, y_anchor = _parse_template( + template_path, datafile + ) + append_index = x_anchor and not x_field + if append_index: + x_field = PlotData.INDEX_FIELD + + template_data = {} + for template_datafile in template_datafiles: + from dvc.repo.plot.data import _load_from_revisions + + plot_datas = _load_from_revisions(repo, template_datafile, revisions) + tmp_data = [] + for pd in plot_datas: + rev_data_points = pd.to_datapoints( + fields=fields, + path=path, + csv_header=csv_header, + append_index=append_index, + ) + + if y_anchor and not y_field: + y_field = _infer_y_field(rev_data_points, x_field) + tmp_data.extend(rev_data_points) + + template_data[template_datafile] = tmp_data + + if len(template_data) == 0: + raise NoDataForTemplateError(template_path) + + return Template.fill( + template_path, + template_data, + priority_datafile=datafile, + x_field=x_field, + y_field=y_field, + **kwargs + ) + + +def _infer_y_field(rev_data_points, x_field): + all_fields = list(first(rev_data_points).keys()) + all_fields.remove(PlotData.REVISION_FIELD) + if x_field and x_field in all_fields: + all_fields.remove(x_field) + y_field = last(all_fields) + return y_field + + +def plot( + repo, datafile=None, template=None, revisions=None, embed=False, **kwargs +): + if revisions is None: + from dvc.repo.plot.data import WORKSPACE_REVISION_NAME + + revisions = [WORKSPACE_REVISION_NAME] + + if not datafile and not template: + raise NoDataOrTemplateProvided() + + template_path = _evaluate_templatepath(repo, template) + + plot_content = fill_template( + repo, datafile, template_path, revisions, **kwargs + ) + + if embed: + div = DIV_HTML.format(id="plot", vega_json=plot_content) + plot_content = PAGE_HTML.format(divs=div) + + return plot_content + + +def _parse_template(template_path, priority_datafile): + with open(template_path, "r") as fobj: + tempalte_content = fobj.read() + + template_datafiles = Template.parse_data_anchors(tempalte_content) + if priority_datafile: + if len(template_datafiles) > 1: + raise TooManyDataSourcesError( + priority_datafile, template_datafiles + ) + template_datafiles = {priority_datafile} + + return ( + template_datafiles, + Template.X_ANCHOR in tempalte_content, + Template.Y_ANCHOR in tempalte_content, + ) diff --git a/dvc/repo/plot/data.py b/dvc/repo/plot/data.py new file mode 100644 index 0000000000..31b53b2692 --- /dev/null +++ b/dvc/repo/plot/data.py @@ -0,0 +1,292 @@ +import csv +import io +import json +import logging +import os +from collections import OrderedDict +from copy import copy + +import yaml +from funcy import first +from yaml import SafeLoader + +from dvc.exceptions import DvcException, PathMissingError + + +logger = logging.getLogger(__name__) + +WORKSPACE_REVISION_NAME = "workspace" + + +class PlotMetricTypeError(DvcException): + def __init__(self, file): + super().__init__( + "'{}' - file type error\n" + "Only JSON, YAML, CSV and TSV formats are supported.".format(file) + ) + + +class PlotDataStructureError(DvcException): + def __init__(self): + super().__init__( + "Plot data extraction failed. Please see " + "documentation for supported data formats." + ) + + +class JsonParsingError(DvcException): + def __init__(self, file): + super().__init__( + "Failed to infer data structure from '{}'. Did you forget " + "to specify JSONpath?".format(file) + ) + + +class NoMetricOnRevisionError(DvcException): + def __init__(self, path, revision): + self.path = path + self.revision = revision + super().__init__( + "Could not find '{}' on revision '{}'".format(path, revision) + ) + + +class NoMetricInHistoryError(DvcException): + def __init__(self, path): + super().__init__("Could not find '{}'.".format(path)) + + +def plot_data(filename, revision, content): + _, extension = os.path.splitext(filename.lower()) + if extension == ".json": + return JSONPlotData(filename, revision, content) + elif extension == ".csv": + return CSVPlotData(filename, revision, content) + elif extension == ".tsv": + return CSVPlotData(filename, revision, content, delimiter="\t") + elif extension == ".yaml": + return YAMLPlotData(filename, revision, content) + raise PlotMetricTypeError(filename) + + +def _filter_fields(data_points, filename, revision, fields=None, **kwargs): + if not fields: + return data_points + assert isinstance(fields, set) + + new_data = [] + for data_point in data_points: + new_dp = copy(data_point) + + keys = set(data_point.keys()) + if keys & fields != fields: + raise DvcException( + "Could not find fields: '{}' for '{}' at '{}'.".format( + ", " "".join(fields), filename, revision + ) + ) + + to_del = keys - fields + for key in to_del: + del new_dp[key] + new_data.append(new_dp) + return new_data + + +def _apply_path(data, path=None, **kwargs): + if not path or not isinstance(data, dict): + return data + + import jsonpath_ng + + found = jsonpath_ng.parse(path).find(data) + first_datum = first(found) + if ( + len(found) == 1 + and isinstance(first_datum.value, list) + and isinstance(first(first_datum.value), dict) + ): + data_points = first_datum.value + elif len(first_datum.path.fields) == 1: + field_name = first(first_datum.path.fields) + data_points = [{field_name: datum.value} for datum in found] + else: + raise PlotDataStructureError() + + if not isinstance(data_points, list) or not ( + isinstance(first(data_points), dict) + ): + raise PlotDataStructureError() + + return data_points + + +def _lists(dictionary): + for _, value in dictionary.items(): + if isinstance(value, dict): + yield from (_lists(value)) + elif isinstance(value, list): + yield value + + +def _find_data(data, fields=None, **kwargs): + if not isinstance(data, dict): + return data + + if not fields: + # just look for first list of dicts + fields = set() + + for l in _lists(data): + if all(isinstance(dp, dict) for dp in l): + if set(first(l).keys()) & fields == fields: + return l + raise PlotDataStructureError() + + +def _append_index(data_points, append_index=False, **kwargs): + if not append_index: + return data_points + + if PlotData.INDEX_FIELD in first(data_points).keys(): + raise DvcException( + "Cannot append index. Field of same name ('{}') found in data. " + "Use `-x` to specify x axis field.".format(PlotData.INDEX_FIELD) + ) + + for index, data_point in enumerate(data_points): + data_point[PlotData.INDEX_FIELD] = index + return data_points + + +def _append_revision(data_points, revision, **kwargs): + for data_point in data_points: + data_point[PlotData.REVISION_FIELD] = revision + return data_points + + +class PlotData: + REVISION_FIELD = "rev" + INDEX_FIELD = "index" + + def __init__(self, filename, revision, content, **kwargs): + self.filename = filename + self.revision = revision + self.content = content + + def raw(self, **kwargs): + raise NotImplementedError + + def _processors(self): + return [_filter_fields, _append_index, _append_revision] + + def to_datapoints(self, **kwargs): + data = self.raw(**kwargs) + + for data_proc in self._processors(): + data = data_proc( + data, filename=self.filename, revision=self.revision, **kwargs + ) + return data + + +class JSONPlotData(PlotData): + def raw(self, **kwargs): + return json.loads(self.content, object_pairs_hook=OrderedDict) + + def _processors(self): + parent_processors = super(JSONPlotData, self)._processors() + return [_apply_path, _find_data] + parent_processors + + +class CSVPlotData(PlotData): + def __init__(self, filename, revision, content, delimiter=","): + super(CSVPlotData, self).__init__(filename, revision, content) + self.delimiter = delimiter + + def raw(self, csv_header=True, **kwargs): + first_row = first(csv.reader(io.StringIO(self.content))) + + if csv_header: + reader = csv.DictReader( + io.StringIO(self.content), delimiter=self.delimiter, + ) + else: + reader = csv.DictReader( + io.StringIO(self.content), + delimiter=self.delimiter, + fieldnames=[str(i) for i in range(len(first_row))], + ) + + fieldnames = reader.fieldnames + data = [row for row in reader] + + return [ + OrderedDict([(field, data_point[field]) for field in fieldnames]) + for data_point in data + ] + + +class YAMLPlotData(PlotData): + def raw(self, **kwargs): + class OrderedLoader(SafeLoader): + pass + + def construct_mapping(loader, node): + loader.flatten_mapping(node) + return OrderedDict(loader.construct_pairs(node)) + + OrderedLoader.add_constructor( + yaml.resolver.BaseResolver.DEFAULT_MAPPING_TAG, construct_mapping + ) + + return yaml.load(self.content, OrderedLoader) + + +def _load_from_revision(repo, datafile, revision): + if revision is WORKSPACE_REVISION_NAME: + + def open_datafile(): + return repo.tree.open(datafile, "r") + + else: + + def open_datafile(): + from dvc import api + + return api.open(datafile, repo.root_dir, revision) + + try: + with open_datafile() as fobj: + datafile_content = fobj.read() + + except (FileNotFoundError, PathMissingError): + raise NoMetricOnRevisionError(datafile, revision) + + return plot_data(datafile, revision, datafile_content) + + +def _load_from_revisions(repo, datafile, revisions): + data = [] + exceptions = [] + + for rev in revisions: + try: + data.append(_load_from_revision(repo, datafile, rev)) + except NoMetricOnRevisionError as e: + exceptions.append(e) + except PlotMetricTypeError: + raise + except (yaml.error.YAMLError, json.decoder.JSONDecodeError, csv.Error): + logger.error("Failed to parse '{}' at '{}'.".format(datafile, rev)) + raise + + if not data and exceptions: + raise NoMetricInHistoryError(datafile) + else: + for e in exceptions: + logger.warning( + "File '{}' was not found at: '{}'. It will not be " + "plotted.".format(e.path, e.revision) + ) + return data diff --git a/dvc/repo/plot/template.py b/dvc/repo/plot/template.py new file mode 100644 index 0000000000..2850a6220f --- /dev/null +++ b/dvc/repo/plot/template.py @@ -0,0 +1,276 @@ +import json +import logging +import os +import re + +from funcy import cached_property + +from dvc.exceptions import DvcException +from dvc.utils.fs import makedirs + + +logger = logging.getLogger(__name__) + + +class TemplateNotFoundError(DvcException): + def __init__(self, path): + super().__init__("Template '{}' not found.".format(path)) + + +class NoDataForTemplateError(DvcException): + def __init__(self, template_path): + super().__init__( + "No data provided for '{}'.".format(os.path.relpath(template_path)) + ) + + +class Template: + INDENT = 4 + SEPARATORS = (",", ": ") + EXTENSION = ".json" + METRIC_DATA_ANCHOR = "" + X_ANCHOR = "" + Y_ANCHOR = "" + TITLE_ANCHOR = "" + X_TITLE_ANCHOR = "" + Y_TITLE_ANCHOR = "" + + def __init__(self, templates_dir): + self.plot_templates_dir = templates_dir + + def dump(self): + makedirs(self.plot_templates_dir, exist_ok=True) + + with open( + os.path.join( + self.plot_templates_dir, self.TEMPLATE_NAME + self.EXTENSION + ), + "w", + ) as fobj: + json.dump( + self.DEFAULT_CONTENT, + fobj, + indent=self.INDENT, + separators=self.SEPARATORS, + ) + + @staticmethod + def get_data_anchor(template_content): + regex = re.compile('""]*>"') + return regex.findall(template_content) + + @staticmethod + def parse_data_anchors(template_content): + data_files = { + Template.get_datafile(m) + for m in Template.get_data_anchor(template_content) + } + return {df for df in data_files if df} + + @staticmethod + def get_datafile(anchor_string): + return ( + anchor_string.replace("<", "") + .replace(">", "") + .replace('"', "") + .replace("DVC_METRIC_DATA", "") + .replace(",", "") + ) + + @staticmethod + def fill( + template_path, + data, + priority_datafile=None, + x_field=None, + y_field=None, + title=None, + x_title=None, + y_title=None, + ): + with open(template_path, "r") as fobj: + result_content = fobj.read() + + result_content = Template._replace_data_anchors( + result_content, data, priority_datafile + ) + + result_content = Template._replace_metadata_anchors( + result_content, title, x_field, x_title, y_field, y_title + ) + + return result_content + + @staticmethod + def _replace_metadata_anchors( + result_content, title, x_field, x_title, y_field, y_title + ): + if Template.TITLE_ANCHOR in result_content: + if title: + result_content = result_content.replace( + Template.TITLE_ANCHOR, title + ) + else: + result_content = result_content.replace( + Template.TITLE_ANCHOR, "" + ) + if Template.X_ANCHOR in result_content and x_field: + result_content = result_content.replace(Template.X_ANCHOR, x_field) + if Template.Y_ANCHOR in result_content and y_field: + result_content = result_content.replace(Template.Y_ANCHOR, y_field) + if Template.X_TITLE_ANCHOR in result_content: + if not x_title and x_field: + x_title = x_field + result_content = result_content.replace( + Template.X_TITLE_ANCHOR, x_title + ) + if Template.Y_TITLE_ANCHOR in result_content: + if not y_title and y_field: + y_title = y_field + result_content = result_content.replace( + Template.Y_TITLE_ANCHOR, y_title + ) + return result_content + + @staticmethod + def _replace_data_anchors(result_content, data, priority_datafile): + for anchor in Template.get_data_anchor(result_content): + file = Template.get_datafile(anchor) + + if not file or priority_datafile: + key = priority_datafile + else: + key = file + + result_content = result_content.replace( + anchor, + json.dumps( + data[key], + indent=Template.INDENT, + separators=Template.SEPARATORS, + sort_keys=True, + ), + ) + return result_content + + +class DefaultLinearTemplate(Template): + TEMPLATE_NAME = "default" + + DEFAULT_CONTENT = { + "$schema": "https://vega.github.io/schema/vega-lite/v4.json", + "data": {"values": Template.METRIC_DATA_ANCHOR}, + "title": Template.TITLE_ANCHOR, + "mark": {"type": "line"}, + "encoding": { + "x": { + "field": Template.X_ANCHOR, + "type": "quantitative", + "title": Template.X_TITLE_ANCHOR, + }, + "y": { + "field": Template.Y_ANCHOR, + "type": "quantitative", + "title": Template.Y_TITLE_ANCHOR, + }, + "color": {"field": "rev", "type": "nominal"}, + }, + } + + +class DefaultConfusionTemplate(Template): + TEMPLATE_NAME = "confusion" + DEFAULT_CONTENT = { + "$schema": "https://vega.github.io/schema/vega-lite/v4.json", + "data": {"values": Template.METRIC_DATA_ANCHOR}, + "title": Template.TITLE_ANCHOR, + "mark": "rect", + "encoding": { + "x": { + "field": Template.X_ANCHOR, + "type": "nominal", + "sort": "ascending", + "title": Template.X_TITLE_ANCHOR, + }, + "y": { + "field": Template.Y_ANCHOR, + "type": "nominal", + "sort": "ascending", + "title": Template.Y_TITLE_ANCHOR, + }, + "color": {"aggregate": "count", "type": "quantitative"}, + "facet": {"field": "rev", "type": "nominal"}, + }, + } + + +class DefaultScatterTemplate(Template): + TEMPLATE_NAME = "scatter" + DEFAULT_CONTENT = { + "$schema": "https://vega.github.io/schema/vega-lite/v4.json", + "data": {"values": Template.METRIC_DATA_ANCHOR}, + "title": Template.TITLE_ANCHOR, + "mark": "point", + "encoding": { + "x": { + "field": Template.X_ANCHOR, + "type": "quantitative", + "title": Template.X_TITLE_ANCHOR, + }, + "y": { + "field": Template.Y_ANCHOR, + "type": "quantitative", + "title": Template.Y_TITLE_ANCHOR, + }, + "color": {"field": "rev", "type": "nominal"}, + }, + } + + +class PlotTemplates: + TEMPLATES_DIR = "plot" + TEMPLATES = [ + DefaultLinearTemplate, + DefaultConfusionTemplate, + DefaultScatterTemplate, + ] + + @cached_property + def templates_dir(self): + return os.path.join(self.dvc_dir, self.TEMPLATES_DIR) + + @cached_property + def default_template(self): + default_plot_path = os.path.join(self.templates_dir, "default.json") + if not os.path.exists(default_plot_path): + raise TemplateNotFoundError(os.path.relpath(default_plot_path)) + return default_plot_path + + def get_template(self, path): + t_path = os.path.join(self.templates_dir, path) + if os.path.exists(t_path): + return t_path + + all_templates = [ + os.path.join(root, file) + for root, _, files in os.walk(self.templates_dir) + for file in files + ] + matches = [ + template + for template in all_templates + if os.path.splitext(template)[0] == t_path + ] + if matches: + assert len(matches) == 1 + return matches[0] + + raise TemplateNotFoundError(path) + + def __init__(self, dvc_dir): + self.dvc_dir = dvc_dir + + if not os.path.exists(self.templates_dir): + makedirs(self.templates_dir, exist_ok=True) + for t in self.TEMPLATES: + t(self.templates_dir).dump() diff --git a/scripts/completion/dvc.bash b/scripts/completion/dvc.bash index cde8130e62..78f966acde 100644 --- a/scripts/completion/dvc.bash +++ b/scripts/completion/dvc.bash @@ -5,7 +5,7 @@ # - https://stackoverflow.com/questions/12933362 _dvc_commands='add cache checkout commit config destroy diff fetch get-url get gc \ - import-url import init install lock list metrics move pipeline pull push \ + import-url import init install lock list metrics move pipeline plot pull push \ remote remove repro root run status unlock unprotect update version' _dvc_options='-h --help -V --version' @@ -51,6 +51,9 @@ _dvc_pipeline='list show' _dvc_pipeline_list='' _dvc_pipeline_show='-c --commands -o --outs --ascii --dot --tree -l --locked' _dvc_pipeline_show_COMPGEN=_dvc_compgen_DVCFiles +_dvc_plot='show diff' +_dvc_plot_show='-t --template -f --file -s --select -x -y --stdout --no-csv-header --no-html --title --xlab --ylab' +_dvc_plot_diff='-t --template -d --datafile -f --file -s --select -x -y --stdout --no-csv-header --no-html --title --xlab --ylab' _dvc_pull='-j --jobs -r --remote -a --all-branches -T --all-tags -f --force -d --with-deps -R --recursive' _dvc_pull_COMPGEN=_dvc_compgen_DVCFiles _dvc_push='-j --jobs -r --remote -a --all-branches -T --all-tags -d --with-deps -R --recursive' diff --git a/setup.py b/setup.py index a77f5b767d..e27b30843e 100644 --- a/setup.py +++ b/setup.py @@ -129,6 +129,7 @@ def run(self): "mock-ssh-server>=0.6.0", "moto==1.3.14.dev464", "rangehttpserver==1.2.0", + "beautifulsoup4==4.4.0", ] if (sys.version_info) >= (3, 6): diff --git a/tests/func/test_plot.py b/tests/func/test_plot.py new file mode 100644 index 0000000000..6be11320f8 --- /dev/null +++ b/tests/func/test_plot.py @@ -0,0 +1,504 @@ +import csv +import json +import logging +import shutil +from collections import OrderedDict + +import pytest +import yaml +from bs4 import BeautifulSoup +from funcy import first + +from dvc.compat import fspath +from dvc.repo.plot.data import ( + NoMetricInHistoryError, + PlotMetricTypeError, + PlotData, +) +from dvc.repo.plot.template import ( + TemplateNotFoundError, + NoDataForTemplateError, +) +from dvc.repo.plot import NoDataOrTemplateProvided + + +def _remove_whitespace(value): + return value.replace(" ", "").replace("\n", "") + + +def _run_with_metric(tmp_dir, metric_filename, commit=None, tag=None): + tmp_dir.dvc.run(metrics_no_cache=[metric_filename]) + if hasattr(tmp_dir.dvc, "scm"): + tmp_dir.dvc.scm.add([metric_filename, metric_filename + ".dvc"]) + if commit: + tmp_dir.dvc.scm.commit(commit) + if tag: + tmp_dir.dvc.scm.tag(tag) + + +def _write_csv(metric, filename, header=True): + with open(filename, "w", newline="") as csvobj: + if header: + writer = csv.DictWriter( + csvobj, fieldnames=list(first(metric).keys()) + ) + writer.writeheader() + writer.writerows(metric) + else: + writer = csv.writer(csvobj) + for d in metric: + assert len(d) == 1 + writer.writerow(list(d.values())) + + +def _write_json(tmp_dir, metric, filename): + tmp_dir.gen(filename, json.dumps(metric, sort_keys=True)) + + +def test_plot_csv_one_column(tmp_dir, scm, dvc): + # no header + metric = [{"val": 2}, {"val": 3}] + _write_csv(metric, "metric.csv", header=False) + _run_with_metric(tmp_dir, metric_filename="metric.csv") + + plot_string = dvc.plot( + "metric.csv", + csv_header=False, + x_title="x_title", + y_title="y_title", + title="mytitle", + ) + + plot_content = json.loads(plot_string) + assert plot_content["title"] == "mytitle" + assert plot_content["data"]["values"] == [ + {"0": "2", PlotData.INDEX_FIELD: 0, "rev": "workspace"}, + {"0": "3", PlotData.INDEX_FIELD: 1, "rev": "workspace"}, + ] + assert plot_content["encoding"]["x"]["field"] == PlotData.INDEX_FIELD + assert plot_content["encoding"]["y"]["field"] == "0" + assert plot_content["encoding"]["x"]["title"] == "x_title" + assert plot_content["encoding"]["y"]["title"] == "y_title" + + +def test_plot_csv_multiple_columns(tmp_dir, scm, dvc): + metric = [ + OrderedDict([("first_val", 100), ("second_val", 100), ("val", 2)]), + OrderedDict([("first_val", 200), ("second_val", 300), ("val", 3)]), + ] + _write_csv(metric, "metric.csv") + _run_with_metric(tmp_dir, metric_filename="metric.csv") + + plot_string = dvc.plot("metric.csv") + + plot_content = json.loads(plot_string) + assert plot_content["data"]["values"] == [ + { + "val": "2", + PlotData.INDEX_FIELD: 0, + "rev": "workspace", + "first_val": "100", + "second_val": "100", + }, + { + "val": "3", + PlotData.INDEX_FIELD: 1, + "rev": "workspace", + "first_val": "200", + "second_val": "300", + }, + ] + assert plot_content["encoding"]["x"]["field"] == PlotData.INDEX_FIELD + assert plot_content["encoding"]["y"]["field"] == "val" + + +def test_plot_csv_choose_axes(tmp_dir, scm, dvc): + metric = [ + OrderedDict([("first_val", 100), ("second_val", 100), ("val", 2)]), + OrderedDict([("first_val", 200), ("second_val", 300), ("val", 3)]), + ] + _write_csv(metric, "metric.csv") + _run_with_metric(tmp_dir, metric_filename="metric.csv") + + plot_string = dvc.plot( + "metric.csv", x_field="first_val", y_field="second_val" + ) + + plot_content = json.loads(plot_string) + assert plot_content["data"]["values"] == [ + { + "val": "2", + "rev": "workspace", + "first_val": "100", + "second_val": "100", + }, + { + "val": "3", + "rev": "workspace", + "first_val": "200", + "second_val": "300", + }, + ] + assert plot_content["encoding"]["x"]["field"] == "first_val" + assert plot_content["encoding"]["y"]["field"] == "second_val" + + +def test_plot_json_single_val(tmp_dir, scm, dvc): + metric = [{"val": 2}, {"val": 3}] + _write_json(tmp_dir, metric, "metric.json") + _run_with_metric(tmp_dir, "metric.json", "first run") + + plot_string = dvc.plot("metric.json") + + plot_json = json.loads(plot_string) + assert plot_json["data"]["values"] == [ + {"val": 2, PlotData.INDEX_FIELD: 0, "rev": "workspace"}, + {"val": 3, PlotData.INDEX_FIELD: 1, "rev": "workspace"}, + ] + assert plot_json["encoding"]["x"]["field"] == PlotData.INDEX_FIELD + assert plot_json["encoding"]["y"]["field"] == "val" + + +def test_plot_json_multiple_val(tmp_dir, scm, dvc): + metric = [ + {"first_val": 100, "val": 2}, + {"first_val": 200, "val": 3}, + ] + _write_json(tmp_dir, metric, "metric.json") + _run_with_metric(tmp_dir, "metric.json", "first run") + + plot_string = dvc.plot("metric.json") + + plot_content = json.loads(plot_string) + assert plot_content["data"]["values"] == [ + { + "val": 2, + PlotData.INDEX_FIELD: 0, + "first_val": 100, + "rev": "workspace", + }, + { + "val": 3, + PlotData.INDEX_FIELD: 1, + "first_val": 200, + "rev": "workspace", + }, + ] + assert plot_content["encoding"]["x"]["field"] == PlotData.INDEX_FIELD + assert plot_content["encoding"]["y"]["field"] == "val" + + +def test_plot_confusion(tmp_dir, dvc): + confusion_matrix = [ + {"predicted": "B", "actual": "A"}, + {"predicted": "A", "actual": "A"}, + ] + _write_json(tmp_dir, confusion_matrix, "metric.json") + _run_with_metric(tmp_dir, "metric.json", "first run") + + plot_string = dvc.plot( + datafile="metric.json", + template="confusion", + x_field="predicted", + y_field="actual", + ) + + plot_content = json.loads(plot_string) + assert plot_content["data"]["values"] == [ + {"predicted": "B", "actual": "A", "rev": "workspace"}, + {"predicted": "A", "actual": "A", "rev": "workspace"}, + ] + assert plot_content["encoding"]["x"]["field"] == "predicted" + assert plot_content["encoding"]["y"]["field"] == "actual" + + +def test_plot_multiple_revs_default(tmp_dir, scm, dvc): + metric_1 = [{"y": 2}, {"y": 3}] + _write_json(tmp_dir, metric_1, "metric.json") + _run_with_metric(tmp_dir, "metric.json", "init", "v1") + + metric_2 = [{"y": 3}, {"y": 5}] + _write_json(tmp_dir, metric_2, "metric.json") + _run_with_metric(tmp_dir, "metric.json", "second", "v2") + + metric_3 = [{"y": 5}, {"y": 6}] + _write_json(tmp_dir, metric_3, "metric.json") + _run_with_metric(tmp_dir, "metric.json", "third") + + plot_string = dvc.plot( + "metric.json", fields={"y"}, revisions=["HEAD", "v2", "v1"], + ) + + plot_content = json.loads(plot_string) + assert plot_content["data"]["values"] == [ + {"y": 5, PlotData.INDEX_FIELD: 0, "rev": "HEAD"}, + {"y": 6, PlotData.INDEX_FIELD: 1, "rev": "HEAD"}, + {"y": 3, PlotData.INDEX_FIELD: 0, "rev": "v2"}, + {"y": 5, PlotData.INDEX_FIELD: 1, "rev": "v2"}, + {"y": 2, PlotData.INDEX_FIELD: 0, "rev": "v1"}, + {"y": 3, PlotData.INDEX_FIELD: 1, "rev": "v1"}, + ] + assert plot_content["encoding"]["x"]["field"] == PlotData.INDEX_FIELD + assert plot_content["encoding"]["y"]["field"] == "y" + + +def test_plot_multiple_revs(tmp_dir, scm, dvc): + shutil.copy( + fspath(tmp_dir / ".dvc" / "plot" / "default.json"), "template.json" + ) + + metric_1 = [{"y": 2}, {"y": 3}] + _write_json(tmp_dir, metric_1, "metric.json") + _run_with_metric(tmp_dir, "metric.json", "init", "v1") + + metric_2 = [{"y": 3}, {"y": 5}] + _write_json(tmp_dir, metric_2, "metric.json") + _run_with_metric(tmp_dir, "metric.json", "second", "v2") + + metric_3 = [{"y": 5}, {"y": 6}] + _write_json(tmp_dir, metric_3, "metric.json") + _run_with_metric(tmp_dir, "metric.json", "third") + + plot_string = dvc.plot( + "metric.json", + template="template.json", + revisions=["HEAD", "v2", "v1"], + ) + + plot_content = json.loads(plot_string) + assert plot_content["data"]["values"] == [ + {"y": 5, PlotData.INDEX_FIELD: 0, "rev": "HEAD"}, + {"y": 6, PlotData.INDEX_FIELD: 1, "rev": "HEAD"}, + {"y": 3, PlotData.INDEX_FIELD: 0, "rev": "v2"}, + {"y": 5, PlotData.INDEX_FIELD: 1, "rev": "v2"}, + {"y": 2, PlotData.INDEX_FIELD: 0, "rev": "v1"}, + {"y": 3, PlotData.INDEX_FIELD: 1, "rev": "v1"}, + ] + assert plot_content["encoding"]["x"]["field"] == PlotData.INDEX_FIELD + assert plot_content["encoding"]["y"]["field"] == "y" + + +def test_plot_even_if_metric_missing(tmp_dir, scm, dvc, caplog): + tmp_dir.scm_gen("some_file", "content", commit="there is no metric") + scm.tag("v1") + + metric = [{"y": 2}, {"y": 3}] + _write_json(tmp_dir, metric, "metric.json") + _run_with_metric(tmp_dir, "metric.json", "there is metric", "v2") + + caplog.clear() + with caplog.at_level(logging.WARNING, "dvc"): + plot_string = dvc.plot("metric.json", revisions=["v1", "v2"]) + assert ( + "File 'metric.json' was not found at: 'v1'. " + "It will not be plotted." in caplog.text + ) + + plot_content = json.loads(plot_string) + assert plot_content["data"]["values"] == [ + {"y": 2, PlotData.INDEX_FIELD: 0, "rev": "v2"}, + {"y": 3, PlotData.INDEX_FIELD: 1, "rev": "v2"}, + ] + assert plot_content["encoding"]["x"]["field"] == PlotData.INDEX_FIELD + assert plot_content["encoding"]["y"]["field"] == "y" + + +def test_throw_on_no_metric_at_all(tmp_dir, scm, dvc, caplog): + tmp_dir.scm_gen("some_file", "content", commit="there is no metric") + scm.tag("v1") + + tmp_dir.gen("some_file", "make repo dirty") + + caplog.clear() + with pytest.raises(NoMetricInHistoryError) as error, caplog.at_level( + logging.WARNING, "dvc" + ): + dvc.plot("metric.json", revisions=["v1"]) + + # do not warn if none found + assert len(caplog.messages) == 0 + + assert str(error.value) == "Could not find 'metric.json'." + + +@pytest.fixture() +def custom_template(tmp_dir, dvc): + custom_template = tmp_dir / "custom_template.json" + shutil.copy( + fspath(tmp_dir / ".dvc" / "plot" / "default.json"), + fspath(custom_template), + ) + return custom_template + + +def test_custom_template(tmp_dir, scm, dvc, custom_template): + metric = [{"a": 1, "b": 2}, {"a": 2, "b": 3}] + _write_json(tmp_dir, metric, "metric.json") + _run_with_metric(tmp_dir, "metric.json", "init", "v1") + + plot_string = dvc.plot( + "metric.json", fspath(custom_template), x_field="a", y_field="b" + ) + + plot_content = json.loads(plot_string) + assert plot_content["data"]["values"] == [ + {"a": 1, "b": 2, "rev": "workspace"}, + {"a": 2, "b": 3, "rev": "workspace"}, + ] + assert plot_content["encoding"]["x"]["field"] == "a" + assert plot_content["encoding"]["y"]["field"] == "b" + + +def _replace(path, src, dst): + path.write_text(path.read_text().replace(src, dst)) + + +def test_custom_template_with_specified_data( + tmp_dir, scm, dvc, custom_template +): + _replace( + custom_template, "DVC_METRIC_DATA", "DVC_METRIC_DATA,metric.json", + ) + + metric = [{"a": 1, "b": 2}, {"a": 2, "b": 3}] + _write_json(tmp_dir, metric, "metric.json") + _run_with_metric(tmp_dir, "metric.json", "init", "v1") + + plot_string = dvc.plot( + datafile=None, + template=fspath(custom_template), + x_field="a", + y_field="b", + ) + + plot_content = json.loads(plot_string) + assert plot_content["data"]["values"] == [ + {"a": 1, "b": 2, "rev": "workspace"}, + {"a": 2, "b": 3, "rev": "workspace"}, + ] + assert plot_content["encoding"]["x"]["field"] == "a" + assert plot_content["encoding"]["y"]["field"] == "b" + + +def test_plot_override_specified_data_source(tmp_dir, scm, dvc): + shutil.copy( + fspath(tmp_dir / ".dvc" / "plot" / "default.json"), + fspath(tmp_dir / "newtemplate.json"), + ) + _replace( + tmp_dir / "newtemplate.json", + "DVC_METRIC_DATA", + "DVC_METRIC_DATA,metric.json", + ) + + metric = [{"a": 1, "b": 2}, {"a": 2, "b": 3}] + _write_json(tmp_dir, metric, "metric2.json") + _run_with_metric(tmp_dir, "metric2.json", "init", "v1") + + plot_string = dvc.plot( + datafile="metric2.json", template="newtemplate.json", x_field="a" + ) + + plot_content = json.loads(plot_string) + assert plot_content["data"]["values"] == [ + {"a": 1, "b": 2, "rev": "workspace"}, + {"a": 2, "b": 3, "rev": "workspace"}, + ] + assert plot_content["encoding"]["x"]["field"] == "a" + assert plot_content["encoding"]["y"]["field"] == "b" + + +def test_should_raise_on_no_template_and_datafile(tmp_dir, dvc): + with pytest.raises(NoDataOrTemplateProvided): + dvc.plot() + + +def test_should_raise_on_no_template(tmp_dir, dvc): + with pytest.raises(TemplateNotFoundError): + dvc.plot("metric.json", "non_existing_template.json") + + +def test_plot_no_data(tmp_dir, dvc): + with pytest.raises(NoDataForTemplateError): + dvc.plot(template="default") + + +def test_plot_wrong_metric_type(tmp_dir, scm, dvc): + tmp_dir.scm_gen("metric.txt", "content", commit="initial") + with pytest.raises(PlotMetricTypeError): + dvc.plot(datafile="metric.txt") + + +def test_plot_choose_columns(tmp_dir, scm, dvc, custom_template): + metric = [{"a": 1, "b": 2, "c": 3}, {"a": 2, "b": 3, "c": 4}] + _write_json(tmp_dir, metric, "metric.json") + _run_with_metric(tmp_dir, "metric.json", "init", "v1") + + plot_string = dvc.plot( + "metric.json", + fspath(custom_template), + fields={"b", "c"}, + x_field="b", + y_field="c", + ) + + plot_content = json.loads(plot_string) + assert plot_content["data"]["values"] == [ + {"b": 2, "c": 3, "rev": "workspace"}, + {"b": 3, "c": 4, "rev": "workspace"}, + ] + assert plot_content["encoding"]["x"]["field"] == "b" + assert plot_content["encoding"]["y"]["field"] == "c" + + +def test_plot_default_choose_column(tmp_dir, scm, dvc): + metric = [{"a": 1, "b": 2, "c": 3}, {"a": 2, "b": 3, "c": 4}] + _write_json(tmp_dir, metric, "metric.json") + _run_with_metric(tmp_dir, "metric.json", "init", "v1") + + plot_string = dvc.plot("metric.json", fields={"b"}) + + plot_content = json.loads(plot_string) + assert plot_content["data"]["values"] == [ + {PlotData.INDEX_FIELD: 0, "b": 2, "rev": "workspace"}, + {PlotData.INDEX_FIELD: 1, "b": 3, "rev": "workspace"}, + ] + assert plot_content["encoding"]["x"]["field"] == PlotData.INDEX_FIELD + assert plot_content["encoding"]["y"]["field"] == "b" + + +def test_plot_embed(tmp_dir, scm, dvc): + metric = [{"val": 2}, {"val": 3}] + _write_json(tmp_dir, metric, "metric.json") + _run_with_metric(tmp_dir, "metric.json", "first run") + + plot_string = dvc.plot("metric.json", embed=True, y_field="val") + + page_content = BeautifulSoup(plot_string) + data_dump = json.dumps( + [ + {"val": 2, PlotData.INDEX_FIELD: 0, "rev": "workspace"}, + {"val": 3, PlotData.INDEX_FIELD: 1, "rev": "workspace"}, + ], + sort_keys=True, + ) + + assert _remove_whitespace(data_dump) in _remove_whitespace( + first(page_content.body.script.contents) + ) + + +def test_plot_yaml(tmp_dir, scm, dvc): + metric = [{"val": 2}, {"val": 3}] + with open("metric.yaml", "w") as fobj: + yaml.dump(metric, fobj) + + _run_with_metric(tmp_dir, metric_filename="metric.yaml") + + plot_string = dvc.plot("metric.yaml",) + + plot_content = json.loads(plot_string) + assert plot_content["data"]["values"] == [ + {"val": 2, PlotData.INDEX_FIELD: 0, "rev": "workspace"}, + {"val": 3, PlotData.INDEX_FIELD: 1, "rev": "workspace"}, + ] diff --git a/tests/unit/command/test_plot.py b/tests/unit/command/test_plot.py new file mode 100644 index 0000000000..c0a3b00bc9 --- /dev/null +++ b/tests/unit/command/test_plot.py @@ -0,0 +1,122 @@ +import pytest + +from dvc.cli import parse_args +from dvc.command.plot import CmdPlotShow, CmdPlotDiff + + +def test_metrics_diff(mocker): + cli_args = parse_args( + [ + "plot", + "diff", + "--file", + "result.extension", + "-t", + "template", + "-d", + "datafile", + "--select", + "column1,column2", + "--no-html", + "--stdout", + "-x", + "x_field", + "-y", + "y_field", + "--title", + "my_title", + "--xlab", + "x_title", + "--ylab", + "y_title", + "HEAD", + "tag1", + "tag2", + ] + ) + assert cli_args.func == CmdPlotDiff + + cmd = cli_args.func(cli_args) + + m = mocker.patch.object(cmd.repo, "plot", autospec=True) + mocker.patch("builtins.open") + mocker.patch("os.path.join") + + assert cmd.run() == 0 + + m.assert_called_once_with( + datafile="datafile", + template="template", + revisions=["HEAD", "tag1", "tag2"], + fields={"column1", "column2"}, + path=None, + embed=False, + x_field="x_field", + y_field="y_field", + csv_header=True, + title="my_title", + x_title="x_title", + y_title="y_title", + ) + + +def test_metrics_show(mocker): + cli_args = parse_args( + [ + "plot", + "show", + "-f", + "result.extension", + "-t", + "template", + "-s", + "$.data", + "--no-html", + "--stdout", + "--no-csv-header", + "datafile", + ] + ) + assert cli_args.func == CmdPlotShow + + cmd = cli_args.func(cli_args) + + m = mocker.patch.object(cmd.repo, "plot", autospec=True) + mocker.patch("builtins.open") + mocker.patch("os.path.join") + + assert cmd.run() == 0 + + m.assert_called_once_with( + datafile="datafile", + template="template", + revisions=None, + fields=None, + path="$.data", + embed=False, + x_field=None, + y_field=None, + csv_header=False, + title=None, + x_title=None, + y_title=None, + ) + + +@pytest.mark.parametrize( + "arg_revisions,is_dirty,expected_revisions", + [ + ([], False, ["workspace"]), + ([], True, ["HEAD", "workspace"]), + (["v1", "v2", "workspace"], False, ["v1", "v2", "workspace"]), + (["v1", "v2", "workspace"], True, ["v1", "v2", "workspace"]), + ], +) +def test_revisions(mocker, arg_revisions, is_dirty, expected_revisions): + args = mocker.MagicMock() + + cmd = CmdPlotDiff(args) + mocker.patch.object(args, "revisions", arg_revisions) + mocker.patch.object(cmd.repo.scm, "is_dirty", return_value=is_dirty) + + assert cmd._revisions() == expected_revisions diff --git a/tests/unit/test_plot.py b/tests/unit/test_plot.py new file mode 100644 index 0000000000..fcf2ce30b1 --- /dev/null +++ b/tests/unit/test_plot.py @@ -0,0 +1,46 @@ +import pytest + +from dvc.repo.plot.data import _apply_path, _lists, _find_data + + +@pytest.mark.parametrize( + "path,expected_result", + [ + ("$.some.path[*].a", [{"a": 1}, {"a": 4}]), + ("$.some.path", [{"a": 1, "b": 2, "c": 3}, {"a": 4, "b": 5, "c": 6}]), + ], +) +def test_parse_json(path, expected_result): + value = { + "some": {"path": [{"a": 1, "b": 2, "c": 3}, {"a": 4, "b": 5, "c": 6}]} + } + + result = _apply_path(value, path=path) + + assert result == expected_result + + +@pytest.mark.parametrize( + "dictionary, expected_result", + [ + ({}, []), + ({"x": ["a", "b", "c"]}, [["a", "b", "c"]]), + ( + {"x": {"y": ["a", "b"]}, "z": {"w": ["c", "d"]}}, + [["a", "b"], ["c", "d"]], + ), + ], +) +def test_finding_lists(dictionary, expected_result): + result = _lists(dictionary) + + assert list(result) == expected_result + + +@pytest.mark.parametrize("fields", [{"x"}, set()]) +def test_finding_data(fields): + data = {"a": {"b": [{"x": 2, "y": 3}, {"x": 1, "y": 5}]}} + + result = _find_data(data, fields=fields) + + assert result == [{"x": 2, "y": 3}, {"x": 1, "y": 5}]