diff --git a/dvc/api/__init__.py b/dvc/api/__init__.py new file mode 100644 index 0000000000..dde993ec76 --- /dev/null +++ b/dvc/api/__init__.py @@ -0,0 +1,9 @@ +from .data import ( # noqa, pylint: disable=redefined-builtin + get_url, + open, + read, +) +from .experiments import make_checkpoint +from .params import params_show + +__all__ = ["get_url", "make_checkpoint", "open", "params_show", "read"] diff --git a/dvc/api.py b/dvc/api/data.py similarity index 88% rename from dvc/api.py rename to dvc/api/data.py index a57a67fefc..a063612f10 100644 --- a/dvc/api.py +++ b/dvc/api/data.py @@ -1,4 +1,3 @@ -import os from contextlib import _GeneratorContextManager as GCM from typing import Optional @@ -212,35 +211,3 @@ def read(path, repo=None, rev=None, remote=None, mode="r", encoding=None): path, repo=repo, rev=rev, remote=remote, mode=mode, encoding=encoding ) as fd: return fd.read() - - -def make_checkpoint(): - """ - Signal DVC to create a checkpoint experiment. - - If the current process is being run from DVC, this function will block - until DVC has finished creating the checkpoint. Otherwise, this function - will return immediately. - """ - import builtins - from time import sleep - - from dvc.env import DVC_CHECKPOINT, DVC_ROOT - from dvc.stage.monitor import CheckpointTask - - if os.getenv(DVC_CHECKPOINT) is None: - return - - root_dir = os.getenv(DVC_ROOT, Repo.find_root()) - signal_file = os.path.join( - root_dir, Repo.DVC_DIR, "tmp", CheckpointTask.SIGNAL_FILE - ) - - with builtins.open(signal_file, "w", encoding="utf-8") as fobj: - # NOTE: force flushing/writing empty file to disk, otherwise when - # run in certain contexts (pytest) file may not actually be written - fobj.write("") - fobj.flush() - os.fsync(fobj.fileno()) - while os.path.exists(signal_file): - sleep(0.1) diff --git a/dvc/api/experiments.py b/dvc/api/experiments.py new file mode 100644 index 0000000000..f6448eedb0 --- /dev/null +++ b/dvc/api/experiments.py @@ -0,0 +1,33 @@ +import builtins +import os +from time import sleep + +from dvc.env import DVC_CHECKPOINT, DVC_ROOT +from dvc.repo import Repo +from dvc.stage.monitor import CheckpointTask + + +def make_checkpoint(): + """ + Signal DVC to create a checkpoint experiment. + + If the current process is being run from DVC, this function will block + until DVC has finished creating the checkpoint. Otherwise, this function + will return immediately. + """ + if os.getenv(DVC_CHECKPOINT) is None: + return + + root_dir = os.getenv(DVC_ROOT, Repo.find_root()) + signal_file = os.path.join( + root_dir, Repo.DVC_DIR, "tmp", CheckpointTask.SIGNAL_FILE + ) + + with builtins.open(signal_file, "w", encoding="utf-8") as fobj: + # NOTE: force flushing/writing empty file to disk, otherwise when + # run in certain contexts (pytest) file may not actually be written + fobj.write("") + fobj.flush() + os.fsync(fobj.fileno()) + while os.path.exists(signal_file): + sleep(0.1) diff --git a/dvc/api/params.py b/dvc/api/params.py new file mode 100644 index 0000000000..d66b19a591 --- /dev/null +++ b/dvc/api/params.py @@ -0,0 +1,267 @@ +from collections import Counter +from typing import Dict, Iterable, Optional, Union + +from funcy import first + +from dvc.repo import Repo + + +def params_show( + *targets: str, + repo: Optional[str] = None, + stages: Optional[Union[str, Iterable[str]]] = None, + rev: Optional[str] = None, + deps: bool = False, +) -> Dict: + """Get parameters tracked in `repo`. + + Without arguments, this function will retrieve all params from all tracked + parameter files, for the current working tree. + + See the options below to restrict the parameters retrieved. + + Args: + *targets (str, optional): Names of the parameter files to retrieve + params from. For example, "params.py, myparams.toml". + If no `targets` are provided, all parameter files tracked in `dvc.yaml` + will be used. + Note that targets don't necessarily have to be defined in `dvc.yaml`. + repo (str, optional): location of the DVC repository. + Defaults to the current project (found by walking up from the + current working directory tree). + It can be a URL or a file system path. + Both HTTP and SSH protocols are supported for online Git repos + (e.g. [user@]server:project.git). + stages (Union[str, Iterable[str]], optional): Name or names of the + stages to retrieve parameters from. + Defaults to `None`. + If `None`, all parameters from all stages will be retrieved. + rev (str, optional): Name of the `Git revision`_ to retrieve parameters + from. + Defaults to `None`. + An example of git revision can be a branch or tag name, a commit + hash or a dvc experiment name. + If `repo` is not a Git repo, this option is ignored. + If `None`, the current working tree will be used. + deps (bool, optional): Whether to retrieve only parameters that are + stage dependencies or not. + Defaults to `False`. + + Returns: + Dict: See Examples below. + + Examples: + + - No arguments. + + Working on https://github.com/iterative/example-get-started + + >>> import json + >>> import dvc.api + >>> params = dvc.api.params_show() + >>> print(json.dumps(params, indent=4)) + { + "prepare": { + "split": 0.2, + "seed": 20170428 + }, + "featurize": { + "max_features": 200, + "ngrams": 2 + }, + "train": { + "seed": 20170428, + "n_est": 50, + "min_split": 0.01 + } + } + + --- + + - Filtering with `stages`. + + Working on https://github.com/iterative/example-get-started + + `stages` can a single string: + + >>> import json + >>> import dvc.api + >>> params = dvc.api.params_show(stages="prepare") + >>> print(json.dumps(params, indent=4)) + { + "prepare": { + "split": 0.2, + "seed": 20170428 + } + } + + Or an iterable of strings: + + >>> import json + >>> import dvc.api + >>> params = dvc.api.params_show(stages=["prepare", "train"]) + >>> print(json.dumps(params, indent=4)) + { + "prepare": { + "split": 0.2, + "seed": 20170428 + }, + "train": { + "seed": 20170428, + "n_est": 50, + "min_split": 0.01 + } + } + + --- + + - Using `rev`. + + Working on https://github.com/iterative/example-get-started + + >>> import json + >>> import dvc.api + >>> params = dvc.api.params_show(rev="tune-hyperparams") + >>> print(json.dumps(params, indent=4)) + { + "prepare": { + "split": 0.2, + "seed": 20170428 + }, + "featurize": { + "max_features": 200, + "ngrams": 2 + }, + "train": { + "seed": 20170428, + "n_est": 100, + "min_split": 8 + } + } + + --- + + - Using `targets`. + + Working on `multi-params-files` folder of + https://github.com/iterative/pipeline-conifguration + + You can pass a single target: + + >>> import json + >>> import dvc.api + >>> params = dvc.api.params_show("params.yaml") + >>> print(json.dumps(params, indent=4)) + { + "run_mode": "prod", + "configs": { + "dev": "configs/params_dev.yaml", + "test": "configs/params_test.yaml", + "prod": "configs/params_prod.yaml" + }, + "evaluate": { + "dataset": "micro", + "size": 5000, + "metrics": ["f1", "roc-auc"], + "metrics_file": "reports/metrics.json", + "plots_cm": "reports/plot_confusion_matrix.png" + } + } + + + Or multiple targets: + + >>> import json + >>> import dvc.api + >>> params = dvc.api.params_show( + ... "configs/params_dev.yaml", "configs/params_prod.yaml") + >>> print(json.dumps(params, indent=4)) + { + "configs/params_prod.yaml:run_mode": "prod", + "configs/params_prod.yaml:config_file": "configs/params_prod.yaml", + "configs/params_prod.yaml:data_load": { + "dataset": "large", + "sampling": { + "enable": true, + "size": 50000 + } + }, + "configs/params_prod.yaml:train": { + "epochs": 1000 + }, + "configs/params_dev.yaml:run_mode": "dev", + "configs/params_dev.yaml:config_file": "configs/params_dev.yaml", + "configs/params_dev.yaml:data_load": { + "dataset": "development", + "sampling": { + "enable": true, + "size": 1000 + } + }, + "configs/params_dev.yaml:train": { + "epochs": 10 + } + } + + --- + + - Git URL as `repo`. + + >>> import json + >>> import dvc.api + >>> params = dvc.api.params_show( + ... repo="https://github.com/iterative/demo-fashion-mnist") + { + "train": { + "batch_size": 128, + "hidden_units": 64, + "dropout": 0.4, + "num_epochs": 10, + "lr": 0.001, + "conv_activation": "relu" + } + } + + + .. _Git revision: + https://git-scm.com/docs/revisions + + """ + if isinstance(stages, str): + stages = [stages] + + def _onerror_raise(result: Dict, exception: Exception, *args, **kwargs): + raise exception + + def _postprocess(params): + processed = {} + for rev, rev_data in params.items(): + processed[rev] = {} + + counts = Counter() + for file_data in rev_data["data"].values(): + for k in file_data["data"]: + counts[k] += 1 + + for file_name, file_data in rev_data["data"].items(): + to_merge = { + (k if counts[k] == 1 else f"{file_name}:{k}"): v + for k, v in file_data["data"].items() + } + processed[rev] = {**processed[rev], **to_merge} + + if "workspace" in processed: + del processed["workspace"] + + return processed[first(processed)] + + with Repo.open(repo) as _repo: + params = _repo.params.show( + revs=rev if rev is None else [rev], + targets=targets, + deps=deps, + onerror=_onerror_raise, + stages=stages, + ) + + return _postprocess(params) diff --git a/dvc/repo/collect.py b/dvc/repo/collect.py index dda93ec496..a6000307ba 100644 --- a/dvc/repo/collect.py +++ b/dvc/repo/collect.py @@ -51,8 +51,8 @@ def _collect_paths( return target_paths -def _filter_duplicates( - outs: Outputs, fs_paths: StrPaths +def _filter_outs( + outs: Outputs, fs_paths: StrPaths, duplicates=False ) -> Tuple[Outputs, StrPaths]: res_outs: Outputs = [] fs_res_paths = fs_paths @@ -61,8 +61,9 @@ def _filter_duplicates( fs_path = out.repo.dvcfs.from_os_path(out.fs_path) if fs_path in fs_paths: res_outs.append(out) - # MUTATING THE SAME LIST!! - fs_res_paths.remove(fs_path) + if not duplicates: + # MUTATING THE SAME LIST!! + fs_res_paths.remove(fs_path) return res_outs, fs_res_paths @@ -74,6 +75,7 @@ def collect( output_filter: FilterFn = None, rev: str = None, recursive: bool = False, + duplicates: bool = False, ) -> Tuple[Outputs, StrPaths]: assert targets or output_filter @@ -85,4 +87,4 @@ def collect( target_paths = _collect_paths(repo, targets, recursive=recursive, rev=rev) - return _filter_duplicates(outs, target_paths) + return _filter_outs(outs, target_paths, duplicates=duplicates) diff --git a/dvc/repo/params/show.py b/dvc/repo/params/show.py index 2b55ec0d5f..573093f4a4 100644 --- a/dvc/repo/params/show.py +++ b/dvc/repo/params/show.py @@ -2,7 +2,15 @@ import os from collections import defaultdict from copy import copy -from typing import TYPE_CHECKING, Callable, Dict, List, Optional, Tuple +from typing import ( + TYPE_CHECKING, + Callable, + Dict, + Iterable, + List, + Optional, + Tuple, +) from scmrepo.exceptions import SCMError @@ -27,7 +35,7 @@ def _is_params(dep: "Output"): def _collect_configs( - repo: "Repo", rev, targets=None + repo: "Repo", rev, targets=None, duplicates=False ) -> Tuple[List["Output"], List[str]]: params, fs_paths = collect( @@ -36,6 +44,7 @@ def _collect_configs( deps=True, output_filter=_is_params, rev=rev, + duplicates=duplicates, ) all_fs_paths = fs_paths + [p.fs_path for p in params] if not targets: @@ -62,18 +71,23 @@ def _read_params( params_fs_paths, deps=False, onerror: Optional[Callable] = None, + stages: Optional[Iterable[str]] = None, ): - res: Dict[str, Dict] = defaultdict(dict) + res: Dict[str, Dict] = defaultdict(lambda: defaultdict(dict)) fs_paths = copy(params_fs_paths) - if deps: + if deps or stages: for param in params: + if stages and param.stage.addressing not in stages: + continue params_dict = error_handler(param.read_params)( onerror=onerror, flatten=False ) if params_dict: name = os.sep.join(repo.fs.path.relparts(param.fs_path)) - res[name] = params_dict + res[name]["data"].update(params_dict["data"]) + if name in fs_paths: + fs_paths.remove(name) else: fs_paths += [param.fs_path for param in params] @@ -86,11 +100,13 @@ def _read_params( return res -def _collect_vars(repo, params) -> Dict: +def _collect_vars(repo, params, stages=None) -> Dict: vars_params: Dict[str, Dict] = defaultdict(dict) for stage in repo.index.stages: if isinstance(stage, PipelineStage) and stage.tracked_vars: + if stages and stage.addressing not in stages: + continue for file, vars_ in stage.tracked_vars.items(): # `params` file are shown regardless of `tracked` or not # to reduce noise and duplication, they are skipped @@ -103,14 +119,26 @@ def _collect_vars(repo, params) -> Dict: @locked -def show(repo, revs=None, targets=None, deps=False, onerror: Callable = None): +def show( + repo, + revs=None, + targets=None, + deps=False, + onerror: Callable = None, + stages=None, +): if onerror is None: onerror = onerror_collect res = {} for branch in repo.brancher(revs=revs): params = error_handler(_gather_params)( - repo=repo, rev=branch, targets=targets, deps=deps, onerror=onerror + repo=repo, + rev=branch, + targets=targets, + deps=deps, + onerror=onerror, + stages=stages, ) if params: @@ -137,16 +165,21 @@ def show(repo, revs=None, targets=None, deps=False, onerror: Callable = None): return res -def _gather_params(repo, rev, targets=None, deps=False, onerror=None): - param_outs, params_fs_paths = _collect_configs(repo, rev, targets=targets) +def _gather_params( + repo, rev, targets=None, deps=False, onerror=None, stages=None +): + param_outs, params_fs_paths = _collect_configs( + repo, rev, targets=targets, duplicates=deps or stages + ) params = _read_params( repo, params=param_outs, params_fs_paths=params_fs_paths, deps=deps, onerror=onerror, + stages=stages, ) - vars_params = _collect_vars(repo, params) + vars_params = _collect_vars(repo, params, stages=stages) # NOTE: only those that are not added as a ParamDependency are # included so we don't need to recursively merge them yet. diff --git a/dvc/testing/tmp_dir.py b/dvc/testing/tmp_dir.py index 2517166eb4..ea4cd0d777 100644 --- a/dvc/testing/tmp_dir.py +++ b/dvc/testing/tmp_dir.py @@ -259,6 +259,15 @@ def modify(self, *args, **kwargs): dump_toml = partialmethod(serialize.dump_toml) +def make_subrepo(dir_: TmpDir, scm, config=None): + dir_.mkdir(parents=True, exist_ok=True) + with dir_.chdir(): + dir_.scm = scm + dir_.init(dvc=True, subdir=True) + if config: + dir_.add_remote(config=config) + + def _coerce_filenames(filenames): if isinstance(filenames, (str, bytes, pathlib.PurePath)): filenames = [filenames] diff --git a/tests/func/test_api.py b/tests/func/api/test_data.py similarity index 99% rename from tests/func/test_api.py rename to tests/func/api/test_data.py index 44bfd54f21..d73226175d 100644 --- a/tests/func/test_api.py +++ b/tests/func/api/test_data.py @@ -10,8 +10,8 @@ PathMissingError, ) from dvc.testing.test_api import TestAPI # noqa, pylint: disable=unused-import +from dvc.testing.tmp_dir import make_subrepo from dvc.utils.fs import remove -from tests.unit.fs.test_dvc import make_subrepo def test_get_url_external(tmp_dir, erepo_dir, cloud): diff --git a/tests/func/api/test_params.py b/tests/func/api/test_params.py new file mode 100644 index 0000000000..9a2fb0f83b --- /dev/null +++ b/tests/func/api/test_params.py @@ -0,0 +1,130 @@ +from textwrap import dedent + +import pytest + +from dvc import api + + +@pytest.fixture +def params_repo(tmp_dir, scm, dvc): + tmp_dir.gen("params.yaml", "foo: 1") + tmp_dir.gen("params.json", '{"bar": 2, "foobar": 3}') + tmp_dir.gen("other_params.json", '{"foo": {"bar": 4}}') + + dvc.run( + name="stage-1", + cmd="echo stage-1", + params=["foo", "params.json:bar"], + ) + + dvc.run( + name="stage-2", + cmd="echo stage-2", + params=["other_params.json:foo"], + ) + + dvc.run( + name="stage-3", + cmd="echo stage-2", + params=["params.json:foobar"], + ) + + scm.add( + [ + "params.yaml", + "params.json", + "other_params.json", + "dvc.yaml", + "dvc.lock", + ] + ) + scm.commit("commit dvc files") + + tmp_dir.gen("params.yaml", "foo: 5") + scm.add(["params.yaml"]) + scm.commit("update params.yaml") + + +def test_params_show_no_args(params_repo): + assert api.params_show() == { + "params.yaml:foo": 5, + "bar": 2, + "foobar": 3, + "other_params.json:foo": {"bar": 4}, + } + + +def test_params_show_targets(params_repo): + assert api.params_show("params.yaml") == {"foo": 5} + assert api.params_show("params.yaml", "params.json") == { + "foo": 5, + "bar": 2, + "foobar": 3, + } + + +def test_params_show_deps(params_repo): + params = api.params_show(deps=True) + assert params == { + "params.yaml:foo": 5, + "bar": 2, + "foobar": 3, + "other_params.json:foo": {"bar": 4}, + } + + +def test_params_show_stages(params_repo): + assert api.params_show(stages="stage-2") == {"foo": {"bar": 4}} + + assert api.params_show() == api.params_show( + stages=["stage-1", "stage-2", "stage-3"] + ) + + assert api.params_show("params.json", stages="stage-3") == {"foobar": 3} + + +def test_params_show_revs(params_repo): + assert api.params_show(rev="HEAD~1") == { + "params.yaml:foo": 1, + "bar": 2, + "foobar": 3, + "other_params.json:foo": {"bar": 4}, + } + + +def test_params_show_while_running_stage(tmp_dir, dvc): + (tmp_dir / "params.yaml").dump({"foo": {"bar": 1}}) + (tmp_dir / "params.json").dump({"bar": 2}) + + tmp_dir.gen( + "merge.py", + dedent( + """ + import json + from dvc import api + with open("merged.json", "w") as f: + json.dump(api.params_show(stages="merge"), f) + """ + ), + ) + dvc.stage.add( + name="merge", + cmd="python merge.py", + params=["foo.bar", {"params.json": ["bar"]}], + outs=["merged.json"], + ) + + dvc.reproduce() + + assert (tmp_dir / "merged.json").parse() == {"foo": {"bar": 1}, "bar": 2} + + +def test_params_show_repo(tmp_dir, erepo_dir): + with erepo_dir.chdir(): + erepo_dir.scm_gen("params.yaml", "foo: 1", commit="Create params.yaml") + erepo_dir.dvc.run( + name="stage-1", + cmd="echo stage-1", + params=["foo"], + ) + assert api.params_show(repo=erepo_dir) == {"foo": 1} diff --git a/tests/func/experiments/test_experiments.py b/tests/func/experiments/test_experiments.py index 6ee62ff72e..3988e44db0 100644 --- a/tests/func/experiments/test_experiments.py +++ b/tests/func/experiments/test_experiments.py @@ -544,7 +544,7 @@ def test_subdir(tmp_dir, scm, dvc, workspace): @pytest.mark.parametrize("workspace", [True, False]) def test_subrepo(tmp_dir, scm, workspace): - from tests.unit.fs.test_dvc import make_subrepo + from dvc.testing.tmp_dir import make_subrepo subrepo = tmp_dir / "dir" / "repo" make_subrepo(subrepo, scm) diff --git a/tests/func/params/test_show.py b/tests/func/params/test_show.py index b7ea742dbe..7609145b58 100644 --- a/tests/func/params/test_show.py +++ b/tests/func/params/test_show.py @@ -192,3 +192,50 @@ def test_show_without_targets_specified(tmp_dir, dvc, scm, file): ) assert dvc.params.show() == {"": {"data": {file: {"data": data}}}} + + +def test_deps_multi_stage(tmp_dir, scm, dvc, run_copy): + tmp_dir.gen( + {"foo": "foo", "params.yaml": "foo: bar\nxyz: val\nabc: ignore"} + ) + run_copy("foo", "bar", name="copy-foo-bar", params=["foo"]) + run_copy("foo", "bar1", name="copy-foo-bar-1", params=["xyz"]) + + scm.add(["params.yaml", PIPELINE_FILE]) + scm.commit("add stage") + + assert dvc.params.show(revs=["master"], deps=True) == { + "master": { + "data": {"params.yaml": {"data": {"foo": "bar", "xyz": "val"}}} + } + } + + +def test_deps_with_targets(tmp_dir, scm, dvc, run_copy): + tmp_dir.gen( + {"foo": "foo", "params.yaml": "foo: bar\nxyz: val\nabc: ignore"} + ) + run_copy("foo", "bar", name="copy-foo-bar", params=["foo"]) + run_copy("foo", "bar1", name="copy-foo-bar-1", params=["xyz"]) + + scm.add(["params.yaml", PIPELINE_FILE]) + scm.commit("add stage") + + assert dvc.params.show(targets=["params.yaml"], deps=True) == { + "": {"data": {"params.yaml": {"data": {"foo": "bar", "xyz": "val"}}}} + } + + +def test_deps_with_bad_target(tmp_dir, scm, dvc, run_copy): + tmp_dir.gen( + { + "foo": "foo", + "foobar": "", + "params.yaml": "foo: bar\nxyz: val\nabc: ignore", + } + ) + run_copy("foo", "bar", name="copy-foo-bar", params=["foo"]) + run_copy("foo", "bar1", name="copy-foo-bar-1", params=["xyz"]) + scm.add(["params.yaml", PIPELINE_FILE]) + scm.commit("add stage") + assert dvc.params.show(targets=["foobar"], deps=True) == {} diff --git a/tests/func/test_external_repo.py b/tests/func/test_external_repo.py index 36e2bad56e..6e31edef48 100644 --- a/tests/func/test_external_repo.py +++ b/tests/func/test_external_repo.py @@ -4,11 +4,11 @@ from scmrepo.git import Git from dvc.external_repo import CLONES, external_repo +from dvc.testing.tmp_dir import make_subrepo from dvc.utils import relpath from dvc.utils.fs import makedirs, remove from dvc_data.stage import stage from dvc_data.transfer import transfer -from tests.unit.fs.test_dvc import make_subrepo def test_external_repo(erepo_dir, mocker): diff --git a/tests/func/test_get.py b/tests/func/test_get.py index fb477dfe0b..d2d7eb3cfb 100644 --- a/tests/func/test_get.py +++ b/tests/func/test_get.py @@ -8,8 +8,8 @@ from dvc.odbmgr import ODBManager from dvc.repo import Repo from dvc.repo.get import GetDVCFileError +from dvc.testing.tmp_dir import make_subrepo from dvc.utils.fs import makedirs -from tests.unit.fs.test_dvc import make_subrepo def test_get_repo_file(tmp_dir, erepo_dir): diff --git a/tests/func/test_import.py b/tests/func/test_import.py index a0c160f960..6de58fe4c5 100644 --- a/tests/func/test_import.py +++ b/tests/func/test_import.py @@ -12,8 +12,8 @@ from dvc.fs import system from dvc.odbmgr import ODBManager from dvc.stage.exceptions import StagePathNotFoundError +from dvc.testing.tmp_dir import make_subrepo from dvc.utils.fs import makedirs, remove -from tests.unit.fs.test_dvc import make_subrepo def test_import(tmp_dir, scm, dvc, erepo_dir): diff --git a/tests/func/test_update.py b/tests/func/test_update.py index 94ba2c9d36..37f98f314a 100644 --- a/tests/func/test_update.py +++ b/tests/func/test_update.py @@ -4,7 +4,7 @@ from dvc.dvcfile import Dvcfile from dvc.exceptions import InvalidArgumentError -from tests.unit.fs.test_dvc import make_subrepo +from dvc.testing.tmp_dir import make_subrepo @pytest.mark.parametrize("cached", [True, False]) diff --git a/tests/unit/fs/test_dvc.py b/tests/unit/fs/test_dvc.py index 06e48f1b3e..f7c6dec7bb 100644 --- a/tests/unit/fs/test_dvc.py +++ b/tests/unit/fs/test_dvc.py @@ -6,6 +6,7 @@ import pytest from dvc.fs.dvc import DvcFileSystem +from dvc.testing.tmp_dir import make_subrepo from dvc_data.hashfile.hash_info import HashInfo from dvc_data.stage import stage @@ -318,15 +319,6 @@ def test_isdvc(tmp_dir, dvc): assert fs.isdvc("dir/baz", recursive=True) -def make_subrepo(dir_, scm, config=None): - dir_.mkdir(parents=True, exist_ok=True) - with dir_.chdir(): - dir_.scm = scm - dir_.init(dvc=True, subdir=True) - if config: - dir_.add_remote(config=config) - - def test_subrepos(tmp_dir, scm, dvc, mocker): tmp_dir.scm_gen( {"dir": {"repo.txt": "file to confuse DvcFileSystem"}}, diff --git a/tests/unit/fs/test_dvc_info.py b/tests/unit/fs/test_dvc_info.py index 03bdec1f7c..dc175662e5 100644 --- a/tests/unit/fs/test_dvc_info.py +++ b/tests/unit/fs/test_dvc_info.py @@ -3,7 +3,7 @@ import pytest from dvc.fs.dvc import DvcFileSystem -from tests.unit.fs.test_dvc import make_subrepo +from dvc.testing.tmp_dir import make_subrepo @pytest.fixture diff --git a/tests/unit/test_collect.py b/tests/unit/test_collect.py index 9237e436f5..0e64c42126 100644 --- a/tests/unit/test_collect.py +++ b/tests/unit/test_collect.py @@ -8,3 +8,20 @@ def test_no_file_on_target_rev(tmp_dir, scm, dvc, caplog): collect(dvc, targets=["file.yaml"], rev="current_branch") assert "'file.yaml' was not found at: 'current_branch'." in caplog.text + + +def test_collect_duplicates(tmp_dir, scm, dvc): + tmp_dir.gen("params.yaml", "foo: 1\nbar: 2") + tmp_dir.gen("foobar", "") + + dvc.run(name="stage-1", cmd="echo stage-1", params=["foo"]) + dvc.run(name="stage-2", cmd="echo stage-2", params=["bar"]) + + outs, _ = collect(dvc, deps=True, targets=["params.yaml"]) + assert len(outs) == 1 + + outs, _ = collect(dvc, deps=True, targets=["params.yaml"], duplicates=True) + assert len(outs) == 2 + + outs, _ = collect(dvc, deps=True, targets=["foobar"], duplicates=True) + assert not outs diff --git a/tests/unit/test_external_repo.py b/tests/unit/test_external_repo.py index 7f2f680b71..c69b46355b 100644 --- a/tests/unit/test_external_repo.py +++ b/tests/unit/test_external_repo.py @@ -4,7 +4,7 @@ import pytest from dvc.external_repo import external_repo -from tests.unit.fs.test_dvc import make_subrepo +from dvc.testing.tmp_dir import make_subrepo def test_hook_is_called(tmp_dir, erepo_dir, mocker):