From 15198ba3110fc933707069b96cea9d0ea24296cc Mon Sep 17 00:00:00 2001 From: Anirudh Dagar Date: Fri, 25 Jun 2021 06:37:35 +0530 Subject: [PATCH 1/4] Port test_datasets_utils to pytest --- test/test_datasets_utils.py | 209 ++++++++++++------------------------ 1 file changed, 70 insertions(+), 139 deletions(-) diff --git a/test/test_datasets_utils.py b/test/test_datasets_utils.py index d89eced1241..fd7b25c7431 100644 --- a/test/test_datasets_utils.py +++ b/test/test_datasets_utils.py @@ -1,8 +1,7 @@ import bz2 import os import torchvision.datasets.utils as utils -import unittest -import unittest.mock +import pytest import zipfile import tarfile import gzip @@ -13,30 +12,31 @@ import lzma from common_utils import get_tmp_dir, call_args_to_kwargs_only +from torchvision.datasets.utils import _COMPRESSED_FILE_OPENERS TEST_FILE = get_file_path_2( os.path.dirname(os.path.abspath(__file__)), 'assets', 'encode_jpeg', 'grace_hopper_517x606.jpg') -class Tester(unittest.TestCase): +class TestDatasetsUtils: def test_check_md5(self): fpath = TEST_FILE correct_md5 = '9c0bb82894bb3af7f7675ef2b3b6dcdc' false_md5 = '' - self.assertTrue(utils.check_md5(fpath, correct_md5)) - self.assertFalse(utils.check_md5(fpath, false_md5)) + assert utils.check_md5(fpath, correct_md5) + assert not utils.check_md5(fpath, false_md5) def test_check_integrity(self): existing_fpath = TEST_FILE nonexisting_fpath = '' correct_md5 = '9c0bb82894bb3af7f7675ef2b3b6dcdc' false_md5 = '' - self.assertTrue(utils.check_integrity(existing_fpath, correct_md5)) - self.assertFalse(utils.check_integrity(existing_fpath, false_md5)) - self.assertTrue(utils.check_integrity(existing_fpath)) - self.assertFalse(utils.check_integrity(nonexisting_fpath)) + assert utils.check_integrity(existing_fpath, correct_md5) + assert not utils.check_integrity(existing_fpath, false_md5) + assert utils.check_integrity(existing_fpath) + assert not utils.check_integrity(nonexisting_fpath) def test_get_google_drive_file_id(self): url = "https://drive.google.com/file/d/1hbzc_P1FuxMkcabkgn9ZKinBwW683j45/view" @@ -50,44 +50,38 @@ def test_get_google_drive_file_id_invalid_url(self): assert utils._get_google_drive_file_id(url) is None - def test_detect_file_type(self): - for file, expected in [ - ("foo.tar.bz2", (".tar.bz2", ".tar", ".bz2")), - ("foo.tar.xz", (".tar.xz", ".tar", ".xz")), - ("foo.tar", (".tar", ".tar", None)), - ("foo.tar.gz", (".tar.gz", ".tar", ".gz")), - ("foo.tbz", (".tbz", ".tar", ".bz2")), - ("foo.tbz2", (".tbz2", ".tar", ".bz2")), - ("foo.tgz", (".tgz", ".tar", ".gz")), - ("foo.bz2", (".bz2", None, ".bz2")), - ("foo.gz", (".gz", None, ".gz")), - ("foo.zip", (".zip", ".zip", None)), - ("foo.xz", (".xz", None, ".xz")), - ("foo.bar.tar.gz", (".tar.gz", ".tar", ".gz")), - ("foo.bar.gz", (".gz", None, ".gz")), - ("foo.bar.zip", (".zip", ".zip", None)), - ]: - with self.subTest(file=file): - self.assertSequenceEqual(utils._detect_file_type(file), expected) - - def test_detect_file_type_no_ext(self): - with self.assertRaises(RuntimeError): - utils._detect_file_type("foo") - - def test_detect_file_type_unknown_compression(self): - with self.assertRaises(RuntimeError): - utils._detect_file_type("foo.tar.baz") - - def test_detect_file_type_unknown_partial_ext(self): - with self.assertRaises(RuntimeError): - utils._detect_file_type("foo.bar") - - def test_decompress_bz2(self): + @pytest.mark.parametrize('file, expected', [ + ("foo.tar.bz2", (".tar.bz2", ".tar", ".bz2")), + ("foo.tar.xz", (".tar.xz", ".tar", ".xz")), + ("foo.tar", (".tar", ".tar", None)), + ("foo.tar.gz", (".tar.gz", ".tar", ".gz")), + ("foo.tbz", (".tbz", ".tar", ".bz2")), + ("foo.tbz2", (".tbz2", ".tar", ".bz2")), + ("foo.tgz", (".tgz", ".tar", ".gz")), + ("foo.bz2", (".bz2", None, ".bz2")), + ("foo.gz", (".gz", None, ".gz")), + ("foo.zip", (".zip", ".zip", None)), + ("foo.xz", (".xz", None, ".xz")), + ("foo.bar.tar.gz", (".tar.gz", ".tar", ".gz")), + ("foo.bar.gz", (".gz", None, ".gz")), + ("foo.bar.zip", (".zip", ".zip", None))]) + def test_detect_file_type(self, file, expected): + assert utils._detect_file_type(file) == expected + + @pytest.mark.parametrize('file', ["foo", "foo.tar.baz", "foo.bar"]) + def test_detect_file_type_incompatible(self, file): + # tests detect file type for no extension, unknown compression and unknown partial extension + with pytest.raises(RuntimeError): + utils._detect_file_type(file) + + @pytest.mark.parametrize('extension', [".bz2", ".gz", ".xz"]) + def test_decompress(self, extension): def create_compressed(root, content="this is the content"): file = os.path.join(root, "file") - compressed = f"{file}.bz2" + compressed = f"{file}{extension}" + compressed_file_opener = _COMPRESSED_FILE_OPENERS[extension] - with bz2.open(compressed, "wb") as fh: + with compressed_file_opener(compressed, "wb") as fh: fh.write(content.encode()) return compressed, file, content @@ -97,53 +91,13 @@ def create_compressed(root, content="this is the content"): utils._decompress(compressed) - self.assertTrue(os.path.exists(file)) + assert os.path.exists(file) with open(file, "r") as fh: - self.assertEqual(fh.read(), content) - - def test_decompress_gzip(self): - def create_compressed(root, content="this is the content"): - file = os.path.join(root, "file") - compressed = f"{file}.gz" - - with gzip.open(compressed, "wb") as fh: - fh.write(content.encode()) - - return compressed, file, content - - with get_tmp_dir() as temp_dir: - compressed, file, content = create_compressed(temp_dir) - - utils._decompress(compressed) - - self.assertTrue(os.path.exists(file)) - - with open(file, "r") as fh: - self.assertEqual(fh.read(), content) - - def test_decompress_lzma(self): - def create_compressed(root, content="this is the content"): - file = os.path.join(root, "file") - compressed = f"{file}.xz" - - with lzma.open(compressed, "wb") as fh: - fh.write(content.encode()) - - return compressed, file, content - - with get_tmp_dir() as temp_dir: - compressed, file, content = create_compressed(temp_dir) - - utils.extract_archive(compressed, temp_dir) - - self.assertTrue(os.path.exists(file)) - - with open(file, "r") as fh: - self.assertEqual(fh.read(), content) + assert fh.read() == content def test_decompress_no_compression(self): - with self.assertRaises(RuntimeError): + with pytest.raises(RuntimeError): utils._decompress("foo.tar") def test_decompress_remove_finished(self): @@ -161,21 +115,22 @@ def create_compressed(root, content="this is the content"): utils.extract_archive(compressed, temp_dir, remove_finished=True) - self.assertFalse(os.path.exists(compressed)) + assert not os.path.exists(compressed) - def test_extract_archive_defer_to_decompress(self): + @pytest.mark.parametrize('ext', [".gz", ".xz"]) + @pytest.mark.parametrize('remove_finished', [True, False]) + def test_extract_archive_defer_to_decompress(self, ext, remove_finished, mocker): filename = "foo" - for ext, remove_finished in itertools.product((".gz", ".xz"), (True, False)): - with self.subTest(ext=ext, remove_finished=remove_finished): - with unittest.mock.patch("torchvision.datasets.utils._decompress") as mock: - file = f"{filename}{ext}" - utils.extract_archive(file, remove_finished=remove_finished) - - mock.assert_called_once() - self.assertEqual( - call_args_to_kwargs_only(mock.call_args, utils._decompress), - dict(from_path=file, to_path=filename, remove_finished=remove_finished), - ) + original_decompress = utils._decompress + mocked = mocker.patch("torchvision.datasets.utils._decompress") + file = f"{filename}{ext}" + utils.extract_archive(file, remove_finished=remove_finished) + + mocked.assert_called_once() + print(mocked.call_args) + + assert (call_args_to_kwargs_only(mocked.call_args, original_decompress) == + dict(from_path=file, to_path=filename, remove_finished=remove_finished)) def test_extract_zip(self): def create_archive(root, content="this is the content"): @@ -192,37 +147,14 @@ def create_archive(root, content="this is the content"): utils.extract_archive(archive, temp_dir) - self.assertTrue(os.path.exists(file)) + assert os.path.exists(file) with open(file, "r") as fh: - self.assertEqual(fh.read(), content) - - def test_extract_tar(self): - def create_archive(root, ext, mode, content="this is the content"): - src = os.path.join(root, "src.txt") - dst = os.path.join(root, "dst.txt") - archive = os.path.join(root, f"archive{ext}") - - with open(src, "w") as fh: - fh.write(content) - - with tarfile.open(archive, mode=mode) as fh: - fh.add(src, arcname=os.path.basename(dst)) - - return archive, dst, content - - for ext, mode in zip(['.tar', '.tar.gz', '.tgz'], ['w', 'w:gz', 'w:gz']): - with get_tmp_dir() as temp_dir: - archive, file, content = create_archive(temp_dir, ext, mode) + assert fh.read() == content - utils.extract_archive(archive, temp_dir) - - self.assertTrue(os.path.exists(file)) - - with open(file, "r") as fh: - self.assertEqual(fh.read(), content) - - def test_extract_tar_xz(self): + @pytest.mark.parametrize('ext, mode', [ + ('.tar', 'w'), ('.tar.gz', 'w:gz'), ('.tgz', 'w:gz'), ('.tar.xz', 'w:xz')]) + def test_extract_tar(self, ext, mode): def create_archive(root, ext, mode, content="this is the content"): src = os.path.join(root, "src.txt") dst = os.path.join(root, "dst.txt") @@ -236,22 +168,21 @@ def create_archive(root, ext, mode, content="this is the content"): return archive, dst, content - for ext, mode in zip(['.tar.xz'], ['w:xz']): - with get_tmp_dir() as temp_dir: - archive, file, content = create_archive(temp_dir, ext, mode) + with get_tmp_dir() as temp_dir: + archive, file, content = create_archive(temp_dir, ext, mode) - utils.extract_archive(archive, temp_dir) + utils.extract_archive(archive, temp_dir) - self.assertTrue(os.path.exists(file)) + assert os.path.exists(file) - with open(file, "r") as fh: - self.assertEqual(fh.read(), content) + with open(file, "r") as fh: + assert fh.read() == content def test_verify_str_arg(self): - self.assertEqual("a", utils.verify_str_arg("a", "arg", ("a",))) - self.assertRaises(ValueError, utils.verify_str_arg, 0, ("a",), "arg") - self.assertRaises(ValueError, utils.verify_str_arg, "b", ("a",), "arg") + assert "a" == utils.verify_str_arg("a", "arg", ("a",)) + pytest.raises(ValueError, utils.verify_str_arg, 0, ("a",), "arg") + pytest.raises(ValueError, utils.verify_str_arg, "b", ("a",), "arg") if __name__ == '__main__': - unittest.main() + pytest.main([__file__]) From b0a8c4f9143d78d22df4dfbb808e58b0d0600b66 Mon Sep 17 00:00:00 2001 From: Anirudh Dagar Date: Fri, 25 Jun 2021 06:54:42 +0530 Subject: [PATCH 2/4] refactor ext -> extension --- test/test_datasets_utils.py | 16 ++++++++-------- 1 file changed, 8 insertions(+), 8 deletions(-) diff --git a/test/test_datasets_utils.py b/test/test_datasets_utils.py index fd7b25c7431..df7bf994439 100644 --- a/test/test_datasets_utils.py +++ b/test/test_datasets_utils.py @@ -117,13 +117,13 @@ def create_compressed(root, content="this is the content"): assert not os.path.exists(compressed) - @pytest.mark.parametrize('ext', [".gz", ".xz"]) + @pytest.mark.parametrize('extension', [".gz", ".xz"]) @pytest.mark.parametrize('remove_finished', [True, False]) - def test_extract_archive_defer_to_decompress(self, ext, remove_finished, mocker): + def test_extract_archive_defer_to_decompress(self, extension, remove_finished, mocker): filename = "foo" original_decompress = utils._decompress mocked = mocker.patch("torchvision.datasets.utils._decompress") - file = f"{filename}{ext}" + file = f"{filename}{extension}" utils.extract_archive(file, remove_finished=remove_finished) mocked.assert_called_once() @@ -152,13 +152,13 @@ def create_archive(root, content="this is the content"): with open(file, "r") as fh: assert fh.read() == content - @pytest.mark.parametrize('ext, mode', [ + @pytest.mark.parametrize('extension, mode', [ ('.tar', 'w'), ('.tar.gz', 'w:gz'), ('.tgz', 'w:gz'), ('.tar.xz', 'w:xz')]) - def test_extract_tar(self, ext, mode): - def create_archive(root, ext, mode, content="this is the content"): + def test_extract_tar(self, extension, mode): + def create_archive(root, extension, mode, content="this is the content"): src = os.path.join(root, "src.txt") dst = os.path.join(root, "dst.txt") - archive = os.path.join(root, f"archive{ext}") + archive = os.path.join(root, f"archive{extension}") with open(src, "w") as fh: fh.write(content) @@ -169,7 +169,7 @@ def create_archive(root, ext, mode, content="this is the content"): return archive, dst, content with get_tmp_dir() as temp_dir: - archive, file, content = create_archive(temp_dir, ext, mode) + archive, file, content = create_archive(temp_dir, extension, mode) utils.extract_archive(archive, temp_dir) From 59884bccecd668b6b5b64f67d39debc58725074e Mon Sep 17 00:00:00 2001 From: Anirudh Dagar Date: Fri, 25 Jun 2021 07:01:02 +0530 Subject: [PATCH 3/4] remove redundant print --- test/test_datasets_utils.py | 1 - 1 file changed, 1 deletion(-) diff --git a/test/test_datasets_utils.py b/test/test_datasets_utils.py index df7bf994439..954e3f05a8c 100644 --- a/test/test_datasets_utils.py +++ b/test/test_datasets_utils.py @@ -127,7 +127,6 @@ def test_extract_archive_defer_to_decompress(self, extension, remove_finished, m utils.extract_archive(file, remove_finished=remove_finished) mocked.assert_called_once() - print(mocked.call_args) assert (call_args_to_kwargs_only(mocked.call_args, original_decompress) == dict(from_path=file, to_path=filename, remove_finished=remove_finished)) From 2a00109fd0f6d6ceaf664d4aff93020b536bca72 Mon Sep 17 00:00:00 2001 From: Anirudh Dagar Date: Fri, 25 Jun 2021 15:33:10 +0530 Subject: [PATCH 4/4] remove call_args_to_kwargs_only --- test/common_utils.py | 17 ----------------- test/test_datasets_utils.py | 11 ++++------- 2 files changed, 4 insertions(+), 24 deletions(-) diff --git a/test/common_utils.py b/test/common_utils.py index 06e0e16b1ef..3f8ad8a7f55 100644 --- a/test/common_utils.py +++ b/test/common_utils.py @@ -240,23 +240,6 @@ def disable_console_output(): yield -def call_args_to_kwargs_only(call_args, *callable_or_arg_names): - callable_or_arg_name = callable_or_arg_names[0] - if callable(callable_or_arg_name): - argspec = inspect.getfullargspec(callable_or_arg_name) - arg_names = argspec.args - if isinstance(callable_or_arg_name, type): - # remove self - arg_names.pop(0) - else: - arg_names = callable_or_arg_names - - args, kwargs = call_args - kwargs_only = kwargs.copy() - kwargs_only.update(dict(zip(arg_names, args))) - return kwargs_only - - def cpu_and_gpu(): import pytest # noqa return ('cpu', pytest.param('cuda', marks=pytest.mark.needs_cuda)) diff --git a/test/test_datasets_utils.py b/test/test_datasets_utils.py index 954e3f05a8c..34ca3da6847 100644 --- a/test/test_datasets_utils.py +++ b/test/test_datasets_utils.py @@ -11,7 +11,7 @@ import itertools import lzma -from common_utils import get_tmp_dir, call_args_to_kwargs_only +from common_utils import get_tmp_dir from torchvision.datasets.utils import _COMPRESSED_FILE_OPENERS @@ -121,15 +121,12 @@ def create_compressed(root, content="this is the content"): @pytest.mark.parametrize('remove_finished', [True, False]) def test_extract_archive_defer_to_decompress(self, extension, remove_finished, mocker): filename = "foo" - original_decompress = utils._decompress - mocked = mocker.patch("torchvision.datasets.utils._decompress") file = f"{filename}{extension}" - utils.extract_archive(file, remove_finished=remove_finished) - mocked.assert_called_once() + mocked = mocker.patch("torchvision.datasets.utils._decompress") + utils.extract_archive(file, remove_finished=remove_finished) - assert (call_args_to_kwargs_only(mocked.call_args, original_decompress) == - dict(from_path=file, to_path=filename, remove_finished=remove_finished)) + mocked.assert_called_once_with(file, filename, remove_finished=remove_finished) def test_extract_zip(self): def create_archive(root, content="this is the content"):