diff --git a/dvc/remote/base.py b/dvc/remote/base.py index 10969250a0..70e5266390 100644 --- a/dvc/remote/base.py +++ b/dvc/remote/base.py @@ -547,11 +547,51 @@ def download( if to_info.scheme != "local": raise NotImplementedError - logger.debug("Downloading '{}' to '{}'".format(from_info, to_info)) + if self.isdir(from_info): + return self._download_dir( + from_info, to_info, name, no_progress_bar, file_mode, dir_mode + ) + return self._download_file( + from_info, to_info, name, no_progress_bar, file_mode, dir_mode + ) - name = name or to_info.name + def _download_dir( + self, from_info, to_info, name, no_progress_bar, file_mode, dir_mode + ): + file_to_infos = ( + to_info / file_to_info.relative_to(from_info) + for file_to_info in self.walk_files(from_info) + ) + with ThreadPoolExecutor(max_workers=self.JOBS) as executor: + file_from_info = list(self.walk_files(from_info)) + download_files = partial( + self._download_file, + name=name, + no_progress_bar=True, + file_mode=file_mode, + dir_mode=dir_mode, + ) + futures = executor.map( + download_files, file_from_info, file_to_infos + ) + with Tqdm( + futures, + total=len(file_from_info), + desc="Downloading directory", + unit="Files", + disable=no_progress_bar, + ) as futures: + return sum(futures) + + def _download_file( + self, from_info, to_info, name, no_progress_bar, file_mode, dir_mode + ): makedirs(to_info.parent, exist_ok=True, mode=dir_mode) + + logger.debug("Downloading '{}' to '{}'".format(from_info, to_info)) + name = name or to_info.name + tmp_file = tmp_fname(to_info) try: diff --git a/tests/unit/remote/test_remote_dir.py b/tests/unit/remote/test_remote_dir.py index 583898a60d..6c868a1a6e 100644 --- a/tests/unit/remote/test_remote_dir.py +++ b/tests/unit/remote/test_remote_dir.py @@ -1,8 +1,9 @@ # -*- coding: utf-8 -*- import pytest - +import os from dvc.remote.s3 import RemoteS3 - +from dvc.utils import walk_files +from dvc.path_info import PathInfo from tests.remotes import GCP, S3Mocked remotes = [GCP, S3Mocked] @@ -132,3 +133,24 @@ def test_isfile(remote): for expected, path in test_cases: assert remote.isfile(remote.path_info / path) == expected + + +@pytest.mark.parametrize("remote", remotes, indirect=True) +def test_download_dir(remote, tmpdir): + path = str(tmpdir / "data") + to_info = PathInfo(path) + remote.download(remote.path_info / "data", to_info) + assert os.path.isdir(path) + data_dir = tmpdir / "data" + assert len(list(walk_files(path, None))) == 7 + assert (data_dir / "alice").read_text(encoding="utf-8") == "alice" + assert (data_dir / "alpha").read_text(encoding="utf-8") == "alpha" + assert (data_dir / "subdir-file.txt").read_text( + encoding="utf-8" + ) == "subdir" + assert (data_dir / "subdir" / "1").read_text(encoding="utf-8") == "1" + assert (data_dir / "subdir" / "2").read_text(encoding="utf-8") == "2" + assert (data_dir / "subdir" / "3").read_text(encoding="utf-8") == "3" + assert (data_dir / "subdir" / "empty_file").read_text( + encoding="utf-8" + ) == ""