diff --git a/test/test_utils.py b/test/test_utils.py index 4c39520a692..2ad2189c57f 100644 --- a/test/test_utils.py +++ b/test/test_utils.py @@ -3,6 +3,9 @@ import torch import torchvision.utils as utils import unittest +from io import BytesIO +import torchvision.transforms.functional as F +from PIL import Image class Tester(unittest.TestCase): @@ -49,6 +52,26 @@ def test_save_image_single_pixel(self): utils.save_image(t, f.name) assert os.path.exists(f.name), 'The pixel image is not present after save' + def test_save_image_file_object(self): + with tempfile.NamedTemporaryFile(suffix='.png') as f: + t = torch.rand(2, 3, 64, 64) + utils.save_image(t, f.name) + img_orig = Image.open(f.name) + fp = BytesIO() + utils.save_image(t, fp, format='png') + img_bytes = Image.open(fp) + assert torch.equal(F.to_tensor(img_orig), F.to_tensor(img_bytes)), 'Image not stored in file object' + + def test_save_image_single_pixel_file_object(self): + with tempfile.NamedTemporaryFile(suffix='.png') as f: + t = torch.rand(1, 3, 1, 1) + utils.save_image(t, f.name) + img_orig = Image.open(f.name) + fp = BytesIO() + utils.save_image(t, fp, format='png') + img_bytes = Image.open(fp) + assert torch.equal(F.to_tensor(img_orig), F.to_tensor(img_bytes)), 'Pixel Image not stored in file object' + if __name__ == '__main__': unittest.main() diff --git a/torchvision/utils.py b/torchvision/utils.py index f07a3bb4016..18c282c42c8 100644 --- a/torchvision/utils.py +++ b/torchvision/utils.py @@ -87,13 +87,16 @@ def norm_range(t, range): return grid -def save_image(tensor, filename, nrow=8, padding=2, - normalize=False, range=None, scale_each=False, pad_value=0): +def save_image(tensor, fp, nrow=8, padding=2, + normalize=False, range=None, scale_each=False, pad_value=0, format=None): """Save a given Tensor into an image file. Args: tensor (Tensor or list): Image to be saved. If given a mini-batch tensor, saves the tensor as a grid of images by calling ``make_grid``. + fp - A filename(string) or file object + format(Optional): If omitted, the format to use is determined from the filename extension. + If a file object was used instead of a filename, this parameter should always be used. **kwargs: Other arguments are documented in ``make_grid``. """ from PIL import Image @@ -102,4 +105,4 @@ def save_image(tensor, filename, nrow=8, padding=2, # Add 0.5 after unnormalizing to [0, 255] to round to nearest integer ndarr = grid.mul(255).add_(0.5).clamp_(0, 255).permute(1, 2, 0).to('cpu', torch.uint8).numpy() im = Image.fromarray(ndarr) - im.save(filename) + im.save(fp, format=format)