From 4194db4be11bce5c860abc5afe58d73ec6ba4342 Mon Sep 17 00:00:00 2001 From: zhiqiang Date: Thu, 27 May 2021 01:42:02 +0800 Subject: [PATCH 1/7] Port test/test_image.py to pytest --- test/test_image.py | 445 ++++++++++++++++++++++++--------------------- 1 file changed, 238 insertions(+), 207 deletions(-) diff --git a/test/test_image.py b/test/test_image.py index eae4a1473c5..e14b213b169 100644 --- a/test/test_image.py +++ b/test/test_image.py @@ -2,7 +2,6 @@ import io import os import sys -import unittest from pathlib import Path import pytest @@ -54,176 +53,202 @@ def normalize_dimensions(img_pil): return img_pil -class ImageTester(unittest.TestCase): - def test_decode_jpeg(self): - conversion = [(None, ImageReadMode.UNCHANGED), ("L", ImageReadMode.GRAY), ("RGB", ImageReadMode.RGB)] - for img_path in get_images(IMAGE_ROOT, ".jpg"): - for pil_mode, mode in conversion: - with Image.open(img_path) as img: - is_cmyk = img.mode == "CMYK" - if pil_mode is not None: - if is_cmyk: - # libjpeg does not support the conversion - continue - img = img.convert(pil_mode) - img_pil = torch.from_numpy(np.array(img)) - if is_cmyk: - # flip the colors to match libjpeg - img_pil = 255 - img_pil - - img_pil = normalize_dimensions(img_pil) - data = read_file(img_path) - img_ljpeg = decode_image(data, mode=mode) - - # Permit a small variation on pixel values to account for implementation - # differences between Pillow and LibJPEG. - abs_mean_diff = (img_ljpeg.type(torch.float32) - img_pil).abs().mean().item() - self.assertTrue(abs_mean_diff < 2) - - with self.assertRaisesRegex(RuntimeError, "Expected a non empty 1-dimensional tensor"): - decode_jpeg(torch.empty((100, 1), dtype=torch.uint8)) - - with self.assertRaisesRegex(RuntimeError, "Expected a torch.uint8 tensor"): - decode_jpeg(torch.empty((100,), dtype=torch.float16)) - - with self.assertRaises(RuntimeError): - decode_jpeg(torch.empty((100), dtype=torch.uint8)) - - def test_damaged_images(self): - # Test image with bad Huffman encoding (should not raise) - bad_huff = read_file(os.path.join(DAMAGED_JPEG, 'bad_huffman.jpg')) - try: - _ = decode_jpeg(bad_huff) - except RuntimeError: - self.assertTrue(False) - - # Truncated images should raise an exception - truncated_images = glob.glob( - os.path.join(DAMAGED_JPEG, 'corrupt*.jpg')) - for image_path in truncated_images: - data = read_file(image_path) - with self.assertRaises(RuntimeError): - decode_jpeg(data) - - def test_decode_png(self): - conversion = [(None, ImageReadMode.UNCHANGED), ("L", ImageReadMode.GRAY), ("LA", ImageReadMode.GRAY_ALPHA), - ("RGB", ImageReadMode.RGB), ("RGBA", ImageReadMode.RGB_ALPHA)] - for img_path in get_images(FAKEDATA_DIR, ".png"): - for pil_mode, mode in conversion: - with Image.open(img_path) as img: - if pil_mode is not None: - img = img.convert(pil_mode) - img_pil = torch.from_numpy(np.array(img)) - - img_pil = normalize_dimensions(img_pil) - data = read_file(img_path) - img_lpng = decode_image(data, mode=mode) - - tol = 0 if conversion is None else 1 - self.assertTrue(img_lpng.allclose(img_pil, atol=tol)) - - with self.assertRaises(RuntimeError): - decode_png(torch.empty((), dtype=torch.uint8)) - with self.assertRaises(RuntimeError): - decode_png(torch.randint(3, 5, (300,), dtype=torch.uint8)) - - def test_encode_png(self): - for img_path in get_images(IMAGE_DIR, '.png'): - pil_image = Image.open(img_path) - img_pil = torch.from_numpy(np.array(pil_image)) - img_pil = img_pil.permute(2, 0, 1) - png_buf = encode_png(img_pil, compression_level=6) - - rec_img = Image.open(io.BytesIO(bytes(png_buf.tolist()))) - rec_img = torch.from_numpy(np.array(rec_img)) - rec_img = rec_img.permute(2, 0, 1) - - assert_equal(img_pil, rec_img) - - with self.assertRaisesRegex( - RuntimeError, "Input tensor dtype should be uint8"): - encode_png(torch.empty((3, 100, 100), dtype=torch.float32)) - - with self.assertRaisesRegex( - RuntimeError, "Compression level should be between 0 and 9"): - encode_png(torch.empty((3, 100, 100), dtype=torch.uint8), - compression_level=-1) - - with self.assertRaisesRegex( - RuntimeError, "Compression level should be between 0 and 9"): - encode_png(torch.empty((3, 100, 100), dtype=torch.uint8), - compression_level=10) - - with self.assertRaisesRegex( - RuntimeError, "The number of channels should be 1 or 3, got: 5"): - encode_png(torch.empty((5, 100, 100), dtype=torch.uint8)) - - def test_write_png(self): - with get_tmp_dir() as d: - for img_path in get_images(IMAGE_DIR, '.png'): - pil_image = Image.open(img_path) - img_pil = torch.from_numpy(np.array(pil_image)) - img_pil = img_pil.permute(2, 0, 1) - - filename, _ = os.path.splitext(os.path.basename(img_path)) - torch_png = os.path.join(d, '{0}_torch.png'.format(filename)) - write_png(img_pil, torch_png, compression_level=6) - saved_image = torch.from_numpy(np.array(Image.open(torch_png))) - saved_image = saved_image.permute(2, 0, 1) - - assert_equal(img_pil, saved_image) - - def test_read_file(self): - with get_tmp_dir() as d: - fname, content = 'test1.bin', b'TorchVision\211\n' - fpath = os.path.join(d, fname) - with open(fpath, 'wb') as f: - f.write(content) - - data = read_file(fpath) - expected = torch.tensor(list(content), dtype=torch.uint8) - assert_equal(data, expected) - os.unlink(fpath) - - with self.assertRaisesRegex( - RuntimeError, "No such file or directory: 'tst'"): - read_file('tst') - - def test_read_file_non_ascii(self): - with get_tmp_dir() as d: - fname, content = '日本語(Japanese).bin', b'TorchVision\211\n' - fpath = os.path.join(d, fname) - with open(fpath, 'wb') as f: - f.write(content) - - data = read_file(fpath) - expected = torch.tensor(list(content), dtype=torch.uint8) - assert_equal(data, expected) - os.unlink(fpath) - - def test_write_file(self): - with get_tmp_dir() as d: - fname, content = 'test1.bin', b'TorchVision\211\n' - fpath = os.path.join(d, fname) - content_tensor = torch.tensor(list(content), dtype=torch.uint8) - write_file(fpath, content_tensor) - - with open(fpath, 'rb') as f: - saved_content = f.read() - self.assertEqual(content, saved_content) - os.unlink(fpath) - - def test_write_file_non_ascii(self): - with get_tmp_dir() as d: - fname, content = '日本語(Japanese).bin', b'TorchVision\211\n' - fpath = os.path.join(d, fname) - content_tensor = torch.tensor(list(content), dtype=torch.uint8) - write_file(fpath, content_tensor) - - with open(fpath, 'rb') as f: - saved_content = f.read() - self.assertEqual(content, saved_content) - os.unlink(fpath) +@pytest.mark.parametrize('img_path', [ + pytest.param(jpeg_path, id=_get_safe_image_name(jpeg_path)) + for jpeg_path in get_images(IMAGE_ROOT, ".jpg") +]) +@pytest.mark.parametrize('conversion', [ + (None, ImageReadMode.UNCHANGED), + ("L", ImageReadMode.GRAY), + ("RGB", ImageReadMode.RGB), +]) +def test_decode_jpeg(img_path, conversion): + pil_mode, mode = conversion + + with Image.open(img_path) as img: + is_cmyk = img.mode == "CMYK" + if pil_mode is not None: + if is_cmyk: + # libjpeg does not support the conversion + pytest.xfail("Decoding a CMYK jpeg isn't supported") + img = img.convert(pil_mode) + img_pil = torch.from_numpy(np.array(img)) + if is_cmyk: + # flip the colors to match libjpeg + img_pil = 255 - img_pil + + img_pil = normalize_dimensions(img_pil) + data = read_file(img_path) + img_ljpeg = decode_image(data, mode=mode) + + # Permit a small variation on pixel values to account for implementation + # differences between Pillow and LibJPEG. + abs_mean_diff = (img_ljpeg.type(torch.float32) - img_pil).abs().mean().item() + assert abs_mean_diff < 2 + + with pytest.raises(RuntimeError, match="Expected a non empty 1-dimensional tensor"): + decode_jpeg(torch.empty((100, 1), dtype=torch.uint8)) + + with pytest.raises(RuntimeError, match="Expected a torch.uint8 tensor"): + decode_jpeg(torch.empty((100,), dtype=torch.float16)) + + with pytest.raises(RuntimeError): + decode_jpeg(torch.empty((100), dtype=torch.uint8)) + + +@pytest.mark.parametrize('img_path', [ + pytest.param(truncated_image, id=_get_safe_image_name(truncated_image)) + for truncated_image in glob.glob(os.path.join(DAMAGED_JPEG, 'corrupt*.jpg')) +]) +def test_damaged_images(img_path): + # Test image with bad Huffman encoding (should not raise) + bad_huff = read_file(os.path.join(DAMAGED_JPEG, 'bad_huffman.jpg')) + try: + _ = decode_jpeg(bad_huff) + except RuntimeError: + assert False + + # Truncated images should raise an exception + data = read_file(img_path) + with pytest.raises(RuntimeError): + decode_jpeg(data) + + +@pytest.mark.parametrize('img_path', [ + pytest.param(png_path, id=_get_safe_image_name(png_path)) + for png_path in get_images(FAKEDATA_DIR, ".png") +]) +@pytest.mark.parametrize('conversion', [ + (None, ImageReadMode.UNCHANGED), + ("L", ImageReadMode.GRAY), + ("LA", ImageReadMode.GRAY_ALPHA), + ("RGB", ImageReadMode.RGB), + ("RGBA", ImageReadMode.RGB_ALPHA), +]) +def test_decode_png(img_path, conversion): + pil_mode, mode = conversion + + with Image.open(img_path) as img: + if pil_mode is not None: + img = img.convert(pil_mode) + img_pil = torch.from_numpy(np.array(img)) + + img_pil = normalize_dimensions(img_pil) + data = read_file(img_path) + img_lpng = decode_image(data, mode=mode) + + tol = 0 if conversion is None else 1 + assert img_lpng.allclose(img_pil, atol=tol) + + with pytest.raises(RuntimeError): + decode_png(torch.empty((), dtype=torch.uint8)) + with pytest.raises(RuntimeError): + decode_png(torch.randint(3, 5, (300,), dtype=torch.uint8)) + + +@pytest.mark.parametrize('img_path', [ + pytest.param(png_path, id=_get_safe_image_name(png_path)) + for png_path in get_images(IMAGE_DIR, ".png") +]) +def test_encode_png(img_path): + pil_image = Image.open(img_path) + img_pil = torch.from_numpy(np.array(pil_image)) + img_pil = img_pil.permute(2, 0, 1) + png_buf = encode_png(img_pil, compression_level=6) + + rec_img = Image.open(io.BytesIO(bytes(png_buf.tolist()))) + rec_img = torch.from_numpy(np.array(rec_img)) + rec_img = rec_img.permute(2, 0, 1) + + assert_equal(img_pil, rec_img) + + with pytest.raises(RuntimeError, match="Input tensor dtype should be uint8"): + encode_png(torch.empty((3, 100, 100), dtype=torch.float32)) + + with pytest.raises(RuntimeError, match="Compression level should be between 0 and 9"): + encode_png(torch.empty((3, 100, 100), dtype=torch.uint8), + compression_level=-1) + + with pytest.raises(RuntimeError, match="Compression level should be between 0 and 9"): + encode_png(torch.empty((3, 100, 100), dtype=torch.uint8), + compression_level=10) + + with pytest.raises(RuntimeError, match="The number of channels should be 1 or 3, got: 5"): + encode_png(torch.empty((5, 100, 100), dtype=torch.uint8)) + + +@pytest.mark.parametrize('img_path', [ + pytest.param(png_path, id=_get_safe_image_name(png_path)) + for png_path in get_images(IMAGE_DIR, ".png") +]) +def test_write_png(img_path): + with get_tmp_dir() as d: + pil_image = Image.open(img_path) + img_pil = torch.from_numpy(np.array(pil_image)) + img_pil = img_pil.permute(2, 0, 1) + + filename, _ = os.path.splitext(os.path.basename(img_path)) + torch_png = os.path.join(d, '{0}_torch.png'.format(filename)) + write_png(img_pil, torch_png, compression_level=6) + saved_image = torch.from_numpy(np.array(Image.open(torch_png))) + saved_image = saved_image.permute(2, 0, 1) + + assert_equal(img_pil, saved_image) + + +def test_read_file(): + with get_tmp_dir() as d: + fname, content = 'test1.bin', b'TorchVision\211\n' + fpath = os.path.join(d, fname) + with open(fpath, 'wb') as f: + f.write(content) + + data = read_file(fpath) + expected = torch.tensor(list(content), dtype=torch.uint8) + assert_equal(data, expected) + os.unlink(fpath) + + with pytest.raises(RuntimeError, match="No such file or directory: 'tst'"): + read_file('tst') + + +def test_read_file_non_ascii(): + with get_tmp_dir() as d: + fname, content = '日本語(Japanese).bin', b'TorchVision\211\n' + fpath = os.path.join(d, fname) + with open(fpath, 'wb') as f: + f.write(content) + + data = read_file(fpath) + expected = torch.tensor(list(content), dtype=torch.uint8) + assert_equal(data, expected) + os.unlink(fpath) + + +def test_write_file(): + with get_tmp_dir() as d: + fname, content = 'test1.bin', b'TorchVision\211\n' + fpath = os.path.join(d, fname) + content_tensor = torch.tensor(list(content), dtype=torch.uint8) + write_file(fpath, content_tensor) + + with open(fpath, 'rb') as f: + saved_content = f.read() + assert content == saved_content + os.unlink(fpath) + + +def test_write_file_non_ascii(): + with get_tmp_dir() as d: + fname, content = '日本語(Japanese).bin', b'TorchVision\211\n' + fpath = os.path.join(d, fname) + content_tensor = torch.tensor(list(content), dtype=torch.uint8) + write_file(fpath, content_tensor) + + with open(fpath, 'rb') as f: + saved_content = f.read() + assert content == saved_content + os.unlink(fpath) @needs_cuda @@ -236,14 +261,14 @@ def test_write_file_non_ascii(self): def test_decode_jpeg_cuda(mode, img_path, scripted): if 'cmyk' in img_path: pytest.xfail("Decoding a CMYK jpeg isn't supported") - tester = ImageTester() + data = read_file(img_path) img = decode_image(data, mode=mode) f = torch.jit.script(decode_jpeg) if scripted else decode_jpeg img_nvjpeg = f(data, mode=mode, device='cuda') # Some difference expected between jpeg implementations - tester.assertTrue((img.float() - img_nvjpeg.cpu().float()).abs().mean() < 2) + assert (img.float() - img_nvjpeg.cpu().float()).abs().mean() < 2 @needs_cuda @@ -304,7 +329,11 @@ def _inner(test_func): @cpu_only @_collect_if(cond=IS_WINDOWS) -def test_encode_jpeg_windows(): +@pytest.mark.parametrize('img_path', [ + pytest.param(jpeg_path, id=_get_safe_image_name(jpeg_path)) + for jpeg_path in get_images(ENCODE_JPEG, ".jpg") +]) +def test_encode_jpeg_windows(img_path): # This test is *wrong*. # It compares a torchvision-encoded jpeg with a PIL-encoded jpeg, but it # starts encoding the torchvision version from an image that comes from @@ -315,48 +344,50 @@ def test_encode_jpeg_windows(): # these more correct tests fail on windows (probably because of a difference # in libjpeg) between torchvision and PIL. # FIXME: make the correct tests pass on windows and remove this. - for img_path in get_images(ENCODE_JPEG, ".jpg"): - dirname = os.path.dirname(img_path) - filename, _ = os.path.splitext(os.path.basename(img_path)) - write_folder = os.path.join(dirname, 'jpeg_write') - expected_file = os.path.join( - write_folder, '{0}_pil.jpg'.format(filename)) - img = decode_jpeg(read_file(img_path)) - - with open(expected_file, 'rb') as f: - pil_bytes = f.read() - pil_bytes = torch.as_tensor(list(pil_bytes), dtype=torch.uint8) - for src_img in [img, img.contiguous()]: - # PIL sets jpeg quality to 75 by default - jpeg_bytes = encode_jpeg(src_img, quality=75) - assert_equal(jpeg_bytes, pil_bytes) + dirname = os.path.dirname(img_path) + filename, _ = os.path.splitext(os.path.basename(img_path)) + write_folder = os.path.join(dirname, 'jpeg_write') + expected_file = os.path.join( + write_folder, '{0}_pil.jpg'.format(filename)) + img = decode_jpeg(read_file(img_path)) + + with open(expected_file, 'rb') as f: + pil_bytes = f.read() + pil_bytes = torch.as_tensor(list(pil_bytes), dtype=torch.uint8) + for src_img in [img, img.contiguous()]: + # PIL sets jpeg quality to 75 by default + jpeg_bytes = encode_jpeg(src_img, quality=75) + assert_equal(jpeg_bytes, pil_bytes) @cpu_only @_collect_if(cond=IS_WINDOWS) -def test_write_jpeg_windows(): +@pytest.mark.parametrize('img_path', [ + pytest.param(jpeg_path, id=_get_safe_image_name(jpeg_path)) + for jpeg_path in get_images(ENCODE_JPEG, ".jpg") +]) +def test_write_jpeg_windows(img_path): # FIXME: Remove this eventually, see test_encode_jpeg_windows with get_tmp_dir() as d: - for img_path in get_images(ENCODE_JPEG, ".jpg"): - data = read_file(img_path) - img = decode_jpeg(data) + data = read_file(img_path) + img = decode_jpeg(data) - basedir = os.path.dirname(img_path) - filename, _ = os.path.splitext(os.path.basename(img_path)) - torch_jpeg = os.path.join( - d, '{0}_torch.jpg'.format(filename)) - pil_jpeg = os.path.join( - basedir, 'jpeg_write', '{0}_pil.jpg'.format(filename)) + basedir = os.path.dirname(img_path) + filename, _ = os.path.splitext(os.path.basename(img_path)) + torch_jpeg = os.path.join( + d, '{0}_torch.jpg'.format(filename)) + pil_jpeg = os.path.join( + basedir, 'jpeg_write', '{0}_pil.jpg'.format(filename)) - write_jpeg(img, torch_jpeg, quality=75) + write_jpeg(img, torch_jpeg, quality=75) - with open(torch_jpeg, 'rb') as f: - torch_bytes = f.read() + with open(torch_jpeg, 'rb') as f: + torch_bytes = f.read() - with open(pil_jpeg, 'rb') as f: - pil_bytes = f.read() + with open(pil_jpeg, 'rb') as f: + pil_bytes = f.read() - assert_equal(torch_bytes, pil_bytes) + assert_equal(torch_bytes, pil_bytes) @cpu_only @@ -408,5 +439,5 @@ def test_write_jpeg(img_path): assert_equal(torch_bytes, pil_bytes) -if __name__ == '__main__': - unittest.main() +if __name__ == "__main__": + pytest.main([__file__]) From 96ea299217068f0f8716bd1258b1a0b7c9d6f314 Mon Sep 17 00:00:00 2001 From: zhiqiang Date: Thu, 27 May 2021 01:53:07 +0800 Subject: [PATCH 2/7] Call os.unlink before the assert --- test/test_image.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/test/test_image.py b/test/test_image.py index e14b213b169..1e8e439e143 100644 --- a/test/test_image.py +++ b/test/test_image.py @@ -205,8 +205,8 @@ def test_read_file(): data = read_file(fpath) expected = torch.tensor(list(content), dtype=torch.uint8) - assert_equal(data, expected) os.unlink(fpath) + assert_equal(data, expected) with pytest.raises(RuntimeError, match="No such file or directory: 'tst'"): read_file('tst') @@ -221,8 +221,8 @@ def test_read_file_non_ascii(): data = read_file(fpath) expected = torch.tensor(list(content), dtype=torch.uint8) - assert_equal(data, expected) os.unlink(fpath) + assert_equal(data, expected) def test_write_file(): @@ -234,8 +234,8 @@ def test_write_file(): with open(fpath, 'rb') as f: saved_content = f.read() - assert content == saved_content os.unlink(fpath) + assert content == saved_content def test_write_file_non_ascii(): @@ -247,8 +247,8 @@ def test_write_file_non_ascii(): with open(fpath, 'rb') as f: saved_content = f.read() - assert content == saved_content os.unlink(fpath) + assert content == saved_content @needs_cuda From b25c100d01d82782f720bd1b1d75202a12639397 Mon Sep 17 00:00:00 2001 From: zhiqwang Date: Thu, 27 May 2021 05:39:24 -0400 Subject: [PATCH 3/7] Use tuples parametrize to unpack conversion --- test/test_image.py | 12 +++++------- 1 file changed, 5 insertions(+), 7 deletions(-) diff --git a/test/test_image.py b/test/test_image.py index 1e8e439e143..ff179f4bc7f 100644 --- a/test/test_image.py +++ b/test/test_image.py @@ -57,13 +57,12 @@ def normalize_dimensions(img_pil): pytest.param(jpeg_path, id=_get_safe_image_name(jpeg_path)) for jpeg_path in get_images(IMAGE_ROOT, ".jpg") ]) -@pytest.mark.parametrize('conversion', [ +@pytest.mark.parametrize('pil_mode, mode', [ (None, ImageReadMode.UNCHANGED), ("L", ImageReadMode.GRAY), ("RGB", ImageReadMode.RGB), ]) -def test_decode_jpeg(img_path, conversion): - pil_mode, mode = conversion +def test_decode_jpeg(img_path, pil_mode, mode): with Image.open(img_path) as img: is_cmyk = img.mode == "CMYK" @@ -118,15 +117,14 @@ def test_damaged_images(img_path): pytest.param(png_path, id=_get_safe_image_name(png_path)) for png_path in get_images(FAKEDATA_DIR, ".png") ]) -@pytest.mark.parametrize('conversion', [ +@pytest.mark.parametrize('pil_mode, mode', [ (None, ImageReadMode.UNCHANGED), ("L", ImageReadMode.GRAY), ("LA", ImageReadMode.GRAY_ALPHA), ("RGB", ImageReadMode.RGB), ("RGBA", ImageReadMode.RGB_ALPHA), ]) -def test_decode_png(img_path, conversion): - pil_mode, mode = conversion +def test_decode_png(img_path, pil_mode, mode): with Image.open(img_path) as img: if pil_mode is not None: @@ -137,7 +135,7 @@ def test_decode_png(img_path, conversion): data = read_file(img_path) img_lpng = decode_image(data, mode=mode) - tol = 0 if conversion is None else 1 + tol = 0 if pil_mode is None else 1 assert img_lpng.allclose(img_pil, atol=tol) with pytest.raises(RuntimeError): From 40311dc13b3f1f146eda681373efdb28f0bb2bcb Mon Sep 17 00:00:00 2001 From: zhiqwang Date: Thu, 27 May 2021 06:50:01 -0400 Subject: [PATCH 4/7] Process bad huffman and currupt images separately --- test/test_image.py | 21 +++++++++++---------- 1 file changed, 11 insertions(+), 10 deletions(-) diff --git a/test/test_image.py b/test/test_image.py index ff179f4bc7f..45d63001fc0 100644 --- a/test/test_image.py +++ b/test/test_image.py @@ -85,31 +85,32 @@ def test_decode_jpeg(img_path, pil_mode, mode): abs_mean_diff = (img_ljpeg.type(torch.float32) - img_pil).abs().mean().item() assert abs_mean_diff < 2 + +def test_decode_jpeg_errors(): with pytest.raises(RuntimeError, match="Expected a non empty 1-dimensional tensor"): decode_jpeg(torch.empty((100, 1), dtype=torch.uint8)) with pytest.raises(RuntimeError, match="Expected a torch.uint8 tensor"): decode_jpeg(torch.empty((100,), dtype=torch.float16)) - with pytest.raises(RuntimeError): + with pytest.raises(RuntimeError, match="Not a JPEG file"): decode_jpeg(torch.empty((100), dtype=torch.uint8)) +def test_decode_bad_huffman_images(): + # sanity check: make sure we can decode the bad Huffman encoding + bad_huff = read_file(os.path.join(DAMAGED_JPEG, 'bad_huffman.jpg')) + decode_jpeg(bad_huff) + + @pytest.mark.parametrize('img_path', [ pytest.param(truncated_image, id=_get_safe_image_name(truncated_image)) for truncated_image in glob.glob(os.path.join(DAMAGED_JPEG, 'corrupt*.jpg')) ]) -def test_damaged_images(img_path): - # Test image with bad Huffman encoding (should not raise) - bad_huff = read_file(os.path.join(DAMAGED_JPEG, 'bad_huffman.jpg')) - try: - _ = decode_jpeg(bad_huff) - except RuntimeError: - assert False - +def test_damaged_corrupt_images(img_path): # Truncated images should raise an exception data = read_file(img_path) - with pytest.raises(RuntimeError): + with pytest.raises(RuntimeError, match="Image is incomplete or truncated"): decode_jpeg(data) From 14bc8e5c11f1aace04068c2c7cae9da033fe4421 Mon Sep 17 00:00:00 2001 From: zhiqwang Date: Thu, 27 May 2021 06:52:40 -0400 Subject: [PATCH 5/7] Separate test_encode_png_errors --- test/test_image.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/test/test_image.py b/test/test_image.py index 45d63001fc0..52e7307fca5 100644 --- a/test/test_image.py +++ b/test/test_image.py @@ -161,6 +161,8 @@ def test_encode_png(img_path): assert_equal(img_pil, rec_img) + +def test_encode_png_errors(): with pytest.raises(RuntimeError, match="Input tensor dtype should be uint8"): encode_png(torch.empty((3, 100, 100), dtype=torch.float32)) From 82fa4d0d9c383ee99564816be37a3055018b7956 Mon Sep 17 00:00:00 2001 From: zhiqwang Date: Thu, 27 May 2021 06:56:35 -0400 Subject: [PATCH 6/7] Separate test_decode_png_errors and supplement error message --- test/test_image.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/test/test_image.py b/test/test_image.py index 52e7307fca5..c707c5beaf5 100644 --- a/test/test_image.py +++ b/test/test_image.py @@ -139,9 +139,11 @@ def test_decode_png(img_path, pil_mode, mode): tol = 0 if pil_mode is None else 1 assert img_lpng.allclose(img_pil, atol=tol) - with pytest.raises(RuntimeError): + +def test_decode_png_errors(): + with pytest.raises(RuntimeError, match="Expected a non empty 1-dimensional tensor"): decode_png(torch.empty((), dtype=torch.uint8)) - with pytest.raises(RuntimeError): + with pytest.raises(RuntimeError, match="Content is not png"): decode_png(torch.randint(3, 5, (300,), dtype=torch.uint8)) From 608a47f79cfea49b689386e6e0012aff7f1c2ab7 Mon Sep 17 00:00:00 2001 From: zhiqwang Date: Thu, 27 May 2021 07:34:56 -0400 Subject: [PATCH 7/7] To handle different error messages --- test/test_image.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/test/test_image.py b/test/test_image.py index c707c5beaf5..61bf3a4070c 100644 --- a/test/test_image.py +++ b/test/test_image.py @@ -110,7 +110,11 @@ def test_decode_bad_huffman_images(): def test_damaged_corrupt_images(img_path): # Truncated images should raise an exception data = read_file(img_path) - with pytest.raises(RuntimeError, match="Image is incomplete or truncated"): + if 'corrupt34' in img_path: + match_message = "Image is incomplete or truncated" + else: + match_message = "Unsupported marker type" + with pytest.raises(RuntimeError, match=match_message): decode_jpeg(data)