Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
117 changes: 95 additions & 22 deletions dvc/utils/yaml.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,10 @@
from collections import OrderedDict

from funcy import reraise
from ruamel.yaml import YAML
from ruamel.yaml.emitter import Emitter
from ruamel.yaml.error import YAMLError
from ruamel.yaml.events import DocumentStartEvent

from dvc.exceptions import YAMLFileCorruptedError

Expand All @@ -11,21 +14,38 @@
from yaml import SafeLoader


def load_yaml(path):
with open(path, encoding="utf-8") as fd:
return parse_yaml(fd.read(), path)
class YAMLVersion:
V11 = (1, 1)
V12 = (1, 2)


def parse_yaml(text, path):
try:
import yaml
def _parse_yaml_v1_1(text, path):
import yaml

with reraise(yaml.error.YAMLError, YAMLFileCorruptedError(path)):
return yaml.load(text, Loader=SafeLoader) or {}
except yaml.error.YAMLError as exc:
raise YAMLFileCorruptedError(path) from exc


def parse_yaml_for_update(text, path):
def _parse_yaml_v1_2(text, path):
yaml = YAML(typ="safe")
yaml.version = YAMLVersion.V12
with reraise(YAMLError, YAMLFileCorruptedError(path)):
return yaml.load(text) or {}


def parse_yaml(text, path, *, version=None):
parser = _parse_yaml_v1_1
if version == YAMLVersion.V12:
parser = _parse_yaml_v1_2
return parser(text, path)


def load_yaml(path, *, version=None):
with open(path, encoding="utf-8") as fd:
return parse_yaml(fd.read(), path, version=version)


def parse_yaml_for_update(text, path, *, version=YAMLVersion.V11):
"""Parses text into Python structure.

Unlike `parse_yaml()` this returns ordered dicts, values have special
Expand All @@ -34,20 +54,73 @@ def parse_yaml_for_update(text, path):

This one is, however, several times slower than simple `parse_yaml()`.
"""
try:
yaml = YAML()
yaml = YAML()
yaml.version = version
with reraise(YAMLError, YAMLFileCorruptedError(path)):
return yaml.load(text) or {}
except YAMLError as exc:
raise YAMLFileCorruptedError(path) from exc


def dump_yaml(path, data):
class _YAMLEmitterNoVersionDirective(Emitter):
"""
This emitter skips printing version directive when we set yaml version
on `dump_yaml()`. Also, ruamel.yaml will still try to add a document start
marker line (assuming version directive was written), for which we
need to find a _hack_ to ensure the marker line is not written to the
stream, as our dvcfiles and hopefully, params file are single document
YAML files.

NOTE: do not use this emitter during load/parse, only when dump for 1.1
"""

MARKER_START_LINE = "---"

def write_version_directive(self, version_text):
"""Do not write version directive at all."""

def expect_first_document_start(self):
# as our yaml files are expected to only have a single document,
# this is not needed, just trying to make it a bit resilient,
# but it's not well-thought out.
# pylint: disable=attribute-defined-outside-init
self._first_document = True
ret = super().expect_first_document_start()
self._first_document = False
return ret

# pylint: disable=signature-differs
def write_indicator(self, indicator, *args, **kwargs):
# NOTE: if the yaml file already have a directive,
# this will strip it
if isinstance(self.event, DocumentStartEvent):
skip_marker = (
# see comments in _expect_first_document_start()
getattr(self, "_first_document", False)
and not self.event.explicit
and not self.canonical
and not self.event.tags
)
if skip_marker and indicator == self.MARKER_START_LINE:
return
super().write_indicator(indicator, *args, **kwargs)


def _dump_yaml(data, stream, *, version=None, with_directive=False):
yaml = YAML()
if version in (None, YAMLVersion.V11):
yaml.version = YAMLVersion.V11
if not with_directive:
yaml.Emitter = _YAMLEmitterNoVersionDirective
elif with_directive and version == YAMLVersion.V12:
# `ruamel.yaml` dumps in 1.2 by default
yaml.version = version

yaml.default_flow_style = False
yaml.Representer.add_representer(
OrderedDict, yaml.Representer.represent_dict
)
yaml.dump(data, stream)


def dump_yaml(path, data, *, version=None, with_directive=False):
with open(path, "w", encoding="utf-8") as fd:
yaml = YAML()
yaml.default_flow_style = False
# tell Dumper to represent OrderedDict as
# normal dict
yaml.Representer.add_representer(
OrderedDict, yaml.Representer.represent_dict
)
yaml.dump(data, fd)
_dump_yaml(data, fd, version=version, with_directive=with_directive)
74 changes: 74 additions & 0 deletions tests/unit/utils/test_yaml.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,74 @@
import pytest

from dvc.exceptions import YAMLFileCorruptedError
from dvc.utils.yaml import (
YAMLVersion,
dump_yaml,
parse_yaml,
parse_yaml_for_update,
)

V12 = YAMLVersion.V12
V11 = YAMLVersion.V11


def _get_directive(version):
return "%YAML {ver}\n---\n".format(
ver=".".join(str(num) for num in version)
)


@pytest.mark.parametrize("data", [{"x": 3e24}])
@pytest.mark.parametrize("with_directive", [True, False])
@pytest.mark.parametrize(
"ver, directive, expected",
[
# dot before mantissa is not required in yaml1.2,
# whereas it's required in yaml1.1
(V12, _get_directive(V12), "x: 3e+24\n"),
(V11, _get_directive(V11), "x: 3.0e+24\n"),
],
)
def test_dump_yaml_with_directive(
tmp_dir, ver, directive, expected, with_directive, data
):
dump_yaml("data.yaml", data, version=ver, with_directive=with_directive)
actual = (tmp_dir / "data.yaml").read_text()
exp = expected if not with_directive else directive + expected
assert actual == exp


def test_load_yaml():
assert parse_yaml("x: 3e24", "data.yaml") == {"x": "3e24"}
assert parse_yaml("x: 3.0e+24", "data.yaml") == {"x": 3e24}

assert parse_yaml("x: 3e24", "data.yaml", version=V12) == {"x": 3e24}
assert parse_yaml("x: 3.0e+24", "data.yaml", version=V12) == {"x": 3e24}

with pytest.raises(YAMLFileCorruptedError):
assert parse_yaml("invalid: '", "data.yaml")

with pytest.raises(YAMLFileCorruptedError):
assert parse_yaml("invalid: '", "data.yaml", version=V12)


def test_comments_are_preserved_on_update_and_dump(tmp_dir):
text = "x: 3 # this is a comment"
d = parse_yaml_for_update(text, "data.yaml")
d["w"] = 7e24

dump_yaml("data.yaml", d)
assert (tmp_dir / "data.yaml").read_text() == f"{text}\nw: 7.0e+24\n"

dump_yaml("data.yaml", d, with_directive=True)
assert (tmp_dir / "data.yaml").read_text() == _get_directive(
V11
) + f"{text}\nw: 7.0e+24\n"

dump_yaml("data.yaml", d, version=V12)
assert (tmp_dir / "data.yaml").read_text() == f"{text}\nw: 7e+24\n"

dump_yaml("data.yaml", d, with_directive=True, version=V12)
assert (tmp_dir / "data.yaml").read_text() == _get_directive(
V12
) + f"{text}\nw: 7e+24\n"