Skip to content

Commit

Permalink
[fbsync] refactor to_pil_image and align array with tensor inputs (#8097
Browse files Browse the repository at this point in the history
)

Summary: Co-authored-by: Nicolas Hug <contact@nicolas-hug.com>

Reviewed By: vmoens

Differential Revision: D51391967

fbshipit-source-id: d64cd63a7417ceea4c7eee88e02c1866411e61f1
  • Loading branch information
NicolasHug authored and facebook-github-bot committed Nov 16, 2023
1 parent 85546aa commit f7bc701
Show file tree
Hide file tree
Showing 2 changed files with 20 additions and 33 deletions.
10 changes: 6 additions & 4 deletions test/test_transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -660,7 +660,7 @@ def test_1_channel_float_tensor_to_pil_image(self):
@pytest.mark.parametrize(
"img_data, expected_mode",
[
(torch.Tensor(4, 4, 1).uniform_().numpy(), "F"),
(torch.Tensor(4, 4, 1).uniform_().numpy(), "L"),
(torch.ByteTensor(4, 4, 1).random_(0, 255).numpy(), "L"),
(torch.ShortTensor(4, 4, 1).random_().numpy(), "I;16"),
(torch.IntTensor(4, 4, 1).random_().numpy(), "I"),
Expand All @@ -670,6 +670,8 @@ def test_1_channel_ndarray_to_pil_image(self, with_mode, img_data, expected_mode
transform = transforms.ToPILImage(mode=expected_mode) if with_mode else transforms.ToPILImage()
img = transform(img_data)
assert img.mode == expected_mode
if np.issubdtype(img_data.dtype, np.floating):
img_data = (img_data * 255).astype(np.uint8)
# note: we explicitly convert img's dtype because pytorch doesn't support uint16
# and otherwise assert_close wouldn't be able to construct a tensor from the uint16 array
torch.testing.assert_close(img_data[:, :, 0], np.asarray(img).astype(img_data.dtype))
Expand Down Expand Up @@ -740,7 +742,7 @@ def test_2d_tensor_to_pil_image(self, with_mode, img_data, expected_output, expe
@pytest.mark.parametrize(
"img_data, expected_mode",
[
(torch.Tensor(4, 4).uniform_().numpy(), "F"),
(torch.Tensor(4, 4).uniform_().numpy(), "L"),
(torch.ByteTensor(4, 4).random_(0, 255).numpy(), "L"),
(torch.ShortTensor(4, 4).random_().numpy(), "I;16"),
(torch.IntTensor(4, 4).random_().numpy(), "I"),
Expand All @@ -750,6 +752,8 @@ def test_2d_ndarray_to_pil_image(self, with_mode, img_data, expected_mode):
transform = transforms.ToPILImage(mode=expected_mode) if with_mode else transforms.ToPILImage()
img = transform(img_data)
assert img.mode == expected_mode
if np.issubdtype(img_data.dtype, np.floating):
img_data = (img_data * 255).astype(np.uint8)
np.testing.assert_allclose(img_data, img)

@pytest.mark.parametrize("expected_mode", [None, "RGB", "HSV", "YCbCr"])
Expand Down Expand Up @@ -873,8 +877,6 @@ def test_ndarray_bad_types_to_pil_image(self):
trans(np.ones([4, 4, 1], np.uint16))
with pytest.raises(TypeError, match=reg_msg):
trans(np.ones([4, 4, 1], np.uint32))
with pytest.raises(TypeError, match=reg_msg):
trans(np.ones([4, 4, 1], np.float64))

with pytest.raises(ValueError, match=r"pic should be 2/3 dimensional. Got \d+ dimensions."):
transforms.ToPILImage()(np.ones([1, 4, 4, 3]))
Expand Down
43 changes: 14 additions & 29 deletions torchvision/transforms/functional.py
Original file line number Diff line number Diff line change
Expand Up @@ -258,41 +258,26 @@ def to_pil_image(pic, mode=None):
if not torch.jit.is_scripting() and not torch.jit.is_tracing():
_log_api_usage_once(to_pil_image)

if not (isinstance(pic, torch.Tensor) or isinstance(pic, np.ndarray)):
if isinstance(pic, torch.Tensor):
if pic.ndim == 3:
pic = pic.permute((1, 2, 0))
pic = pic.numpy(force=True)
elif not isinstance(pic, np.ndarray):
raise TypeError(f"pic should be Tensor or ndarray. Got {type(pic)}.")

elif isinstance(pic, torch.Tensor):
if pic.ndimension() not in {2, 3}:
raise ValueError(f"pic should be 2/3 dimensional. Got {pic.ndimension()} dimensions.")

elif pic.ndimension() == 2:
# if 2D image, add channel dimension (CHW)
pic = pic.unsqueeze(0)

# check number of channels
if pic.shape[-3] > 4:
raise ValueError(f"pic should not have > 4 channels. Got {pic.shape[-3]} channels.")

elif isinstance(pic, np.ndarray):
if pic.ndim not in {2, 3}:
raise ValueError(f"pic should be 2/3 dimensional. Got {pic.ndim} dimensions.")

elif pic.ndim == 2:
# if 2D image, add channel dimension (HWC)
pic = np.expand_dims(pic, 2)
if pic.ndim == 2:
# if 2D image, add channel dimension (HWC)
pic = np.expand_dims(pic, 2)
if pic.ndim != 3:
raise ValueError(f"pic should be 2/3 dimensional. Got {pic.ndim} dimensions.")

# check number of channels
if pic.shape[-1] > 4:
raise ValueError(f"pic should not have > 4 channels. Got {pic.shape[-1]} channels.")
if pic.shape[-1] > 4:
raise ValueError(f"pic should not have > 4 channels. Got {pic.shape[-1]} channels.")

npimg = pic
if isinstance(pic, torch.Tensor):
if pic.is_floating_point() and mode != "F":
pic = pic.mul(255).byte()
npimg = np.transpose(pic.cpu().numpy(), (1, 2, 0))

if not isinstance(npimg, np.ndarray):
raise TypeError("Input pic must be a torch.Tensor or NumPy ndarray, not {type(npimg)}")
if np.issubdtype(npimg.dtype, np.floating) and mode != "F":
npimg = (npimg * 255).astype(np.uint8)

if npimg.shape[2] == 1:
expected_mode = None
Expand Down

0 comments on commit f7bc701

Please sign in to comment.