Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 9 additions & 7 deletions dvc/remote/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -468,7 +468,7 @@ def _cache_is_copy(self, path_info):
self.cache_type_confirmed = True
return self.cache_types[0] == "copy"

def _save_dir(self, path_info, checksum):
def _save_dir(self, path_info, checksum, save_link=True):
cache_info = self.checksum_to_path_info(checksum)
dir_info = self.get_dir_cache(checksum)

Expand All @@ -479,7 +479,9 @@ def _save_dir(self, path_info, checksum):
entry_checksum = entry[self.PARAM_CHECKSUM]
self._save_file(entry_info, entry_checksum, save_link=False)

self.state.save_link(path_info)
if save_link:
self.state.save_link(path_info)

self.state.save(cache_info, checksum)
self.state.save(path_info, checksum)

Expand Down Expand Up @@ -510,23 +512,23 @@ def walk_files(self, path_info):
def protect(path_info):
pass

def save(self, path_info, checksum_info):
def save(self, path_info, checksum_info, save_link=True):
if path_info.scheme != self.scheme:
raise RemoteActionNotImplemented(
"save {} -> {}".format(path_info.scheme, self.scheme),
self.scheme,
)

checksum = checksum_info[self.PARAM_CHECKSUM]
self._save(path_info, checksum)
self._save(path_info, checksum, save_link)

def _save(self, path_info, checksum):
def _save(self, path_info, checksum, save_link=True):
to_info = self.checksum_to_path_info(checksum)
logger.debug("Saving '%s' to '%s'.", path_info, to_info)
if self.isdir(path_info):
self._save_dir(path_info, checksum)
self._save_dir(path_info, checksum, save_link)
return
self._save_file(path_info, checksum)
self._save_file(path_info, checksum, save_link)

def _handle_transfer_exception(
self, from_info, to_info, exception, operation
Expand Down
78 changes: 53 additions & 25 deletions dvc/repo/fetch.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,9 @@

from dvc.cache import NamedCache
from dvc.config import NoRemoteError
from dvc.exceptions import DownloadError
from dvc.exceptions import OutputNotFoundError
from dvc.exceptions import DownloadError, OutputNotFoundError
from dvc.scm.base import CloneError
from dvc.path_info import PathInfo


logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -69,35 +69,63 @@ def _fetch(


def _fetch_external(self, repo_url, repo_rev, files, jobs):
from dvc.external_repo import external_repo
from dvc.external_repo import external_repo, ExternalRepo

failed = 0
failed, downloaded = 0, 0
try:
with external_repo(repo_url, repo_rev) as repo:
repo.cache.local.cache_dir = self.cache.local.cache_dir

with repo.state:
cache = NamedCache()
for name in files:
try:
out = repo.find_out_by_relpath(name)
except OutputNotFoundError:
failed += 1
logger.exception(
"failed to fetch data for '{}'".format(name)
)
continue
else:
cache.update(out.get_used_cache())

try:
return repo.cloud.pull(cache, jobs=jobs), failed
except DownloadError as exc:
failed += exc.amount
is_dvc_repo = isinstance(repo, ExternalRepo)
# gather git-only tracked files if dvc repo
git_files = [] if is_dvc_repo else files
if is_dvc_repo:
repo.cache.local.cache_dir = self.cache.local.cache_dir
with repo.state:
cache = NamedCache()
for name in files:
try:
out = repo.find_out_by_relpath(name)
except OutputNotFoundError:
# try to add to cache if they are git-tracked files
git_files.append(name)
else:
cache.update(out.get_used_cache())

try:
downloaded += repo.cloud.pull(cache, jobs=jobs)
except DownloadError as exc:
failed += exc.amount

d, f = _git_to_cache(self.cache.local, repo.root_dir, git_files)
downloaded += d
failed += f
except CloneError:
failed += 1
logger.exception(
"failed to fetch data for '{}'".format(", ".join(files))
)

return 0, failed
return downloaded, failed


def _git_to_cache(cache, repo_root, files):
"""Save files from a git repo directly to the cache."""
failed = set()
num_downloads = 0
repo_root = PathInfo(repo_root)
for file in files:
info = cache.save_info(repo_root / file)
if info.get(cache.PARAM_CHECKSUM) is None:
failed.add(file)
continue

if cache.changed_cache(info[cache.PARAM_CHECKSUM]):
logger.debug("fetched '%s' from '%s' repo", file, repo_root)
num_downloads += 1
cache.save(repo_root / file, info, save_link=False)

if failed:
logger.exception(
"failed to fetch data for {}".format(", ".join(failed))
)

return num_downloads, len(failed)
55 changes: 55 additions & 0 deletions tests/func/test_data_cloud.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,9 @@
from dvc.remote import RemoteSSH
from dvc.remote.base import STATUS_DELETED, STATUS_NEW, STATUS_OK
from dvc.utils import file_md5
from dvc.utils.fs import remove
from dvc.utils.stage import dump_stage_file, load_stage_file
from dvc.external_repo import clean_repos
from tests.basic_env import TestDvc

from tests.remotes import (
Expand Down Expand Up @@ -709,3 +711,56 @@ def test_verify_checksums(tmp_dir, scm, dvc, mocker, tmp_path_factory):

dvc.pull()
assert checksum_spy.call_count == 3


@pytest.mark.parametrize("erepo", ["git_dir", "erepo_dir"])
def test_pull_git_imports(request, tmp_dir, dvc, scm, erepo):
erepo = request.getfixturevalue(erepo)
with erepo.chdir():
erepo.scm_gen({"dir": {"bar": "bar"}}, commit="second")
erepo.scm_gen("foo", "foo", commit="first")

dvc.imp(fspath(erepo), "foo")
dvc.imp(fspath(erepo), "dir", out="new_dir", rev="HEAD~")

assert dvc.pull()["downloaded"] == 0

for item in ["foo", "new_dir", dvc.cache.local.cache_dir]:
remove(item)
os.makedirs(dvc.cache.local.cache_dir, exist_ok=True)
clean_repos()

assert dvc.pull(force=True)["downloaded"] == 2

assert (tmp_dir / "foo").exists()
assert (tmp_dir / "foo").read_text() == "foo"

assert (tmp_dir / "new_dir").exists()
assert (tmp_dir / "new_dir" / "bar").read_text() == "bar"


def test_pull_external_dvc_imports(tmp_dir, dvc, scm, erepo_dir):
with erepo_dir.chdir():
erepo_dir.dvc_gen({"dir": {"bar": "bar"}}, commit="second")
erepo_dir.dvc_gen("foo", "foo", commit="first")

os.remove("foo")
shutil.rmtree("dir")

dvc.imp(fspath(erepo_dir), "foo")
dvc.imp(fspath(erepo_dir), "dir", out="new_dir", rev="HEAD~")

assert dvc.pull()["downloaded"] == 0

for item in ["foo", "new_dir", dvc.cache.local.cache_dir]:
remove(item)
os.makedirs(dvc.cache.local.cache_dir, exist_ok=True)
clean_repos()

assert dvc.pull(force=True)["downloaded"] == 2

assert (tmp_dir / "foo").exists()
assert (tmp_dir / "foo").read_text() == "foo"

assert (tmp_dir / "new_dir").exists()
assert (tmp_dir / "new_dir" / "bar").read_text() == "bar"