diff --git a/dvc/config.py b/dvc/config.py index c7925aec21..d89b2fc6f0 100644 --- a/dvc/config.py +++ b/dvc/config.py @@ -31,6 +31,18 @@ def __init__(self, msg, cause=None): ) +class NoRemoteError(ConfigError): + def __init__(self, command, cause=None): + msg = ( + "no remote specified. Setup default remote with\n" + " dvc config core.remote \n" + "or use:\n" + " dvc {} -r \n".format(command) + ) + + super(NoRemoteError, self).__init__(msg, cause=cause) + + def supported_cache_type(types): """Checks if link type config option has a valid value. diff --git a/dvc/data_cloud.py b/dvc/data_cloud.py index 3568ae4cbb..841c8ffa80 100644 --- a/dvc/data_cloud.py +++ b/dvc/data_cloud.py @@ -4,7 +4,7 @@ import logging -from dvc.config import Config, ConfigError +from dvc.config import Config, NoRemoteError from dvc.remote import Remote from dvc.remote.s3 import RemoteS3 from dvc.remote.gs import RemoteGS @@ -60,12 +60,7 @@ def get_remote(self, remote=None, command=""): if remote: return self._init_remote(remote) - raise ConfigError( - "No remote repository specified. Setup default repository with\n" - " dvc config core.remote \n" - "or use:\n" - " dvc {} -r \n".format(command) - ) + raise NoRemoteError(command) def _init_remote(self, remote): return Remote(self.repo, name=remote) diff --git a/dvc/dependency/repo.py b/dvc/dependency/repo.py index 30d49fdc5f..b56bc8008d 100644 --- a/dvc/dependency/repo.py +++ b/dvc/dependency/repo.py @@ -61,7 +61,7 @@ def save(self): def dumpd(self): return {self.PARAM_PATH: self.def_path, self.PARAM_REPO: self.def_repo} - def download(self, to, resume=False): + def fetch(self): with self._make_repo( cache_dir=self.repo.cache.local.cache_dir ) as repo: @@ -70,8 +70,13 @@ def download(self, to, resume=False): out = repo.find_out_by_relpath(self.def_path) with repo.state: repo.cloud.pull(out.get_used_cache()) - to.info = copy.copy(out.info) - to.checkout() + + return out + + def download(self, to): + out = self.fetch() + to.info = copy.copy(out.info) + to.checkout() def update(self): with self._make_repo(rev_lock=None) as repo: diff --git a/dvc/exceptions.py b/dvc/exceptions.py index f7b6739a33..eb6697068f 100644 --- a/dvc/exceptions.py +++ b/dvc/exceptions.py @@ -303,3 +303,21 @@ def __init__(self, hook_name): "https://man.dvc.org/install " "for more info.".format(hook_name) ) + + +class DownloadError(DvcException): + def __init__(self, amount): + self.amount = amount + + super(DownloadError, self).__init__( + "{amount} files failed to download".format(amount=amount) + ) + + +class UploadError(DvcException): + def __init__(self, amount): + self.amount = amount + + super(UploadError, self).__init__( + "{amount} files failed to upload".format(amount=amount) + ) diff --git a/dvc/output/base.py b/dvc/output/base.py index c478f6190f..455f8951d6 100644 --- a/dvc/output/base.py +++ b/dvc/output/base.py @@ -402,9 +402,6 @@ def get_used_cache(self, **kwargs): include the `info` of its files. """ - if self.stage.is_repo_import: - return [] - if not self.use_cache: return [] diff --git a/dvc/remote/local/__init__.py b/dvc/remote/local/__init__.py index 1c1256efd7..7592ecb274 100644 --- a/dvc/remote/local/__init__.py +++ b/dvc/remote/local/__init__.py @@ -30,7 +30,7 @@ makedirs, ) from dvc.config import Config -from dvc.exceptions import DvcException +from dvc.exceptions import DvcException, DownloadError, UploadError from dvc.progress import Tqdm, TqdmThreadPoolExecutor from dvc.path_info import PathInfo @@ -388,10 +388,9 @@ def _process( fails = sum(map(func, *plans)) if fails: - msg = "{} files failed to {}" - raise DvcException( - msg.format(fails, "download" if download else "upload") - ) + if download: + raise DownloadError(fails) + raise UploadError(fails) return len(plans[0]) diff --git a/dvc/repo/__init__.py b/dvc/repo/__init__.py index 16317a68fc..31eca793e0 100644 --- a/dvc/repo/__init__.py +++ b/dvc/repo/__init__.py @@ -214,6 +214,7 @@ def used_cache( cache["hdfs"] = [] cache["ssh"] = [] cache["azure"] = [] + cache["repo"] = [] for branch in self.brancher( all_branches=all_branches, all_tags=all_tags @@ -231,6 +232,10 @@ def used_cache( stages = self.stages() for stage in stages: + if stage.is_repo_import: + cache["repo"] += stage.deps + continue + for out in stage.outs: scheme = out.path_info.scheme used_cache = out.get_used_cache( diff --git a/dvc/repo/fetch.py b/dvc/repo/fetch.py index 3132e30d33..ee58c15423 100644 --- a/dvc/repo/fetch.py +++ b/dvc/repo/fetch.py @@ -1,5 +1,14 @@ from __future__ import unicode_literals +import logging + +from dvc.config import NoRemoteError +from dvc.exceptions import DownloadError, OutputNotFoundError +from dvc.scm.base import CloneError + + +logger = logging.getLogger(__name__) + def fetch( self, @@ -12,6 +21,18 @@ def fetch( all_tags=False, recursive=False, ): + """Download data items from a cloud and imported repositories + + Returns: + int: number of succesfully downloaded files + + Raises: + DownloadError: thrown when there are failed downloads, either + during `cloud.pull` or trying to fetch imported files + + config.NoRemoteError: thrown when downloading only local files and no + remote is configured + """ with self.state: used = self.used_cache( targets, @@ -22,7 +43,38 @@ def fetch( remote=remote, jobs=jobs, recursive=recursive, - )["local"] - return self.cloud.pull( - used, jobs, remote=remote, show_checksums=show_checksums ) + + downloaded = 0 + failed = 0 + + try: + downloaded += self.cloud.pull( + used["local"], + jobs, + remote=remote, + show_checksums=show_checksums, + ) + except NoRemoteError: + if not used["repo"] and used["local"]: + raise + + except DownloadError as exc: + failed += exc.amount + + for dep in used["repo"]: + try: + out = dep.fetch() + downloaded += out.get_files_number() + except DownloadError as exc: + failed += exc.amount + except (CloneError, OutputNotFoundError): + failed += 1 + logger.exception( + "failed to fetch data for '{}'".format(dep.stage.outs[0]) + ) + + if failed: + raise DownloadError(failed) + + return downloaded diff --git a/tests/func/test_import.py b/tests/func/test_import.py index b097f98450..b5b99d596a 100644 --- a/tests/func/test_import.py +++ b/tests/func/test_import.py @@ -1,8 +1,13 @@ import os import filecmp +import shutil +import pytest from tests.utils import trees_equal +from dvc.stage import Stage +from dvc.exceptions import DownloadError + def test_import(repo_dir, git, dvc_repo, erepo): src = erepo.FOO @@ -39,3 +44,38 @@ def test_import_rev(repo_dir, git, dvc_repo, erepo): with open(dst, "r+") as fobj: assert fobj.read() == "branch" assert git.git.check_ignore(dst) + + +def test_pull_imported_stage(dvc_repo, erepo): + src = erepo.FOO + dst = erepo.FOO + "_imported" + + dvc_repo.imp(erepo.root_dir, src, dst) + + dst_stage = Stage.load(dvc_repo, "foo_imported.dvc") + dst_cache = dst_stage.outs[0].cache_path + + os.remove(dst) + os.remove(dst_cache) + + dvc_repo.pull(["foo_imported.dvc"]) + + assert os.path.isfile(dst) + assert os.path.isfile(dst_cache) + + +def test_download_error_pulling_imported_stage(dvc_repo, erepo): + src = erepo.FOO + dst = erepo.FOO + "_imported" + + dvc_repo.imp(erepo.root_dir, src, dst) + + dst_stage = Stage.load(dvc_repo, "foo_imported.dvc") + dst_cache = dst_stage.outs[0].cache_path + + shutil.rmtree(erepo.root_dir) + os.remove(dst) + os.remove(dst_cache) + + with pytest.raises(DownloadError): + dvc_repo.pull(["foo_imported.dvc"])