Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
39 changes: 20 additions & 19 deletions dvc/repo/experiments/queue/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,6 @@

from funcy import cached_property

from dvc.dependency.param import MissingParamsError
from dvc.env import DVCLIVE_RESUME
from dvc.exceptions import DvcException
from dvc.ui import ui
Expand Down Expand Up @@ -283,7 +282,7 @@ def logs(
def _stash_exp(
self,
*args,
params: Optional[dict] = None,
params: Optional[Dict[str, List[str]]] = None,
resume_rev: Optional[str] = None,
baseline_rev: Optional[str] = None,
branch: Optional[str] = None,
Expand All @@ -292,10 +291,9 @@ def _stash_exp(
) -> QueueEntry:
"""Stash changes from the workspace as an experiment.

Arguments:
params: Optional dictionary of parameter values to be used.
Values take priority over any parameters specified in the
user's workspace.
Args:
params: Dict mapping paths to `Hydra Override`_ patterns,
provided via `exp run --set-param`.
resume_rev: Optional checkpoint resume rev.
baseline_rev: Optional baseline rev for this experiment, defaults
to the current SCM rev.
Expand All @@ -305,6 +303,9 @@ def _stash_exp(
name: Optional experiment name. If specified this will be used as
the human-readable name in the experiment branch ref. Has no
effect of branch is specified.

.. _Hydra Override:
https://hydra.cc/docs/next/advanced/override_grammar/basic/
"""
with self.scm.detach_head(client="dvc") as orig_head:
stash_head = orig_head
Expand Down Expand Up @@ -508,22 +509,22 @@ def _format_new_params_msg(new_params, config_path):
f"from '{config_path}': {param_list}"
)

def _update_params(self, params: dict):
"""Update experiment params files with the specified values."""
from dvc.utils.collections import NewParamsFound, merge_params
from dvc.utils.serialize import MODIFIERS
def _update_params(self, params: Dict[str, List[str]]):
"""Update param files with the provided `Hydra Override`_ patterns.

Args:
params: Dict mapping paths to `Hydra Override`_ patterns,
provided via `exp run --set-param`.

.. _Hydra Override:
https://hydra.cc/docs/next/advanced/override_grammar/basic/
"""
logger.debug("Using experiment params '%s'", params)

for path in params:
suffix = self.repo.fs.path.suffix(path).lower()
modify_data = MODIFIERS[suffix]
with modify_data(path, fs=self.repo.fs) as data:
try:
merge_params(data, params[path], allow_new=False)
except NewParamsFound as e:
msg = self._format_new_params_msg(e.new_params, path)
raise MissingParamsError(msg)
from dvc.utils.hydra import apply_overrides

for path, overrides in params.items():
apply_overrides(path, overrides)

# Force params file changes to be staged in git
# Otherwise in certain situations the changes to params file may be
Expand Down
4 changes: 2 additions & 2 deletions dvc/repo/experiments/run.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@

from dvc.repo import locked
from dvc.ui import ui
from dvc.utils.cli_parse import loads_param_overrides
from dvc.utils.cli_parse import to_path_overrides

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -31,7 +31,7 @@ def run(
return repo.experiments.reproduce_celery(entries, jobs=jobs)

if params:
params = loads_param_overrides(params)
params = to_path_overrides(params)

if queue:
if not kwargs.get("checkpoint_resume", None):
Expand Down
26 changes: 0 additions & 26 deletions dvc/utils/_benedict.py

This file was deleted.

39 changes: 13 additions & 26 deletions dvc/utils/cli_parse.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from collections import defaultdict
from typing import Any, Dict, Iterable, List
from typing import Dict, Iterable, List


def parse_params(path_params: Iterable[str]) -> List[Dict[str, List[str]]]:
Expand All @@ -17,35 +17,22 @@ def parse_params(path_params: Iterable[str]) -> List[Dict[str, List[str]]]:
return [{path: params} for path, params in ret.items()]


def loads_param_overrides(
def to_path_overrides(
path_params: Iterable[str],
) -> Dict[str, Dict[str, Any]]:
"""Loads the content of params from the cli as Python object."""
from ruamel.yaml import YAMLError

) -> Dict[str, List[str]]:
"""Group overrides by path"""
from dvc.dependency.param import ParamsDependency
from dvc.exceptions import InvalidArgumentError

from .serialize import loads_yaml

ret: Dict[str, Dict[str, Any]] = defaultdict(dict)

path_overrides = defaultdict(list)
for path_param in path_params:
param_name, _, param_value = path_param.partition("=")
if not param_value:
raise InvalidArgumentError(
f"Must provide a value for parameter '{param_name}'"
)
path, _, param_name = param_name.partition(":")
if not param_name:
param_name = path

path_and_name = path_param.partition("=")[0]
if ":" not in path_and_name:
override = path_param
path = ParamsDependency.DEFAULT_PARAMS_FILE
else:
path, _, override = path_param.partition(":")

try:
ret[path][param_name] = loads_yaml(param_value)
except (ValueError, YAMLError):
raise InvalidArgumentError(
f"Invalid parameter value for '{param_name}': '{param_value}"
)
path_overrides[path].append(override)

return ret
return dict(path_overrides)
82 changes: 47 additions & 35 deletions dvc/utils/collections.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,16 +3,6 @@
from functools import wraps
from typing import Callable, Dict, Iterable, List, TypeVar, Union

from dvc.exceptions import DvcException


class NewParamsFound(DvcException):
"""Thrown if new params were found during merge_params"""

def __init__(self, new_params: List, *args):
self.new_params = new_params
super().__init__("New params found during merge", *args)


def apply_diff(src, dest):
"""Recursively apply changes from src to dest.
Expand Down Expand Up @@ -61,6 +51,53 @@ def is_same_type(a, b):
)


def to_omegaconf(item):
Copy link
Contributor Author

@daavoo daavoo Jul 29, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

All these 4 functions had to be added because we try to preserve comments while modifying the params file.

I don't know if this comment-preserving logic was a user request or how relevant it actually is.
For example, YAML spec says:

Comments are a presentation detail and must not have any effect on the serialization tree or representation graph. 
In particular, comments are not associated with a particular node.

If we discard the idea of trying to preserve comments we could remove quite a lot of code

"""
Some parsers return custom classes (i.e. parse_yaml_for_update)
that can mess up with omegaconf logic.
Cast the custom classes to Python primitives.
"""
if isinstance(item, dict):
item = {k: to_omegaconf(v) for k, v in item.items()}
elif isinstance(item, list):
item = [to_omegaconf(x) for x in item]
return item


def remove_missing_keys(src, to_update):
keys = list(src.keys())
for key in keys:
if key not in to_update:
del src[key]
elif isinstance(src[key], dict):
remove_missing_keys(src[key], to_update[key])

return src


def _merge_item(d, key, value):
if key in d:
item = d.get(key, None)
if isinstance(item, dict) and isinstance(value, dict):
merge_dicts(item, value)
else:
d[key] = value
else:
d[key] = value


def merge_dicts(src: Dict, to_update: Dict) -> Dict:
"""Recursively merges dictionaries.

Args:
src (dict): source dictionary of parameters
to_update (dict): dictionary of parameters to merge into src
"""
for key, value in to_update.items():
_merge_item(src, key, value)
return src


def ensure_list(item: Union[Iterable[str], str, None]) -> List[str]:
if item is None:
return []
Expand All @@ -79,31 +116,6 @@ def chunk_dict(d: Dict[_KT, _VT], size: int = 1) -> List[Dict[_KT, _VT]]:
return [{key: d[key] for key in chunk} for chunk in chunks(size, d)]


def merge_params(src: Dict, to_update: Dict, allow_new: bool = True) -> Dict:
"""
Recursively merges params with benedict's syntax support in-place.

Args:
src (dict): source dictionary of parameters
to_update (dict): dictionary of parameters to merge into src
allow_new (bool): if False, raises an error if new keys would be
added to src
"""
from ._benedict import benedict

data = benedict(src)

if not allow_new:
new_params = list(
set(to_update.keys()) - set(data.keypaths(indexes=True))
)
if new_params:
raise NewParamsFound(new_params)

data.merge(to_update, overwrite=True)
return src


class _NamespacedDict(dict):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
Expand Down
53 changes: 53 additions & 0 deletions dvc/utils/hydra.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
from pathlib import Path
from typing import TYPE_CHECKING, List

from hydra._internal.config_loader_impl import ConfigLoaderImpl
from hydra.core.override_parser.overrides_parser import OverridesParser
from hydra.errors import ConfigCompositionException, OverrideParseException
from hydra.types import RunMode
from omegaconf import OmegaConf

from dvc.exceptions import InvalidArgumentError

from .collections import merge_dicts, remove_missing_keys, to_omegaconf
from .serialize import MODIFIERS

if TYPE_CHECKING:
from dvc.types import StrPath


def apply_overrides(path: "StrPath", overrides: List[str]) -> None:
"""Update `path` params with the provided `Hydra Override`_ patterns.

Args:
overrides: List of `Hydra Override`_ patterns.

.. _Hydra Override:
https://hydra.cc/docs/next/advanced/override_grammar/basic/
"""
suffix = Path(path).suffix.lower()

hydra_errors = (ConfigCompositionException, OverrideParseException)

modify_data = MODIFIERS[suffix]
with modify_data(path) as original_data:
try:
parser = OverridesParser.create()
parsed = parser.parse_overrides(overrides=overrides)
ConfigLoaderImpl.validate_sweep_overrides_legal(
parsed, run_mode=RunMode.RUN, from_shell=True
)

new_data = OmegaConf.create(
to_omegaconf(original_data),
flags={"allow_objects": True},
)
OmegaConf.set_struct(new_data, True)
# pylint: disable=protected-access
ConfigLoaderImpl._apply_overrides_to_config(parsed, new_data)
new_data = OmegaConf.to_object(new_data)
except hydra_errors as e:
raise InvalidArgumentError("Invalid `--set-param` value") from e

merge_dicts(original_data, new_data)
remove_missing_keys(original_data, new_data)
2 changes: 1 addition & 1 deletion setup.cfg
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,6 @@ install_requires =
dpath>=2.0.2,<3
shtab>=1.3.4,<2
rich>=10.13.0
python-benedict>=0.24.2
pyparsing>=2.4.7
typing-extensions>=3.7.4
fsspec[http]>=2021.10.1
Expand All @@ -70,6 +69,7 @@ install_requires =
dvc-task==0.1.2
dvclive>=0.10.0
dvc-data==0.1.13
hydra-core>=1.1.0

[options.extras_require]
all =
Expand Down
14 changes: 14 additions & 0 deletions tests/func/experiments/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -133,3 +133,17 @@ def http_auth_patch(mocker):
@pytest.fixture(params=[True, False])
def workspace(request, test_queue) -> bool: # noqa
return request.param


@pytest.fixture
def params_repo(tmp_dir, scm, dvc):
(tmp_dir / "params.yaml").dump(
{"foo": [{"bar": 1}, {"baz": 2}], "goo": {"bag": 3.0}, "lorem": False}
)
dvc.run(
cmd="echo foo",
params=["params.yaml:"],
name="foo",
)
scm.add(["dvc.yaml", "dvc.lock", "copy.py", "params.yaml"])
scm.commit("init")
Loading