From 61613aa0ea744e0aac0024afbba694930a1cb524 Mon Sep 17 00:00:00 2001 From: Alex Hall Date: Wed, 6 Jul 2022 13:38:08 +0200 Subject: [PATCH] metrics: support TOML files Fixes #6402 --- dvc/repo/metrics/show.py | 6 ++++-- tests/func/metrics/test_show.py | 14 ++++++++++++-- 2 files changed, 16 insertions(+), 4 deletions(-) diff --git a/dvc/repo/metrics/show.py b/dvc/repo/metrics/show.py index 166765a1ef..a829413e53 100644 --- a/dvc/repo/metrics/show.py +++ b/dvc/repo/metrics/show.py @@ -12,7 +12,7 @@ from dvc.scm import NoSCMError from dvc.utils import error_handler, errored_revisions, onerror_collect from dvc.utils.collections import ensure_list -from dvc.utils.serialize import load_yaml +from dvc.utils.serialize import LOADERS logger = logging.getLogger(__name__) @@ -71,7 +71,9 @@ def _extract_metrics(metrics, path, rev): @error_handler def _read_metric(path, fs, rev, **kwargs): - val = load_yaml(path, fs=fs) + suffix = fs.path.suffix(path).lower() + loader = LOADERS[suffix] + val = loader(path, fs=fs) val = _extract_metrics(val, path, rev) return val or {} diff --git a/tests/func/metrics/test_show.py b/tests/func/metrics/test_show.py index bcbdc2147d..d7a92739e2 100644 --- a/tests/func/metrics/test_show.py +++ b/tests/func/metrics/test_show.py @@ -8,7 +8,7 @@ from dvc.exceptions import OverlappingOutputPathsError from dvc.repo import Repo from dvc.utils.fs import remove -from dvc.utils.serialize import YAMLFileCorruptedError +from dvc.utils.serialize import JSONFileCorruptedError, YAMLFileCorruptedError def test_show_simple(tmp_dir, dvc, run_copy_metrics): @@ -31,6 +31,16 @@ def test_show(tmp_dir, dvc, run_copy_metrics): } +def test_show_toml(tmp_dir, dvc, run_copy_metrics): + tmp_dir.gen("metrics_t.toml", "[foo]\nbar = 1.2") + run_copy_metrics( + "metrics_t.toml", "metrics.toml", metrics=["metrics.toml"] + ) + assert dvc.metrics.show() == { + "": {"data": {"metrics.toml": {"data": {"foo": {"bar": 1.2}}}}} + } + + def test_show_targets(tmp_dir, dvc, run_copy_metrics): tmp_dir.gen("metrics_t.yaml", "foo: 1.1") run_copy_metrics( @@ -218,7 +228,7 @@ def test_show_malformed_metric(tmp_dir, scm, dvc, caplog): dvc.metrics.show(targets=["metric.json"])[""]["data"]["metric.json"][ "error" ], - YAMLFileCorruptedError, + JSONFileCorruptedError, )