diff --git a/dvc/dependency/repo.py b/dvc/dependency/repo.py index 0963e1061d..595b836dc5 100644 --- a/dvc/dependency/repo.py +++ b/dvc/dependency/repo.py @@ -6,7 +6,9 @@ from funcy import merge from .local import DependencyLOCAL +from dvc.external_repo import cached_clone from dvc.external_repo import external_repo +from dvc.exceptions import NotDvcRepoError from dvc.exceptions import OutputNotFoundError from dvc.exceptions import PathMissingError from dvc.utils.fs import fs_copy @@ -75,27 +77,35 @@ def fetch(self): return out @staticmethod - def _is_git_file(repo, path): - if not os.path.isabs(path): - try: - output = repo.find_out_by_relpath(path) - if not output.use_cache: - return True - except OutputNotFoundError: - return True - return False + def _is_git_file(repo_dir, path): + from dvc.repo import Repo + + if os.path.isabs(path): + return False + + try: + repo = Repo(repo_dir) + except NotDvcRepoError: + return True + + try: + output = repo.find_out_by_relpath(path) + return not output.use_cache + except OutputNotFoundError: + return True + finally: + repo.close() def _copy_if_git_file(self, to_path): src_path = self.def_path - with self._make_repo( - cache_dir=self.repo.cache.local.cache_dir - ) as repo: - if not self._is_git_file(repo, src_path): - return False + repo_dir = cached_clone(**self.def_repo) + + if not self._is_git_file(repo_dir, src_path): + return False - src_full_path = os.path.join(repo.root_dir, src_path) - dst_full_path = os.path.abspath(to_path) - fs_copy(src_full_path, dst_full_path) + src_full_path = os.path.join(repo_dir, src_path) + dst_full_path = os.path.abspath(to_path) + fs_copy(src_full_path, dst_full_path) return True def download(self, to): diff --git a/dvc/external_repo.py b/dvc/external_repo.py index cf5ecacdd1..9ff2f2a413 100644 --- a/dvc/external_repo.py +++ b/dvc/external_repo.py @@ -33,18 +33,20 @@ def external_repo(url=None, rev=None, rev_lock=None, cache_dir=None): repo.close() -def _external_repo(url=None, rev=None, cache_dir=None): - from dvc.config import Config - from dvc.cache import CacheConfig - from dvc.repo import Repo +def cached_clone(url, rev=None, **_ignored_kwargs): + """Clone an external git repo to a temporary directory. - key = (url, rev, cache_dir) - if key in REPO_CACHE: - return REPO_CACHE[key] + Returns the path to a local temporary directory with the specified + revision checked out. + + Uses the REPO_CACHE to avoid accessing the remote server again if + cloning from the same URL twice in the same session. + + """ new_path = tempfile.mkdtemp("dvc-erepo") - # Copy and adjust existing clone + # Copy and adjust existing clean clone if (url, None, None) in REPO_CACHE: old_path = REPO_CACHE[url, None, None] @@ -59,13 +61,24 @@ def _external_repo(url=None, rev=None, cache_dir=None): copy_tree(new_path, clean_clone_path) REPO_CACHE[url, None, None] = clean_clone_path - # Adjust new clone/copy to fit rev and cache_dir - - # Checkout needs to be done first because current branch might not be - # DVC repository + # Check out the specified revision if rev is not None: _git_checkout(new_path, rev) + return new_path + + +def _external_repo(url=None, rev=None, cache_dir=None): + from dvc.config import Config + from dvc.cache import CacheConfig + from dvc.repo import Repo + + key = (url, rev, cache_dir) + if key in REPO_CACHE: + return REPO_CACHE[key] + + new_path = cached_clone(url, rev=rev) + repo = Repo(new_path) try: # check if the URL is local and no default remote is present diff --git a/tests/func/test_import.py b/tests/func/test_import.py index 99da2df896..19ef1edd1b 100644 --- a/tests/func/test_import.py +++ b/tests/func/test_import.py @@ -29,7 +29,12 @@ def test_import(tmp_dir, scm, dvc, erepo_dir, monkeypatch): assert scm.repo.git.check_ignore("foo_imported") -def test_import_git_file(erepo_dir, tmp_dir, dvc, scm): +@pytest.mark.parametrize("src_is_dvc", [True, False]) +def test_import_git_file(erepo_dir, tmp_dir, dvc, scm, src_is_dvc): + if not src_is_dvc: + erepo_dir.dvc.scm.repo.index.remove([".dvc"], r=True) + erepo_dir.dvc.scm.commit("remove .dvc") + src = "some_file" dst = "some_file_imported" @@ -44,7 +49,12 @@ def test_import_git_file(erepo_dir, tmp_dir, dvc, scm): assert tmp_dir.scm.repo.git.check_ignore(fspath(tmp_dir / dst)) -def test_import_git_dir(erepo_dir, tmp_dir, dvc, scm): +@pytest.mark.parametrize("src_is_dvc", [True, False]) +def test_import_git_dir(erepo_dir, tmp_dir, dvc, scm, src_is_dvc): + if not src_is_dvc: + erepo_dir.dvc.scm.repo.index.remove([".dvc"], r=True) + erepo_dir.dvc.scm.commit("remove .dvc") + src = "some_directory" dst = "some_directory_imported"