Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 6 additions & 9 deletions dvc/repo/plots/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -472,15 +472,12 @@ def _collect_pipeline_files(repo, targets: List[str], props, onerror=None):
for dvcfile, plots_def in top_plots.items():
dvcfile_path = _relpath(repo.dvcfs, dvcfile)
dvcfile_defs_dict: Dict[str, Union[Dict, None]] = {}
if isinstance(plots_def, list):
for elem in plots_def:
if isinstance(elem, str):
dvcfile_defs_dict[elem] = None
else:
k, v = list(elem.items())[0]
dvcfile_defs_dict[k] = v
else:
dvcfile_defs_dict = plots_def
for elem in plots_def:
if isinstance(elem, str):
dvcfile_defs_dict[elem] = None
else:
k, v = list(elem.items())[0]
dvcfile_defs_dict[k] = v

resolved = _resolve_definitions(
repo.dvcfs,
Expand Down
2 changes: 1 addition & 1 deletion dvc/schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,7 +121,7 @@ def validator(data):
str: either_or(STAGE_DEFINITION, FOREACH_IN, [FOREACH_KWD, DO_KWD])
}
MULTI_STAGE_SCHEMA = {
PLOTS: Any(SINGLE_PLOT_SCHEMA, [Any(str, SINGLE_PLOT_SCHEMA)]),
PLOTS: [Any(str, SINGLE_PLOT_SCHEMA)],
STAGES: SINGLE_PIPELINE_STAGE_SCHEMA,
VARS_KWD: VARS_SCHEMA,
StageParams.PARAM_PARAMS: [str],
Expand Down
2 changes: 1 addition & 1 deletion tests/func/plots/test_show.py
Original file line number Diff line number Diff line change
Expand Up @@ -377,7 +377,7 @@ def test_top_level_plots(

config_file = "dvc.yaml"
with modify_yaml(config_file) as dvcfile_content:
dvcfile_content["plots"] = plot_config
dvcfile_content["plots"] = [plot_config]

result = dvc.plots.show()

Expand Down
18 changes: 3 additions & 15 deletions tests/func/test_dvcfile.py
Original file line number Diff line number Diff line change
Expand Up @@ -407,17 +407,8 @@ def test_dvcfile_load_dump_stage_with_desc_meta(tmp_dir, dvc):
assert (tmp_dir / "dvc.yaml").parse() == data


@pytest.mark.parametrize(
"data",
(
{
"plots": {
"path/to/plot": {"x": "value", "y": "value"},
"path/to/another/plot": {"x": "value", "y": "value"},
"path/to/empty/plot": None,
},
"stages": STAGE_EXAMPLE,
},
def test_dvcfile_load_with_plots(tmp_dir, dvc):
(tmp_dir / "dvc.yaml").dump(
{
"plots": [
{"path/to/plot": {"x": "value", "y": "value"}},
Expand All @@ -427,10 +418,7 @@ def test_dvcfile_load_dump_stage_with_desc_meta(tmp_dir, dvc):
],
"stages": STAGE_EXAMPLE,
},
),
)
def test_dvcfile_load_with_plots(tmp_dir, dvc, data):
(tmp_dir / "dvc.yaml").dump(data)
)
plots = list(dvc.plots.collect())
top_level_plots = plots[0]["workspace"]["definitions"]["data"]["dvc.yaml"]["data"]
assert all(
Expand Down
30 changes: 17 additions & 13 deletions tests/integration/plots/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -153,21 +153,25 @@ def make():
outs=["confusion_test.json"],
commit="confusion_test",
)
plots_config = {
"linear_train_vs_test": {
"x": "x",
"y": {"linear_train.json": "y", "linear_test.json": "y"},
"title": "linear plot",
plots_config = [
{
"linear_train_vs_test": {
"x": "x",
"y": {"linear_train.json": "y", "linear_test.json": "y"},
"title": "linear plot",
}
},
"confusion_train_vs_test": {
"x": "actual",
"y": {
"confusion_train.json": "predicted",
"confusion_test.json": "predicted",
},
"template": "confusion",
{
"confusion_train_vs_test": {
"x": "actual",
"y": {
"confusion_train.json": "predicted",
"confusion_test.json": "predicted",
},
"template": "confusion",
}
},
}
]

from dvc.utils.serialize import modify_yaml

Expand Down
8 changes: 4 additions & 4 deletions tests/integration/plots/test_repo_plots_api.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import pytest
from funcy import merge

from tests.utils.plots import get_plot

Expand Down Expand Up @@ -108,10 +109,9 @@ def test_api_with_config_plots(tmp_dir, dvc, capsys, repo_with_config_plots):

plots_data = next(dvc.plots.collect())

assert (
get_plot(plots_data, "workspace", typ="definitions", file="dvc.yaml")
== plots_state["configs"]["dvc.yaml"]
)
assert get_plot(
plots_data, "workspace", typ="definitions", file="dvc.yaml"
) == merge(*plots_state["configs"]["dvc.yaml"])

for file in plots_state["data"]:
data_source = get_plot(plots_data, "workspace", file=file, endkey="data_source")
Expand Down