diff --git a/dvc/repo/plots/__init__.py b/dvc/repo/plots/__init__.py index 281a8e8b28..5973f1198f 100644 --- a/dvc/repo/plots/__init__.py +++ b/dvc/repo/plots/__init__.py @@ -472,15 +472,12 @@ def _collect_pipeline_files(repo, targets: List[str], props, onerror=None): for dvcfile, plots_def in top_plots.items(): dvcfile_path = _relpath(repo.dvcfs, dvcfile) dvcfile_defs_dict: Dict[str, Union[Dict, None]] = {} - if isinstance(plots_def, list): - for elem in plots_def: - if isinstance(elem, str): - dvcfile_defs_dict[elem] = None - else: - k, v = list(elem.items())[0] - dvcfile_defs_dict[k] = v - else: - dvcfile_defs_dict = plots_def + for elem in plots_def: + if isinstance(elem, str): + dvcfile_defs_dict[elem] = None + else: + k, v = list(elem.items())[0] + dvcfile_defs_dict[k] = v resolved = _resolve_definitions( repo.dvcfs, diff --git a/dvc/schema.py b/dvc/schema.py index 5ea24a26cd..42127e51e7 100644 --- a/dvc/schema.py +++ b/dvc/schema.py @@ -121,7 +121,7 @@ def validator(data): str: either_or(STAGE_DEFINITION, FOREACH_IN, [FOREACH_KWD, DO_KWD]) } MULTI_STAGE_SCHEMA = { - PLOTS: Any(SINGLE_PLOT_SCHEMA, [Any(str, SINGLE_PLOT_SCHEMA)]), + PLOTS: [Any(str, SINGLE_PLOT_SCHEMA)], STAGES: SINGLE_PIPELINE_STAGE_SCHEMA, VARS_KWD: VARS_SCHEMA, StageParams.PARAM_PARAMS: [str], diff --git a/tests/func/plots/test_show.py b/tests/func/plots/test_show.py index 0a31c7fcf7..39b71fdd41 100644 --- a/tests/func/plots/test_show.py +++ b/tests/func/plots/test_show.py @@ -377,7 +377,7 @@ def test_top_level_plots( config_file = "dvc.yaml" with modify_yaml(config_file) as dvcfile_content: - dvcfile_content["plots"] = plot_config + dvcfile_content["plots"] = [plot_config] result = dvc.plots.show() diff --git a/tests/func/test_dvcfile.py b/tests/func/test_dvcfile.py index fb8bd9d37e..2ceccd44a1 100644 --- a/tests/func/test_dvcfile.py +++ b/tests/func/test_dvcfile.py @@ -407,17 +407,8 @@ def test_dvcfile_load_dump_stage_with_desc_meta(tmp_dir, dvc): assert (tmp_dir / "dvc.yaml").parse() == data -@pytest.mark.parametrize( - "data", - ( - { - "plots": { - "path/to/plot": {"x": "value", "y": "value"}, - "path/to/another/plot": {"x": "value", "y": "value"}, - "path/to/empty/plot": None, - }, - "stages": STAGE_EXAMPLE, - }, +def test_dvcfile_load_with_plots(tmp_dir, dvc): + (tmp_dir / "dvc.yaml").dump( { "plots": [ {"path/to/plot": {"x": "value", "y": "value"}}, @@ -427,10 +418,7 @@ def test_dvcfile_load_dump_stage_with_desc_meta(tmp_dir, dvc): ], "stages": STAGE_EXAMPLE, }, - ), -) -def test_dvcfile_load_with_plots(tmp_dir, dvc, data): - (tmp_dir / "dvc.yaml").dump(data) + ) plots = list(dvc.plots.collect()) top_level_plots = plots[0]["workspace"]["definitions"]["data"]["dvc.yaml"]["data"] assert all( diff --git a/tests/integration/plots/conftest.py b/tests/integration/plots/conftest.py index 252107eaa1..1e6f36ea78 100644 --- a/tests/integration/plots/conftest.py +++ b/tests/integration/plots/conftest.py @@ -153,21 +153,25 @@ def make(): outs=["confusion_test.json"], commit="confusion_test", ) - plots_config = { - "linear_train_vs_test": { - "x": "x", - "y": {"linear_train.json": "y", "linear_test.json": "y"}, - "title": "linear plot", + plots_config = [ + { + "linear_train_vs_test": { + "x": "x", + "y": {"linear_train.json": "y", "linear_test.json": "y"}, + "title": "linear plot", + } }, - "confusion_train_vs_test": { - "x": "actual", - "y": { - "confusion_train.json": "predicted", - "confusion_test.json": "predicted", - }, - "template": "confusion", + { + "confusion_train_vs_test": { + "x": "actual", + "y": { + "confusion_train.json": "predicted", + "confusion_test.json": "predicted", + }, + "template": "confusion", + } }, - } + ] from dvc.utils.serialize import modify_yaml diff --git a/tests/integration/plots/test_repo_plots_api.py b/tests/integration/plots/test_repo_plots_api.py index 252b4909fc..c3eab5f81e 100644 --- a/tests/integration/plots/test_repo_plots_api.py +++ b/tests/integration/plots/test_repo_plots_api.py @@ -1,4 +1,5 @@ import pytest +from funcy import merge from tests.utils.plots import get_plot @@ -108,10 +109,9 @@ def test_api_with_config_plots(tmp_dir, dvc, capsys, repo_with_config_plots): plots_data = next(dvc.plots.collect()) - assert ( - get_plot(plots_data, "workspace", typ="definitions", file="dvc.yaml") - == plots_state["configs"]["dvc.yaml"] - ) + assert get_plot( + plots_data, "workspace", typ="definitions", file="dvc.yaml" + ) == merge(*plots_state["configs"]["dvc.yaml"]) for file in plots_state["data"]: data_source = get_plot(plots_data, "workspace", file=file, endkey="data_source")