diff --git a/dvc/remote/base.py b/dvc/remote/base.py index c12f5df91e..89de458e10 100644 --- a/dvc/remote/base.py +++ b/dvc/remote/base.py @@ -468,7 +468,7 @@ def _cache_is_copy(self, path_info): self.cache_type_confirmed = True return self.cache_types[0] == "copy" - def _save_dir(self, path_info, checksum): + def _save_dir(self, path_info, checksum, save_link=True): cache_info = self.checksum_to_path_info(checksum) dir_info = self.get_dir_cache(checksum) @@ -479,7 +479,9 @@ def _save_dir(self, path_info, checksum): entry_checksum = entry[self.PARAM_CHECKSUM] self._save_file(entry_info, entry_checksum, save_link=False) - self.state.save_link(path_info) + if save_link: + self.state.save_link(path_info) + self.state.save(cache_info, checksum) self.state.save(path_info, checksum) @@ -510,7 +512,7 @@ def walk_files(self, path_info): def protect(path_info): pass - def save(self, path_info, checksum_info): + def save(self, path_info, checksum_info, save_link=True): if path_info.scheme != self.scheme: raise RemoteActionNotImplemented( "save {} -> {}".format(path_info.scheme, self.scheme), @@ -518,15 +520,15 @@ def save(self, path_info, checksum_info): ) checksum = checksum_info[self.PARAM_CHECKSUM] - self._save(path_info, checksum) + self._save(path_info, checksum, save_link) - def _save(self, path_info, checksum): + def _save(self, path_info, checksum, save_link=True): to_info = self.checksum_to_path_info(checksum) logger.debug("Saving '%s' to '%s'.", path_info, to_info) if self.isdir(path_info): - self._save_dir(path_info, checksum) + self._save_dir(path_info, checksum, save_link) return - self._save_file(path_info, checksum) + self._save_file(path_info, checksum, save_link) def _handle_transfer_exception( self, from_info, to_info, exception, operation diff --git a/dvc/repo/fetch.py b/dvc/repo/fetch.py index ab6e178d8c..fbbaab1be8 100644 --- a/dvc/repo/fetch.py +++ b/dvc/repo/fetch.py @@ -2,9 +2,9 @@ from dvc.cache import NamedCache from dvc.config import NoRemoteError -from dvc.exceptions import DownloadError -from dvc.exceptions import OutputNotFoundError +from dvc.exceptions import DownloadError, OutputNotFoundError from dvc.scm.base import CloneError +from dvc.path_info import PathInfo logger = logging.getLogger(__name__) @@ -69,35 +69,63 @@ def _fetch( def _fetch_external(self, repo_url, repo_rev, files, jobs): - from dvc.external_repo import external_repo + from dvc.external_repo import external_repo, ExternalRepo - failed = 0 + failed, downloaded = 0, 0 try: with external_repo(repo_url, repo_rev) as repo: - repo.cache.local.cache_dir = self.cache.local.cache_dir - - with repo.state: - cache = NamedCache() - for name in files: - try: - out = repo.find_out_by_relpath(name) - except OutputNotFoundError: - failed += 1 - logger.exception( - "failed to fetch data for '{}'".format(name) - ) - continue - else: - cache.update(out.get_used_cache()) - - try: - return repo.cloud.pull(cache, jobs=jobs), failed - except DownloadError as exc: - failed += exc.amount + is_dvc_repo = isinstance(repo, ExternalRepo) + # gather git-only tracked files if dvc repo + git_files = [] if is_dvc_repo else files + if is_dvc_repo: + repo.cache.local.cache_dir = self.cache.local.cache_dir + with repo.state: + cache = NamedCache() + for name in files: + try: + out = repo.find_out_by_relpath(name) + except OutputNotFoundError: + # try to add to cache if they are git-tracked files + git_files.append(name) + else: + cache.update(out.get_used_cache()) + + try: + downloaded += repo.cloud.pull(cache, jobs=jobs) + except DownloadError as exc: + failed += exc.amount + + d, f = _git_to_cache(self.cache.local, repo.root_dir, git_files) + downloaded += d + failed += f except CloneError: failed += 1 logger.exception( "failed to fetch data for '{}'".format(", ".join(files)) ) - return 0, failed + return downloaded, failed + + +def _git_to_cache(cache, repo_root, files): + """Save files from a git repo directly to the cache.""" + failed = set() + num_downloads = 0 + repo_root = PathInfo(repo_root) + for file in files: + info = cache.save_info(repo_root / file) + if info.get(cache.PARAM_CHECKSUM) is None: + failed.add(file) + continue + + if cache.changed_cache(info[cache.PARAM_CHECKSUM]): + logger.debug("fetched '%s' from '%s' repo", file, repo_root) + num_downloads += 1 + cache.save(repo_root / file, info, save_link=False) + + if failed: + logger.exception( + "failed to fetch data for {}".format(", ".join(failed)) + ) + + return num_downloads, len(failed) diff --git a/tests/func/test_data_cloud.py b/tests/func/test_data_cloud.py index aaf1a12c5c..8f4c6574b3 100644 --- a/tests/func/test_data_cloud.py +++ b/tests/func/test_data_cloud.py @@ -22,7 +22,9 @@ from dvc.remote import RemoteSSH from dvc.remote.base import STATUS_DELETED, STATUS_NEW, STATUS_OK from dvc.utils import file_md5 +from dvc.utils.fs import remove from dvc.utils.stage import dump_stage_file, load_stage_file +from dvc.external_repo import clean_repos from tests.basic_env import TestDvc from tests.remotes import ( @@ -709,3 +711,56 @@ def test_verify_checksums(tmp_dir, scm, dvc, mocker, tmp_path_factory): dvc.pull() assert checksum_spy.call_count == 3 + + +@pytest.mark.parametrize("erepo", ["git_dir", "erepo_dir"]) +def test_pull_git_imports(request, tmp_dir, dvc, scm, erepo): + erepo = request.getfixturevalue(erepo) + with erepo.chdir(): + erepo.scm_gen({"dir": {"bar": "bar"}}, commit="second") + erepo.scm_gen("foo", "foo", commit="first") + + dvc.imp(fspath(erepo), "foo") + dvc.imp(fspath(erepo), "dir", out="new_dir", rev="HEAD~") + + assert dvc.pull()["downloaded"] == 0 + + for item in ["foo", "new_dir", dvc.cache.local.cache_dir]: + remove(item) + os.makedirs(dvc.cache.local.cache_dir, exist_ok=True) + clean_repos() + + assert dvc.pull(force=True)["downloaded"] == 2 + + assert (tmp_dir / "foo").exists() + assert (tmp_dir / "foo").read_text() == "foo" + + assert (tmp_dir / "new_dir").exists() + assert (tmp_dir / "new_dir" / "bar").read_text() == "bar" + + +def test_pull_external_dvc_imports(tmp_dir, dvc, scm, erepo_dir): + with erepo_dir.chdir(): + erepo_dir.dvc_gen({"dir": {"bar": "bar"}}, commit="second") + erepo_dir.dvc_gen("foo", "foo", commit="first") + + os.remove("foo") + shutil.rmtree("dir") + + dvc.imp(fspath(erepo_dir), "foo") + dvc.imp(fspath(erepo_dir), "dir", out="new_dir", rev="HEAD~") + + assert dvc.pull()["downloaded"] == 0 + + for item in ["foo", "new_dir", dvc.cache.local.cache_dir]: + remove(item) + os.makedirs(dvc.cache.local.cache_dir, exist_ok=True) + clean_repos() + + assert dvc.pull(force=True)["downloaded"] == 2 + + assert (tmp_dir / "foo").exists() + assert (tmp_dir / "foo").read_text() == "foo" + + assert (tmp_dir / "new_dir").exists() + assert (tmp_dir / "new_dir" / "bar").read_text() == "bar"