diff --git a/dvc/dvcfile.py b/dvc/dvcfile.py index 4c1af7f739..6932a7bb7d 100644 --- a/dvc/dvcfile.py +++ b/dvc/dvcfile.py @@ -19,6 +19,7 @@ from dvc.utils.serialize import ( dump_yaml, load_yaml, + modify_yaml, parse_yaml, parse_yaml_for_update, ) @@ -193,30 +194,27 @@ def _dump_lockfile(self, stage): self._lockfile.dump(stage) def _dump_pipeline_file(self, stage): - data = {} - if self.exists(): - with open(self.path) as fd: - data = parse_yaml_for_update(fd.read(), self.path) - else: - logger.info("Creating '%s'", self.relpath) - open(self.path, "w+").close() - - data["stages"] = data.get("stages", {}) stage_data = serialize.to_pipeline_file(stage) - existing_entry = stage.name in data["stages"] - action = "Modifying" if existing_entry else "Adding" - logger.info("%s stage '%s' in '%s'", action, stage.name, self.relpath) + with modify_yaml(self.path, tree=self.repo.tree) as data: + if not data: + logger.info("Creating '%s'", self.relpath) - if existing_entry: - orig_stage_data = data["stages"][stage.name] - if "meta" in orig_stage_data: - stage_data[stage.name]["meta"] = orig_stage_data["meta"] - apply_diff(stage_data[stage.name], orig_stage_data) - else: - data["stages"].update(stage_data) + data["stages"] = data.get("stages", {}) + existing_entry = stage.name in data["stages"] + action = "Modifying" if existing_entry else "Adding" + logger.info( + "%s stage '%s' in '%s'", action, stage.name, self.relpath + ) + + if existing_entry: + orig_stage_data = data["stages"][stage.name] + if "meta" in orig_stage_data: + stage_data[stage.name]["meta"] = orig_stage_data["meta"] + apply_diff(stage_data[stage.name], orig_stage_data) + else: + data["stages"].update(stage_data) - dump_yaml(self.path, data) self.repo.scm.track_file(self.relpath) @property @@ -281,21 +279,18 @@ def load(self): def dump(self, stage, **kwargs): stage_data = serialize.to_lockfile(stage) - if not self.exists(): - modified = True - logger.info("Generating lock file '%s'", self.relpath) - data = stage_data - open(self.path, "w+").close() - else: - with self.repo.tree.open(self.path, "r") as fd: - data = parse_yaml_for_update(fd.read(), self.path) + + with modify_yaml(self.path, tree=self.repo.tree) as data: + if not data: + logger.info("Generating lock file '%s'", self.relpath) + modified = data.get(stage.name, {}) != stage_data.get( stage.name, {} ) if modified: logger.info("Updating lock file '%s'", self.relpath) data.update(stage_data) - dump_yaml(self.path, data) + if modified: self.repo.scm.track_file(self.relpath) diff --git a/dvc/repo/experiments/__init__.py b/dvc/repo/experiments/__init__.py index 0b51d04449..516c7eebd9 100644 --- a/dvc/repo/experiments/__init__.py +++ b/dvc/repo/experiments/__init__.py @@ -2,7 +2,6 @@ import os import re import tempfile -from collections import defaultdict from collections.abc import Mapping from concurrent.futures import ( ProcessPoolExecutor, @@ -213,12 +212,7 @@ def _unpack_args(self, tree=None): def _update_params(self, params: dict): """Update experiment params files with the specified values.""" - from dvc.utils.serialize import ( - dump_toml, - dump_yaml, - parse_toml_for_update, - parse_yaml_for_update, - ) + from dvc.utils.serialize import MODIFIERS logger.debug("Using experiment params '%s'", params) @@ -231,19 +225,12 @@ def _update(dict_, other): dict_[key] = value return dict_ - loaders = defaultdict(lambda: parse_yaml_for_update) - loaders.update({".toml": parse_toml_for_update}) - dumpers = defaultdict(lambda: dump_yaml) - dumpers.update({".toml": dump_toml}) - for params_fname in params: path = PathInfo(self.exp_dvc.root_dir) / params_fname - with self.exp_dvc.tree.open(path, "r") as fobj: - text = fobj.read() suffix = path.suffix.lower() - data = loaders[suffix](text, path) - _update(data, params[params_fname]) - dumpers[suffix](path, data) + modify_data = MODIFIERS[suffix] + with modify_data(path, tree=self.exp_dvc.tree) as data: + _update(data, params[params_fname]) def _commit(self, exp_hash, check_exists=True, branch=True): """Commit stages as an experiment and return the commit SHA.""" diff --git a/dvc/scm/git.py b/dvc/scm/git.py index aa41d24832..b777488ac8 100644 --- a/dvc/scm/git.py +++ b/dvc/scm/git.py @@ -19,7 +19,7 @@ ) from dvc.utils import fix_env, is_binary, relpath from dvc.utils.fs import path_isin -from dvc.utils.serialize import dump_yaml, load_yaml +from dvc.utils.serialize import modify_yaml logger = logging.getLogger(__name__) @@ -330,36 +330,32 @@ def install(self, use_pre_commit_tool=False): return config_path = os.path.join(self.root_dir, ".pre-commit-config.yaml") - config = load_yaml(config_path) if os.path.exists(config_path) else {} - - entry = { - "repo": "https://github.com/iterative/dvc", - "rev": "master", - "hooks": [ - { - "id": "dvc-pre-commit", - "language_version": "python3", - "stages": ["commit"], - }, - { - "id": "dvc-pre-push", - "language_version": "python3", - "stages": ["push"], - }, - { - "id": "dvc-post-checkout", - "language_version": "python3", - "stages": ["post-checkout"], - "always_run": True, - }, - ], - } - - if entry in config["repos"]: - return - - config["repos"].append(entry) - dump_yaml(config_path, config) + with modify_yaml(config_path) as config: + entry = { + "repo": "https://github.com/iterative/dvc", + "rev": "master", + "hooks": [ + { + "id": "dvc-pre-commit", + "language_version": "python3", + "stages": ["commit"], + }, + { + "id": "dvc-pre-push", + "language_version": "python3", + "stages": ["push"], + }, + { + "id": "dvc-post-checkout", + "language_version": "python3", + "stages": ["post-checkout"], + "always_run": True, + }, + ], + } + + if entry not in config["repos"]: + config["repos"].append(entry) def cleanup_ignores(self): for path in self.ignored_paths: diff --git a/dvc/utils/serialize/__init__.py b/dvc/utils/serialize/__init__.py index 33b85159e3..1dd1e92778 100644 --- a/dvc/utils/serialize/__init__.py +++ b/dvc/utils/serialize/__init__.py @@ -6,3 +6,6 @@ LOADERS = defaultdict(lambda: load_yaml) # noqa: F405 LOADERS.update({".toml": load_toml}) # noqa: F405 + +MODIFIERS = defaultdict(lambda: modify_yaml) # noqa: F405 +MODIFIERS.update({".toml": modify_toml}) # noqa: F405 diff --git a/dvc/utils/serialize/_common.py b/dvc/utils/serialize/_common.py index 10dbab3bcf..e01b2c9572 100644 --- a/dvc/utils/serialize/_common.py +++ b/dvc/utils/serialize/_common.py @@ -1,4 +1,6 @@ """Common utilities for serialize.""" +import os +from contextlib import contextmanager from dvc.exceptions import DvcException from dvc.utils import relpath @@ -22,3 +24,11 @@ def _dump_data(path, data, dumper, tree=None): open_fn = tree.open if tree else open with open_fn(path, "w+", encoding="utf-8") as fd: dumper(data, fd) + + +@contextmanager +def _modify_data(path, parser, dumper, tree=None): + exists = tree.exists if tree else os.path.exists + data = _load_data(path, parser=parser, tree=tree) if exists(path) else {} + yield data + dumper(path, data, tree=tree) diff --git a/dvc/utils/serialize/_toml.py b/dvc/utils/serialize/_toml.py index 916c779d45..ce703a8a78 100644 --- a/dvc/utils/serialize/_toml.py +++ b/dvc/utils/serialize/_toml.py @@ -1,7 +1,9 @@ +from contextlib import contextmanager + import toml from funcy import reraise -from ._common import ParseError, _dump_data, _load_data +from ._common import ParseError, _dump_data, _load_data, _modify_data class TOMLFileCorruptedError(ParseError): @@ -35,3 +37,9 @@ def _dump(data, stream): def dump_toml(path, data, tree=None): return _dump_data(path, data, dumper=_dump, tree=tree) + + +@contextmanager +def modify_toml(path, tree=None): + with _modify_data(path, parse_toml_for_update, dump_toml, tree=tree) as d: + yield d diff --git a/dvc/utils/serialize/_yaml.py b/dvc/utils/serialize/_yaml.py index ec04ac30b6..7738530848 100644 --- a/dvc/utils/serialize/_yaml.py +++ b/dvc/utils/serialize/_yaml.py @@ -1,11 +1,12 @@ import io from collections import OrderedDict +from contextlib import contextmanager from funcy import reraise from ruamel.yaml import YAML from ruamel.yaml.error import YAMLError -from ._common import ParseError, _dump_data, _load_data +from ._common import ParseError, _dump_data, _load_data, _modify_data class YAMLFileCorruptedError(ParseError): @@ -60,6 +61,11 @@ def loads_yaml(s, typ="safe"): def dumps_yaml(d): stream = io.StringIO() - yaml = _get_yaml() - yaml.dump(d, stream) + _dump(d, stream) return stream.getvalue() + + +@contextmanager +def modify_yaml(path, tree=None): + with _modify_data(path, parse_yaml_for_update, dump_yaml, tree=tree) as d: + yield d