diff --git a/test/test_transforms.py b/test/test_transforms.py index 8abf9e88db3..392978d988b 100644 --- a/test/test_transforms.py +++ b/test/test_transforms.py @@ -620,6 +620,20 @@ def test_to_tensor(self): output = trans(img) self.assertTrue(np.allclose(input_data.numpy(), output.numpy())) + def test_to_tensor_with_other_default_dtypes(self): + current_def_dtype = torch.get_default_dtype() + + t = transforms.ToTensor() + np_arr = np.random.randint(0, 255, (32, 32, 3), dtype=np.uint8) + img = Image.fromarray(np_arr) + + for dtype in [torch.float16, torch.float, torch.double]: + torch.set_default_dtype(dtype) + res = t(img) + self.assertTrue(res.dtype == dtype, msg=f"{res.dtype} vs {dtype}") + + torch.set_default_dtype(current_def_dtype) + def test_max_value(self): for dtype in int_dtypes(): self.assertEqual(F_t._max_value(dtype), torch.iinfo(dtype).max) diff --git a/torchvision/transforms/functional.py b/torchvision/transforms/functional.py index ab1e2e9b29b..ea521213588 100644 --- a/torchvision/transforms/functional.py +++ b/torchvision/transforms/functional.py @@ -104,6 +104,8 @@ def to_tensor(pic): if _is_numpy(pic) and not _is_numpy_image(pic): raise ValueError('pic should be 2/3 dimensional. Got {} dimensions.'.format(pic.ndim)) + default_float_dtype = torch.get_default_dtype() + if isinstance(pic, np.ndarray): # handle numpy array if pic.ndim == 2: @@ -112,12 +114,12 @@ def to_tensor(pic): img = torch.from_numpy(pic.transpose((2, 0, 1))).contiguous() # backward compatibility if isinstance(img, torch.ByteTensor): - return img.float().div(255) + return img.to(dtype=default_float_dtype).div(255) else: return img if accimage is not None and isinstance(pic, accimage.Image): - nppic = np.zeros([pic.channels, pic.height, pic.width], dtype=np.float32) + nppic = np.zeros([pic.channels, pic.height, pic.width], dtype=default_float_dtype) pic.copyto(nppic) return torch.from_numpy(nppic) @@ -137,7 +139,7 @@ def to_tensor(pic): # put it from HWC to CHW format img = img.permute((2, 0, 1)).contiguous() if isinstance(img, torch.ByteTensor): - return img.float().div(255) + return img.to(dtype=default_float_dtype).div(255) else: return img