diff --git a/dvc/output/base.py b/dvc/output/base.py index f940d3f52e..8fa3ad367c 100644 --- a/dvc/output/base.py +++ b/dvc/output/base.py @@ -334,7 +334,7 @@ def unprotect(self): self.remote.unprotect(self.path_info) def _collect_used_dir_cache(self, remote=None, force=False, jobs=None): - """Get a list of `info`s retaled to the given directory. + """Get a list of `info`s related to the given directory. - Pull the directory entry from the remote cache if it was changed. diff --git a/dvc/repo/fetch.py b/dvc/repo/fetch.py index 7029461630..b69c8c04b2 100644 --- a/dvc/repo/fetch.py +++ b/dvc/repo/fetch.py @@ -76,20 +76,20 @@ def _fetch_external(self, repo_url, repo_rev, files): cache_dir = self.cache.local.cache_dir try: with external_repo(repo_url, repo_rev, cache_dir=cache_dir) as repo: - cache = NamedCache() - for name in files: - try: - out = repo.find_out_by_relpath(name) - except OutputNotFoundError: - failed += 1 - logger.exception( - "failed to fetch data for '{}'".format(name) - ) - continue - else: - cache.update(out.get_used_cache()) - with repo.state: + cache = NamedCache() + for name in files: + try: + out = repo.find_out_by_relpath(name) + except OutputNotFoundError: + failed += 1 + logger.exception( + "failed to fetch data for '{}'".format(name) + ) + continue + else: + cache.update(out.get_used_cache()) + try: return repo.cloud.pull(cache), failed except DownloadError as exc: diff --git a/tests/func/test_import.py b/tests/func/test_import.py index fc00ee7b1e..4fe8857aad 100644 --- a/tests/func/test_import.py +++ b/tests/func/test_import.py @@ -2,6 +2,7 @@ import os import filecmp +import shutil import pytest from mock import patch @@ -67,6 +68,23 @@ def test_pull_imported_stage(dvc_repo, erepo): assert os.path.isfile(dst_cache) +def test_pull_imported_directory_stage(dvc_repo, erepo): + src = erepo.DATA_DIR + dst = erepo.DATA_DIR + "_imported" + stage_file = dst + ".dvc" + + dvc_repo.imp(erepo.root_dir, src, dst) + + shutil.rmtree(dst) + shutil.rmtree(dvc_repo.cache.local.cache_dir) + + dvc_repo.pull([stage_file]) + + assert os.path.exists(dst) + assert os.path.isdir(dst) + trees_equal(src, dst) + + def test_download_error_pulling_imported_stage(dvc_repo, erepo): src = erepo.FOO dst = erepo.FOO + "_imported"