diff --git a/tests/func/test_add.py b/tests/func/test_add.py index 1c885ead76..cdd43f9cc3 100644 --- a/tests/func/test_add.py +++ b/tests/func/test_add.py @@ -9,7 +9,7 @@ import pytest from mock import patch -import dvc +import dvc as dvc_module from dvc.cache import Cache from dvc.exceptions import DvcException from dvc.exceptions import RecursiveAddingWhileUsingFilename @@ -28,7 +28,6 @@ from dvc.utils.stage import load_stage_file from tests.basic_env import TestDvc from tests.utils import get_gitignore_content -from tests.utils import spy def test_add(tmp_dir, dvc): @@ -245,66 +244,66 @@ def test_dir(self): self.assertEqual(ret, 0) -class TestShouldUpdateStateEntryForFileAfterAdd(TestDvc): - def test(self): - file_md5_counter = spy(dvc.remote.local.file_md5) - with patch.object(dvc.remote.local, "file_md5", file_md5_counter): - ret = main(["config", "cache.type", "copy"]) - self.assertEqual(ret, 0) +def test_should_update_state_entry_for_file_after_add(mocker, dvc, tmp_dir): + file_md5_counter = mocker.spy(dvc_module.remote.local, "file_md5") + tmp_dir.gen("foo", "foo") - ret = main(["add", self.FOO]) - self.assertEqual(ret, 0) - self.assertEqual(file_md5_counter.mock.call_count, 1) + ret = main(["config", "cache.type", "copy"]) + assert ret == 0 - ret = main(["status"]) - self.assertEqual(ret, 0) - self.assertEqual(file_md5_counter.mock.call_count, 1) + ret = main(["add", "foo"]) + assert ret == 0 + assert file_md5_counter.mock.call_count == 1 - ret = main(["run", "-d", self.FOO, "echo foo"]) - self.assertEqual(ret, 0) - self.assertEqual(file_md5_counter.mock.call_count, 1) + ret = main(["status"]) + assert ret == 0 + assert file_md5_counter.mock.call_count == 1 - os.rename(self.FOO, self.FOO + ".back") - ret = main(["checkout"]) - self.assertEqual(ret, 0) - self.assertEqual(file_md5_counter.mock.call_count, 1) + ret = main(["run", "-d", "foo", "echo foo"]) + assert ret == 0 + assert file_md5_counter.mock.call_count == 1 - ret = main(["status"]) - self.assertEqual(ret, 0) - self.assertEqual(file_md5_counter.mock.call_count, 1) + os.rename("foo", "foo.back") + ret = main(["checkout"]) + assert ret == 0 + assert file_md5_counter.mock.call_count == 1 + ret = main(["status"]) + assert ret == 0 + assert file_md5_counter.mock.call_count == 1 -class TestShouldUpdateStateEntryForDirectoryAfterAdd(TestDvc): - def test(self): - file_md5_counter = spy(dvc.remote.local.file_md5) - with patch.object(dvc.remote.local, "file_md5", file_md5_counter): - ret = main(["config", "cache.type", "copy"]) - self.assertEqual(ret, 0) +def test_should_update_state_entry_for_directory_after_add( + mocker, dvc, tmp_dir +): + file_md5_counter = mocker.spy(dvc_module.remote.local, "file_md5") - ret = main(["add", self.DATA_DIR]) - self.assertEqual(ret, 0) - self.assertEqual(file_md5_counter.mock.call_count, 3) + tmp_dir.gen({"data/data": "foo", "data/data_sub/sub_data": "foo"}) - ret = main(["status"]) - self.assertEqual(ret, 0) - self.assertEqual(file_md5_counter.mock.call_count, 3) + ret = main(["config", "cache.type", "copy"]) + assert ret == 0 - ls = "dir" if os.name == "nt" else "ls" - ret = main( - ["run", "-d", self.DATA_DIR, "{} {}".format(ls, self.DATA_DIR)] - ) - self.assertEqual(ret, 0) - self.assertEqual(file_md5_counter.mock.call_count, 3) + ret = main(["add", "data"]) + assert ret == 0 + assert file_md5_counter.mock.call_count == 3 - os.rename(self.DATA_DIR, self.DATA_DIR + ".back") - ret = main(["checkout"]) - self.assertEqual(ret, 0) - self.assertEqual(file_md5_counter.mock.call_count, 3) + ret = main(["status"]) + assert ret == 0 + assert file_md5_counter.mock.call_count == 3 - ret = main(["status"]) - self.assertEqual(ret, 0) - self.assertEqual(file_md5_counter.mock.call_count, 3) + ls = "dir" if os.name == "nt" else "ls" + ret = main(["run", "-d", "data", "{} {}".format(ls, "data")]) + assert ret == 0 + assert file_md5_counter.mock.call_count == 3 + + os.rename("data", "data" + ".back") + ret = main(["checkout"]) + assert ret == 0 + assert file_md5_counter.mock.call_count == 3 + + ret = main(["status"]) + assert ret == 0 + assert file_md5_counter.mock.call_count == 3 class TestAddCommit(TestDvc): @@ -320,23 +319,18 @@ def test(self): self.assertEqual(len(os.listdir(self.dvc.cache.local.cache_dir)), 1) -class TestShouldCollectDirCacheOnlyOnce(TestDvc): - def test(self): - from dvc.remote.local import RemoteLOCAL +def test_should_collect_dir_cache_only_once(mocker, tmp_dir, dvc): + tmp_dir.gen({"data/data": "foo"}) + get_dir_checksum_counter = mocker.spy(RemoteLOCAL, "get_dir_checksum") + ret = main(["add", "data"]) + assert ret == 0 - get_dir_checksum_counter = spy(RemoteLOCAL.get_dir_checksum) - with patch.object( - RemoteLOCAL, "get_dir_checksum", get_dir_checksum_counter - ): - ret = main(["add", self.DATA_DIR]) - self.assertEqual(0, ret) + ret = main(["status"]) + assert ret == 0 - ret = main(["status"]) - self.assertEqual(0, ret) - - ret = main(["status"]) - self.assertEqual(0, ret) - self.assertEqual(1, get_dir_checksum_counter.mock.call_count) + ret = main(["status"]) + assert ret == 0 + assert get_dir_checksum_counter.mock.call_count == 1 class SymlinkAddTestBase(TestDvc): @@ -477,17 +471,15 @@ def test_should_cleanup_after_failed_add(tmp_dir, scm, dvc, repo_template): assert "/bar" not in gitignore_content -class TestShouldNotTrackGitInternalFiles(TestDvc): - def test(self): - stage_creator_spy = spy(dvc.repo.add._create_stages) +def test_should_not_track_git_internal_files(mocker, dvc, tmp_dir): + stage_creator_spy = mocker.spy(dvc_module.repo.add, "_create_stages") - with patch.object(dvc.repo.add, "_create_stages", stage_creator_spy): - ret = main(["add", "-R", self.dvc.root_dir]) - self.assertEqual(0, ret) + ret = main(["add", "-R", dvc.root_dir]) + assert ret == 0 - created_stages_filenames = stage_creator_spy.mock.call_args[0][0] - for fname in created_stages_filenames: - self.assertNotIn(".git", fname) + created_stages_filenames = stage_creator_spy.mock.call_args[0][1] + for fname in created_stages_filenames: + assert ".git" not in fname class TestAddUnprotected(TestDvc): @@ -572,8 +564,7 @@ def test_readding_dir_should_not_unprotect_all(tmp_dir, dvc, mocker): dvc.add("dir") tmp_dir.gen("dir/new_file", "new_file_content") - unprotect_spy = spy(RemoteLOCAL.unprotect) - mocker.patch.object(RemoteLOCAL, "unprotect", unprotect_spy) + unprotect_spy = mocker.spy(RemoteLOCAL, "unprotect") dvc.add("dir") assert not unprotect_spy.mock.called @@ -587,8 +578,7 @@ def test_should_not_checkout_when_adding_cached_copy(tmp_dir, dvc, mocker): shutil.copy("bar", "foo") - copy_spy = spy(dvc.cache.local.copy) - mocker.patch.object(dvc.cache.local, "copy", copy_spy) + copy_spy = mocker.spy(dvc.cache.local, "copy") dvc.add("foo") diff --git a/tests/func/test_data_cloud.py b/tests/func/test_data_cloud.py index 748a7fe631..56189d8213 100644 --- a/tests/func/test_data_cloud.py +++ b/tests/func/test_data_cloud.py @@ -6,7 +6,6 @@ from unittest import SkipTest import pytest -from mock import patch from dvc.cache import NamedCache from dvc.config import Config @@ -28,7 +27,6 @@ from dvc.utils.stage import dump_stage_file from dvc.utils.stage import load_stage_file from tests.basic_env import TestDvc -from tests.utils import spy from tests.remotes import ( Azure, @@ -585,24 +583,21 @@ def test(self): self._test_recursive_pull() -class TestCheckSumRecalculation(TestDvc): - def test(self): - test_get_file_checksum = spy(RemoteLOCAL.get_file_checksum) - with patch.object( - RemoteLOCAL, "get_file_checksum", test_get_file_checksum - ): - url = Local.get_url() - ret = main(["remote", "add", "-d", TEST_REMOTE, url]) - self.assertEqual(ret, 0) - ret = main(["config", "cache.type", "hardlink"]) - self.assertEqual(ret, 0) - ret = main(["add", self.FOO]) - self.assertEqual(ret, 0) - ret = main(["push"]) - self.assertEqual(ret, 0) - ret = main(["run", "-d", self.FOO, "echo foo"]) - self.assertEqual(ret, 0) - self.assertEqual(test_get_file_checksum.mock.call_count, 1) +def test_checksum_recalculation(mocker, dvc, tmp_dir): + tmp_dir.gen({"foo": "foo"}) + test_get_file_checksum = mocker.spy(RemoteLOCAL, "get_file_checksum") + url = Local.get_url() + ret = main(["remote", "add", "-d", TEST_REMOTE, url]) + assert ret == 0 + ret = main(["config", "cache.type", "hardlink"]) + assert ret == 0 + ret = main(["add", "foo"]) + assert ret == 0 + ret = main(["push"]) + assert ret == 0 + ret = main(["run", "-d", "foo", "echo foo"]) + assert ret == 0 + assert test_get_file_checksum.mock.call_count == 1 class TestShouldWarnOnNoChecksumInLocalAndRemoteCache(TestDvc): diff --git a/tests/func/test_import_url.py b/tests/func/test_import_url.py index ae59e3aa57..bf042e32b8 100644 --- a/tests/func/test_import_url.py +++ b/tests/func/test_import_url.py @@ -2,13 +2,12 @@ from uuid import uuid4 import pytest -from mock import patch -import dvc +from dvc.stage import Stage from dvc.main import main from dvc.utils.fs import makedirs +from dvc.compat import fspath from tests.basic_env import TestDvc -from tests.utils import spy class TestCmdImport(TestDvc): @@ -41,23 +40,14 @@ def test(self): self.assertEqual(fd.read(), "content") -class TestShouldRemoveOutsBeforeImport(TestDvc): - def setUp(self): - super().setUp() - tmp_dir = self.mkdtemp() - self.external_source = os.path.join(tmp_dir, "file") - with open(self.external_source, "w") as fobj: - fobj.write("content") +def test_should_remove_outs_before_import(mocker, erepo_dir): + erepo_dir.gen({"foo": "foo"}) - def test(self): - remove_outs_call_counter = spy(dvc.stage.Stage.remove_outs) - with patch.object( - dvc.stage.Stage, "remove_outs", remove_outs_call_counter - ): - ret = main(["import-url", self.external_source]) - self.assertEqual(0, ret) - - self.assertEqual(1, remove_outs_call_counter.mock.call_count) + remove_outs_call_counter = mocker.spy(Stage, "remove_outs") + ret = main(["import-url", fspath(erepo_dir / "foo")]) + + assert ret == 0 + assert remove_outs_call_counter.mock.call_count == 1 class TestImportFilename(TestDvc): diff --git a/tests/unit/utils/test_fs.py b/tests/unit/utils/test_fs.py index 0c0e19a0be..07c62e0a2b 100644 --- a/tests/unit/utils/test_fs.py +++ b/tests/unit/utils/test_fs.py @@ -22,7 +22,6 @@ from dvc.utils.fs import makedirs from dvc.utils.fs import walk_files from tests.basic_env import TestDir -from tests.utils import spy class TestMtimeAndSize(TestDir): @@ -70,21 +69,6 @@ def test_should_return_false_on_no_more_dirs_below_path( ) dirname_patch.assert_called_once() - @patch.object(System, "is_symlink", return_value=False) - def test_should_call_recursive_on_no_condition_matched(self, _): - contains_symlink_spy = spy(contains_symlink_up_to) - with patch.object( - dvc.utils.fs, "contains_symlink_up_to", contains_symlink_spy - ): - - # call from full path to match contains_symlink_spy patch path - self.assertFalse( - dvc.utils.fs.contains_symlink_up_to( - os.path.join("foo", "path"), "foo" - ) - ) - self.assertEqual(2, contains_symlink_spy.mock.call_count) - @patch.object(System, "is_symlink", return_value=True) def test_should_return_false_when_base_path_is_symlink(self, _): base_path = "foo" @@ -109,6 +93,18 @@ def test_path_object_and_str_are_valid_arg_types(self): ) +def test_should_call_recursive_on_no_condition_matched(mocker): + mocker.patch.object(System, "is_symlink", return_value=False) + + contains_symlink_spy = mocker.spy(dvc.utils.fs, "contains_symlink_up_to") + + # call from full path to match contains_symlink_spy patch path + assert not dvc.utils.fs.contains_symlink_up_to( + os.path.join("foo", "path"), "foo" + ) + assert contains_symlink_spy.mock.call_count == 2 + + @pytest.mark.skipif(os.name != "nt", reason="Windows specific") def test_relpath_windows_different_drives(): path1 = os.path.join("A:", os.sep, "some", "path") diff --git a/tests/utils/__init__.py b/tests/utils/__init__.py index a40a91f2cf..9e5448abb3 100644 --- a/tests/utils/__init__.py +++ b/tests/utils/__init__.py @@ -2,22 +2,9 @@ from contextlib import contextmanager from filecmp import dircmp -from mock import MagicMock - from dvc.scm import Git -def spy(method_to_decorate): - mock = MagicMock() - - def wrapper(self, *args, **kwargs): - mock(*args, **kwargs) - return method_to_decorate(self, *args, **kwargs) - - wrapper.mock = mock - return wrapper - - def get_gitignore_content(): with open(Git.GITIGNORE, "r") as gitignore: return gitignore.read().splitlines()