From 2535cccf7cc08e61977f0fced51355b04e2a1d0e Mon Sep 17 00:00:00 2001 From: Sami Jawhar Date: Wed, 21 Oct 2020 12:36:44 -0500 Subject: [PATCH 1/3] Allow setting of list parameteres Check if dict is Mapping or list before using .get() --- dvc/repo/experiments/__init__.py | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/dvc/repo/experiments/__init__.py b/dvc/repo/experiments/__init__.py index e2468704c3..3401a9e3b1 100644 --- a/dvc/repo/experiments/__init__.py +++ b/dvc/repo/experiments/__init__.py @@ -319,7 +319,12 @@ def _update_params(self, params: dict): def _update(dict_, other): for key, value in other.items(): if isinstance(value, Mapping): - dict_[key] = _update(dict_.get(key, {}), value) + if isinstance(dict_, Mapping): + fallback_value = dict_.get(key, {}) + elif isinstance(dict_, list) and key.isdigit(): + key = int(key) + fallback_value = dict_[key] + dict_[key] = _update(fallback_value, value) else: dict_[key] = value return dict_ From 05a2f6e47b247b136d3228b8128da7e10ba01b20 Mon Sep 17 00:00:00 2001 From: Sami Jawhar Date: Thu, 22 Oct 2020 12:05:31 -0500 Subject: [PATCH 2/3] Added test for dvc exp run with list paramater --- tests/func/experiments/test_experiments.py | 21 +++++++++++++++++++++ 1 file changed, 21 insertions(+) diff --git a/tests/func/experiments/test_experiments.py b/tests/func/experiments/test_experiments.py index afe970481e..1a8699d491 100644 --- a/tests/func/experiments/test_experiments.py +++ b/tests/func/experiments/test_experiments.py @@ -93,6 +93,27 @@ def test_update_with_pull(tmp_dir, scm, dvc, mocker): assert exp_scm.has_rev(rev) +def test_modify_list_parameter(tmp_dir, scm, dvc, mocker): + tmp_dir.gen("copy.py", COPY_SCRIPT) + tmp_dir.gen("params.yaml", "foo: [bar: 1, baz: 2]") + stage = dvc.run( + cmd="python copy.py params.yaml metrics.yaml", + metrics_no_cache=["metrics.yaml"], + params=["foo"], + name="copy-file", + ) + scm.add(["dvc.yaml", "dvc.lock", "copy.py", "params.yaml", "metrics.yaml"]) + scm.commit("init") + + new_mock = mocker.spy(dvc.experiments, "new") + dvc.experiments.run(stage.addressing, params=["foo.1.baz=3"]) + + new_mock.assert_called_once() + assert ( + tmp_dir / ".dvc" / "experiments" / "metrics.yaml" + ).read_text().strip() == "foo: [bar: 1, baz: 3]" + + def test_checkout(tmp_dir, scm, dvc): tmp_dir.gen("copy.py", COPY_SCRIPT) tmp_dir.gen("params.yaml", "foo: 1") From 033c773c73e0b7f61ca3e7fa89f7b59022e4726c Mon Sep 17 00:00:00 2001 From: Sami Jawhar Date: Thu, 22 Oct 2020 15:08:19 -0500 Subject: [PATCH 3/3] Add more tests for string parameter modification --- dvc/repo/experiments/__init__.py | 9 +++++---- tests/func/experiments/test_experiments.py | 14 +++++++++++--- 2 files changed, 16 insertions(+), 7 deletions(-) diff --git a/dvc/repo/experiments/__init__.py b/dvc/repo/experiments/__init__.py index 3401a9e3b1..55e6c26071 100644 --- a/dvc/repo/experiments/__init__.py +++ b/dvc/repo/experiments/__init__.py @@ -318,12 +318,13 @@ def _update_params(self, params: dict): # recursive dict update def _update(dict_, other): for key, value in other.items(): + if isinstance(dict_, list) and key.isdigit(): + key = int(key) if isinstance(value, Mapping): - if isinstance(dict_, Mapping): - fallback_value = dict_.get(key, {}) - elif isinstance(dict_, list) and key.isdigit(): - key = int(key) + if isinstance(dict_, list): fallback_value = dict_[key] + else: + fallback_value = dict_.get(key, {}) dict_[key] = _update(fallback_value, value) else: dict_[key] = value diff --git a/tests/func/experiments/test_experiments.py b/tests/func/experiments/test_experiments.py index 1a8699d491..74c3d2c7b6 100644 --- a/tests/func/experiments/test_experiments.py +++ b/tests/func/experiments/test_experiments.py @@ -93,7 +93,15 @@ def test_update_with_pull(tmp_dir, scm, dvc, mocker): assert exp_scm.has_rev(rev) -def test_modify_list_parameter(tmp_dir, scm, dvc, mocker): +@pytest.mark.parametrize( + "change, expected", + [ + ["foo.1.baz=3", "foo: [bar: 1, baz: 3]"], + ["foo.0=bar", "foo: [bar, baz: 2]"], + ["foo.1=- baz\n- goo", "foo: [bar: 1, [baz, goo]]"], + ], +) +def test_modify_list_param(tmp_dir, scm, dvc, mocker, change, expected): tmp_dir.gen("copy.py", COPY_SCRIPT) tmp_dir.gen("params.yaml", "foo: [bar: 1, baz: 2]") stage = dvc.run( @@ -106,12 +114,12 @@ def test_modify_list_parameter(tmp_dir, scm, dvc, mocker): scm.commit("init") new_mock = mocker.spy(dvc.experiments, "new") - dvc.experiments.run(stage.addressing, params=["foo.1.baz=3"]) + dvc.experiments.run(stage.addressing, params=[change]) new_mock.assert_called_once() assert ( tmp_dir / ".dvc" / "experiments" / "metrics.yaml" - ).read_text().strip() == "foo: [bar: 1, baz: 3]" + ).read_text().strip() == expected def test_checkout(tmp_dir, scm, dvc):