diff --git a/dvc/api.py b/dvc/api.py index bf976dcf01..45c20c0c15 100644 --- a/dvc/api.py +++ b/dvc/api.py @@ -1,4 +1,7 @@ +import importlib import os +import sys +import copy from contextlib import contextmanager try: @@ -6,11 +9,38 @@ except ImportError: from contextlib import GeneratorContextManager as GCM -from dvc.utils.compat import urlparse +from dvc.utils.compat import urlparse, builtin_str + +import ruamel.yaml +from voluptuous import Schema, Required, Invalid + from dvc.repo import Repo +from dvc.exceptions import DvcException, FileMissingError from dvc.external_repo import external_repo +SUMMON_SCHEMA = Schema( + { + Required("objects"): [ + { + Required("name"): str, + "meta": dict, + Required("summon"): { + Required("type"): "python", + Required("call"): str, + "args": dict, + "deps": [str], + }, + } + ] + } +) + + +class SummonError(DvcException): + pass + + def get_url(path, repo=None, rev=None, remote=None): """Returns an url of a resource specified by path in repo""" with _make_repo(repo, rev=rev) as _repo: @@ -69,3 +99,99 @@ def _make_repo(repo_url, rev=None): else: with external_repo(url=repo_url, rev=rev) as repo: yield repo + + +def summon(name, repo=None, rev=None, summon_file="dvcsummon.yaml", args=None): + """Instantiate an object described in the summon file.""" + with _make_repo(repo, rev=rev) as _repo: + try: + path = os.path.join(_repo.root_dir, summon_file) + obj = _get_object_from_summon_file(name, path) + info = obj["summon"] + except SummonError as exc: + raise SummonError( + str(exc) + " at '{}' in '{}'".format(summon_file, repo), + cause=exc.cause, + ) + + _pull_dependencies(_repo, info.get("deps", [])) + + _args = copy.deepcopy(info.get("args", {})) + _args.update(args or {}) + + return _invoke_method(info["call"], _args, path=_repo.root_dir) + + +def _get_object_from_summon_file(name, path): + """ + Given a summonable object's name, search for it on the given file + and return its description. + """ + try: + with open(path, "r") as fobj: + content = SUMMON_SCHEMA(ruamel.yaml.safe_load(fobj.read())) + objects = [x for x in content["objects"] if x["name"] == name] + + if not objects: + raise SummonError("No object with name '{}'".format(name)) + elif len(objects) >= 2: + raise SummonError( + "More than one object with name '{}'".format(name) + ) + + return objects[0] + + except FileMissingError: + raise SummonError("Summon file not found") + except ruamel.yaml.YAMLError as exc: + raise SummonError("Failed to parse summon file", exc) + except Invalid as exc: + raise SummonError(str(exc)) + + +def _pull_dependencies(repo, deps): + if not deps: + return + + outs = [repo.find_out_by_relpath(dep) for dep in deps] + + with repo.state: + for out in outs: + repo.cloud.pull(out.get_used_cache()) + out.checkout() + + +def _invoke_method(call, args, path): + # XXX: Some issues with this approach: + # * Not thread safe + # * Import will pollute sys.modules + # * Weird errors if there is a name clash within sys.modules + + # XXX: sys.path manipulation is "theoretically" not needed + # but tests are failing for an unknown reason. + cwd = os.getcwd() + + try: + os.chdir(path) + sys.path.insert(0, path) + method = _import_string(call) + return method(**args) + finally: + os.chdir(cwd) + sys.path.pop(0) + + +def _import_string(import_name): + """Imports an object based on a string. + Useful to delay import to not load everything on startup. + Use dotted notaion in `import_name`, e.g. 'dvc.remote.gs.RemoteGS'. + + :return: imported object + """ + import_name = builtin_str(import_name) + + if "." in import_name: + module, obj = import_name.rsplit(".", 1) + else: + return importlib.import_module(import_name) + return getattr(importlib.import_module(module), obj) diff --git a/tests/func/test_api.py b/tests/func/test_api.py index 54623c080d..393e94edf8 100644 --- a/tests/func/test_api.py +++ b/tests/func/test_api.py @@ -1,15 +1,17 @@ -from __future__ import unicode_literals - import os import shutil +import copy +import ruamel.yaml import pytest from dvc import api +from dvc.api import SummonError from dvc.exceptions import FileMissingError from dvc.main import main from dvc.path_info import URLInfo from dvc.remote.config import RemoteConfig +from dvc.utils.compat import fspath from tests.remotes import Azure, GCP, HDFS, Local, OSS, S3, SSH @@ -126,3 +128,70 @@ def test_open_not_cached(dvc): os.remove(metric_file) with pytest.raises(FileMissingError): api.read(metric_file) + + +def test_summon(tmp_dir, erepo_dir, dvc, monkeypatch): + objects = { + "objects": [ + { + "name": "sum", + "meta": {"description": "Add to "}, + "summon": { + "type": "python", + "call": "calculator.add_to_num", + "args": {"x": 1}, + "deps": ["number"], + }, + } + ] + } + + other_objects = copy.deepcopy(objects) + other_objects["objects"][0]["summon"]["args"]["x"] = 100 + + dup_objects = copy.deepcopy(objects) + dup_objects["objects"] *= 2 + + with monkeypatch.context() as m: + m.chdir(fspath(erepo_dir)) + + erepo_dir.dvc_gen("number", "100", commit="Add number.dvc") + erepo_dir.scm_gen("dvcsummon.yaml", ruamel.yaml.dump(objects)) + erepo_dir.scm_gen("other.yaml", ruamel.yaml.dump(other_objects)) + erepo_dir.scm_gen("dup.yaml", ruamel.yaml.dump(dup_objects)) + erepo_dir.scm_gen("invalid.yaml", ruamel.yaml.dump({"name": "sum"})) + erepo_dir.scm_gen("not_yaml.yaml", "a: - this is not a YAML file") + erepo_dir.scm_gen( + "calculator.py", + "def add_to_num(x): return x + int(open('number').read())", + ) + erepo_dir.scm.commit("Add files") + + repo_url = "file://{}".format(erepo_dir) + + assert api.summon("sum", repo=repo_url) == 101 + assert api.summon("sum", repo=repo_url, args={"x": 2}) == 102 + assert api.summon("sum", repo=repo_url, summon_file="other.yaml") == 200 + + try: + api.summon("sum", repo=repo_url, summon_file="missing.yaml") + except SummonError as exc: + assert "Summon file not found" in str(exc) + assert "missing.yaml" in str(exc) + assert repo_url in str(exc) + else: + pytest.fail("Did not raise on missing summon file") + + with pytest.raises(SummonError, match=r"No object with name 'missing'"): + api.summon("missing", repo=repo_url) + + with pytest.raises( + SummonError, match=r"More than one object with name 'sum'" + ): + api.summon("sum", repo=repo_url, summon_file="dup.yaml") + + with pytest.raises(SummonError, match=r"extra keys not allowed"): + api.summon("sum", repo=repo_url, summon_file="invalid.yaml") + + with pytest.raises(SummonError, match=r"Failed to parse summon file"): + api.summon("sum", repo=repo_url, summon_file="not_yaml.yaml")