diff --git a/dvc/parsing/__init__.py b/dvc/parsing/__init__.py index ea30dfd785..7efa4616cd 100644 --- a/dvc/parsing/__init__.py +++ b/dvc/parsing/__init__.py @@ -1,4 +1,5 @@ import logging +import os from collections.abc import Mapping, Sequence from copy import deepcopy from typing import ( @@ -134,9 +135,19 @@ def make_definition( class DataResolver: def __init__(self, repo: "Repo", wdir: str, d: dict): + from dvc.fs import LocalFileSystem + self.fs = fs = repo.fs + + if os.path.isabs(wdir): + start = ( + os.curdir if isinstance(fs, LocalFileSystem) else repo.root_dir + ) + wdir = relpath(wdir, start) + wdir = "" if wdir == os.curdir else wdir + self.wdir = wdir - self.relpath = relpath(fs.path.join(self.wdir, "dvc.yaml")) + self.relpath = os.path.normpath(fs.path.join(self.wdir, "dvc.yaml")) vars_ = d.get(VARS_KWD, []) check_interpolations(vars_, VARS_KWD, self.relpath) diff --git a/dvc/parsing/context.py b/dvc/parsing/context.py index 2e95594c3b..98db8852f5 100644 --- a/dvc/parsing/context.py +++ b/dvc/parsing/context.py @@ -19,7 +19,6 @@ recurse, str_interpolate, ) -from dvc.utils import relpath logger = logging.getLogger(__name__) SeqOrMap = Union[Sequence, Mapping] @@ -359,20 +358,19 @@ def load_from( ) -> "Context": from dvc.utils.serialize import LOADERS - file = relpath(path) if not fs.exists(path): - raise ParamsLoadError(f"'{file}' does not exist") + raise ParamsLoadError(f"'{path}' does not exist") if fs.isdir(path): - raise ParamsLoadError(f"'{file}' is a directory") + raise ParamsLoadError(f"'{path}' is a directory") - _, ext = os.path.splitext(file) + _, ext = os.path.splitext(path) loader = LOADERS[ext] data = loader(path, fs=fs) if not isinstance(data, Mapping): typ = type(data).__name__ raise ParamsLoadError( - f"expected a dictionary, got '{typ}' in file '{file}'" + f"expected a dictionary, got '{typ}' in file '{path}'" ) if select_keys: @@ -381,12 +379,12 @@ def load_from( except KeyError as exc: key, *_ = exc.args raise ParamsLoadError( - f"could not find '{key}' in '{file}'" + f"could not find '{key}' in '{path}'" ) from exc - meta = Meta(source=file, local=False) + meta = Meta(source=path, local=False) ctx = cls(data, meta=meta) - ctx.imports[os.path.abspath(path)] = select_keys + ctx.imports[path] = select_keys return ctx def merge_update(self, other: "Context", overwrite=False): @@ -397,26 +395,26 @@ def merge_update(self, other: "Context", overwrite=False): def merge_from(self, fs, item: str, wdir: str, overwrite=False): path, _, keys_str = item.partition(":") - select_keys = lfilter(bool, keys_str.split(",")) if keys_str else None + path = os.path.normpath(fs.path.join(wdir, path)) - abspath = os.path.abspath(fs.path.join(wdir, path)) - if abspath in self.imports: - if not select_keys and self.imports[abspath] is None: + select_keys = lfilter(bool, keys_str.split(",")) if keys_str else None + if path in self.imports: + if not select_keys and self.imports[path] is None: return # allow specifying complete filepath multiple times - self.check_loaded(abspath, item, select_keys) + self.check_loaded(path, item, select_keys) - ctx = Context.load_from(fs, abspath, select_keys) + ctx = Context.load_from(fs, path, select_keys) try: self.merge_update(ctx, overwrite=overwrite) except ReservedKeyError as exc: raise ReservedKeyError(exc.keys, item) from exc - cp = ctx.imports[abspath] - if abspath not in self.imports: - self.imports[abspath] = cp + cp = ctx.imports[path] + if path not in self.imports: + self.imports[path] = cp elif cp: - self.imports[abspath].extend(cp) + self.imports[path].extend(cp) def check_loaded(self, path, item, keys): if not keys and isinstance(self.imports[path], list): diff --git a/dvc/repo/params/show.py b/dvc/repo/params/show.py index 701dec1ce8..15d906a3b4 100644 --- a/dvc/repo/params/show.py +++ b/dvc/repo/params/show.py @@ -10,7 +10,12 @@ from dvc.scm import NoSCMError from dvc.stage import PipelineStage from dvc.ui import ui -from dvc.utils import error_handler, errored_revisions, onerror_collect +from dvc.utils import ( + error_handler, + errored_revisions, + onerror_collect, + relpath, +) from dvc.utils.serialize import LOADERS if TYPE_CHECKING: @@ -85,10 +90,20 @@ def _read_params( def _collect_vars(repo, params) -> Dict: + from dvc.fs.git import GitFileSystem + vars_params: Dict[str, Dict] = defaultdict(dict) + rel_to_root = relpath(repo.root_dir) + for stage in repo.index.stages: if isinstance(stage, PipelineStage) and stage.tracked_vars: for file, vars_ in stage.tracked_vars.items(): + if isinstance(repo.fs, GitFileSystem): + # GitFileSystem uses relatively-absolute paths from the + # root of the repo. We need to convert them to relative + # paths based on the current working directory. + file = os.path.normpath(os.path.join(rel_to_root, file)) + # `params` file are shown regardless of `tracked` or not # to reduce noise and duplication, they are skipped if file in params: diff --git a/tests/func/parsing/test_foreach.py b/tests/func/parsing/test_foreach.py index cfc10f5465..87ee97f74b 100644 --- a/tests/func/parsing/test_foreach.py +++ b/tests/func/parsing/test_foreach.py @@ -386,9 +386,7 @@ def test_foreach_with_interpolated_wdir_and_local_vars( } }, } - assert resolver.context.imports == { - str(tmp_dir / DEFAULT_PARAMS_FILE): None - } + assert resolver.context.imports == {DEFAULT_PARAMS_FILE: None} def test_foreach_do_syntax_is_checked_once(tmp_dir, dvc, mocker): diff --git a/tests/func/parsing/test_interpolated_entry.py b/tests/func/parsing/test_interpolated_entry.py index 21223624fd..6c9650d221 100644 --- a/tests/func/parsing/test_interpolated_entry.py +++ b/tests/func/parsing/test_interpolated_entry.py @@ -160,7 +160,7 @@ def test_with_templated_wdir(tmp_dir, dvc): DEFAULT_PARAMS_FILE: {"dict.bar": "bar", "dict.ws": "data"}, } } - assert resolver.context.imports == {str(tmp_dir / "params.yaml"): None} + assert resolver.context.imports == {"params.yaml": None} assert resolver.context == {"dict": {"bar": "bar", "ws": "data"}} @@ -236,7 +236,7 @@ def test_vars_relpath_overwrite(tmp_dir, dvc): } resolver = DataResolver(dvc, tmp_dir.fs_path, d) resolver.resolve() - assert resolver.context.imports == {str(tmp_dir / "params.yaml"): None} + assert resolver.context.imports == {"params.yaml": None} @pytest.mark.parametrize("local", [True, False]) diff --git a/tests/func/test_import.py b/tests/func/test_import.py index 59c732868d..1c9dd986ed 100644 --- a/tests/func/test_import.py +++ b/tests/func/test_import.py @@ -600,3 +600,33 @@ def test_circular_import(tmp_dir, dvc, scm, erepo_dir): erepo_dir.dvc.imp( os.fspath(tmp_dir), "dir_imported", "circular_import" ) + + +@pytest.mark.parametrize("paths", ([], ["dir"])) +def test_parameterized_repo(tmp_dir, dvc, scm, erepo_dir, paths): + path = erepo_dir.joinpath(*paths) + path.mkdir(parents=True, exist_ok=True) + (path / "params.yaml").dump({"out": "foo"}) + (path / "dvc.yaml").dump( + { + "stages": { + "train": {"cmd": "echo ${out} > ${out}", "outs": ["${out}"]}, + } + } + ) + path.gen({"foo": "foo"}) + with path.chdir(): + erepo_dir.dvc.commit(None, force=True) + erepo_dir.scm.add_commit( + ["params.yaml", "dvc.yaml", "dvc.lock", ".gitignore"], + message="init", + ) + + to_import = os.path.join(*paths, "foo") + stage = dvc.imp(os.fspath(erepo_dir), to_import, "foo_imported") + + assert (tmp_dir / "foo_imported").read_text() == "foo" + assert stage.deps[0].def_repo == { + "url": os.fspath(erepo_dir), + "rev_lock": erepo_dir.scm.get_rev(), + } diff --git a/tests/unit/test_context.py b/tests/unit/test_context.py index 5d942941c0..d89e2ed088 100644 --- a/tests/unit/test_context.py +++ b/tests/unit/test_context.py @@ -288,14 +288,13 @@ def test_track(tmp_dir): "dct": {"foo": "foo", "bar": "bar", "baz": "baz"}, } fs = LocalFileSystem() - path = tmp_dir / "params.yaml" - path.dump(d, fs=fs) + (tmp_dir / "params.yaml").dump(d, fs=fs) - context = Context.load_from(fs, path) + context = Context.load_from(fs, "params.yaml") def key_tracked(d, key): assert len(d) == 1 - return key in d[relpath(path)] + return key in d["params.yaml"] with context.track() as tracked: context.select("lst") @@ -323,10 +322,10 @@ def test_track_from_multiple_files(tmp_dir): d2 = {"Train": {"us": {"layers": 100}}} fs = LocalFileSystem() - path1 = tmp_dir / "params.yaml" - path2 = tmp_dir / "params2.yaml" - path1.dump(d1, fs=fs) - path2.dump(d2, fs=fs) + path1 = "params.yaml" + path2 = "params2.yaml" + (tmp_dir / path1).dump(d1, fs=fs) + (tmp_dir / path2).dump(d2, fs=fs) context = Context.load_from(fs, path1) c = Context.load_from(fs, path2) @@ -428,16 +427,15 @@ def test_resolve_resolves_boolean_value(): def test_load_from_raises_if_file_not_exist(tmp_dir, dvc): with pytest.raises(ParamsLoadError) as exc_info: - Context.load_from(dvc.fs, tmp_dir / DEFAULT_PARAMS_FILE) + Context.load_from(dvc.fs, DEFAULT_PARAMS_FILE) assert str(exc_info.value) == "'params.yaml' does not exist" def test_load_from_raises_if_file_is_directory(tmp_dir, dvc): - data_dir = tmp_dir / "data" - data_dir.mkdir() + (tmp_dir / "data").mkdir() with pytest.raises(ParamsLoadError) as exc_info: - Context.load_from(dvc.fs, data_dir) + Context.load_from(dvc.fs, "data") assert str(exc_info.value) == "'data' is a directory"