diff --git a/test/test_image.py b/test/test_image.py index cdeadf7a0a0..ebc9a221f6d 100644 --- a/test/test_image.py +++ b/test/test_image.py @@ -1,17 +1,16 @@ -import os -import io import glob +import io +import os import unittest +import numpy as np import torch from PIL import Image +from common_utils import get_tmp_dir + from torchvision.io.image import ( decode_png, decode_jpeg, encode_jpeg, write_jpeg, decode_image, read_file, encode_png, write_png, write_file, ImageReadMode) -import numpy as np - -from common_utils import get_tmp_dir - IMAGE_ROOT = os.path.join(os.path.dirname(os.path.abspath(__file__)), "assets") FAKEDATA_DIR = os.path.join(IMAGE_ROOT, "fakedata") @@ -22,14 +21,10 @@ def get_images(directory, img_ext): assert os.path.isdir(directory) - for root, _, files in os.walk(directory): - if os.path.basename(root) in {'damaged_jpeg', 'jpeg_write'}: - continue - - for fl in files: - _, ext = os.path.splitext(fl) - if ext == img_ext: - yield os.path.join(root, fl) + image_paths = glob.glob(directory + f'/**/*{img_ext}', recursive=True) + for path in image_paths: + if path.split(os.sep)[-2] not in ['damaged_jpeg', 'jpeg_write']: + yield path def pil_read_image(img_path): @@ -75,7 +70,7 @@ def test_decode_jpeg(self): decode_jpeg(torch.empty((100, 1), dtype=torch.uint8)) with self.assertRaisesRegex(RuntimeError, "Expected a torch.uint8 tensor"): - decode_jpeg(torch.empty((100, ), dtype=torch.float16)) + decode_jpeg(torch.empty((100,), dtype=torch.float16)) with self.assertRaises(RuntimeError): decode_jpeg(torch.empty((100), dtype=torch.uint8)) @@ -119,12 +114,12 @@ def test_encode_jpeg(self): with self.assertRaisesRegex( ValueError, "Image quality should be a positive number " - "between 1 and 100"): + "between 1 and 100"): encode_jpeg(torch.empty((3, 100, 100), dtype=torch.uint8), quality=-1) with self.assertRaisesRegex( ValueError, "Image quality should be a positive number " - "between 1 and 100"): + "between 1 and 100"): encode_jpeg(torch.empty((3, 100, 100), dtype=torch.uint8), quality=101) with self.assertRaisesRegex( @@ -140,27 +135,27 @@ def test_encode_jpeg(self): encode_jpeg(torch.empty((100, 100), dtype=torch.uint8)) def test_write_jpeg(self): - for img_path in get_images(ENCODE_JPEG, ".jpg"): - data = read_file(img_path) - img = decode_jpeg(data) + with get_tmp_dir() as d: + for img_path in get_images(ENCODE_JPEG, ".jpg"): + data = read_file(img_path) + img = decode_jpeg(data) - basedir = os.path.dirname(img_path) - filename, _ = os.path.splitext(os.path.basename(img_path)) - torch_jpeg = os.path.join( - basedir, '{0}_torch.jpg'.format(filename)) - pil_jpeg = os.path.join( - basedir, 'jpeg_write', '{0}_pil.jpg'.format(filename)) + basedir = os.path.dirname(img_path) + filename, _ = os.path.splitext(os.path.basename(img_path)) + torch_jpeg = os.path.join( + d, '{0}_torch.jpg'.format(filename)) + pil_jpeg = os.path.join( + basedir, 'jpeg_write', '{0}_pil.jpg'.format(filename)) - write_jpeg(img, torch_jpeg, quality=75) + write_jpeg(img, torch_jpeg, quality=75) - with open(torch_jpeg, 'rb') as f: - torch_bytes = f.read() + with open(torch_jpeg, 'rb') as f: + torch_bytes = f.read() - with open(pil_jpeg, 'rb') as f: - pil_bytes = f.read() + with open(pil_jpeg, 'rb') as f: + pil_bytes = f.read() - os.remove(torch_jpeg) - self.assertEqual(torch_bytes, pil_bytes) + self.assertEqual(torch_bytes, pil_bytes) def test_decode_png(self): conversion = [(None, ImageReadMode.UNCHANGED), ("L", ImageReadMode.GRAY), ("LA", ImageReadMode.GRAY_ALPHA), @@ -216,20 +211,19 @@ def test_encode_png(self): encode_png(torch.empty((5, 100, 100), dtype=torch.uint8)) def test_write_png(self): - for img_path in get_images(IMAGE_DIR, '.png'): - pil_image = Image.open(img_path) - img_pil = torch.from_numpy(np.array(pil_image)) - img_pil = img_pil.permute(2, 0, 1) - - basedir = os.path.dirname(img_path) - filename, _ = os.path.splitext(os.path.basename(img_path)) - torch_png = os.path.join(basedir, '{0}_torch.png'.format(filename)) - write_png(img_pil, torch_png, compression_level=6) - saved_image = torch.from_numpy(np.array(Image.open(torch_png))) - os.remove(torch_png) - saved_image = saved_image.permute(2, 0, 1) - - self.assertTrue(img_pil.equal(saved_image)) + with get_tmp_dir() as d: + for img_path in get_images(IMAGE_DIR, '.png'): + pil_image = Image.open(img_path) + img_pil = torch.from_numpy(np.array(pil_image)) + img_pil = img_pil.permute(2, 0, 1) + + filename, _ = os.path.splitext(os.path.basename(img_path)) + torch_png = os.path.join(d, '{0}_torch.png'.format(filename)) + write_png(img_pil, torch_png, compression_level=6) + saved_image = torch.from_numpy(np.array(Image.open(torch_png))) + saved_image = saved_image.permute(2, 0, 1) + + self.assertTrue(img_pil.equal(saved_image)) def test_read_file(self): with get_tmp_dir() as d: