From d4aa27d6ab1892ae888471a17caef8208e52aa72 Mon Sep 17 00:00:00 2001 From: Aman Sharma Date: Fri, 8 Nov 2019 00:42:13 +0530 Subject: [PATCH] Ensure `copyfile` accepts str and Path-like objects --- dvc/remote/local.py | 5 +---- dvc/utils/__init__.py | 3 +++ tests/unit/utils/test_utils.py | 29 +++++++++++++++++++++++++++++ 3 files changed, 33 insertions(+), 4 deletions(-) diff --git a/dvc/remote/local.py b/dvc/remote/local.py index 286d150993..0cffce3d0e 100644 --- a/dvc/remote/local.py +++ b/dvc/remote/local.py @@ -236,10 +236,7 @@ def _download( self, from_info, to_file, name=None, no_progress_bar=False, **_kwargs ): copyfile( - fspath_py35(from_info), - to_file, - no_progress_bar=no_progress_bar, - name=name, + from_info, to_file, no_progress_bar=no_progress_bar, name=name ) @staticmethod diff --git a/dvc/utils/__init__.py b/dvc/utils/__init__.py index c95e4dff67..848a27c505 100644 --- a/dvc/utils/__init__.py +++ b/dvc/utils/__init__.py @@ -121,6 +121,9 @@ def copyfile(src, dest, no_progress_bar=False, name=None): from dvc.progress import Tqdm from dvc.system import System + src = fspath_py35(src) + dest = fspath_py35(dest) + name = name if name else os.path.basename(dest) total = os.stat(src).st_size diff --git a/tests/unit/utils/test_utils.py b/tests/unit/utils/test_utils.py index 8cf5ee9933..02c256afd9 100644 --- a/tests/unit/utils/test_utils.py +++ b/tests/unit/utils/test_utils.py @@ -1,11 +1,14 @@ +import filecmp import os import pytest from dvc.path_info import PathInfo +from dvc.utils import copyfile from dvc.utils import file_md5 from dvc.utils import fix_env from dvc.utils import to_chunks +from tests.basic_env import TestDir @pytest.mark.parametrize( @@ -79,3 +82,29 @@ def test_file_md5(repo_dir): fname = repo_dir.FOO fname_object = PathInfo(fname) assert file_md5(fname) == file_md5(fname_object) + + +@pytest.mark.parametrize("path", [TestDir.DATA, TestDir.DATA_DIR]) +def test_copyfile(path, repo_dir): + src = repo_dir.FOO + dest = path + src_info = PathInfo(repo_dir.BAR) + dest_info = PathInfo(path) + + copyfile(src, dest) + if os.path.isdir(dest): + assert filecmp.cmp( + src, os.path.join(dest, os.path.basename(src)), shallow=False + ) + else: + assert filecmp.cmp(src, dest, shallow=False) + + copyfile(src_info, dest_info) + if os.path.isdir(dest_info.fspath): + assert filecmp.cmp( + src_info.fspath, + os.path.join(dest_info.fspath, os.path.basename(src_info.fspath)), + shallow=False, + ) + else: + assert filecmp.cmp(src_info.fspath, dest_info.fspath, shallow=False)