diff --git a/test/test_image.py b/test/test_image.py index eae4a1473c5..61bf3a4070c 100644 --- a/test/test_image.py +++ b/test/test_image.py @@ -2,7 +2,6 @@ import io import os import sys -import unittest from pathlib import Path import pytest @@ -54,176 +53,209 @@ def normalize_dimensions(img_pil): return img_pil -class ImageTester(unittest.TestCase): - def test_decode_jpeg(self): - conversion = [(None, ImageReadMode.UNCHANGED), ("L", ImageReadMode.GRAY), ("RGB", ImageReadMode.RGB)] - for img_path in get_images(IMAGE_ROOT, ".jpg"): - for pil_mode, mode in conversion: - with Image.open(img_path) as img: - is_cmyk = img.mode == "CMYK" - if pil_mode is not None: - if is_cmyk: - # libjpeg does not support the conversion - continue - img = img.convert(pil_mode) - img_pil = torch.from_numpy(np.array(img)) - if is_cmyk: - # flip the colors to match libjpeg - img_pil = 255 - img_pil - - img_pil = normalize_dimensions(img_pil) - data = read_file(img_path) - img_ljpeg = decode_image(data, mode=mode) - - # Permit a small variation on pixel values to account for implementation - # differences between Pillow and LibJPEG. - abs_mean_diff = (img_ljpeg.type(torch.float32) - img_pil).abs().mean().item() - self.assertTrue(abs_mean_diff < 2) - - with self.assertRaisesRegex(RuntimeError, "Expected a non empty 1-dimensional tensor"): - decode_jpeg(torch.empty((100, 1), dtype=torch.uint8)) - - with self.assertRaisesRegex(RuntimeError, "Expected a torch.uint8 tensor"): - decode_jpeg(torch.empty((100,), dtype=torch.float16)) - - with self.assertRaises(RuntimeError): - decode_jpeg(torch.empty((100), dtype=torch.uint8)) - - def test_damaged_images(self): - # Test image with bad Huffman encoding (should not raise) - bad_huff = read_file(os.path.join(DAMAGED_JPEG, 'bad_huffman.jpg')) - try: - _ = decode_jpeg(bad_huff) - except RuntimeError: - self.assertTrue(False) - - # Truncated images should raise an exception - truncated_images = glob.glob( - os.path.join(DAMAGED_JPEG, 'corrupt*.jpg')) - for image_path in truncated_images: - data = read_file(image_path) - with self.assertRaises(RuntimeError): - decode_jpeg(data) - - def test_decode_png(self): - conversion = [(None, ImageReadMode.UNCHANGED), ("L", ImageReadMode.GRAY), ("LA", ImageReadMode.GRAY_ALPHA), - ("RGB", ImageReadMode.RGB), ("RGBA", ImageReadMode.RGB_ALPHA)] - for img_path in get_images(FAKEDATA_DIR, ".png"): - for pil_mode, mode in conversion: - with Image.open(img_path) as img: - if pil_mode is not None: - img = img.convert(pil_mode) - img_pil = torch.from_numpy(np.array(img)) - - img_pil = normalize_dimensions(img_pil) - data = read_file(img_path) - img_lpng = decode_image(data, mode=mode) - - tol = 0 if conversion is None else 1 - self.assertTrue(img_lpng.allclose(img_pil, atol=tol)) - - with self.assertRaises(RuntimeError): - decode_png(torch.empty((), dtype=torch.uint8)) - with self.assertRaises(RuntimeError): - decode_png(torch.randint(3, 5, (300,), dtype=torch.uint8)) - - def test_encode_png(self): - for img_path in get_images(IMAGE_DIR, '.png'): - pil_image = Image.open(img_path) - img_pil = torch.from_numpy(np.array(pil_image)) - img_pil = img_pil.permute(2, 0, 1) - png_buf = encode_png(img_pil, compression_level=6) - - rec_img = Image.open(io.BytesIO(bytes(png_buf.tolist()))) - rec_img = torch.from_numpy(np.array(rec_img)) - rec_img = rec_img.permute(2, 0, 1) - - assert_equal(img_pil, rec_img) - - with self.assertRaisesRegex( - RuntimeError, "Input tensor dtype should be uint8"): - encode_png(torch.empty((3, 100, 100), dtype=torch.float32)) - - with self.assertRaisesRegex( - RuntimeError, "Compression level should be between 0 and 9"): - encode_png(torch.empty((3, 100, 100), dtype=torch.uint8), - compression_level=-1) - - with self.assertRaisesRegex( - RuntimeError, "Compression level should be between 0 and 9"): - encode_png(torch.empty((3, 100, 100), dtype=torch.uint8), - compression_level=10) - - with self.assertRaisesRegex( - RuntimeError, "The number of channels should be 1 or 3, got: 5"): - encode_png(torch.empty((5, 100, 100), dtype=torch.uint8)) - - def test_write_png(self): - with get_tmp_dir() as d: - for img_path in get_images(IMAGE_DIR, '.png'): - pil_image = Image.open(img_path) - img_pil = torch.from_numpy(np.array(pil_image)) - img_pil = img_pil.permute(2, 0, 1) - - filename, _ = os.path.splitext(os.path.basename(img_path)) - torch_png = os.path.join(d, '{0}_torch.png'.format(filename)) - write_png(img_pil, torch_png, compression_level=6) - saved_image = torch.from_numpy(np.array(Image.open(torch_png))) - saved_image = saved_image.permute(2, 0, 1) - - assert_equal(img_pil, saved_image) - - def test_read_file(self): - with get_tmp_dir() as d: - fname, content = 'test1.bin', b'TorchVision\211\n' - fpath = os.path.join(d, fname) - with open(fpath, 'wb') as f: - f.write(content) - - data = read_file(fpath) - expected = torch.tensor(list(content), dtype=torch.uint8) - assert_equal(data, expected) - os.unlink(fpath) - - with self.assertRaisesRegex( - RuntimeError, "No such file or directory: 'tst'"): - read_file('tst') - - def test_read_file_non_ascii(self): - with get_tmp_dir() as d: - fname, content = '日本語(Japanese).bin', b'TorchVision\211\n' - fpath = os.path.join(d, fname) - with open(fpath, 'wb') as f: - f.write(content) - - data = read_file(fpath) - expected = torch.tensor(list(content), dtype=torch.uint8) - assert_equal(data, expected) - os.unlink(fpath) - - def test_write_file(self): - with get_tmp_dir() as d: - fname, content = 'test1.bin', b'TorchVision\211\n' - fpath = os.path.join(d, fname) - content_tensor = torch.tensor(list(content), dtype=torch.uint8) - write_file(fpath, content_tensor) - - with open(fpath, 'rb') as f: - saved_content = f.read() - self.assertEqual(content, saved_content) - os.unlink(fpath) - - def test_write_file_non_ascii(self): - with get_tmp_dir() as d: - fname, content = '日本語(Japanese).bin', b'TorchVision\211\n' - fpath = os.path.join(d, fname) - content_tensor = torch.tensor(list(content), dtype=torch.uint8) - write_file(fpath, content_tensor) - - with open(fpath, 'rb') as f: - saved_content = f.read() - self.assertEqual(content, saved_content) - os.unlink(fpath) +@pytest.mark.parametrize('img_path', [ + pytest.param(jpeg_path, id=_get_safe_image_name(jpeg_path)) + for jpeg_path in get_images(IMAGE_ROOT, ".jpg") +]) +@pytest.mark.parametrize('pil_mode, mode', [ + (None, ImageReadMode.UNCHANGED), + ("L", ImageReadMode.GRAY), + ("RGB", ImageReadMode.RGB), +]) +def test_decode_jpeg(img_path, pil_mode, mode): + + with Image.open(img_path) as img: + is_cmyk = img.mode == "CMYK" + if pil_mode is not None: + if is_cmyk: + # libjpeg does not support the conversion + pytest.xfail("Decoding a CMYK jpeg isn't supported") + img = img.convert(pil_mode) + img_pil = torch.from_numpy(np.array(img)) + if is_cmyk: + # flip the colors to match libjpeg + img_pil = 255 - img_pil + + img_pil = normalize_dimensions(img_pil) + data = read_file(img_path) + img_ljpeg = decode_image(data, mode=mode) + + # Permit a small variation on pixel values to account for implementation + # differences between Pillow and LibJPEG. + abs_mean_diff = (img_ljpeg.type(torch.float32) - img_pil).abs().mean().item() + assert abs_mean_diff < 2 + + +def test_decode_jpeg_errors(): + with pytest.raises(RuntimeError, match="Expected a non empty 1-dimensional tensor"): + decode_jpeg(torch.empty((100, 1), dtype=torch.uint8)) + + with pytest.raises(RuntimeError, match="Expected a torch.uint8 tensor"): + decode_jpeg(torch.empty((100,), dtype=torch.float16)) + + with pytest.raises(RuntimeError, match="Not a JPEG file"): + decode_jpeg(torch.empty((100), dtype=torch.uint8)) + + +def test_decode_bad_huffman_images(): + # sanity check: make sure we can decode the bad Huffman encoding + bad_huff = read_file(os.path.join(DAMAGED_JPEG, 'bad_huffman.jpg')) + decode_jpeg(bad_huff) + + +@pytest.mark.parametrize('img_path', [ + pytest.param(truncated_image, id=_get_safe_image_name(truncated_image)) + for truncated_image in glob.glob(os.path.join(DAMAGED_JPEG, 'corrupt*.jpg')) +]) +def test_damaged_corrupt_images(img_path): + # Truncated images should raise an exception + data = read_file(img_path) + if 'corrupt34' in img_path: + match_message = "Image is incomplete or truncated" + else: + match_message = "Unsupported marker type" + with pytest.raises(RuntimeError, match=match_message): + decode_jpeg(data) + + +@pytest.mark.parametrize('img_path', [ + pytest.param(png_path, id=_get_safe_image_name(png_path)) + for png_path in get_images(FAKEDATA_DIR, ".png") +]) +@pytest.mark.parametrize('pil_mode, mode', [ + (None, ImageReadMode.UNCHANGED), + ("L", ImageReadMode.GRAY), + ("LA", ImageReadMode.GRAY_ALPHA), + ("RGB", ImageReadMode.RGB), + ("RGBA", ImageReadMode.RGB_ALPHA), +]) +def test_decode_png(img_path, pil_mode, mode): + + with Image.open(img_path) as img: + if pil_mode is not None: + img = img.convert(pil_mode) + img_pil = torch.from_numpy(np.array(img)) + + img_pil = normalize_dimensions(img_pil) + data = read_file(img_path) + img_lpng = decode_image(data, mode=mode) + + tol = 0 if pil_mode is None else 1 + assert img_lpng.allclose(img_pil, atol=tol) + + +def test_decode_png_errors(): + with pytest.raises(RuntimeError, match="Expected a non empty 1-dimensional tensor"): + decode_png(torch.empty((), dtype=torch.uint8)) + with pytest.raises(RuntimeError, match="Content is not png"): + decode_png(torch.randint(3, 5, (300,), dtype=torch.uint8)) + + +@pytest.mark.parametrize('img_path', [ + pytest.param(png_path, id=_get_safe_image_name(png_path)) + for png_path in get_images(IMAGE_DIR, ".png") +]) +def test_encode_png(img_path): + pil_image = Image.open(img_path) + img_pil = torch.from_numpy(np.array(pil_image)) + img_pil = img_pil.permute(2, 0, 1) + png_buf = encode_png(img_pil, compression_level=6) + + rec_img = Image.open(io.BytesIO(bytes(png_buf.tolist()))) + rec_img = torch.from_numpy(np.array(rec_img)) + rec_img = rec_img.permute(2, 0, 1) + + assert_equal(img_pil, rec_img) + + +def test_encode_png_errors(): + with pytest.raises(RuntimeError, match="Input tensor dtype should be uint8"): + encode_png(torch.empty((3, 100, 100), dtype=torch.float32)) + + with pytest.raises(RuntimeError, match="Compression level should be between 0 and 9"): + encode_png(torch.empty((3, 100, 100), dtype=torch.uint8), + compression_level=-1) + + with pytest.raises(RuntimeError, match="Compression level should be between 0 and 9"): + encode_png(torch.empty((3, 100, 100), dtype=torch.uint8), + compression_level=10) + + with pytest.raises(RuntimeError, match="The number of channels should be 1 or 3, got: 5"): + encode_png(torch.empty((5, 100, 100), dtype=torch.uint8)) + + +@pytest.mark.parametrize('img_path', [ + pytest.param(png_path, id=_get_safe_image_name(png_path)) + for png_path in get_images(IMAGE_DIR, ".png") +]) +def test_write_png(img_path): + with get_tmp_dir() as d: + pil_image = Image.open(img_path) + img_pil = torch.from_numpy(np.array(pil_image)) + img_pil = img_pil.permute(2, 0, 1) + + filename, _ = os.path.splitext(os.path.basename(img_path)) + torch_png = os.path.join(d, '{0}_torch.png'.format(filename)) + write_png(img_pil, torch_png, compression_level=6) + saved_image = torch.from_numpy(np.array(Image.open(torch_png))) + saved_image = saved_image.permute(2, 0, 1) + + assert_equal(img_pil, saved_image) + + +def test_read_file(): + with get_tmp_dir() as d: + fname, content = 'test1.bin', b'TorchVision\211\n' + fpath = os.path.join(d, fname) + with open(fpath, 'wb') as f: + f.write(content) + + data = read_file(fpath) + expected = torch.tensor(list(content), dtype=torch.uint8) + os.unlink(fpath) + assert_equal(data, expected) + + with pytest.raises(RuntimeError, match="No such file or directory: 'tst'"): + read_file('tst') + + +def test_read_file_non_ascii(): + with get_tmp_dir() as d: + fname, content = '日本語(Japanese).bin', b'TorchVision\211\n' + fpath = os.path.join(d, fname) + with open(fpath, 'wb') as f: + f.write(content) + + data = read_file(fpath) + expected = torch.tensor(list(content), dtype=torch.uint8) + os.unlink(fpath) + assert_equal(data, expected) + + +def test_write_file(): + with get_tmp_dir() as d: + fname, content = 'test1.bin', b'TorchVision\211\n' + fpath = os.path.join(d, fname) + content_tensor = torch.tensor(list(content), dtype=torch.uint8) + write_file(fpath, content_tensor) + + with open(fpath, 'rb') as f: + saved_content = f.read() + os.unlink(fpath) + assert content == saved_content + + +def test_write_file_non_ascii(): + with get_tmp_dir() as d: + fname, content = '日本語(Japanese).bin', b'TorchVision\211\n' + fpath = os.path.join(d, fname) + content_tensor = torch.tensor(list(content), dtype=torch.uint8) + write_file(fpath, content_tensor) + + with open(fpath, 'rb') as f: + saved_content = f.read() + os.unlink(fpath) + assert content == saved_content @needs_cuda @@ -236,14 +268,14 @@ def test_write_file_non_ascii(self): def test_decode_jpeg_cuda(mode, img_path, scripted): if 'cmyk' in img_path: pytest.xfail("Decoding a CMYK jpeg isn't supported") - tester = ImageTester() + data = read_file(img_path) img = decode_image(data, mode=mode) f = torch.jit.script(decode_jpeg) if scripted else decode_jpeg img_nvjpeg = f(data, mode=mode, device='cuda') # Some difference expected between jpeg implementations - tester.assertTrue((img.float() - img_nvjpeg.cpu().float()).abs().mean() < 2) + assert (img.float() - img_nvjpeg.cpu().float()).abs().mean() < 2 @needs_cuda @@ -304,7 +336,11 @@ def _inner(test_func): @cpu_only @_collect_if(cond=IS_WINDOWS) -def test_encode_jpeg_windows(): +@pytest.mark.parametrize('img_path', [ + pytest.param(jpeg_path, id=_get_safe_image_name(jpeg_path)) + for jpeg_path in get_images(ENCODE_JPEG, ".jpg") +]) +def test_encode_jpeg_windows(img_path): # This test is *wrong*. # It compares a torchvision-encoded jpeg with a PIL-encoded jpeg, but it # starts encoding the torchvision version from an image that comes from @@ -315,48 +351,50 @@ def test_encode_jpeg_windows(): # these more correct tests fail on windows (probably because of a difference # in libjpeg) between torchvision and PIL. # FIXME: make the correct tests pass on windows and remove this. - for img_path in get_images(ENCODE_JPEG, ".jpg"): - dirname = os.path.dirname(img_path) - filename, _ = os.path.splitext(os.path.basename(img_path)) - write_folder = os.path.join(dirname, 'jpeg_write') - expected_file = os.path.join( - write_folder, '{0}_pil.jpg'.format(filename)) - img = decode_jpeg(read_file(img_path)) - - with open(expected_file, 'rb') as f: - pil_bytes = f.read() - pil_bytes = torch.as_tensor(list(pil_bytes), dtype=torch.uint8) - for src_img in [img, img.contiguous()]: - # PIL sets jpeg quality to 75 by default - jpeg_bytes = encode_jpeg(src_img, quality=75) - assert_equal(jpeg_bytes, pil_bytes) + dirname = os.path.dirname(img_path) + filename, _ = os.path.splitext(os.path.basename(img_path)) + write_folder = os.path.join(dirname, 'jpeg_write') + expected_file = os.path.join( + write_folder, '{0}_pil.jpg'.format(filename)) + img = decode_jpeg(read_file(img_path)) + + with open(expected_file, 'rb') as f: + pil_bytes = f.read() + pil_bytes = torch.as_tensor(list(pil_bytes), dtype=torch.uint8) + for src_img in [img, img.contiguous()]: + # PIL sets jpeg quality to 75 by default + jpeg_bytes = encode_jpeg(src_img, quality=75) + assert_equal(jpeg_bytes, pil_bytes) @cpu_only @_collect_if(cond=IS_WINDOWS) -def test_write_jpeg_windows(): +@pytest.mark.parametrize('img_path', [ + pytest.param(jpeg_path, id=_get_safe_image_name(jpeg_path)) + for jpeg_path in get_images(ENCODE_JPEG, ".jpg") +]) +def test_write_jpeg_windows(img_path): # FIXME: Remove this eventually, see test_encode_jpeg_windows with get_tmp_dir() as d: - for img_path in get_images(ENCODE_JPEG, ".jpg"): - data = read_file(img_path) - img = decode_jpeg(data) + data = read_file(img_path) + img = decode_jpeg(data) - basedir = os.path.dirname(img_path) - filename, _ = os.path.splitext(os.path.basename(img_path)) - torch_jpeg = os.path.join( - d, '{0}_torch.jpg'.format(filename)) - pil_jpeg = os.path.join( - basedir, 'jpeg_write', '{0}_pil.jpg'.format(filename)) + basedir = os.path.dirname(img_path) + filename, _ = os.path.splitext(os.path.basename(img_path)) + torch_jpeg = os.path.join( + d, '{0}_torch.jpg'.format(filename)) + pil_jpeg = os.path.join( + basedir, 'jpeg_write', '{0}_pil.jpg'.format(filename)) - write_jpeg(img, torch_jpeg, quality=75) + write_jpeg(img, torch_jpeg, quality=75) - with open(torch_jpeg, 'rb') as f: - torch_bytes = f.read() + with open(torch_jpeg, 'rb') as f: + torch_bytes = f.read() - with open(pil_jpeg, 'rb') as f: - pil_bytes = f.read() + with open(pil_jpeg, 'rb') as f: + pil_bytes = f.read() - assert_equal(torch_bytes, pil_bytes) + assert_equal(torch_bytes, pil_bytes) @cpu_only @@ -408,5 +446,5 @@ def test_write_jpeg(img_path): assert_equal(torch_bytes, pil_bytes) -if __name__ == '__main__': - unittest.main() +if __name__ == "__main__": + pytest.main([__file__])