Skip to content

Commit

Permalink
align numpy and torch on floating point inputs
Browse files Browse the repository at this point in the history
Co-authored-by: Nicolas Hug <contact@nicolas-hug.com>
  • Loading branch information
pmeier and NicolasHug committed Nov 7, 2023
1 parent 83a5ab6 commit 3e546cb
Show file tree
Hide file tree
Showing 2 changed files with 9 additions and 6 deletions.
10 changes: 6 additions & 4 deletions test/test_transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -661,7 +661,7 @@ def test_1_channel_float_tensor_to_pil_image(self):
@pytest.mark.parametrize(
"img_data, expected_mode",
[
(torch.Tensor(4, 4, 1).uniform_().numpy(), "F"),
(torch.Tensor(4, 4, 1).uniform_().numpy(), "L"),
(torch.ByteTensor(4, 4, 1).random_(0, 255).numpy(), "L"),
(torch.ShortTensor(4, 4, 1).random_().numpy(), "I;16"),
(torch.IntTensor(4, 4, 1).random_().numpy(), "I"),
Expand All @@ -671,6 +671,8 @@ def test_1_channel_ndarray_to_pil_image(self, with_mode, img_data, expected_mode
transform = transforms.ToPILImage(mode=expected_mode) if with_mode else transforms.ToPILImage()
img = transform(img_data)
assert img.mode == expected_mode
if np.issubdtype(img_data.dtype, np.floating):
img_data = (img_data * 255).astype(np.uint8)
# note: we explicitly convert img's dtype because pytorch doesn't support uint16
# and otherwise assert_close wouldn't be able to construct a tensor from the uint16 array
torch.testing.assert_close(img_data[:, :, 0], np.asarray(img).astype(img_data.dtype))
Expand Down Expand Up @@ -741,7 +743,7 @@ def test_2d_tensor_to_pil_image(self, with_mode, img_data, expected_output, expe
@pytest.mark.parametrize(
"img_data, expected_mode",
[
(torch.Tensor(4, 4).uniform_().numpy(), "F"),
(torch.Tensor(4, 4).uniform_().numpy(), "L"),
(torch.ByteTensor(4, 4).random_(0, 255).numpy(), "L"),
(torch.ShortTensor(4, 4).random_().numpy(), "I;16"),
(torch.IntTensor(4, 4).random_().numpy(), "I"),
Expand All @@ -751,6 +753,8 @@ def test_2d_ndarray_to_pil_image(self, with_mode, img_data, expected_mode):
transform = transforms.ToPILImage(mode=expected_mode) if with_mode else transforms.ToPILImage()
img = transform(img_data)
assert img.mode == expected_mode
if np.issubdtype(img_data.dtype, np.floating):
img_data = (img_data * 255).astype(np.uint8)
np.testing.assert_allclose(img_data, img)

@pytest.mark.parametrize("expected_mode", [None, "RGB", "HSV", "YCbCr"])
Expand Down Expand Up @@ -874,8 +878,6 @@ def test_ndarray_bad_types_to_pil_image(self):
trans(np.ones([4, 4, 1], np.uint16))
with pytest.raises(TypeError, match=reg_msg):
trans(np.ones([4, 4, 1], np.uint32))
with pytest.raises(TypeError, match=reg_msg):
trans(np.ones([4, 4, 1], np.float64))

with pytest.raises(ValueError, match=r"pic should be 2/3 dimensional. Got \d+ dimensions."):
transforms.ToPILImage()(np.ones([1, 4, 4, 3]))
Expand Down
5 changes: 3 additions & 2 deletions torchvision/transforms/functional.py
Original file line number Diff line number Diff line change
Expand Up @@ -261,8 +261,6 @@ def to_pil_image(pic, mode=None):
if isinstance(pic, Image.Image):
return pic
if isinstance(pic, torch.Tensor):
if pic.is_floating_point() and mode != "F":
pic = pic.mul(255).byte()
if pic.ndim == 3:
pic = pic.permute((1, 2, 0))
pic = pic.numpy(force=True)
Expand All @@ -280,6 +278,9 @@ def to_pil_image(pic, mode=None):

npimg = pic

if np.issubdtype(npimg.dtype, np.floating) and mode != "F":
npimg = (npimg * 255).astype(np.uint8)

if npimg.shape[2] == 1:
expected_mode = None
npimg = npimg[:, :, 0]
Expand Down

0 comments on commit 3e546cb

Please sign in to comment.