From b21369b6dac9d2eb192c3ae2611dbb66e67dcbd4 Mon Sep 17 00:00:00 2001 From: David de la Iglesia Castro Date: Mon, 6 Jun 2022 19:24:56 +0200 Subject: [PATCH 1/4] repo.collect: Add `duplicates` option. Defaults to `False`. If `True`, multiple `outs` sharing a provided `target_path` will not be filtered. --- dvc/repo/collect.py | 12 +++++++----- tests/unit/test_collect.py | 17 +++++++++++++++++ 2 files changed, 24 insertions(+), 5 deletions(-) diff --git a/dvc/repo/collect.py b/dvc/repo/collect.py index dda93ec496..a6000307ba 100644 --- a/dvc/repo/collect.py +++ b/dvc/repo/collect.py @@ -51,8 +51,8 @@ def _collect_paths( return target_paths -def _filter_duplicates( - outs: Outputs, fs_paths: StrPaths +def _filter_outs( + outs: Outputs, fs_paths: StrPaths, duplicates=False ) -> Tuple[Outputs, StrPaths]: res_outs: Outputs = [] fs_res_paths = fs_paths @@ -61,8 +61,9 @@ def _filter_duplicates( fs_path = out.repo.dvcfs.from_os_path(out.fs_path) if fs_path in fs_paths: res_outs.append(out) - # MUTATING THE SAME LIST!! - fs_res_paths.remove(fs_path) + if not duplicates: + # MUTATING THE SAME LIST!! + fs_res_paths.remove(fs_path) return res_outs, fs_res_paths @@ -74,6 +75,7 @@ def collect( output_filter: FilterFn = None, rev: str = None, recursive: bool = False, + duplicates: bool = False, ) -> Tuple[Outputs, StrPaths]: assert targets or output_filter @@ -85,4 +87,4 @@ def collect( target_paths = _collect_paths(repo, targets, recursive=recursive, rev=rev) - return _filter_duplicates(outs, target_paths) + return _filter_outs(outs, target_paths, duplicates=duplicates) diff --git a/tests/unit/test_collect.py b/tests/unit/test_collect.py index 9237e436f5..0e64c42126 100644 --- a/tests/unit/test_collect.py +++ b/tests/unit/test_collect.py @@ -8,3 +8,20 @@ def test_no_file_on_target_rev(tmp_dir, scm, dvc, caplog): collect(dvc, targets=["file.yaml"], rev="current_branch") assert "'file.yaml' was not found at: 'current_branch'." in caplog.text + + +def test_collect_duplicates(tmp_dir, scm, dvc): + tmp_dir.gen("params.yaml", "foo: 1\nbar: 2") + tmp_dir.gen("foobar", "") + + dvc.run(name="stage-1", cmd="echo stage-1", params=["foo"]) + dvc.run(name="stage-2", cmd="echo stage-2", params=["bar"]) + + outs, _ = collect(dvc, deps=True, targets=["params.yaml"]) + assert len(outs) == 1 + + outs, _ = collect(dvc, deps=True, targets=["params.yaml"], duplicates=True) + assert len(outs) == 2 + + outs, _ = collect(dvc, deps=True, targets=["foobar"], duplicates=True) + assert not outs From 7e89223932e30b7b4d1102ea8f7f2aa059165386 Mon Sep 17 00:00:00 2001 From: David de la Iglesia Castro Date: Sat, 4 Jun 2022 16:31:00 +0200 Subject: [PATCH 2/4] params.show: Fix `deps` for stages using same params file. --- dvc/repo/params/show.py | 11 +++++--- tests/func/params/test_show.py | 47 ++++++++++++++++++++++++++++++++++ 2 files changed, 54 insertions(+), 4 deletions(-) diff --git a/dvc/repo/params/show.py b/dvc/repo/params/show.py index 2b55ec0d5f..668577af2e 100644 --- a/dvc/repo/params/show.py +++ b/dvc/repo/params/show.py @@ -27,7 +27,7 @@ def _is_params(dep: "Output"): def _collect_configs( - repo: "Repo", rev, targets=None + repo: "Repo", rev, targets=None, duplicates=False ) -> Tuple[List["Output"], List[str]]: params, fs_paths = collect( @@ -36,6 +36,7 @@ def _collect_configs( deps=True, output_filter=_is_params, rev=rev, + duplicates=duplicates, ) all_fs_paths = fs_paths + [p.fs_path for p in params] if not targets: @@ -63,7 +64,7 @@ def _read_params( deps=False, onerror: Optional[Callable] = None, ): - res: Dict[str, Dict] = defaultdict(dict) + res: Dict[str, Dict] = defaultdict(lambda: defaultdict(dict)) fs_paths = copy(params_fs_paths) if deps: @@ -73,7 +74,7 @@ def _read_params( ) if params_dict: name = os.sep.join(repo.fs.path.relparts(param.fs_path)) - res[name] = params_dict + res[name]["data"].update(params_dict["data"]) else: fs_paths += [param.fs_path for param in params] @@ -138,7 +139,9 @@ def show(repo, revs=None, targets=None, deps=False, onerror: Callable = None): def _gather_params(repo, rev, targets=None, deps=False, onerror=None): - param_outs, params_fs_paths = _collect_configs(repo, rev, targets=targets) + param_outs, params_fs_paths = _collect_configs( + repo, rev, targets=targets, duplicates=deps + ) params = _read_params( repo, params=param_outs, diff --git a/tests/func/params/test_show.py b/tests/func/params/test_show.py index b7ea742dbe..7609145b58 100644 --- a/tests/func/params/test_show.py +++ b/tests/func/params/test_show.py @@ -192,3 +192,50 @@ def test_show_without_targets_specified(tmp_dir, dvc, scm, file): ) assert dvc.params.show() == {"": {"data": {file: {"data": data}}}} + + +def test_deps_multi_stage(tmp_dir, scm, dvc, run_copy): + tmp_dir.gen( + {"foo": "foo", "params.yaml": "foo: bar\nxyz: val\nabc: ignore"} + ) + run_copy("foo", "bar", name="copy-foo-bar", params=["foo"]) + run_copy("foo", "bar1", name="copy-foo-bar-1", params=["xyz"]) + + scm.add(["params.yaml", PIPELINE_FILE]) + scm.commit("add stage") + + assert dvc.params.show(revs=["master"], deps=True) == { + "master": { + "data": {"params.yaml": {"data": {"foo": "bar", "xyz": "val"}}} + } + } + + +def test_deps_with_targets(tmp_dir, scm, dvc, run_copy): + tmp_dir.gen( + {"foo": "foo", "params.yaml": "foo: bar\nxyz: val\nabc: ignore"} + ) + run_copy("foo", "bar", name="copy-foo-bar", params=["foo"]) + run_copy("foo", "bar1", name="copy-foo-bar-1", params=["xyz"]) + + scm.add(["params.yaml", PIPELINE_FILE]) + scm.commit("add stage") + + assert dvc.params.show(targets=["params.yaml"], deps=True) == { + "": {"data": {"params.yaml": {"data": {"foo": "bar", "xyz": "val"}}}} + } + + +def test_deps_with_bad_target(tmp_dir, scm, dvc, run_copy): + tmp_dir.gen( + { + "foo": "foo", + "foobar": "", + "params.yaml": "foo: bar\nxyz: val\nabc: ignore", + } + ) + run_copy("foo", "bar", name="copy-foo-bar", params=["foo"]) + run_copy("foo", "bar1", name="copy-foo-bar-1", params=["xyz"]) + scm.add(["params.yaml", PIPELINE_FILE]) + scm.commit("add stage") + assert dvc.params.show(targets=["foobar"], deps=True) == {} From 574d50c7dc22e52c46cebe9ad484fbfded498c22 Mon Sep 17 00:00:00 2001 From: David de la Iglesia Castro Date: Thu, 2 Jun 2022 12:32:40 +0200 Subject: [PATCH 3/4] api: Add `params_show`. Closes #6507 Uses `repo.params.show` with custom error_handler and postprocess the outputs for more user-friendly structure. Extend `repo.params.show` to accept `stages` argument to cover the "params of current stage" use case. --- dvc/api.py | 266 +++++++++++++++++++++++++++++++++++++++- dvc/repo/params/show.py | 46 +++++-- tests/func/test_api.py | 126 +++++++++++++++++++ 3 files changed, 428 insertions(+), 10 deletions(-) diff --git a/dvc/api.py b/dvc/api.py index a57a67fefc..61ff23af46 100644 --- a/dvc/api.py +++ b/dvc/api.py @@ -1,8 +1,9 @@ import os +from collections import Counter from contextlib import _GeneratorContextManager as GCM -from typing import Optional +from typing import Dict, Iterable, Optional, Union -from funcy import reraise +from funcy import first, reraise from dvc.exceptions import OutputNotFoundError, PathMissingError from dvc.repo import Repo @@ -214,6 +215,267 @@ def read(path, repo=None, rev=None, remote=None, mode="r", encoding=None): return fd.read() +def params_show( + *targets: str, + repo: Optional[str] = None, + stages: Optional[Union[str, Iterable[str]]] = None, + rev: Optional[str] = None, + deps: bool = False, +) -> Dict: + """Get parameters tracked in `repo`. + + Without arguments, this function will retrieve all params from all tracked + parameter files, for the current working tree. + + See the options below to restrict the parameters retrieved. + + Args: + *targets (str, optional): Names of the parameter files to retrieve + params from. For example, "params.py, myparams.toml". + If no `targets` are provided, all parameter files tracked in `dvc.yaml` + will be used. + Note that targets don't necessarily have to be defined in `dvc.yaml`. + repo (str, optional): location of the DVC repository. + Defaults to the current project (found by walking up from the + current working directory tree). + It can be a URL or a file system path. + Both HTTP and SSH protocols are supported for online Git repos + (e.g. [user@]server:project.git). + stages (Union[str, Iterable[str]], optional): Name or names of the + stages to retrieve parameters from. + Defaults to `None`. + If `None`, all parameters from all stages will be retrieved. + rev (str, optional): Name of the `Git revision`_ to retrieve parameters + from. + Defaults to `None`. + An example of git revision can be a branch or tag name, a commit + hash or a dvc experiment name. + If `repo` is not a Git repo, this option is ignored. + If `None`, the current working tree will be used. + deps (bool, optional): Whether to retrieve only parameters that are + stage dependencies or not. + Defaults to `False`. + + Returns: + Dict: See Examples below. + + Examples: + + - No arguments. + + Working on https://github.com/iterative/example-get-started + + >>> import json + >>> import dvc.api + >>> params = dvc.api.params_show() + >>> print(json.dumps(params, indent=4)) + { + "prepare": { + "split": 0.2, + "seed": 20170428 + }, + "featurize": { + "max_features": 200, + "ngrams": 2 + }, + "train": { + "seed": 20170428, + "n_est": 50, + "min_split": 0.01 + } + } + + --- + + - Filtering with `stages`. + + Working on https://github.com/iterative/example-get-started + + `stages` can a single string: + + >>> import json + >>> import dvc.api + >>> params = dvc.api.params_show(stages="prepare") + >>> print(json.dumps(params, indent=4)) + { + "prepare": { + "split": 0.2, + "seed": 20170428 + } + } + + Or an iterable of strings: + + >>> import json + >>> import dvc.api + >>> params = dvc.api.params_show(stages=["prepare", "train"]) + >>> print(json.dumps(params, indent=4)) + { + "prepare": { + "split": 0.2, + "seed": 20170428 + }, + "train": { + "seed": 20170428, + "n_est": 50, + "min_split": 0.01 + } + } + + --- + + - Using `rev`. + + Working on https://github.com/iterative/example-get-started + + >>> import json + >>> import dvc.api + >>> params = dvc.api.params_show(rev="tune-hyperparams") + >>> print(json.dumps(params, indent=4)) + { + "prepare": { + "split": 0.2, + "seed": 20170428 + }, + "featurize": { + "max_features": 200, + "ngrams": 2 + }, + "train": { + "seed": 20170428, + "n_est": 100, + "min_split": 8 + } + } + + --- + + - Using `targets`. + + Working on `multi-params-files` folder of + https://github.com/iterative/pipeline-conifguration + + You can pass a single target: + + >>> import json + >>> import dvc.api + >>> params = dvc.api.params_show("params.yaml") + >>> print(json.dumps(params, indent=4)) + { + "run_mode": "prod", + "configs": { + "dev": "configs/params_dev.yaml", + "test": "configs/params_test.yaml", + "prod": "configs/params_prod.yaml" + }, + "evaluate": { + "dataset": "micro", + "size": 5000, + "metrics": ["f1", "roc-auc"], + "metrics_file": "reports/metrics.json", + "plots_cm": "reports/plot_confusion_matrix.png" + } + } + + + Or multiple targets: + + >>> import json + >>> import dvc.api + >>> params = dvc.api.params_show( + ... "configs/params_dev.yaml", "configs/params_prod.yaml") + >>> print(json.dumps(params, indent=4)) + { + "configs/params_prod.yaml:run_mode": "prod", + "configs/params_prod.yaml:config_file": "configs/params_prod.yaml", + "configs/params_prod.yaml:data_load": { + "dataset": "large", + "sampling": { + "enable": true, + "size": 50000 + } + }, + "configs/params_prod.yaml:train": { + "epochs": 1000 + }, + "configs/params_dev.yaml:run_mode": "dev", + "configs/params_dev.yaml:config_file": "configs/params_dev.yaml", + "configs/params_dev.yaml:data_load": { + "dataset": "development", + "sampling": { + "enable": true, + "size": 1000 + } + }, + "configs/params_dev.yaml:train": { + "epochs": 10 + } + } + + --- + + - Git URL as `repo`. + + >>> import json + >>> import dvc.api + >>> params = dvc.api.params_show( + ... repo="https://github.com/iterative/demo-fashion-mnist") + { + "train": { + "batch_size": 128, + "hidden_units": 64, + "dropout": 0.4, + "num_epochs": 10, + "lr": 0.001, + "conv_activation": "relu" + } + } + + + .. _Git revision: + https://git-scm.com/docs/revisions + + """ + if isinstance(stages, str): + stages = [stages] + + def _onerror_raise(result: Dict, exception: Exception, *args, **kwargs): + raise exception + + def _postprocess(params): + processed = {} + for rev, rev_data in params.items(): + processed[rev] = {} + + counts = Counter() + for file_data in rev_data["data"].values(): + for k in file_data["data"]: + counts[k] += 1 + + for file_name, file_data in rev_data["data"].items(): + to_merge = { + (k if counts[k] == 1 else f"{file_name}:{k}"): v + for k, v in file_data["data"].items() + } + processed[rev] = {**processed[rev], **to_merge} + + if "workspace" in processed: + del processed["workspace"] + + return processed[first(processed)] + + with Repo.open(repo) as _repo: + params = _repo.params.show( + revs=rev if rev is None else [rev], + targets=targets, + deps=deps, + onerror=_onerror_raise, + stages=stages, + ) + + return _postprocess(params) + + def make_checkpoint(): """ Signal DVC to create a checkpoint experiment. diff --git a/dvc/repo/params/show.py b/dvc/repo/params/show.py index 668577af2e..573093f4a4 100644 --- a/dvc/repo/params/show.py +++ b/dvc/repo/params/show.py @@ -2,7 +2,15 @@ import os from collections import defaultdict from copy import copy -from typing import TYPE_CHECKING, Callable, Dict, List, Optional, Tuple +from typing import ( + TYPE_CHECKING, + Callable, + Dict, + Iterable, + List, + Optional, + Tuple, +) from scmrepo.exceptions import SCMError @@ -63,18 +71,23 @@ def _read_params( params_fs_paths, deps=False, onerror: Optional[Callable] = None, + stages: Optional[Iterable[str]] = None, ): res: Dict[str, Dict] = defaultdict(lambda: defaultdict(dict)) fs_paths = copy(params_fs_paths) - if deps: + if deps or stages: for param in params: + if stages and param.stage.addressing not in stages: + continue params_dict = error_handler(param.read_params)( onerror=onerror, flatten=False ) if params_dict: name = os.sep.join(repo.fs.path.relparts(param.fs_path)) res[name]["data"].update(params_dict["data"]) + if name in fs_paths: + fs_paths.remove(name) else: fs_paths += [param.fs_path for param in params] @@ -87,11 +100,13 @@ def _read_params( return res -def _collect_vars(repo, params) -> Dict: +def _collect_vars(repo, params, stages=None) -> Dict: vars_params: Dict[str, Dict] = defaultdict(dict) for stage in repo.index.stages: if isinstance(stage, PipelineStage) and stage.tracked_vars: + if stages and stage.addressing not in stages: + continue for file, vars_ in stage.tracked_vars.items(): # `params` file are shown regardless of `tracked` or not # to reduce noise and duplication, they are skipped @@ -104,14 +119,26 @@ def _collect_vars(repo, params) -> Dict: @locked -def show(repo, revs=None, targets=None, deps=False, onerror: Callable = None): +def show( + repo, + revs=None, + targets=None, + deps=False, + onerror: Callable = None, + stages=None, +): if onerror is None: onerror = onerror_collect res = {} for branch in repo.brancher(revs=revs): params = error_handler(_gather_params)( - repo=repo, rev=branch, targets=targets, deps=deps, onerror=onerror + repo=repo, + rev=branch, + targets=targets, + deps=deps, + onerror=onerror, + stages=stages, ) if params: @@ -138,9 +165,11 @@ def show(repo, revs=None, targets=None, deps=False, onerror: Callable = None): return res -def _gather_params(repo, rev, targets=None, deps=False, onerror=None): +def _gather_params( + repo, rev, targets=None, deps=False, onerror=None, stages=None +): param_outs, params_fs_paths = _collect_configs( - repo, rev, targets=targets, duplicates=deps + repo, rev, targets=targets, duplicates=deps or stages ) params = _read_params( repo, @@ -148,8 +177,9 @@ def _gather_params(repo, rev, targets=None, deps=False, onerror=None): params_fs_paths=params_fs_paths, deps=deps, onerror=onerror, + stages=stages, ) - vars_params = _collect_vars(repo, params) + vars_params = _collect_vars(repo, params, stages=stages) # NOTE: only those that are not added as a ParamDependency are # included so we don't need to recursively merge them yet. diff --git a/tests/func/test_api.py b/tests/func/test_api.py index 44bfd54f21..116318e3ad 100644 --- a/tests/func/test_api.py +++ b/tests/func/test_api.py @@ -1,4 +1,5 @@ import os +from textwrap import dedent import pytest from funcy import first, get_in @@ -228,3 +229,128 @@ def test_open_from_remote(tmp_dir, erepo_dir, cloud, local_cloud): remote="other", ) as fd: assert fd.read() == "foo content" + + +@pytest.fixture +def params_repo(tmp_dir, scm, dvc): + tmp_dir.gen("params.yaml", "foo: 1") + tmp_dir.gen("params.json", '{"bar": 2, "foobar": 3}') + tmp_dir.gen("other_params.json", '{"foo": {"bar": 4}}') + + dvc.run( + name="stage-1", + cmd="echo stage-1", + params=["foo", "params.json:bar"], + ) + + dvc.run( + name="stage-2", + cmd="echo stage-2", + params=["other_params.json:foo"], + ) + + dvc.run( + name="stage-3", + cmd="echo stage-2", + params=["params.json:foobar"], + ) + + scm.add( + [ + "params.yaml", + "params.json", + "other_params.json", + "dvc.yaml", + "dvc.lock", + ] + ) + scm.commit("commit dvc files") + + tmp_dir.gen("params.yaml", "foo: 5") + scm.add(["params.yaml"]) + scm.commit("update params.yaml") + + +def test_params_show_no_args(params_repo): + assert api.params_show() == { + "params.yaml:foo": 5, + "bar": 2, + "foobar": 3, + "other_params.json:foo": {"bar": 4}, + } + + +def test_params_show_targets(params_repo): + assert api.params_show("params.yaml") == {"foo": 5} + assert api.params_show("params.yaml", "params.json") == { + "foo": 5, + "bar": 2, + "foobar": 3, + } + + +def test_params_show_deps(params_repo): + params = api.params_show(deps=True) + assert params == { + "params.yaml:foo": 5, + "bar": 2, + "foobar": 3, + "other_params.json:foo": {"bar": 4}, + } + + +def test_params_show_stages(params_repo): + assert api.params_show(stages="stage-2") == {"foo": {"bar": 4}} + + assert api.params_show() == api.params_show( + stages=["stage-1", "stage-2", "stage-3"] + ) + + assert api.params_show("params.json", stages="stage-3") == {"foobar": 3} + + +def test_params_show_revs(params_repo): + assert api.params_show(rev="HEAD~1") == { + "params.yaml:foo": 1, + "bar": 2, + "foobar": 3, + "other_params.json:foo": {"bar": 4}, + } + + +def test_params_show_while_running_stage(tmp_dir, dvc): + (tmp_dir / "params.yaml").dump({"foo": {"bar": 1}}) + (tmp_dir / "params.json").dump({"bar": 2}) + + tmp_dir.gen( + "merge.py", + dedent( + """ + import json + from dvc import api + with open("merged.json", "w") as f: + json.dump(api.params_show(stages="merge"), f) + """ + ), + ) + dvc.stage.add( + name="merge", + cmd="python merge.py", + params=["foo.bar", {"params.json": ["bar"]}], + outs=["merged.json"], + ) + + dvc.reproduce() + + assert (tmp_dir / "merged.json").parse() == {"foo": {"bar": 1}, "bar": 2} + + +def test_params_show_repo(tmp_dir, erepo_dir): + with erepo_dir.chdir(): + erepo_dir.scm_gen("params.yaml", "foo: 1", commit="Create params.yaml") + erepo_dir.dvc.run( + name="stage-1", + cmd="echo stage-1", + params=["foo"], + ) + assert api.params_show(repo=erepo_dir) == {"foo": 1} From 7d89b42b14f464d66a3ad5e47d9ffa88ac111c8e Mon Sep 17 00:00:00 2001 From: David de la Iglesia Castro Date: Mon, 20 Jun 2022 16:06:01 +0200 Subject: [PATCH 4/4] api: Move to sub-package. Split into sub-modules. Split tests. --- dvc/api.py | 508 ------------------- dvc/api/__init__.py | 9 + dvc/api/data.py | 213 ++++++++ dvc/api/experiments.py | 33 ++ dvc/api/params.py | 267 ++++++++++ dvc/testing/tmp_dir.py | 9 + tests/func/{test_api.py => api/test_data.py} | 128 +---- tests/func/api/test_params.py | 130 +++++ tests/func/experiments/test_experiments.py | 2 +- tests/func/test_external_repo.py | 2 +- tests/func/test_get.py | 2 +- tests/func/test_import.py | 2 +- tests/func/test_update.py | 2 +- tests/unit/fs/test_dvc.py | 10 +- tests/unit/fs/test_dvc_info.py | 2 +- tests/unit/test_external_repo.py | 2 +- 16 files changed, 670 insertions(+), 651 deletions(-) delete mode 100644 dvc/api.py create mode 100644 dvc/api/__init__.py create mode 100644 dvc/api/data.py create mode 100644 dvc/api/experiments.py create mode 100644 dvc/api/params.py rename tests/func/{test_api.py => api/test_data.py} (69%) create mode 100644 tests/func/api/test_params.py diff --git a/dvc/api.py b/dvc/api.py deleted file mode 100644 index 61ff23af46..0000000000 --- a/dvc/api.py +++ /dev/null @@ -1,508 +0,0 @@ -import os -from collections import Counter -from contextlib import _GeneratorContextManager as GCM -from typing import Dict, Iterable, Optional, Union - -from funcy import first, reraise - -from dvc.exceptions import OutputNotFoundError, PathMissingError -from dvc.repo import Repo - - -def get_url(path, repo=None, rev=None, remote=None): - """ - Returns the URL to the storage location of a data file or directory tracked - in a DVC repo. For Git repos, HEAD is used unless a rev argument is - supplied. The default remote is tried unless a remote argument is supplied. - - Raises OutputNotFoundError if the file is not tracked by DVC. - - NOTE: This function does not check for the actual existence of the file or - directory in the remote storage. - """ - with Repo.open(repo, rev=rev, subrepos=True, uninitialized=True) as _repo: - fs_path = _repo.dvcfs.from_os_path(path) - with reraise(FileNotFoundError, PathMissingError(path, repo)): - info = _repo.dvcfs.info(fs_path) - - dvc_info = info.get("dvc_info") - if not dvc_info: - raise OutputNotFoundError(path, repo) - - dvc_repo = info["repo"] - md5 = dvc_info["md5"] - - return dvc_repo.cloud.get_url_for(remote, checksum=md5) - - -class _OpenContextManager(GCM): - def __init__( - self, func, args, kwds - ): # pylint: disable=super-init-not-called - self.gen = func(*args, **kwds) - self.func, self.args, self.kwds = func, args, kwds - - def __getattr__(self, name): - raise AttributeError( - "dvc.api.open() should be used in a with statement." - ) - - -def open( # noqa, pylint: disable=redefined-builtin - path: str, - repo: Optional[str] = None, - rev: Optional[str] = None, - remote: Optional[str] = None, - mode: str = "r", - encoding: Optional[str] = None, -): - """ - Opens a file tracked in a DVC project. - - This function may only be used as a context manager (using the `with` - keyword, as shown in the examples). - - This function makes a direct connection to the remote storage, so the file - contents can be streamed. Your code can process the data buffer as it's - streamed, which optimizes memory usage. - - Note: - Use dvc.api.read() to load the complete file contents - in a single function call, no context manager involved. - Neither function utilizes disc space. - - Args: - path (str): location and file name of the target to open, - relative to the root of `repo`. - repo (str, optional): location of the DVC project or Git Repo. - Defaults to the current DVC project (found by walking up from the - current working directory tree). - It can be a URL or a file system path. - Both HTTP and SSH protocols are supported for online Git repos - (e.g. [user@]server:project.git). - rev (str, optional): Any `Git revision`_ such as a branch or tag name, - a commit hash or a dvc experiment name. - Defaults to HEAD. - If `repo` is not a Git repo, this option is ignored. - remote (str, optional): Name of the `DVC remote`_ used to form the - returned URL string. - Defaults to the `default remote`_ of `repo`. - For local projects, the cache is tried before the default remote. - mode (str, optional): Specifies the mode in which the file is opened. - Defaults to "r" (read). - Mirrors the namesake parameter in builtin `open()`_. - Only reading `mode` is supported. - encoding(str, optional): `Codec`_ used to decode the file contents. - Defaults to None. - This should only be used in text mode. - Mirrors the namesake parameter in builtin `open()`_. - - Returns: - _OpenContextManager: A context manager that generatse a corresponding - `file object`_. - The exact type of file object depends on the mode used. - For more details, please refer to Python's `open()`_ built-in, - which is used under the hood. - - Raises: - AttributeError: If this method is not used as a context manager. - ValueError: If non-read `mode` is used. - - Examples: - - - Use data or models from a DVC repository. - - Any file tracked in a DVC project (and stored remotely) can be - processed directly in your Python code with this API. - For example, an XML file tracked in a public DVC repo on GitHub can be - processed like this: - - >>> from xml.sax import parse - >>> import dvc.api - >>> from mymodule import mySAXHandler - - >>> with dvc.api.open( - ... 'get-started/data.xml', - ... repo='https://github.com/iterative/dataset-registry' - ... ) as fd: - ... parse(fd, mySAXHandler) - - We use a SAX XML parser here because dvc.api.open() is able to stream - the data from remote storage. - The mySAXHandler object should handle the event-driven parsing of the - document in this case. - This increases the performance of the code (minimizing memory usage), - and is typically faster than loading the whole data into memory. - - - Accessing private repos - - This is just a matter of using the right repo argument, for example an - SSH URL (requires that the credentials are configured locally): - - >>> import dvc.api - - >>> with dvc.api.open( - ... 'features.dat', - ... repo='git@server.com:path/to/repo.git' - ... ) as fd: - ... # ... Process 'features' - ... pass - - - Use different versions of data - - Any git revision (see `rev`) can be accessed programmatically. - For example, if your DVC repo has tagged releases of a CSV dataset: - - >>> import csv - >>> import dvc.api - >>> with dvc.api.open( - ... 'clean.csv', - ... rev='v1.1.0' - ... ) as fd: - ... reader = csv.reader(fd) - ... # ... Process 'clean' data from version 1.1.0 - - .. _Git revision: - https://git-scm.com/docs/revisions - - .. _DVC remote: - https://dvc.org/doc/command-reference/remote - - .. _default remote: - https://dvc.org/doc/command-reference/remote/default - - .. _open(): - https://docs.python.org/3/library/functions.html#open - - .. _Codec: - https://docs.python.org/3/library/codecs.html#standard-encodings - - .. _file object: - https://docs.python.org/3/glossary.html#term-file-object - - """ - if "r" not in mode: - raise ValueError("Only reading `mode` is supported.") - - args = (path,) - kwargs = { - "repo": repo, - "remote": remote, - "rev": rev, - "mode": mode, - "encoding": encoding, - } - return _OpenContextManager(_open, args, kwargs) - - -def _open(path, repo=None, rev=None, remote=None, mode="r", encoding=None): - with Repo.open(repo, rev=rev, subrepos=True, uninitialized=True) as _repo: - with _repo.open_by_relpath( - path, remote=remote, mode=mode, encoding=encoding - ) as fd: - yield fd - - -def read(path, repo=None, rev=None, remote=None, mode="r", encoding=None): - """ - Returns the contents of a tracked file (by DVC or Git). For Git repos, HEAD - is used unless a rev argument is supplied. The default remote is tried - unless a remote argument is supplied. - """ - with open( - path, repo=repo, rev=rev, remote=remote, mode=mode, encoding=encoding - ) as fd: - return fd.read() - - -def params_show( - *targets: str, - repo: Optional[str] = None, - stages: Optional[Union[str, Iterable[str]]] = None, - rev: Optional[str] = None, - deps: bool = False, -) -> Dict: - """Get parameters tracked in `repo`. - - Without arguments, this function will retrieve all params from all tracked - parameter files, for the current working tree. - - See the options below to restrict the parameters retrieved. - - Args: - *targets (str, optional): Names of the parameter files to retrieve - params from. For example, "params.py, myparams.toml". - If no `targets` are provided, all parameter files tracked in `dvc.yaml` - will be used. - Note that targets don't necessarily have to be defined in `dvc.yaml`. - repo (str, optional): location of the DVC repository. - Defaults to the current project (found by walking up from the - current working directory tree). - It can be a URL or a file system path. - Both HTTP and SSH protocols are supported for online Git repos - (e.g. [user@]server:project.git). - stages (Union[str, Iterable[str]], optional): Name or names of the - stages to retrieve parameters from. - Defaults to `None`. - If `None`, all parameters from all stages will be retrieved. - rev (str, optional): Name of the `Git revision`_ to retrieve parameters - from. - Defaults to `None`. - An example of git revision can be a branch or tag name, a commit - hash or a dvc experiment name. - If `repo` is not a Git repo, this option is ignored. - If `None`, the current working tree will be used. - deps (bool, optional): Whether to retrieve only parameters that are - stage dependencies or not. - Defaults to `False`. - - Returns: - Dict: See Examples below. - - Examples: - - - No arguments. - - Working on https://github.com/iterative/example-get-started - - >>> import json - >>> import dvc.api - >>> params = dvc.api.params_show() - >>> print(json.dumps(params, indent=4)) - { - "prepare": { - "split": 0.2, - "seed": 20170428 - }, - "featurize": { - "max_features": 200, - "ngrams": 2 - }, - "train": { - "seed": 20170428, - "n_est": 50, - "min_split": 0.01 - } - } - - --- - - - Filtering with `stages`. - - Working on https://github.com/iterative/example-get-started - - `stages` can a single string: - - >>> import json - >>> import dvc.api - >>> params = dvc.api.params_show(stages="prepare") - >>> print(json.dumps(params, indent=4)) - { - "prepare": { - "split": 0.2, - "seed": 20170428 - } - } - - Or an iterable of strings: - - >>> import json - >>> import dvc.api - >>> params = dvc.api.params_show(stages=["prepare", "train"]) - >>> print(json.dumps(params, indent=4)) - { - "prepare": { - "split": 0.2, - "seed": 20170428 - }, - "train": { - "seed": 20170428, - "n_est": 50, - "min_split": 0.01 - } - } - - --- - - - Using `rev`. - - Working on https://github.com/iterative/example-get-started - - >>> import json - >>> import dvc.api - >>> params = dvc.api.params_show(rev="tune-hyperparams") - >>> print(json.dumps(params, indent=4)) - { - "prepare": { - "split": 0.2, - "seed": 20170428 - }, - "featurize": { - "max_features": 200, - "ngrams": 2 - }, - "train": { - "seed": 20170428, - "n_est": 100, - "min_split": 8 - } - } - - --- - - - Using `targets`. - - Working on `multi-params-files` folder of - https://github.com/iterative/pipeline-conifguration - - You can pass a single target: - - >>> import json - >>> import dvc.api - >>> params = dvc.api.params_show("params.yaml") - >>> print(json.dumps(params, indent=4)) - { - "run_mode": "prod", - "configs": { - "dev": "configs/params_dev.yaml", - "test": "configs/params_test.yaml", - "prod": "configs/params_prod.yaml" - }, - "evaluate": { - "dataset": "micro", - "size": 5000, - "metrics": ["f1", "roc-auc"], - "metrics_file": "reports/metrics.json", - "plots_cm": "reports/plot_confusion_matrix.png" - } - } - - - Or multiple targets: - - >>> import json - >>> import dvc.api - >>> params = dvc.api.params_show( - ... "configs/params_dev.yaml", "configs/params_prod.yaml") - >>> print(json.dumps(params, indent=4)) - { - "configs/params_prod.yaml:run_mode": "prod", - "configs/params_prod.yaml:config_file": "configs/params_prod.yaml", - "configs/params_prod.yaml:data_load": { - "dataset": "large", - "sampling": { - "enable": true, - "size": 50000 - } - }, - "configs/params_prod.yaml:train": { - "epochs": 1000 - }, - "configs/params_dev.yaml:run_mode": "dev", - "configs/params_dev.yaml:config_file": "configs/params_dev.yaml", - "configs/params_dev.yaml:data_load": { - "dataset": "development", - "sampling": { - "enable": true, - "size": 1000 - } - }, - "configs/params_dev.yaml:train": { - "epochs": 10 - } - } - - --- - - - Git URL as `repo`. - - >>> import json - >>> import dvc.api - >>> params = dvc.api.params_show( - ... repo="https://github.com/iterative/demo-fashion-mnist") - { - "train": { - "batch_size": 128, - "hidden_units": 64, - "dropout": 0.4, - "num_epochs": 10, - "lr": 0.001, - "conv_activation": "relu" - } - } - - - .. _Git revision: - https://git-scm.com/docs/revisions - - """ - if isinstance(stages, str): - stages = [stages] - - def _onerror_raise(result: Dict, exception: Exception, *args, **kwargs): - raise exception - - def _postprocess(params): - processed = {} - for rev, rev_data in params.items(): - processed[rev] = {} - - counts = Counter() - for file_data in rev_data["data"].values(): - for k in file_data["data"]: - counts[k] += 1 - - for file_name, file_data in rev_data["data"].items(): - to_merge = { - (k if counts[k] == 1 else f"{file_name}:{k}"): v - for k, v in file_data["data"].items() - } - processed[rev] = {**processed[rev], **to_merge} - - if "workspace" in processed: - del processed["workspace"] - - return processed[first(processed)] - - with Repo.open(repo) as _repo: - params = _repo.params.show( - revs=rev if rev is None else [rev], - targets=targets, - deps=deps, - onerror=_onerror_raise, - stages=stages, - ) - - return _postprocess(params) - - -def make_checkpoint(): - """ - Signal DVC to create a checkpoint experiment. - - If the current process is being run from DVC, this function will block - until DVC has finished creating the checkpoint. Otherwise, this function - will return immediately. - """ - import builtins - from time import sleep - - from dvc.env import DVC_CHECKPOINT, DVC_ROOT - from dvc.stage.monitor import CheckpointTask - - if os.getenv(DVC_CHECKPOINT) is None: - return - - root_dir = os.getenv(DVC_ROOT, Repo.find_root()) - signal_file = os.path.join( - root_dir, Repo.DVC_DIR, "tmp", CheckpointTask.SIGNAL_FILE - ) - - with builtins.open(signal_file, "w", encoding="utf-8") as fobj: - # NOTE: force flushing/writing empty file to disk, otherwise when - # run in certain contexts (pytest) file may not actually be written - fobj.write("") - fobj.flush() - os.fsync(fobj.fileno()) - while os.path.exists(signal_file): - sleep(0.1) diff --git a/dvc/api/__init__.py b/dvc/api/__init__.py new file mode 100644 index 0000000000..dde993ec76 --- /dev/null +++ b/dvc/api/__init__.py @@ -0,0 +1,9 @@ +from .data import ( # noqa, pylint: disable=redefined-builtin + get_url, + open, + read, +) +from .experiments import make_checkpoint +from .params import params_show + +__all__ = ["get_url", "make_checkpoint", "open", "params_show", "read"] diff --git a/dvc/api/data.py b/dvc/api/data.py new file mode 100644 index 0000000000..a063612f10 --- /dev/null +++ b/dvc/api/data.py @@ -0,0 +1,213 @@ +from contextlib import _GeneratorContextManager as GCM +from typing import Optional + +from funcy import reraise + +from dvc.exceptions import OutputNotFoundError, PathMissingError +from dvc.repo import Repo + + +def get_url(path, repo=None, rev=None, remote=None): + """ + Returns the URL to the storage location of a data file or directory tracked + in a DVC repo. For Git repos, HEAD is used unless a rev argument is + supplied. The default remote is tried unless a remote argument is supplied. + + Raises OutputNotFoundError if the file is not tracked by DVC. + + NOTE: This function does not check for the actual existence of the file or + directory in the remote storage. + """ + with Repo.open(repo, rev=rev, subrepos=True, uninitialized=True) as _repo: + fs_path = _repo.dvcfs.from_os_path(path) + with reraise(FileNotFoundError, PathMissingError(path, repo)): + info = _repo.dvcfs.info(fs_path) + + dvc_info = info.get("dvc_info") + if not dvc_info: + raise OutputNotFoundError(path, repo) + + dvc_repo = info["repo"] + md5 = dvc_info["md5"] + + return dvc_repo.cloud.get_url_for(remote, checksum=md5) + + +class _OpenContextManager(GCM): + def __init__( + self, func, args, kwds + ): # pylint: disable=super-init-not-called + self.gen = func(*args, **kwds) + self.func, self.args, self.kwds = func, args, kwds + + def __getattr__(self, name): + raise AttributeError( + "dvc.api.open() should be used in a with statement." + ) + + +def open( # noqa, pylint: disable=redefined-builtin + path: str, + repo: Optional[str] = None, + rev: Optional[str] = None, + remote: Optional[str] = None, + mode: str = "r", + encoding: Optional[str] = None, +): + """ + Opens a file tracked in a DVC project. + + This function may only be used as a context manager (using the `with` + keyword, as shown in the examples). + + This function makes a direct connection to the remote storage, so the file + contents can be streamed. Your code can process the data buffer as it's + streamed, which optimizes memory usage. + + Note: + Use dvc.api.read() to load the complete file contents + in a single function call, no context manager involved. + Neither function utilizes disc space. + + Args: + path (str): location and file name of the target to open, + relative to the root of `repo`. + repo (str, optional): location of the DVC project or Git Repo. + Defaults to the current DVC project (found by walking up from the + current working directory tree). + It can be a URL or a file system path. + Both HTTP and SSH protocols are supported for online Git repos + (e.g. [user@]server:project.git). + rev (str, optional): Any `Git revision`_ such as a branch or tag name, + a commit hash or a dvc experiment name. + Defaults to HEAD. + If `repo` is not a Git repo, this option is ignored. + remote (str, optional): Name of the `DVC remote`_ used to form the + returned URL string. + Defaults to the `default remote`_ of `repo`. + For local projects, the cache is tried before the default remote. + mode (str, optional): Specifies the mode in which the file is opened. + Defaults to "r" (read). + Mirrors the namesake parameter in builtin `open()`_. + Only reading `mode` is supported. + encoding(str, optional): `Codec`_ used to decode the file contents. + Defaults to None. + This should only be used in text mode. + Mirrors the namesake parameter in builtin `open()`_. + + Returns: + _OpenContextManager: A context manager that generatse a corresponding + `file object`_. + The exact type of file object depends on the mode used. + For more details, please refer to Python's `open()`_ built-in, + which is used under the hood. + + Raises: + AttributeError: If this method is not used as a context manager. + ValueError: If non-read `mode` is used. + + Examples: + + - Use data or models from a DVC repository. + + Any file tracked in a DVC project (and stored remotely) can be + processed directly in your Python code with this API. + For example, an XML file tracked in a public DVC repo on GitHub can be + processed like this: + + >>> from xml.sax import parse + >>> import dvc.api + >>> from mymodule import mySAXHandler + + >>> with dvc.api.open( + ... 'get-started/data.xml', + ... repo='https://github.com/iterative/dataset-registry' + ... ) as fd: + ... parse(fd, mySAXHandler) + + We use a SAX XML parser here because dvc.api.open() is able to stream + the data from remote storage. + The mySAXHandler object should handle the event-driven parsing of the + document in this case. + This increases the performance of the code (minimizing memory usage), + and is typically faster than loading the whole data into memory. + + - Accessing private repos + + This is just a matter of using the right repo argument, for example an + SSH URL (requires that the credentials are configured locally): + + >>> import dvc.api + + >>> with dvc.api.open( + ... 'features.dat', + ... repo='git@server.com:path/to/repo.git' + ... ) as fd: + ... # ... Process 'features' + ... pass + + - Use different versions of data + + Any git revision (see `rev`) can be accessed programmatically. + For example, if your DVC repo has tagged releases of a CSV dataset: + + >>> import csv + >>> import dvc.api + >>> with dvc.api.open( + ... 'clean.csv', + ... rev='v1.1.0' + ... ) as fd: + ... reader = csv.reader(fd) + ... # ... Process 'clean' data from version 1.1.0 + + .. _Git revision: + https://git-scm.com/docs/revisions + + .. _DVC remote: + https://dvc.org/doc/command-reference/remote + + .. _default remote: + https://dvc.org/doc/command-reference/remote/default + + .. _open(): + https://docs.python.org/3/library/functions.html#open + + .. _Codec: + https://docs.python.org/3/library/codecs.html#standard-encodings + + .. _file object: + https://docs.python.org/3/glossary.html#term-file-object + + """ + if "r" not in mode: + raise ValueError("Only reading `mode` is supported.") + + args = (path,) + kwargs = { + "repo": repo, + "remote": remote, + "rev": rev, + "mode": mode, + "encoding": encoding, + } + return _OpenContextManager(_open, args, kwargs) + + +def _open(path, repo=None, rev=None, remote=None, mode="r", encoding=None): + with Repo.open(repo, rev=rev, subrepos=True, uninitialized=True) as _repo: + with _repo.open_by_relpath( + path, remote=remote, mode=mode, encoding=encoding + ) as fd: + yield fd + + +def read(path, repo=None, rev=None, remote=None, mode="r", encoding=None): + """ + Returns the contents of a tracked file (by DVC or Git). For Git repos, HEAD + is used unless a rev argument is supplied. The default remote is tried + unless a remote argument is supplied. + """ + with open( + path, repo=repo, rev=rev, remote=remote, mode=mode, encoding=encoding + ) as fd: + return fd.read() diff --git a/dvc/api/experiments.py b/dvc/api/experiments.py new file mode 100644 index 0000000000..f6448eedb0 --- /dev/null +++ b/dvc/api/experiments.py @@ -0,0 +1,33 @@ +import builtins +import os +from time import sleep + +from dvc.env import DVC_CHECKPOINT, DVC_ROOT +from dvc.repo import Repo +from dvc.stage.monitor import CheckpointTask + + +def make_checkpoint(): + """ + Signal DVC to create a checkpoint experiment. + + If the current process is being run from DVC, this function will block + until DVC has finished creating the checkpoint. Otherwise, this function + will return immediately. + """ + if os.getenv(DVC_CHECKPOINT) is None: + return + + root_dir = os.getenv(DVC_ROOT, Repo.find_root()) + signal_file = os.path.join( + root_dir, Repo.DVC_DIR, "tmp", CheckpointTask.SIGNAL_FILE + ) + + with builtins.open(signal_file, "w", encoding="utf-8") as fobj: + # NOTE: force flushing/writing empty file to disk, otherwise when + # run in certain contexts (pytest) file may not actually be written + fobj.write("") + fobj.flush() + os.fsync(fobj.fileno()) + while os.path.exists(signal_file): + sleep(0.1) diff --git a/dvc/api/params.py b/dvc/api/params.py new file mode 100644 index 0000000000..d66b19a591 --- /dev/null +++ b/dvc/api/params.py @@ -0,0 +1,267 @@ +from collections import Counter +from typing import Dict, Iterable, Optional, Union + +from funcy import first + +from dvc.repo import Repo + + +def params_show( + *targets: str, + repo: Optional[str] = None, + stages: Optional[Union[str, Iterable[str]]] = None, + rev: Optional[str] = None, + deps: bool = False, +) -> Dict: + """Get parameters tracked in `repo`. + + Without arguments, this function will retrieve all params from all tracked + parameter files, for the current working tree. + + See the options below to restrict the parameters retrieved. + + Args: + *targets (str, optional): Names of the parameter files to retrieve + params from. For example, "params.py, myparams.toml". + If no `targets` are provided, all parameter files tracked in `dvc.yaml` + will be used. + Note that targets don't necessarily have to be defined in `dvc.yaml`. + repo (str, optional): location of the DVC repository. + Defaults to the current project (found by walking up from the + current working directory tree). + It can be a URL or a file system path. + Both HTTP and SSH protocols are supported for online Git repos + (e.g. [user@]server:project.git). + stages (Union[str, Iterable[str]], optional): Name or names of the + stages to retrieve parameters from. + Defaults to `None`. + If `None`, all parameters from all stages will be retrieved. + rev (str, optional): Name of the `Git revision`_ to retrieve parameters + from. + Defaults to `None`. + An example of git revision can be a branch or tag name, a commit + hash or a dvc experiment name. + If `repo` is not a Git repo, this option is ignored. + If `None`, the current working tree will be used. + deps (bool, optional): Whether to retrieve only parameters that are + stage dependencies or not. + Defaults to `False`. + + Returns: + Dict: See Examples below. + + Examples: + + - No arguments. + + Working on https://github.com/iterative/example-get-started + + >>> import json + >>> import dvc.api + >>> params = dvc.api.params_show() + >>> print(json.dumps(params, indent=4)) + { + "prepare": { + "split": 0.2, + "seed": 20170428 + }, + "featurize": { + "max_features": 200, + "ngrams": 2 + }, + "train": { + "seed": 20170428, + "n_est": 50, + "min_split": 0.01 + } + } + + --- + + - Filtering with `stages`. + + Working on https://github.com/iterative/example-get-started + + `stages` can a single string: + + >>> import json + >>> import dvc.api + >>> params = dvc.api.params_show(stages="prepare") + >>> print(json.dumps(params, indent=4)) + { + "prepare": { + "split": 0.2, + "seed": 20170428 + } + } + + Or an iterable of strings: + + >>> import json + >>> import dvc.api + >>> params = dvc.api.params_show(stages=["prepare", "train"]) + >>> print(json.dumps(params, indent=4)) + { + "prepare": { + "split": 0.2, + "seed": 20170428 + }, + "train": { + "seed": 20170428, + "n_est": 50, + "min_split": 0.01 + } + } + + --- + + - Using `rev`. + + Working on https://github.com/iterative/example-get-started + + >>> import json + >>> import dvc.api + >>> params = dvc.api.params_show(rev="tune-hyperparams") + >>> print(json.dumps(params, indent=4)) + { + "prepare": { + "split": 0.2, + "seed": 20170428 + }, + "featurize": { + "max_features": 200, + "ngrams": 2 + }, + "train": { + "seed": 20170428, + "n_est": 100, + "min_split": 8 + } + } + + --- + + - Using `targets`. + + Working on `multi-params-files` folder of + https://github.com/iterative/pipeline-conifguration + + You can pass a single target: + + >>> import json + >>> import dvc.api + >>> params = dvc.api.params_show("params.yaml") + >>> print(json.dumps(params, indent=4)) + { + "run_mode": "prod", + "configs": { + "dev": "configs/params_dev.yaml", + "test": "configs/params_test.yaml", + "prod": "configs/params_prod.yaml" + }, + "evaluate": { + "dataset": "micro", + "size": 5000, + "metrics": ["f1", "roc-auc"], + "metrics_file": "reports/metrics.json", + "plots_cm": "reports/plot_confusion_matrix.png" + } + } + + + Or multiple targets: + + >>> import json + >>> import dvc.api + >>> params = dvc.api.params_show( + ... "configs/params_dev.yaml", "configs/params_prod.yaml") + >>> print(json.dumps(params, indent=4)) + { + "configs/params_prod.yaml:run_mode": "prod", + "configs/params_prod.yaml:config_file": "configs/params_prod.yaml", + "configs/params_prod.yaml:data_load": { + "dataset": "large", + "sampling": { + "enable": true, + "size": 50000 + } + }, + "configs/params_prod.yaml:train": { + "epochs": 1000 + }, + "configs/params_dev.yaml:run_mode": "dev", + "configs/params_dev.yaml:config_file": "configs/params_dev.yaml", + "configs/params_dev.yaml:data_load": { + "dataset": "development", + "sampling": { + "enable": true, + "size": 1000 + } + }, + "configs/params_dev.yaml:train": { + "epochs": 10 + } + } + + --- + + - Git URL as `repo`. + + >>> import json + >>> import dvc.api + >>> params = dvc.api.params_show( + ... repo="https://github.com/iterative/demo-fashion-mnist") + { + "train": { + "batch_size": 128, + "hidden_units": 64, + "dropout": 0.4, + "num_epochs": 10, + "lr": 0.001, + "conv_activation": "relu" + } + } + + + .. _Git revision: + https://git-scm.com/docs/revisions + + """ + if isinstance(stages, str): + stages = [stages] + + def _onerror_raise(result: Dict, exception: Exception, *args, **kwargs): + raise exception + + def _postprocess(params): + processed = {} + for rev, rev_data in params.items(): + processed[rev] = {} + + counts = Counter() + for file_data in rev_data["data"].values(): + for k in file_data["data"]: + counts[k] += 1 + + for file_name, file_data in rev_data["data"].items(): + to_merge = { + (k if counts[k] == 1 else f"{file_name}:{k}"): v + for k, v in file_data["data"].items() + } + processed[rev] = {**processed[rev], **to_merge} + + if "workspace" in processed: + del processed["workspace"] + + return processed[first(processed)] + + with Repo.open(repo) as _repo: + params = _repo.params.show( + revs=rev if rev is None else [rev], + targets=targets, + deps=deps, + onerror=_onerror_raise, + stages=stages, + ) + + return _postprocess(params) diff --git a/dvc/testing/tmp_dir.py b/dvc/testing/tmp_dir.py index 2517166eb4..ea4cd0d777 100644 --- a/dvc/testing/tmp_dir.py +++ b/dvc/testing/tmp_dir.py @@ -259,6 +259,15 @@ def modify(self, *args, **kwargs): dump_toml = partialmethod(serialize.dump_toml) +def make_subrepo(dir_: TmpDir, scm, config=None): + dir_.mkdir(parents=True, exist_ok=True) + with dir_.chdir(): + dir_.scm = scm + dir_.init(dvc=True, subdir=True) + if config: + dir_.add_remote(config=config) + + def _coerce_filenames(filenames): if isinstance(filenames, (str, bytes, pathlib.PurePath)): filenames = [filenames] diff --git a/tests/func/test_api.py b/tests/func/api/test_data.py similarity index 69% rename from tests/func/test_api.py rename to tests/func/api/test_data.py index 116318e3ad..d73226175d 100644 --- a/tests/func/test_api.py +++ b/tests/func/api/test_data.py @@ -1,5 +1,4 @@ import os -from textwrap import dedent import pytest from funcy import first, get_in @@ -11,8 +10,8 @@ PathMissingError, ) from dvc.testing.test_api import TestAPI # noqa, pylint: disable=unused-import +from dvc.testing.tmp_dir import make_subrepo from dvc.utils.fs import remove -from tests.unit.fs.test_dvc import make_subrepo def test_get_url_external(tmp_dir, erepo_dir, cloud): @@ -229,128 +228,3 @@ def test_open_from_remote(tmp_dir, erepo_dir, cloud, local_cloud): remote="other", ) as fd: assert fd.read() == "foo content" - - -@pytest.fixture -def params_repo(tmp_dir, scm, dvc): - tmp_dir.gen("params.yaml", "foo: 1") - tmp_dir.gen("params.json", '{"bar": 2, "foobar": 3}') - tmp_dir.gen("other_params.json", '{"foo": {"bar": 4}}') - - dvc.run( - name="stage-1", - cmd="echo stage-1", - params=["foo", "params.json:bar"], - ) - - dvc.run( - name="stage-2", - cmd="echo stage-2", - params=["other_params.json:foo"], - ) - - dvc.run( - name="stage-3", - cmd="echo stage-2", - params=["params.json:foobar"], - ) - - scm.add( - [ - "params.yaml", - "params.json", - "other_params.json", - "dvc.yaml", - "dvc.lock", - ] - ) - scm.commit("commit dvc files") - - tmp_dir.gen("params.yaml", "foo: 5") - scm.add(["params.yaml"]) - scm.commit("update params.yaml") - - -def test_params_show_no_args(params_repo): - assert api.params_show() == { - "params.yaml:foo": 5, - "bar": 2, - "foobar": 3, - "other_params.json:foo": {"bar": 4}, - } - - -def test_params_show_targets(params_repo): - assert api.params_show("params.yaml") == {"foo": 5} - assert api.params_show("params.yaml", "params.json") == { - "foo": 5, - "bar": 2, - "foobar": 3, - } - - -def test_params_show_deps(params_repo): - params = api.params_show(deps=True) - assert params == { - "params.yaml:foo": 5, - "bar": 2, - "foobar": 3, - "other_params.json:foo": {"bar": 4}, - } - - -def test_params_show_stages(params_repo): - assert api.params_show(stages="stage-2") == {"foo": {"bar": 4}} - - assert api.params_show() == api.params_show( - stages=["stage-1", "stage-2", "stage-3"] - ) - - assert api.params_show("params.json", stages="stage-3") == {"foobar": 3} - - -def test_params_show_revs(params_repo): - assert api.params_show(rev="HEAD~1") == { - "params.yaml:foo": 1, - "bar": 2, - "foobar": 3, - "other_params.json:foo": {"bar": 4}, - } - - -def test_params_show_while_running_stage(tmp_dir, dvc): - (tmp_dir / "params.yaml").dump({"foo": {"bar": 1}}) - (tmp_dir / "params.json").dump({"bar": 2}) - - tmp_dir.gen( - "merge.py", - dedent( - """ - import json - from dvc import api - with open("merged.json", "w") as f: - json.dump(api.params_show(stages="merge"), f) - """ - ), - ) - dvc.stage.add( - name="merge", - cmd="python merge.py", - params=["foo.bar", {"params.json": ["bar"]}], - outs=["merged.json"], - ) - - dvc.reproduce() - - assert (tmp_dir / "merged.json").parse() == {"foo": {"bar": 1}, "bar": 2} - - -def test_params_show_repo(tmp_dir, erepo_dir): - with erepo_dir.chdir(): - erepo_dir.scm_gen("params.yaml", "foo: 1", commit="Create params.yaml") - erepo_dir.dvc.run( - name="stage-1", - cmd="echo stage-1", - params=["foo"], - ) - assert api.params_show(repo=erepo_dir) == {"foo": 1} diff --git a/tests/func/api/test_params.py b/tests/func/api/test_params.py new file mode 100644 index 0000000000..9a2fb0f83b --- /dev/null +++ b/tests/func/api/test_params.py @@ -0,0 +1,130 @@ +from textwrap import dedent + +import pytest + +from dvc import api + + +@pytest.fixture +def params_repo(tmp_dir, scm, dvc): + tmp_dir.gen("params.yaml", "foo: 1") + tmp_dir.gen("params.json", '{"bar": 2, "foobar": 3}') + tmp_dir.gen("other_params.json", '{"foo": {"bar": 4}}') + + dvc.run( + name="stage-1", + cmd="echo stage-1", + params=["foo", "params.json:bar"], + ) + + dvc.run( + name="stage-2", + cmd="echo stage-2", + params=["other_params.json:foo"], + ) + + dvc.run( + name="stage-3", + cmd="echo stage-2", + params=["params.json:foobar"], + ) + + scm.add( + [ + "params.yaml", + "params.json", + "other_params.json", + "dvc.yaml", + "dvc.lock", + ] + ) + scm.commit("commit dvc files") + + tmp_dir.gen("params.yaml", "foo: 5") + scm.add(["params.yaml"]) + scm.commit("update params.yaml") + + +def test_params_show_no_args(params_repo): + assert api.params_show() == { + "params.yaml:foo": 5, + "bar": 2, + "foobar": 3, + "other_params.json:foo": {"bar": 4}, + } + + +def test_params_show_targets(params_repo): + assert api.params_show("params.yaml") == {"foo": 5} + assert api.params_show("params.yaml", "params.json") == { + "foo": 5, + "bar": 2, + "foobar": 3, + } + + +def test_params_show_deps(params_repo): + params = api.params_show(deps=True) + assert params == { + "params.yaml:foo": 5, + "bar": 2, + "foobar": 3, + "other_params.json:foo": {"bar": 4}, + } + + +def test_params_show_stages(params_repo): + assert api.params_show(stages="stage-2") == {"foo": {"bar": 4}} + + assert api.params_show() == api.params_show( + stages=["stage-1", "stage-2", "stage-3"] + ) + + assert api.params_show("params.json", stages="stage-3") == {"foobar": 3} + + +def test_params_show_revs(params_repo): + assert api.params_show(rev="HEAD~1") == { + "params.yaml:foo": 1, + "bar": 2, + "foobar": 3, + "other_params.json:foo": {"bar": 4}, + } + + +def test_params_show_while_running_stage(tmp_dir, dvc): + (tmp_dir / "params.yaml").dump({"foo": {"bar": 1}}) + (tmp_dir / "params.json").dump({"bar": 2}) + + tmp_dir.gen( + "merge.py", + dedent( + """ + import json + from dvc import api + with open("merged.json", "w") as f: + json.dump(api.params_show(stages="merge"), f) + """ + ), + ) + dvc.stage.add( + name="merge", + cmd="python merge.py", + params=["foo.bar", {"params.json": ["bar"]}], + outs=["merged.json"], + ) + + dvc.reproduce() + + assert (tmp_dir / "merged.json").parse() == {"foo": {"bar": 1}, "bar": 2} + + +def test_params_show_repo(tmp_dir, erepo_dir): + with erepo_dir.chdir(): + erepo_dir.scm_gen("params.yaml", "foo: 1", commit="Create params.yaml") + erepo_dir.dvc.run( + name="stage-1", + cmd="echo stage-1", + params=["foo"], + ) + assert api.params_show(repo=erepo_dir) == {"foo": 1} diff --git a/tests/func/experiments/test_experiments.py b/tests/func/experiments/test_experiments.py index 6ee62ff72e..3988e44db0 100644 --- a/tests/func/experiments/test_experiments.py +++ b/tests/func/experiments/test_experiments.py @@ -544,7 +544,7 @@ def test_subdir(tmp_dir, scm, dvc, workspace): @pytest.mark.parametrize("workspace", [True, False]) def test_subrepo(tmp_dir, scm, workspace): - from tests.unit.fs.test_dvc import make_subrepo + from dvc.testing.tmp_dir import make_subrepo subrepo = tmp_dir / "dir" / "repo" make_subrepo(subrepo, scm) diff --git a/tests/func/test_external_repo.py b/tests/func/test_external_repo.py index 36e2bad56e..6e31edef48 100644 --- a/tests/func/test_external_repo.py +++ b/tests/func/test_external_repo.py @@ -4,11 +4,11 @@ from scmrepo.git import Git from dvc.external_repo import CLONES, external_repo +from dvc.testing.tmp_dir import make_subrepo from dvc.utils import relpath from dvc.utils.fs import makedirs, remove from dvc_data.stage import stage from dvc_data.transfer import transfer -from tests.unit.fs.test_dvc import make_subrepo def test_external_repo(erepo_dir, mocker): diff --git a/tests/func/test_get.py b/tests/func/test_get.py index fb477dfe0b..d2d7eb3cfb 100644 --- a/tests/func/test_get.py +++ b/tests/func/test_get.py @@ -8,8 +8,8 @@ from dvc.odbmgr import ODBManager from dvc.repo import Repo from dvc.repo.get import GetDVCFileError +from dvc.testing.tmp_dir import make_subrepo from dvc.utils.fs import makedirs -from tests.unit.fs.test_dvc import make_subrepo def test_get_repo_file(tmp_dir, erepo_dir): diff --git a/tests/func/test_import.py b/tests/func/test_import.py index a0c160f960..6de58fe4c5 100644 --- a/tests/func/test_import.py +++ b/tests/func/test_import.py @@ -12,8 +12,8 @@ from dvc.fs import system from dvc.odbmgr import ODBManager from dvc.stage.exceptions import StagePathNotFoundError +from dvc.testing.tmp_dir import make_subrepo from dvc.utils.fs import makedirs, remove -from tests.unit.fs.test_dvc import make_subrepo def test_import(tmp_dir, scm, dvc, erepo_dir): diff --git a/tests/func/test_update.py b/tests/func/test_update.py index 94ba2c9d36..37f98f314a 100644 --- a/tests/func/test_update.py +++ b/tests/func/test_update.py @@ -4,7 +4,7 @@ from dvc.dvcfile import Dvcfile from dvc.exceptions import InvalidArgumentError -from tests.unit.fs.test_dvc import make_subrepo +from dvc.testing.tmp_dir import make_subrepo @pytest.mark.parametrize("cached", [True, False]) diff --git a/tests/unit/fs/test_dvc.py b/tests/unit/fs/test_dvc.py index 06e48f1b3e..f7c6dec7bb 100644 --- a/tests/unit/fs/test_dvc.py +++ b/tests/unit/fs/test_dvc.py @@ -6,6 +6,7 @@ import pytest from dvc.fs.dvc import DvcFileSystem +from dvc.testing.tmp_dir import make_subrepo from dvc_data.hashfile.hash_info import HashInfo from dvc_data.stage import stage @@ -318,15 +319,6 @@ def test_isdvc(tmp_dir, dvc): assert fs.isdvc("dir/baz", recursive=True) -def make_subrepo(dir_, scm, config=None): - dir_.mkdir(parents=True, exist_ok=True) - with dir_.chdir(): - dir_.scm = scm - dir_.init(dvc=True, subdir=True) - if config: - dir_.add_remote(config=config) - - def test_subrepos(tmp_dir, scm, dvc, mocker): tmp_dir.scm_gen( {"dir": {"repo.txt": "file to confuse DvcFileSystem"}}, diff --git a/tests/unit/fs/test_dvc_info.py b/tests/unit/fs/test_dvc_info.py index 03bdec1f7c..dc175662e5 100644 --- a/tests/unit/fs/test_dvc_info.py +++ b/tests/unit/fs/test_dvc_info.py @@ -3,7 +3,7 @@ import pytest from dvc.fs.dvc import DvcFileSystem -from tests.unit.fs.test_dvc import make_subrepo +from dvc.testing.tmp_dir import make_subrepo @pytest.fixture diff --git a/tests/unit/test_external_repo.py b/tests/unit/test_external_repo.py index 7f2f680b71..c69b46355b 100644 --- a/tests/unit/test_external_repo.py +++ b/tests/unit/test_external_repo.py @@ -4,7 +4,7 @@ import pytest from dvc.external_repo import external_repo -from tests.unit.fs.test_dvc import make_subrepo +from dvc.testing.tmp_dir import make_subrepo def test_hook_is_called(tmp_dir, erepo_dir, mocker):