Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

refactor to_pil_image and align array with tensor inputs #8097

Merged
merged 4 commits into from
Nov 8, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
10 changes: 6 additions & 4 deletions test/test_transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -661,7 +661,7 @@ def test_1_channel_float_tensor_to_pil_image(self):
@pytest.mark.parametrize(
"img_data, expected_mode",
[
(torch.Tensor(4, 4, 1).uniform_().numpy(), "F"),
(torch.Tensor(4, 4, 1).uniform_().numpy(), "L"),
(torch.ByteTensor(4, 4, 1).random_(0, 255).numpy(), "L"),
(torch.ShortTensor(4, 4, 1).random_().numpy(), "I;16"),
(torch.IntTensor(4, 4, 1).random_().numpy(), "I"),
Expand All @@ -671,6 +671,8 @@ def test_1_channel_ndarray_to_pil_image(self, with_mode, img_data, expected_mode
transform = transforms.ToPILImage(mode=expected_mode) if with_mode else transforms.ToPILImage()
img = transform(img_data)
assert img.mode == expected_mode
if np.issubdtype(img_data.dtype, np.floating):
img_data = (img_data * 255).astype(np.uint8)
# note: we explicitly convert img's dtype because pytorch doesn't support uint16
# and otherwise assert_close wouldn't be able to construct a tensor from the uint16 array
torch.testing.assert_close(img_data[:, :, 0], np.asarray(img).astype(img_data.dtype))
Expand Down Expand Up @@ -741,7 +743,7 @@ def test_2d_tensor_to_pil_image(self, with_mode, img_data, expected_output, expe
@pytest.mark.parametrize(
"img_data, expected_mode",
[
(torch.Tensor(4, 4).uniform_().numpy(), "F"),
(torch.Tensor(4, 4).uniform_().numpy(), "L"),
(torch.ByteTensor(4, 4).random_(0, 255).numpy(), "L"),
(torch.ShortTensor(4, 4).random_().numpy(), "I;16"),
(torch.IntTensor(4, 4).random_().numpy(), "I"),
Expand All @@ -751,6 +753,8 @@ def test_2d_ndarray_to_pil_image(self, with_mode, img_data, expected_mode):
transform = transforms.ToPILImage(mode=expected_mode) if with_mode else transforms.ToPILImage()
img = transform(img_data)
assert img.mode == expected_mode
if np.issubdtype(img_data.dtype, np.floating):
img_data = (img_data * 255).astype(np.uint8)
np.testing.assert_allclose(img_data, img)

@pytest.mark.parametrize("expected_mode", [None, "RGB", "HSV", "YCbCr"])
Expand Down Expand Up @@ -874,8 +878,6 @@ def test_ndarray_bad_types_to_pil_image(self):
trans(np.ones([4, 4, 1], np.uint16))
with pytest.raises(TypeError, match=reg_msg):
trans(np.ones([4, 4, 1], np.uint32))
with pytest.raises(TypeError, match=reg_msg):
trans(np.ones([4, 4, 1], np.float64))

with pytest.raises(ValueError, match=r"pic should be 2/3 dimensional. Got \d+ dimensions."):
transforms.ToPILImage()(np.ones([1, 4, 4, 3]))
Expand Down
43 changes: 14 additions & 29 deletions torchvision/transforms/functional.py
Original file line number Diff line number Diff line change
Expand Up @@ -258,41 +258,26 @@ def to_pil_image(pic, mode=None):
if not torch.jit.is_scripting() and not torch.jit.is_tracing():
_log_api_usage_once(to_pil_image)

if not (isinstance(pic, torch.Tensor) or isinstance(pic, np.ndarray)):
if isinstance(pic, torch.Tensor):
if pic.ndim == 3:
pic = pic.permute((1, 2, 0))
pic = pic.numpy(force=True)
elif not isinstance(pic, np.ndarray):
raise TypeError(f"pic should be Tensor or ndarray. Got {type(pic)}.")

elif isinstance(pic, torch.Tensor):
if pic.ndimension() not in {2, 3}:
raise ValueError(f"pic should be 2/3 dimensional. Got {pic.ndimension()} dimensions.")

elif pic.ndimension() == 2:
# if 2D image, add channel dimension (CHW)
pic = pic.unsqueeze(0)

# check number of channels
if pic.shape[-3] > 4:
raise ValueError(f"pic should not have > 4 channels. Got {pic.shape[-3]} channels.")

elif isinstance(pic, np.ndarray):
if pic.ndim not in {2, 3}:
raise ValueError(f"pic should be 2/3 dimensional. Got {pic.ndim} dimensions.")

elif pic.ndim == 2:
# if 2D image, add channel dimension (HWC)
pic = np.expand_dims(pic, 2)
if pic.ndim == 2:
# if 2D image, add channel dimension (HWC)
pic = np.expand_dims(pic, 2)
if pic.ndim != 3:
raise ValueError(f"pic should be 2/3 dimensional. Got {pic.ndim} dimensions.")

# check number of channels
if pic.shape[-1] > 4:
raise ValueError(f"pic should not have > 4 channels. Got {pic.shape[-1]} channels.")
if pic.shape[-1] > 4:
raise ValueError(f"pic should not have > 4 channels. Got {pic.shape[-1]} channels.")

npimg = pic
if isinstance(pic, torch.Tensor):
if pic.is_floating_point() and mode != "F":
pic = pic.mul(255).byte()
npimg = np.transpose(pic.cpu().numpy(), (1, 2, 0))

if not isinstance(npimg, np.ndarray):
raise TypeError("Input pic must be a torch.Tensor or NumPy ndarray, not {type(npimg)}")
if np.issubdtype(npimg.dtype, np.floating) and mode != "F":
npimg = (npimg * 255).astype(np.uint8)

if npimg.shape[2] == 1:
expected_mode = None
Expand Down