Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 0 additions & 17 deletions test/common_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -240,23 +240,6 @@ def disable_console_output():
yield


def call_args_to_kwargs_only(call_args, *callable_or_arg_names):
callable_or_arg_name = callable_or_arg_names[0]
if callable(callable_or_arg_name):
argspec = inspect.getfullargspec(callable_or_arg_name)
arg_names = argspec.args
if isinstance(callable_or_arg_name, type):
# remove self
arg_names.pop(0)
else:
arg_names = callable_or_arg_names

args, kwargs = call_args
kwargs_only = kwargs.copy()
kwargs_only.update(dict(zip(arg_names, args)))
return kwargs_only


def cpu_and_gpu():
import pytest # noqa
return ('cpu', pytest.param('cuda', marks=pytest.mark.needs_cuda))
Expand Down
211 changes: 69 additions & 142 deletions test/test_datasets_utils.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,7 @@
import bz2
import os
import torchvision.datasets.utils as utils
import unittest
import unittest.mock
import pytest
import zipfile
import tarfile
import gzip
Expand All @@ -12,31 +11,32 @@
import itertools
import lzma

from common_utils import get_tmp_dir, call_args_to_kwargs_only
from common_utils import get_tmp_dir
from torchvision.datasets.utils import _COMPRESSED_FILE_OPENERS


TEST_FILE = get_file_path_2(
os.path.dirname(os.path.abspath(__file__)), 'assets', 'encode_jpeg', 'grace_hopper_517x606.jpg')


class Tester(unittest.TestCase):
class TestDatasetsUtils:

def test_check_md5(self):
fpath = TEST_FILE
correct_md5 = '9c0bb82894bb3af7f7675ef2b3b6dcdc'
false_md5 = ''
self.assertTrue(utils.check_md5(fpath, correct_md5))
self.assertFalse(utils.check_md5(fpath, false_md5))
assert utils.check_md5(fpath, correct_md5)
assert not utils.check_md5(fpath, false_md5)

def test_check_integrity(self):
existing_fpath = TEST_FILE
nonexisting_fpath = ''
correct_md5 = '9c0bb82894bb3af7f7675ef2b3b6dcdc'
false_md5 = ''
self.assertTrue(utils.check_integrity(existing_fpath, correct_md5))
self.assertFalse(utils.check_integrity(existing_fpath, false_md5))
self.assertTrue(utils.check_integrity(existing_fpath))
self.assertFalse(utils.check_integrity(nonexisting_fpath))
assert utils.check_integrity(existing_fpath, correct_md5)
assert not utils.check_integrity(existing_fpath, false_md5)
assert utils.check_integrity(existing_fpath)
assert not utils.check_integrity(nonexisting_fpath)

def test_get_google_drive_file_id(self):
url = "https://drive.google.com/file/d/1hbzc_P1FuxMkcabkgn9ZKinBwW683j45/view"
Expand All @@ -50,44 +50,38 @@ def test_get_google_drive_file_id_invalid_url(self):

assert utils._get_google_drive_file_id(url) is None

def test_detect_file_type(self):
for file, expected in [
("foo.tar.bz2", (".tar.bz2", ".tar", ".bz2")),
("foo.tar.xz", (".tar.xz", ".tar", ".xz")),
("foo.tar", (".tar", ".tar", None)),
("foo.tar.gz", (".tar.gz", ".tar", ".gz")),
("foo.tbz", (".tbz", ".tar", ".bz2")),
("foo.tbz2", (".tbz2", ".tar", ".bz2")),
("foo.tgz", (".tgz", ".tar", ".gz")),
("foo.bz2", (".bz2", None, ".bz2")),
("foo.gz", (".gz", None, ".gz")),
("foo.zip", (".zip", ".zip", None)),
("foo.xz", (".xz", None, ".xz")),
("foo.bar.tar.gz", (".tar.gz", ".tar", ".gz")),
("foo.bar.gz", (".gz", None, ".gz")),
("foo.bar.zip", (".zip", ".zip", None)),
]:
with self.subTest(file=file):
self.assertSequenceEqual(utils._detect_file_type(file), expected)

def test_detect_file_type_no_ext(self):
with self.assertRaises(RuntimeError):
utils._detect_file_type("foo")

def test_detect_file_type_unknown_compression(self):
with self.assertRaises(RuntimeError):
utils._detect_file_type("foo.tar.baz")

def test_detect_file_type_unknown_partial_ext(self):
with self.assertRaises(RuntimeError):
utils._detect_file_type("foo.bar")

def test_decompress_bz2(self):
@pytest.mark.parametrize('file, expected', [
("foo.tar.bz2", (".tar.bz2", ".tar", ".bz2")),
("foo.tar.xz", (".tar.xz", ".tar", ".xz")),
("foo.tar", (".tar", ".tar", None)),
("foo.tar.gz", (".tar.gz", ".tar", ".gz")),
("foo.tbz", (".tbz", ".tar", ".bz2")),
("foo.tbz2", (".tbz2", ".tar", ".bz2")),
("foo.tgz", (".tgz", ".tar", ".gz")),
("foo.bz2", (".bz2", None, ".bz2")),
("foo.gz", (".gz", None, ".gz")),
("foo.zip", (".zip", ".zip", None)),
("foo.xz", (".xz", None, ".xz")),
("foo.bar.tar.gz", (".tar.gz", ".tar", ".gz")),
("foo.bar.gz", (".gz", None, ".gz")),
("foo.bar.zip", (".zip", ".zip", None))])
def test_detect_file_type(self, file, expected):
assert utils._detect_file_type(file) == expected

@pytest.mark.parametrize('file', ["foo", "foo.tar.baz", "foo.bar"])
def test_detect_file_type_incompatible(self, file):
# tests detect file type for no extension, unknown compression and unknown partial extension
with pytest.raises(RuntimeError):
utils._detect_file_type(file)

@pytest.mark.parametrize('extension', [".bz2", ".gz", ".xz"])
def test_decompress(self, extension):
def create_compressed(root, content="this is the content"):
file = os.path.join(root, "file")
compressed = f"{file}.bz2"
compressed = f"{file}{extension}"
compressed_file_opener = _COMPRESSED_FILE_OPENERS[extension]

with bz2.open(compressed, "wb") as fh:
with compressed_file_opener(compressed, "wb") as fh:
fh.write(content.encode())

return compressed, file, content
Expand All @@ -97,53 +91,13 @@ def create_compressed(root, content="this is the content"):

utils._decompress(compressed)

self.assertTrue(os.path.exists(file))
assert os.path.exists(file)

with open(file, "r") as fh:
self.assertEqual(fh.read(), content)

def test_decompress_gzip(self):
def create_compressed(root, content="this is the content"):
file = os.path.join(root, "file")
compressed = f"{file}.gz"

with gzip.open(compressed, "wb") as fh:
fh.write(content.encode())

return compressed, file, content

with get_tmp_dir() as temp_dir:
compressed, file, content = create_compressed(temp_dir)

utils._decompress(compressed)

self.assertTrue(os.path.exists(file))

with open(file, "r") as fh:
self.assertEqual(fh.read(), content)

def test_decompress_lzma(self):
def create_compressed(root, content="this is the content"):
file = os.path.join(root, "file")
compressed = f"{file}.xz"

with lzma.open(compressed, "wb") as fh:
fh.write(content.encode())

return compressed, file, content

with get_tmp_dir() as temp_dir:
compressed, file, content = create_compressed(temp_dir)

utils.extract_archive(compressed, temp_dir)

self.assertTrue(os.path.exists(file))

with open(file, "r") as fh:
self.assertEqual(fh.read(), content)
assert fh.read() == content

def test_decompress_no_compression(self):
with self.assertRaises(RuntimeError):
with pytest.raises(RuntimeError):
utils._decompress("foo.tar")

def test_decompress_remove_finished(self):
Expand All @@ -161,21 +115,18 @@ def create_compressed(root, content="this is the content"):

utils.extract_archive(compressed, temp_dir, remove_finished=True)

self.assertFalse(os.path.exists(compressed))
assert not os.path.exists(compressed)

def test_extract_archive_defer_to_decompress(self):
@pytest.mark.parametrize('extension', [".gz", ".xz"])
@pytest.mark.parametrize('remove_finished', [True, False])
def test_extract_archive_defer_to_decompress(self, extension, remove_finished, mocker):
filename = "foo"
for ext, remove_finished in itertools.product((".gz", ".xz"), (True, False)):
with self.subTest(ext=ext, remove_finished=remove_finished):
with unittest.mock.patch("torchvision.datasets.utils._decompress") as mock:
file = f"{filename}{ext}"
utils.extract_archive(file, remove_finished=remove_finished)

mock.assert_called_once()
self.assertEqual(
call_args_to_kwargs_only(mock.call_args, utils._decompress),
dict(from_path=file, to_path=filename, remove_finished=remove_finished),
)
file = f"{filename}{extension}"

mocked = mocker.patch("torchvision.datasets.utils._decompress")
utils.extract_archive(file, remove_finished=remove_finished)

mocked.assert_called_once_with(file, filename, remove_finished=remove_finished)

def test_extract_zip(self):
def create_archive(root, content="this is the content"):
Expand All @@ -192,16 +143,18 @@ def create_archive(root, content="this is the content"):

utils.extract_archive(archive, temp_dir)

self.assertTrue(os.path.exists(file))
assert os.path.exists(file)

with open(file, "r") as fh:
self.assertEqual(fh.read(), content)
assert fh.read() == content

def test_extract_tar(self):
def create_archive(root, ext, mode, content="this is the content"):
@pytest.mark.parametrize('extension, mode', [
('.tar', 'w'), ('.tar.gz', 'w:gz'), ('.tgz', 'w:gz'), ('.tar.xz', 'w:xz')])
def test_extract_tar(self, extension, mode):
def create_archive(root, extension, mode, content="this is the content"):
src = os.path.join(root, "src.txt")
dst = os.path.join(root, "dst.txt")
archive = os.path.join(root, f"archive{ext}")
archive = os.path.join(root, f"archive{extension}")

with open(src, "w") as fh:
fh.write(content)
Expand All @@ -211,47 +164,21 @@ def create_archive(root, ext, mode, content="this is the content"):

return archive, dst, content

for ext, mode in zip(['.tar', '.tar.gz', '.tgz'], ['w', 'w:gz', 'w:gz']):
with get_tmp_dir() as temp_dir:
archive, file, content = create_archive(temp_dir, ext, mode)

utils.extract_archive(archive, temp_dir)

self.assertTrue(os.path.exists(file))

with open(file, "r") as fh:
self.assertEqual(fh.read(), content)

def test_extract_tar_xz(self):
def create_archive(root, ext, mode, content="this is the content"):
src = os.path.join(root, "src.txt")
dst = os.path.join(root, "dst.txt")
archive = os.path.join(root, f"archive{ext}")

with open(src, "w") as fh:
fh.write(content)

with tarfile.open(archive, mode=mode) as fh:
fh.add(src, arcname=os.path.basename(dst))

return archive, dst, content

for ext, mode in zip(['.tar.xz'], ['w:xz']):
with get_tmp_dir() as temp_dir:
archive, file, content = create_archive(temp_dir, ext, mode)
with get_tmp_dir() as temp_dir:
archive, file, content = create_archive(temp_dir, extension, mode)

utils.extract_archive(archive, temp_dir)
utils.extract_archive(archive, temp_dir)

self.assertTrue(os.path.exists(file))
assert os.path.exists(file)

with open(file, "r") as fh:
self.assertEqual(fh.read(), content)
with open(file, "r") as fh:
assert fh.read() == content

def test_verify_str_arg(self):
self.assertEqual("a", utils.verify_str_arg("a", "arg", ("a",)))
self.assertRaises(ValueError, utils.verify_str_arg, 0, ("a",), "arg")
self.assertRaises(ValueError, utils.verify_str_arg, "b", ("a",), "arg")
assert "a" == utils.verify_str_arg("a", "arg", ("a",))
pytest.raises(ValueError, utils.verify_str_arg, 0, ("a",), "arg")
pytest.raises(ValueError, utils.verify_str_arg, "b", ("a",), "arg")


if __name__ == '__main__':
unittest.main()
pytest.main([__file__])