Skip to content
12 changes: 12 additions & 0 deletions dvc/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,18 @@ def __init__(self, msg, cause=None):
)


class NoRemoteError(ConfigError):
def __init__(self, command, cause=None):
msg = (
"no remote specified. Setup default remote with\n"
" dvc config core.remote <name>\n"
"or use:\n"
" dvc {} -r <name>\n".format(command)
)

super(NoRemoteError, self).__init__(msg, cause=cause)


def supported_cache_type(types):
"""Checks if link type config option has a valid value.

Expand Down
9 changes: 2 additions & 7 deletions dvc/data_cloud.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@

import logging

from dvc.config import Config, ConfigError
from dvc.config import Config, NoRemoteError
from dvc.remote import Remote
from dvc.remote.s3 import RemoteS3
from dvc.remote.gs import RemoteGS
Expand Down Expand Up @@ -60,12 +60,7 @@ def get_remote(self, remote=None, command="<command>"):
if remote:
return self._init_remote(remote)

raise ConfigError(
"No remote repository specified. Setup default repository with\n"
" dvc config core.remote <name>\n"
"or use:\n"
" dvc {} -r <name>\n".format(command)
)
raise NoRemoteError(command)

def _init_remote(self, remote):
return Remote(self.repo, name=remote)
Expand Down
11 changes: 8 additions & 3 deletions dvc/dependency/repo.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@ def save(self):
def dumpd(self):
return {self.PARAM_PATH: self.def_path, self.PARAM_REPO: self.def_repo}

def download(self, to, resume=False):
def fetch(self):
with self._make_repo(
cache_dir=self.repo.cache.local.cache_dir
) as repo:
Expand All @@ -70,8 +70,13 @@ def download(self, to, resume=False):
out = repo.find_out_by_relpath(self.def_path)
with repo.state:
repo.cloud.pull(out.get_used_cache())
to.info = copy.copy(out.info)
to.checkout()

return out

def download(self, to):
out = self.fetch()
to.info = copy.copy(out.info)
to.checkout()

def update(self):
with self._make_repo(rev_lock=None) as repo:
Expand Down
18 changes: 18 additions & 0 deletions dvc/exceptions.py
Original file line number Diff line number Diff line change
Expand Up @@ -303,3 +303,21 @@ def __init__(self, hook_name):
"https://man.dvc.org/install "
"for more info.".format(hook_name)
)


class DownloadError(DvcException):
def __init__(self, amount):
self.amount = amount

super(DownloadError, self).__init__(
"{amount} files failed to download".format(amount=amount)
)


class UploadError(DvcException):
def __init__(self, amount):
self.amount = amount

super(UploadError, self).__init__(
"{amount} files failed to upload".format(amount=amount)
)
3 changes: 0 additions & 3 deletions dvc/output/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -402,9 +402,6 @@ def get_used_cache(self, **kwargs):
include the `info` of its files.
"""

if self.stage.is_repo_import:
return []

if not self.use_cache:
return []

Expand Down
9 changes: 4 additions & 5 deletions dvc/remote/local/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@
makedirs,
)
from dvc.config import Config
from dvc.exceptions import DvcException
from dvc.exceptions import DvcException, DownloadError, UploadError
from dvc.progress import Tqdm, TqdmThreadPoolExecutor

from dvc.path_info import PathInfo
Expand Down Expand Up @@ -388,10 +388,9 @@ def _process(
fails = sum(map(func, *plans))

if fails:
msg = "{} files failed to {}"
raise DvcException(
msg.format(fails, "download" if download else "upload")
)
if download:
raise DownloadError(fails)
raise UploadError(fails)

return len(plans[0])

Expand Down
5 changes: 5 additions & 0 deletions dvc/repo/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -214,6 +214,7 @@ def used_cache(
cache["hdfs"] = []
cache["ssh"] = []
cache["azure"] = []
cache["repo"] = []

for branch in self.brancher(
all_branches=all_branches, all_tags=all_tags
Expand All @@ -231,6 +232,10 @@ def used_cache(
stages = self.stages()

for stage in stages:
if stage.is_repo_import:
cache["repo"] += stage.deps
continue

for out in stage.outs:
scheme = out.path_info.scheme
used_cache = out.get_used_cache(
Expand Down
58 changes: 55 additions & 3 deletions dvc/repo/fetch.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,14 @@
from __future__ import unicode_literals

import logging

from dvc.config import NoRemoteError
from dvc.exceptions import DownloadError, OutputNotFoundError
from dvc.scm.base import CloneError


logger = logging.getLogger(__name__)


def fetch(
self,
Expand All @@ -12,6 +21,18 @@ def fetch(
all_tags=False,
recursive=False,
):
"""Download data items from a cloud and imported repositories

Returns:
int: number of succesfully downloaded files

Raises:
DownloadError: thrown when there are failed downloads, either
during `cloud.pull` or trying to fetch imported files

config.NoRemoteError: thrown when downloading only local files and no
remote is configured
"""
with self.state:
used = self.used_cache(
targets,
Expand All @@ -22,7 +43,38 @@ def fetch(
remote=remote,
jobs=jobs,
recursive=recursive,
)["local"]
return self.cloud.pull(
used, jobs, remote=remote, show_checksums=show_checksums
)

downloaded = 0
failed = 0

try:
downloaded += self.cloud.pull(
used["local"],
jobs,
remote=remote,
show_checksums=show_checksums,
)
except NoRemoteError:
if not used["repo"] and used["local"]:
raise

except DownloadError as exc:
failed += exc.amount

for dep in used["repo"]:
try:
out = dep.fetch()
downloaded += out.get_files_number()
except DownloadError as exc:
failed += exc.amount
except (CloneError, OutputNotFoundError):
failed += 1
logger.exception(
"failed to fetch data for '{}'".format(dep.stage.outs[0])
)

if failed:
raise DownloadError(failed)

return downloaded
40 changes: 40 additions & 0 deletions tests/func/test_import.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,13 @@
import os
import filecmp
import shutil

import pytest
from tests.utils import trees_equal

from dvc.stage import Stage
from dvc.exceptions import DownloadError


def test_import(repo_dir, git, dvc_repo, erepo):
src = erepo.FOO
Expand Down Expand Up @@ -39,3 +44,38 @@ def test_import_rev(repo_dir, git, dvc_repo, erepo):
with open(dst, "r+") as fobj:
assert fobj.read() == "branch"
assert git.git.check_ignore(dst)


def test_pull_imported_stage(dvc_repo, erepo):
src = erepo.FOO
dst = erepo.FOO + "_imported"

dvc_repo.imp(erepo.root_dir, src, dst)

dst_stage = Stage.load(dvc_repo, "foo_imported.dvc")
dst_cache = dst_stage.outs[0].cache_path

os.remove(dst)
os.remove(dst_cache)

dvc_repo.pull(["foo_imported.dvc"])

assert os.path.isfile(dst)
assert os.path.isfile(dst_cache)


def test_download_error_pulling_imported_stage(dvc_repo, erepo):
src = erepo.FOO
dst = erepo.FOO + "_imported"

dvc_repo.imp(erepo.root_dir, src, dst)

dst_stage = Stage.load(dvc_repo, "foo_imported.dvc")
dst_cache = dst_stage.outs[0].cache_path

shutil.rmtree(erepo.root_dir)
os.remove(dst)
os.remove(dst_cache)

with pytest.raises(DownloadError):
dvc_repo.pull(["foo_imported.dvc"])