From 9c7076c65df28e998be3681a5df6387b98aceac0 Mon Sep 17 00:00:00 2001 From: Milos Kondela Date: Sun, 29 Nov 2020 13:23:49 +0100 Subject: [PATCH 1/3] Fix writing to files by using get_tmp_dir() --- test/test_image.py | 81 ++++++++++++++++++++++------------------------ 1 file changed, 39 insertions(+), 42 deletions(-) diff --git a/test/test_image.py b/test/test_image.py index 45a4258816e..1aae291a5c7 100644 --- a/test/test_image.py +++ b/test/test_image.py @@ -1,19 +1,16 @@ -import os -import io import glob +import io +import os import unittest -import sys +import numpy as np import torch -import torchvision from PIL import Image +from common_utils import get_tmp_dir + from torchvision.io.image import ( decode_png, decode_jpeg, encode_jpeg, write_jpeg, decode_image, read_file, encode_png, write_png, write_file) -import numpy as np - -from common_utils import get_tmp_dir - IMAGE_ROOT = os.path.join(os.path.dirname(os.path.abspath(__file__)), "assets") FAKEDATA_DIR = os.path.join(IMAGE_ROOT, "fakedata") @@ -23,9 +20,10 @@ def get_images(directory, img_ext): assert os.path.isdir(directory) - for root, _, files in os.walk(directory): - if os.path.basename(root) in {'damaged_jpeg', 'jpeg_write'}: - continue + image_paths = glob.glob(directory + f'/**/*{img_ext}', recursive=True) + for path in image_paths: + if path.split(os.sep)[-2] not in ['damaged_jpeg', 'jpeg_write']: + yield path for fl in files: _, ext = os.path.splitext(fl) @@ -46,7 +44,7 @@ def test_decode_jpeg(self): decode_jpeg(torch.empty((100, 1), dtype=torch.uint8)) with self.assertRaisesRegex(RuntimeError, "Expected a torch.uint8 tensor"): - decode_jpeg(torch.empty((100, ), dtype=torch.float16)) + decode_jpeg(torch.empty((100,), dtype=torch.float16)) with self.assertRaises(RuntimeError): decode_jpeg(torch.empty((100), dtype=torch.uint8)) @@ -90,12 +88,12 @@ def test_encode_jpeg(self): with self.assertRaisesRegex( ValueError, "Image quality should be a positive number " - "between 1 and 100"): + "between 1 and 100"): encode_jpeg(torch.empty((3, 100, 100), dtype=torch.uint8), quality=-1) with self.assertRaisesRegex( ValueError, "Image quality should be a positive number " - "between 1 and 100"): + "between 1 and 100"): encode_jpeg(torch.empty((3, 100, 100), dtype=torch.uint8), quality=101) with self.assertRaisesRegex( @@ -111,27 +109,27 @@ def test_encode_jpeg(self): encode_jpeg(torch.empty((100, 100), dtype=torch.uint8)) def test_write_jpeg(self): - for img_path in get_images(IMAGE_ROOT, ".jpg"): - data = read_file(img_path) - img = decode_jpeg(data) + with get_tmp_dir() as d: + for img_path in get_images(IMAGE_ROOT, ".jpg"): + data = read_file(img_path) + img = decode_jpeg(data) - basedir = os.path.dirname(img_path) - filename, _ = os.path.splitext(os.path.basename(img_path)) - torch_jpeg = os.path.join( - basedir, '{0}_torch.jpg'.format(filename)) - pil_jpeg = os.path.join( - basedir, 'jpeg_write', '{0}_pil.jpg'.format(filename)) + basedir = os.path.dirname(img_path) + filename, _ = os.path.splitext(os.path.basename(img_path)) + torch_jpeg = os.path.join( + d, '{0}_torch.jpg'.format(filename)) + pil_jpeg = os.path.join( + basedir, 'jpeg_write', '{0}_pil.jpg'.format(filename)) - write_jpeg(img, torch_jpeg, quality=75) + write_jpeg(img, torch_jpeg, quality=75) - with open(torch_jpeg, 'rb') as f: - torch_bytes = f.read() + with open(torch_jpeg, 'rb') as f: + torch_bytes = f.read() - with open(pil_jpeg, 'rb') as f: - pil_bytes = f.read() + with open(pil_jpeg, 'rb') as f: + pil_bytes = f.read() - os.remove(torch_jpeg) - self.assertEqual(torch_bytes, pil_bytes) + self.assertEqual(torch_bytes, pil_bytes) def test_decode_png(self): for img_path in get_images(FAKEDATA_DIR, ".png"): @@ -181,20 +179,19 @@ def test_encode_png(self): encode_png(torch.empty((5, 100, 100), dtype=torch.uint8)) def test_write_png(self): - for img_path in get_images(IMAGE_DIR, '.png'): - pil_image = Image.open(img_path) - img_pil = torch.from_numpy(np.array(pil_image)) - img_pil = img_pil.permute(2, 0, 1) + with get_tmp_dir() as d: + for img_path in get_images(IMAGE_DIR, '.png'): + pil_image = Image.open(img_path) + img_pil = torch.from_numpy(np.array(pil_image)) + img_pil = img_pil.permute(2, 0, 1) - basedir = os.path.dirname(img_path) - filename, _ = os.path.splitext(os.path.basename(img_path)) - torch_png = os.path.join(basedir, '{0}_torch.png'.format(filename)) - write_png(img_pil, torch_png, compression_level=6) - saved_image = torch.from_numpy(np.array(Image.open(torch_png))) - os.remove(torch_png) - saved_image = saved_image.permute(2, 0, 1) + filename, _ = os.path.splitext(os.path.basename(img_path)) + torch_png = os.path.join(d, '{0}_torch.png'.format(filename)) + write_png(img_pil, torch_png, compression_level=6) + saved_image = torch.from_numpy(np.array(Image.open(torch_png))) + saved_image = saved_image.permute(2, 0, 1) - self.assertTrue(img_pil.equal(saved_image)) + self.assertTrue(img_pil.equal(saved_image)) def test_decode_image(self): for img_path in get_images(IMAGE_ROOT, ".jpg"): From 4014736846c452ba6ced3d18ba17d4af5a83b532 Mon Sep 17 00:00:00 2001 From: Milos Kondela Date: Sun, 29 Nov 2020 14:34:45 +0100 Subject: [PATCH 2/3] Add ImageReadMode to imports --- test/test_image.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/test_image.py b/test/test_image.py index accaf2d0499..921b9715a46 100644 --- a/test/test_image.py +++ b/test/test_image.py @@ -10,7 +10,7 @@ from torchvision.io.image import ( decode_png, decode_jpeg, encode_jpeg, write_jpeg, decode_image, read_file, - encode_png, write_png, write_file) + encode_png, write_png, write_file, ImageReadMode) IMAGE_ROOT = os.path.join(os.path.dirname(os.path.abspath(__file__)), "assets") FAKEDATA_DIR = os.path.join(IMAGE_ROOT, "fakedata") From ddd9ca876e14662700eefb899da5ec5976f26577 Mon Sep 17 00:00:00 2001 From: Milos Kondela Date: Sun, 29 Nov 2020 15:13:53 +0100 Subject: [PATCH 3/3] Fix failing test due to incorrect image path --- test/test_image.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/test_image.py b/test/test_image.py index 921b9715a46..ebc9a221f6d 100644 --- a/test/test_image.py +++ b/test/test_image.py @@ -136,7 +136,7 @@ def test_encode_jpeg(self): def test_write_jpeg(self): with get_tmp_dir() as d: - for img_path in get_images(IMAGE_ROOT, ".jpg"): + for img_path in get_images(ENCODE_JPEG, ".jpg"): data = read_file(img_path) img = decode_jpeg(data)