diff --git a/dvc/remote/base.py b/dvc/remote/base.py index 10969250a0..10f21d1ea0 100644 --- a/dvc/remote/base.py +++ b/dvc/remote/base.py @@ -31,6 +31,7 @@ from dvc.utils.fs import move from dvc.utils.http import open_url + logger = logging.getLogger(__name__) STATUS_OK = 1 @@ -547,11 +548,45 @@ def download( if to_info.scheme != "local": raise NotImplementedError - logger.debug("Downloading '{}' to '{}'".format(from_info, to_info)) + if self.isdir(from_info): + file_to_infos = ( + to_info / file_to_info.relative_to(from_info) + for file_to_info in self.walk_files(from_info) + ) - name = name or to_info.name + with ThreadPoolExecutor(max_workers=self.JOBS) as executor: + file_from_info = list(self.walk_files(from_info)) + with Tqdm( + file_from_info, + total=len(file_from_info), + desc="Downloading directory", + ) as file_from_info: + return sum( + executor.map( + partial( + self.single_file_download, + name=name, + no_progress_bar=True, + file_mode=file_mode, + dir_mode=dir_mode, + ), + file_from_info, + file_to_infos, + ) + ) + else: + self.single_file_download( + from_info, to_info, name, no_progress_bar, file_mode, dir_mode + ) + def single_file_download( + self, from_info, to_info, name, no_progress_bar, file_mode, dir_mode + ): makedirs(to_info.parent, exist_ok=True, mode=dir_mode) + + logger.debug("Downloading '{}' to '{}'".format(from_info, to_info)) + name = name or to_info.name + tmp_file = tmp_fname(to_info) try: @@ -559,7 +594,7 @@ def download( from_info, tmp_file, name=name, no_progress_bar=no_progress_bar ) except Exception: - msg = "failed to download '{}' to '{}'" + msg = "failed to doooooownload '{}' to '{}'" logger.exception(msg.format(from_info, to_info)) return 1 # 1 fail diff --git a/tests/unit/remote/test_remote_dir.py b/tests/unit/remote/test_remote_dir.py index 583898a60d..9e740a9684 100644 --- a/tests/unit/remote/test_remote_dir.py +++ b/tests/unit/remote/test_remote_dir.py @@ -1,5 +1,6 @@ # -*- coding: utf-8 -*- import pytest +import os from dvc.remote.s3 import RemoteS3 @@ -132,3 +133,12 @@ def test_isfile(remote): for expected, path in test_cases: assert remote.isfile(remote.path_info / path) == expected + + +@pytest.mark.parametrize("remote", remotes, indirect=True) +def test_download_dir(remote, tmpdir): + path = os.fspath(tmpdir / "data") + to_info = os.PathInfo(path) + remote.download(remote.path_info / "data", to_info) + assert os.path.isdir(path) + # check the list of files