From ce87ea7f750232163b431eda77769cf416f4ee26 Mon Sep 17 00:00:00 2001 From: surgan12 Date: Fri, 6 Sep 2019 11:34:33 +0530 Subject: [PATCH 1/4] format param added --- test/test_utils.py | 22 +++++++++++++++++++++- torchvision/utils.py | 9 ++++++--- 2 files changed, 27 insertions(+), 4 deletions(-) diff --git a/test/test_utils.py b/test/test_utils.py index 4c39520a692..67d4c853c7a 100644 --- a/test/test_utils.py +++ b/test/test_utils.py @@ -3,7 +3,8 @@ import torch import torchvision.utils as utils import unittest - +from io import BytesIO +import torchvision.transforms.functional as F class Tester(unittest.TestCase): @@ -49,6 +50,25 @@ def test_save_image_single_pixel(self): utils.save_image(t, f.name) assert os.path.exists(f.name), 'The pixel image is not present after save' + def test_save_image_file_object(self): + with tempfile.NamedTemporaryFile(suffix='.png') as f: + t = torch.rand(2, 3, 64, 64) + utils.save_image(t, f.name) + img_orig = Image.open(f.name) + fp = BytesIO() + utils.save_image(t, fp, format='png') + img_bytes = Image.open(fp) + assert torch.equal(F.to_tensor(img_orig), F.to_tensor(img_bytes)), 'Image not stored in file object' + + def test_save_image_single_pixel_file_object(self): + with tempfile.NamedTemporaryFile(suffix='.png') as f: + t = torch.rand(1, 3, 1, 1) + utils.save_image(t, f.name) + img_orig = Image.open(f.name) + fp = BytesIO() + utils.save_image(t, fp, format='png') + img_bytes = Image.open(fp) + assert torch.equal(F.to_tensor(img_orig), F.to_tensor(img_bytes)), 'Pixel Image not stored in file object' if __name__ == '__main__': unittest.main() diff --git a/torchvision/utils.py b/torchvision/utils.py index f07a3bb4016..f05b14ec4e5 100644 --- a/torchvision/utils.py +++ b/torchvision/utils.py @@ -87,13 +87,16 @@ def norm_range(t, range): return grid -def save_image(tensor, filename, nrow=8, padding=2, - normalize=False, range=None, scale_each=False, pad_value=0): +def save_image(tensor, fp, nrow=8, padding=2, + normalize=False, range=None, scale_each=False, pad_value=0, format=None): """Save a given Tensor into an image file. Args: tensor (Tensor or list): Image to be saved. If given a mini-batch tensor, saves the tensor as a grid of images by calling ``make_grid``. + fp - A filename(string) or file object + format(Optional): If omitted, the format to use is determined from the filename extension. + If a file object was used instead of a filename, this parameter should always be used. **kwargs: Other arguments are documented in ``make_grid``. """ from PIL import Image @@ -102,4 +105,4 @@ def save_image(tensor, filename, nrow=8, padding=2, # Add 0.5 after unnormalizing to [0, 255] to round to nearest integer ndarr = grid.mul(255).add_(0.5).clamp_(0, 255).permute(1, 2, 0).to('cpu', torch.uint8).numpy() im = Image.fromarray(ndarr) - im.save(filename) + im.save(fp,format=format) From 6ab940805b21d36528b4843149201cf024c50d03 Mon Sep 17 00:00:00 2001 From: surgan12 Date: Fri, 6 Sep 2019 11:52:53 +0530 Subject: [PATCH 2/4] lint fixes --- test/test_utils.py | 2 ++ torchvision/utils.py | 4 ++-- 2 files changed, 4 insertions(+), 2 deletions(-) diff --git a/test/test_utils.py b/test/test_utils.py index 67d4c853c7a..15676d77730 100644 --- a/test/test_utils.py +++ b/test/test_utils.py @@ -5,6 +5,7 @@ import unittest from io import BytesIO import torchvision.transforms.functional as F +from PIL import Image class Tester(unittest.TestCase): @@ -70,5 +71,6 @@ def test_save_image_single_pixel_file_object(self): img_bytes = Image.open(fp) assert torch.equal(F.to_tensor(img_orig), F.to_tensor(img_bytes)), 'Pixel Image not stored in file object' + if __name__ == '__main__': unittest.main() diff --git a/torchvision/utils.py b/torchvision/utils.py index f05b14ec4e5..18c282c42c8 100644 --- a/torchvision/utils.py +++ b/torchvision/utils.py @@ -95,7 +95,7 @@ def save_image(tensor, fp, nrow=8, padding=2, tensor (Tensor or list): Image to be saved. If given a mini-batch tensor, saves the tensor as a grid of images by calling ``make_grid``. fp - A filename(string) or file object - format(Optional): If omitted, the format to use is determined from the filename extension. + format(Optional): If omitted, the format to use is determined from the filename extension. If a file object was used instead of a filename, this parameter should always be used. **kwargs: Other arguments are documented in ``make_grid``. """ @@ -105,4 +105,4 @@ def save_image(tensor, fp, nrow=8, padding=2, # Add 0.5 after unnormalizing to [0, 255] to round to nearest integer ndarr = grid.mul(255).add_(0.5).clamp_(0, 255).permute(1, 2, 0).to('cpu', torch.uint8).numpy() im = Image.fromarray(ndarr) - im.save(fp,format=format) + im.save(fp, format=format) From fae90d2a90b981a05e03fcbd1c1d3b70c04cdbda Mon Sep 17 00:00:00 2001 From: surgan12 Date: Fri, 6 Sep 2019 11:58:05 +0530 Subject: [PATCH 3/4] lint fixes --- test/test_utils.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/test/test_utils.py b/test/test_utils.py index 15676d77730..544e218814e 100644 --- a/test/test_utils.py +++ b/test/test_utils.py @@ -5,10 +5,11 @@ import unittest from io import BytesIO import torchvision.transforms.functional as F -from PIL import Image +from PIL import Image class Tester(unittest.TestCase): + def test_make_grid_not_inplace(self): t = torch.rand(5, 3, 10, 10) t_clone = t.clone() From f80f8d4de18bfa73df612df0cb3a292e9ccc1761 Mon Sep 17 00:00:00 2001 From: surgan12 Date: Fri, 6 Sep 2019 12:03:27 +0530 Subject: [PATCH 4/4] lint fixes --- test/test_utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/test_utils.py b/test/test_utils.py index 544e218814e..2ad2189c57f 100644 --- a/test/test_utils.py +++ b/test/test_utils.py @@ -7,8 +7,8 @@ import torchvision.transforms.functional as F from PIL import Image -class Tester(unittest.TestCase): +class Tester(unittest.TestCase): def test_make_grid_not_inplace(self): t = torch.rand(5, 3, 10, 10)