From 325d6b3ef09c8205c443e23c58d6c4d9171fad3e Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Saugat=20Pachhai=20=28=E0=A4=B8=E0=A5=8C=E0=A4=97=E0=A4=BE?= =?UTF-8?q?=E0=A4=A4=29?= Date: Mon, 10 Aug 2020 23:15:50 +0545 Subject: [PATCH 1/2] temp checkpoint --- dvc/utils/yaml.py | 33 +++++++++++++++++++++++++++++++-- 1 file changed, 31 insertions(+), 2 deletions(-) diff --git a/dvc/utils/yaml.py b/dvc/utils/yaml.py index 70fb46f31a..f4dfdd593b 100644 --- a/dvc/utils/yaml.py +++ b/dvc/utils/yaml.py @@ -1,7 +1,9 @@ from collections import OrderedDict from ruamel.yaml import YAML +from ruamel.yaml.emitter import Emitter from ruamel.yaml.error import YAMLError +from ruamel.yaml.events import DocumentStartEvent from dvc.exceptions import YAMLFileCorruptedError @@ -41,12 +43,39 @@ def parse_yaml_for_update(text, path): raise YAMLFileCorruptedError(path) from exc +class YAMLEmitterNoVersionDirective(Emitter): + MARKER_START_LINE = "---" + + def write_version_directive(self, version_text): + """Do not write version directive at all.""" + + # pylint: disable=signature-differs + def write_indicator(self, indicator, *args, **kwargs): + if isinstance(self.event, DocumentStartEvent): + # TODO: need more tests, how reliable is this check? + skip_marker = ( + not self.event.explicit + and not self.canonical + and not self.event.tags + ) + # FIXME: if there is a marker for "% YAML 1.1", it might + # get removed + if skip_marker and indicator == self.MARKER_START_LINE: + # skip adding marker line + return + super().write_indicator(indicator, *args, **kwargs) + + def dump_yaml(path, data): with open(path, "w", encoding="utf-8") as fd: + yaml = YAML() + # dump by default in v1.1 + yaml.version = (1, 1) yaml.default_flow_style = False - # tell Dumper to represent OrderedDict as - # normal dict + # skip printing directive, and also skip marker line for document start + yaml.Emitter = YAMLEmitterNoVersionDirective + # tell Dumper to represent OrderedDict as a normal dict yaml.Representer.add_representer( OrderedDict, yaml.Representer.represent_dict ) From 265d18d63e6c48e2cc04de91cdbfc1d81390f3cb Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Saugat=20Pachhai=20=28=E0=A4=B8=E0=A5=8C=E0=A4=97=E0=A4=BE?= =?UTF-8?q?=E0=A4=A4=29?= Date: Tue, 11 Aug 2020 19:30:45 +0545 Subject: [PATCH 2/2] dump_yaml: dump as YAML 1.1 --- dvc/utils/yaml.py | 108 ++++++++++++++++++++++++---------- tests/unit/utils/test_yaml.py | 74 +++++++++++++++++++++++ 2 files changed, 150 insertions(+), 32 deletions(-) create mode 100644 tests/unit/utils/test_yaml.py diff --git a/dvc/utils/yaml.py b/dvc/utils/yaml.py index f4dfdd593b..2d81a31504 100644 --- a/dvc/utils/yaml.py +++ b/dvc/utils/yaml.py @@ -1,5 +1,6 @@ from collections import OrderedDict +from funcy import reraise from ruamel.yaml import YAML from ruamel.yaml.emitter import Emitter from ruamel.yaml.error import YAMLError @@ -13,21 +14,38 @@ from yaml import SafeLoader -def load_yaml(path): - with open(path, encoding="utf-8") as fd: - return parse_yaml(fd.read(), path) +class YAMLVersion: + V11 = (1, 1) + V12 = (1, 2) -def parse_yaml(text, path): - try: - import yaml +def _parse_yaml_v1_1(text, path): + import yaml + with reraise(yaml.error.YAMLError, YAMLFileCorruptedError(path)): return yaml.load(text, Loader=SafeLoader) or {} - except yaml.error.YAMLError as exc: - raise YAMLFileCorruptedError(path) from exc -def parse_yaml_for_update(text, path): +def _parse_yaml_v1_2(text, path): + yaml = YAML(typ="safe") + yaml.version = YAMLVersion.V12 + with reraise(YAMLError, YAMLFileCorruptedError(path)): + return yaml.load(text) or {} + + +def parse_yaml(text, path, *, version=None): + parser = _parse_yaml_v1_1 + if version == YAMLVersion.V12: + parser = _parse_yaml_v1_2 + return parser(text, path) + + +def load_yaml(path, *, version=None): + with open(path, encoding="utf-8") as fd: + return parse_yaml(fd.read(), path, version=version) + + +def parse_yaml_for_update(text, path, *, version=YAMLVersion.V11): """Parses text into Python structure. Unlike `parse_yaml()` this returns ordered dicts, values have special @@ -36,47 +54,73 @@ def parse_yaml_for_update(text, path): This one is, however, several times slower than simple `parse_yaml()`. """ - try: - yaml = YAML() + yaml = YAML() + yaml.version = version + with reraise(YAMLError, YAMLFileCorruptedError(path)): return yaml.load(text) or {} - except YAMLError as exc: - raise YAMLFileCorruptedError(path) from exc -class YAMLEmitterNoVersionDirective(Emitter): +class _YAMLEmitterNoVersionDirective(Emitter): + """ + This emitter skips printing version directive when we set yaml version + on `dump_yaml()`. Also, ruamel.yaml will still try to add a document start + marker line (assuming version directive was written), for which we + need to find a _hack_ to ensure the marker line is not written to the + stream, as our dvcfiles and hopefully, params file are single document + YAML files. + + NOTE: do not use this emitter during load/parse, only when dump for 1.1 + """ + MARKER_START_LINE = "---" def write_version_directive(self, version_text): """Do not write version directive at all.""" + def expect_first_document_start(self): + # as our yaml files are expected to only have a single document, + # this is not needed, just trying to make it a bit resilient, + # but it's not well-thought out. + # pylint: disable=attribute-defined-outside-init + self._first_document = True + ret = super().expect_first_document_start() + self._first_document = False + return ret + # pylint: disable=signature-differs def write_indicator(self, indicator, *args, **kwargs): + # NOTE: if the yaml file already have a directive, + # this will strip it if isinstance(self.event, DocumentStartEvent): - # TODO: need more tests, how reliable is this check? skip_marker = ( - not self.event.explicit + # see comments in _expect_first_document_start() + getattr(self, "_first_document", False) + and not self.event.explicit and not self.canonical and not self.event.tags ) - # FIXME: if there is a marker for "% YAML 1.1", it might - # get removed if skip_marker and indicator == self.MARKER_START_LINE: - # skip adding marker line return super().write_indicator(indicator, *args, **kwargs) -def dump_yaml(path, data): - with open(path, "w", encoding="utf-8") as fd: +def _dump_yaml(data, stream, *, version=None, with_directive=False): + yaml = YAML() + if version in (None, YAMLVersion.V11): + yaml.version = YAMLVersion.V11 + if not with_directive: + yaml.Emitter = _YAMLEmitterNoVersionDirective + elif with_directive and version == YAMLVersion.V12: + # `ruamel.yaml` dumps in 1.2 by default + yaml.version = version + + yaml.default_flow_style = False + yaml.Representer.add_representer( + OrderedDict, yaml.Representer.represent_dict + ) + yaml.dump(data, stream) - yaml = YAML() - # dump by default in v1.1 - yaml.version = (1, 1) - yaml.default_flow_style = False - # skip printing directive, and also skip marker line for document start - yaml.Emitter = YAMLEmitterNoVersionDirective - # tell Dumper to represent OrderedDict as a normal dict - yaml.Representer.add_representer( - OrderedDict, yaml.Representer.represent_dict - ) - yaml.dump(data, fd) + +def dump_yaml(path, data, *, version=None, with_directive=False): + with open(path, "w", encoding="utf-8") as fd: + _dump_yaml(data, fd, version=version, with_directive=with_directive) diff --git a/tests/unit/utils/test_yaml.py b/tests/unit/utils/test_yaml.py new file mode 100644 index 0000000000..d6b974e6dd --- /dev/null +++ b/tests/unit/utils/test_yaml.py @@ -0,0 +1,74 @@ +import pytest + +from dvc.exceptions import YAMLFileCorruptedError +from dvc.utils.yaml import ( + YAMLVersion, + dump_yaml, + parse_yaml, + parse_yaml_for_update, +) + +V12 = YAMLVersion.V12 +V11 = YAMLVersion.V11 + + +def _get_directive(version): + return "%YAML {ver}\n---\n".format( + ver=".".join(str(num) for num in version) + ) + + +@pytest.mark.parametrize("data", [{"x": 3e24}]) +@pytest.mark.parametrize("with_directive", [True, False]) +@pytest.mark.parametrize( + "ver, directive, expected", + [ + # dot before mantissa is not required in yaml1.2, + # whereas it's required in yaml1.1 + (V12, _get_directive(V12), "x: 3e+24\n"), + (V11, _get_directive(V11), "x: 3.0e+24\n"), + ], +) +def test_dump_yaml_with_directive( + tmp_dir, ver, directive, expected, with_directive, data +): + dump_yaml("data.yaml", data, version=ver, with_directive=with_directive) + actual = (tmp_dir / "data.yaml").read_text() + exp = expected if not with_directive else directive + expected + assert actual == exp + + +def test_load_yaml(): + assert parse_yaml("x: 3e24", "data.yaml") == {"x": "3e24"} + assert parse_yaml("x: 3.0e+24", "data.yaml") == {"x": 3e24} + + assert parse_yaml("x: 3e24", "data.yaml", version=V12) == {"x": 3e24} + assert parse_yaml("x: 3.0e+24", "data.yaml", version=V12) == {"x": 3e24} + + with pytest.raises(YAMLFileCorruptedError): + assert parse_yaml("invalid: '", "data.yaml") + + with pytest.raises(YAMLFileCorruptedError): + assert parse_yaml("invalid: '", "data.yaml", version=V12) + + +def test_comments_are_preserved_on_update_and_dump(tmp_dir): + text = "x: 3 # this is a comment" + d = parse_yaml_for_update(text, "data.yaml") + d["w"] = 7e24 + + dump_yaml("data.yaml", d) + assert (tmp_dir / "data.yaml").read_text() == f"{text}\nw: 7.0e+24\n" + + dump_yaml("data.yaml", d, with_directive=True) + assert (tmp_dir / "data.yaml").read_text() == _get_directive( + V11 + ) + f"{text}\nw: 7.0e+24\n" + + dump_yaml("data.yaml", d, version=V12) + assert (tmp_dir / "data.yaml").read_text() == f"{text}\nw: 7e+24\n" + + dump_yaml("data.yaml", d, with_directive=True, version=V12) + assert (tmp_dir / "data.yaml").read_text() == _get_directive( + V12 + ) + f"{text}\nw: 7e+24\n"