Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 12 additions & 1 deletion dvc/parsing/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import logging
import os
from collections.abc import Mapping, Sequence
from copy import deepcopy
from typing import (
Expand Down Expand Up @@ -134,9 +135,19 @@ def make_definition(

class DataResolver:
def __init__(self, repo: "Repo", wdir: str, d: dict):
from dvc.fs import LocalFileSystem

self.fs = fs = repo.fs

if os.path.isabs(wdir):
start = (
os.curdir if isinstance(fs, LocalFileSystem) else repo.root_dir
)
wdir = relpath(wdir, start)
wdir = "" if wdir == os.curdir else wdir

self.wdir = wdir
self.relpath = relpath(fs.path.join(self.wdir, "dvc.yaml"))
self.relpath = os.path.normpath(fs.path.join(self.wdir, "dvc.yaml"))

vars_ = d.get(VARS_KWD, [])
check_interpolations(vars_, VARS_KWD, self.relpath)
Expand Down
36 changes: 17 additions & 19 deletions dvc/parsing/context.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,6 @@
recurse,
str_interpolate,
)
from dvc.utils import relpath

logger = logging.getLogger(__name__)
SeqOrMap = Union[Sequence, Mapping]
Expand Down Expand Up @@ -359,20 +358,19 @@ def load_from(
) -> "Context":
from dvc.utils.serialize import LOADERS

file = relpath(path)
if not fs.exists(path):
raise ParamsLoadError(f"'{file}' does not exist")
raise ParamsLoadError(f"'{path}' does not exist")
if fs.isdir(path):
raise ParamsLoadError(f"'{file}' is a directory")
raise ParamsLoadError(f"'{path}' is a directory")

_, ext = os.path.splitext(file)
_, ext = os.path.splitext(path)
loader = LOADERS[ext]

data = loader(path, fs=fs)
if not isinstance(data, Mapping):
typ = type(data).__name__
raise ParamsLoadError(
f"expected a dictionary, got '{typ}' in file '{file}'"
f"expected a dictionary, got '{typ}' in file '{path}'"
)

if select_keys:
Expand All @@ -381,12 +379,12 @@ def load_from(
except KeyError as exc:
key, *_ = exc.args
raise ParamsLoadError(
f"could not find '{key}' in '{file}'"
f"could not find '{key}' in '{path}'"
) from exc

meta = Meta(source=file, local=False)
meta = Meta(source=path, local=False)
ctx = cls(data, meta=meta)
ctx.imports[os.path.abspath(path)] = select_keys
ctx.imports[path] = select_keys
return ctx

def merge_update(self, other: "Context", overwrite=False):
Expand All @@ -397,26 +395,26 @@ def merge_update(self, other: "Context", overwrite=False):

def merge_from(self, fs, item: str, wdir: str, overwrite=False):
path, _, keys_str = item.partition(":")
select_keys = lfilter(bool, keys_str.split(",")) if keys_str else None
path = os.path.normpath(fs.path.join(wdir, path))

abspath = os.path.abspath(fs.path.join(wdir, path))
if abspath in self.imports:
if not select_keys and self.imports[abspath] is None:
select_keys = lfilter(bool, keys_str.split(",")) if keys_str else None
if path in self.imports:
if not select_keys and self.imports[path] is None:
return # allow specifying complete filepath multiple times
self.check_loaded(abspath, item, select_keys)
self.check_loaded(path, item, select_keys)

ctx = Context.load_from(fs, abspath, select_keys)
ctx = Context.load_from(fs, path, select_keys)

try:
self.merge_update(ctx, overwrite=overwrite)
except ReservedKeyError as exc:
raise ReservedKeyError(exc.keys, item) from exc

cp = ctx.imports[abspath]
if abspath not in self.imports:
self.imports[abspath] = cp
cp = ctx.imports[path]
if path not in self.imports:
self.imports[path] = cp
elif cp:
self.imports[abspath].extend(cp)
self.imports[path].extend(cp)

def check_loaded(self, path, item, keys):
if not keys and isinstance(self.imports[path], list):
Expand Down
17 changes: 16 additions & 1 deletion dvc/repo/params/show.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,12 @@
from dvc.scm import NoSCMError
from dvc.stage import PipelineStage
from dvc.ui import ui
from dvc.utils import error_handler, errored_revisions, onerror_collect
from dvc.utils import (
error_handler,
errored_revisions,
onerror_collect,
relpath,
)
from dvc.utils.serialize import LOADERS

if TYPE_CHECKING:
Expand Down Expand Up @@ -85,10 +90,20 @@ def _read_params(


def _collect_vars(repo, params) -> Dict:
from dvc.fs.git import GitFileSystem

vars_params: Dict[str, Dict] = defaultdict(dict)
rel_to_root = relpath(repo.root_dir)

for stage in repo.index.stages:
if isinstance(stage, PipelineStage) and stage.tracked_vars:
for file, vars_ in stage.tracked_vars.items():
if isinstance(repo.fs, GitFileSystem):
# GitFileSystem uses relatively-absolute paths from the
# root of the repo. We need to convert them to relative
# paths based on the current working directory.
file = os.path.normpath(os.path.join(rel_to_root, file))

# `params` file are shown regardless of `tracked` or not
# to reduce noise and duplication, they are skipped
if file in params:
Expand Down
4 changes: 1 addition & 3 deletions tests/func/parsing/test_foreach.py
Original file line number Diff line number Diff line change
Expand Up @@ -386,9 +386,7 @@ def test_foreach_with_interpolated_wdir_and_local_vars(
}
},
}
assert resolver.context.imports == {
str(tmp_dir / DEFAULT_PARAMS_FILE): None
}
assert resolver.context.imports == {DEFAULT_PARAMS_FILE: None}


def test_foreach_do_syntax_is_checked_once(tmp_dir, dvc, mocker):
Expand Down
4 changes: 2 additions & 2 deletions tests/func/parsing/test_interpolated_entry.py
Original file line number Diff line number Diff line change
Expand Up @@ -160,7 +160,7 @@ def test_with_templated_wdir(tmp_dir, dvc):
DEFAULT_PARAMS_FILE: {"dict.bar": "bar", "dict.ws": "data"},
}
}
assert resolver.context.imports == {str(tmp_dir / "params.yaml"): None}
assert resolver.context.imports == {"params.yaml": None}
assert resolver.context == {"dict": {"bar": "bar", "ws": "data"}}


Expand Down Expand Up @@ -236,7 +236,7 @@ def test_vars_relpath_overwrite(tmp_dir, dvc):
}
resolver = DataResolver(dvc, tmp_dir.fs_path, d)
resolver.resolve()
assert resolver.context.imports == {str(tmp_dir / "params.yaml"): None}
assert resolver.context.imports == {"params.yaml": None}


@pytest.mark.parametrize("local", [True, False])
Expand Down
30 changes: 30 additions & 0 deletions tests/func/test_import.py
Original file line number Diff line number Diff line change
Expand Up @@ -600,3 +600,33 @@ def test_circular_import(tmp_dir, dvc, scm, erepo_dir):
erepo_dir.dvc.imp(
os.fspath(tmp_dir), "dir_imported", "circular_import"
)


@pytest.mark.parametrize("paths", ([], ["dir"]))
def test_parameterized_repo(tmp_dir, dvc, scm, erepo_dir, paths):
path = erepo_dir.joinpath(*paths)
path.mkdir(parents=True, exist_ok=True)
(path / "params.yaml").dump({"out": "foo"})
(path / "dvc.yaml").dump(
{
"stages": {
"train": {"cmd": "echo ${out} > ${out}", "outs": ["${out}"]},
}
}
)
path.gen({"foo": "foo"})
with path.chdir():
erepo_dir.dvc.commit(None, force=True)
erepo_dir.scm.add_commit(
["params.yaml", "dvc.yaml", "dvc.lock", ".gitignore"],
message="init",
)

to_import = os.path.join(*paths, "foo")
stage = dvc.imp(os.fspath(erepo_dir), to_import, "foo_imported")

assert (tmp_dir / "foo_imported").read_text() == "foo"
assert stage.deps[0].def_repo == {
"url": os.fspath(erepo_dir),
"rev_lock": erepo_dir.scm.get_rev(),
}
22 changes: 10 additions & 12 deletions tests/unit/test_context.py
Original file line number Diff line number Diff line change
Expand Up @@ -288,14 +288,13 @@ def test_track(tmp_dir):
"dct": {"foo": "foo", "bar": "bar", "baz": "baz"},
}
fs = LocalFileSystem()
path = tmp_dir / "params.yaml"
path.dump(d, fs=fs)
(tmp_dir / "params.yaml").dump(d, fs=fs)

context = Context.load_from(fs, path)
context = Context.load_from(fs, "params.yaml")

def key_tracked(d, key):
assert len(d) == 1
return key in d[relpath(path)]
return key in d["params.yaml"]

with context.track() as tracked:
context.select("lst")
Expand Down Expand Up @@ -323,10 +322,10 @@ def test_track_from_multiple_files(tmp_dir):
d2 = {"Train": {"us": {"layers": 100}}}

fs = LocalFileSystem()
path1 = tmp_dir / "params.yaml"
path2 = tmp_dir / "params2.yaml"
path1.dump(d1, fs=fs)
path2.dump(d2, fs=fs)
path1 = "params.yaml"
path2 = "params2.yaml"
(tmp_dir / path1).dump(d1, fs=fs)
(tmp_dir / path2).dump(d2, fs=fs)

context = Context.load_from(fs, path1)
c = Context.load_from(fs, path2)
Expand Down Expand Up @@ -428,16 +427,15 @@ def test_resolve_resolves_boolean_value():

def test_load_from_raises_if_file_not_exist(tmp_dir, dvc):
with pytest.raises(ParamsLoadError) as exc_info:
Context.load_from(dvc.fs, tmp_dir / DEFAULT_PARAMS_FILE)
Context.load_from(dvc.fs, DEFAULT_PARAMS_FILE)

assert str(exc_info.value) == "'params.yaml' does not exist"


def test_load_from_raises_if_file_is_directory(tmp_dir, dvc):
data_dir = tmp_dir / "data"
data_dir.mkdir()
(tmp_dir / "data").mkdir()

with pytest.raises(ParamsLoadError) as exc_info:
Context.load_from(dvc.fs, data_dir)
Context.load_from(dvc.fs, "data")

assert str(exc_info.value) == "'data' is a directory"