Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 23 additions & 0 deletions test/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,9 @@
import torch
import torchvision.utils as utils
import unittest
from io import BytesIO
import torchvision.transforms.functional as F
from PIL import Image


class Tester(unittest.TestCase):
Expand Down Expand Up @@ -49,6 +52,26 @@ def test_save_image_single_pixel(self):
utils.save_image(t, f.name)
assert os.path.exists(f.name), 'The pixel image is not present after save'

def test_save_image_file_object(self):
with tempfile.NamedTemporaryFile(suffix='.png') as f:
t = torch.rand(2, 3, 64, 64)
utils.save_image(t, f.name)
img_orig = Image.open(f.name)
fp = BytesIO()
utils.save_image(t, fp, format='png')
img_bytes = Image.open(fp)
assert torch.equal(F.to_tensor(img_orig), F.to_tensor(img_bytes)), 'Image not stored in file object'

def test_save_image_single_pixel_file_object(self):
with tempfile.NamedTemporaryFile(suffix='.png') as f:
t = torch.rand(1, 3, 1, 1)
utils.save_image(t, f.name)
img_orig = Image.open(f.name)
fp = BytesIO()
utils.save_image(t, fp, format='png')
img_bytes = Image.open(fp)
assert torch.equal(F.to_tensor(img_orig), F.to_tensor(img_bytes)), 'Pixel Image not stored in file object'


if __name__ == '__main__':
unittest.main()
9 changes: 6 additions & 3 deletions torchvision/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,13 +87,16 @@ def norm_range(t, range):
return grid


def save_image(tensor, filename, nrow=8, padding=2,
normalize=False, range=None, scale_each=False, pad_value=0):
def save_image(tensor, fp, nrow=8, padding=2,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If we add the format option, shouldn't we allow all additional parameters (dpi for .png, ...)for Image.save()? We could add a **kwargs or save_kwargs which is simply passed to im.save().

Copy link
Contributor Author

@surgan12 surgan12 Sep 6, 2019

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sure, they could be added based on the use cases.
IMO file option provides an interface to write as bytes and this might be used in a couple of cases .
What do you think ? @pmeier

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

To be clear: I'm in favour of adding this as you did. I was just asking myself why we should only include the format option although Image.save() has a lot more options for the different file formats. I think if we give the user more control over the options we might as well give him all options. This of course does not need to happen in this PR. I just wanted to put it out there, so I or we don't forget it.

Copy link
Contributor Author

@surgan12 surgan12 Sep 6, 2019

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@pmeier , I also feel the same, providing a full control over the save would be a better thing for the user.
Also the save_image in the current utils only differs in the make_grid part and there should be no harm in making the rest similar to PIL.

normalize=False, range=None, scale_each=False, pad_value=0, format=None):
"""Save a given Tensor into an image file.

Args:
tensor (Tensor or list): Image to be saved. If given a mini-batch tensor,
saves the tensor as a grid of images by calling ``make_grid``.
fp - A filename(string) or file object
format(Optional): If omitted, the format to use is determined from the filename extension.
If a file object was used instead of a filename, this parameter should always be used.
**kwargs: Other arguments are documented in ``make_grid``.
"""
from PIL import Image
Expand All @@ -102,4 +105,4 @@ def save_image(tensor, filename, nrow=8, padding=2,
# Add 0.5 after unnormalizing to [0, 255] to round to nearest integer
ndarr = grid.mul(255).add_(0.5).clamp_(0, 255).permute(1, 2, 0).to('cpu', torch.uint8).numpy()
im = Image.fromarray(ndarr)
im.save(filename)
im.save(fp, format=format)