From 3970299344985e0a0407ea443565af367e34ab76 Mon Sep 17 00:00:00 2001 From: vfdev Date: Thu, 20 Aug 2020 15:16:30 +0200 Subject: [PATCH] Adapted functional tensor tests on CPU/CUDA (#2569) * Adapted almost all functional tensor tests on CPU/CUDA - fixed bug with transforms using generated grid - remains *_crop, blocked by #2568 - TODO: test_adjustments * Apply suggestions from code review Co-authored-by: Francisco Massa * Fixed issues according to review * Split tests into two: cpu and cuda * Updated test_adjustments to run on CPU and CUDA Co-authored-by: Francisco Massa --- test/test_functional_tensor.py | 233 ++++++++++++++------ torchvision/transforms/functional.py | 6 +- torchvision/transforms/functional_tensor.py | 31 +-- 3 files changed, 192 insertions(+), 78 deletions(-) diff --git a/test/test_functional_tensor.py b/test/test_functional_tensor.py index aab9d3d9b02..697316011a9 100644 --- a/test/test_functional_tensor.py +++ b/test/test_functional_tensor.py @@ -17,16 +17,16 @@ class Tester(unittest.TestCase): - def _create_data(self, height=3, width=3, channels=3): - tensor = torch.randint(0, 255, (channels, height, width), dtype=torch.uint8) - pil_img = Image.fromarray(tensor.permute(1, 2, 0).contiguous().numpy()) + def _create_data(self, height=3, width=3, channels=3, device="cpu"): + tensor = torch.randint(0, 255, (channels, height, width), dtype=torch.uint8, device=device) + pil_img = Image.fromarray(tensor.permute(1, 2, 0).contiguous().cpu().numpy()) return tensor, pil_img def compareTensorToPIL(self, tensor, pil_image, msg=None): pil_tensor = torch.as_tensor(np.array(pil_image).transpose((2, 0, 1))) if msg is None: msg = "tensor:\n{} \ndid not equal PIL tensor:\n{}".format(tensor, pil_tensor) - self.assertTrue(tensor.equal(pil_tensor), msg) + self.assertTrue(tensor.cpu().equal(pil_tensor), msg) def approxEqualTensorToPIL(self, tensor, pil_image, tol=1e-5, msg=None): pil_tensor = torch.as_tensor(np.array(pil_image).transpose((2, 0, 1))).to(tensor) @@ -36,9 +36,9 @@ def approxEqualTensorToPIL(self, tensor, pil_image, tol=1e-5, msg=None): msg="{}: mae={}, tol={}: \n{}\nvs\n{}".format(msg, mae, tol, tensor[0, :10, :10], pil_tensor[0, :10, :10]) ) - def test_vflip(self): + def _test_vflip(self, device): script_vflip = torch.jit.script(F_t.vflip) - img_tensor = torch.randn(3, 16, 16) + img_tensor = torch.randn(3, 16, 16, device=device) img_tensor_clone = img_tensor.clone() vflipped_img = F_t.vflip(img_tensor) vflipped_img_again = F_t.vflip(vflipped_img) @@ -49,9 +49,16 @@ def test_vflip(self): vflipped_img_script = script_vflip(img_tensor) self.assertTrue(torch.equal(vflipped_img, vflipped_img_script)) - def test_hflip(self): + def test_vflip_cpu(self): + self._test_vflip("cpu") + + @unittest.skipIf(not torch.cuda.is_available(), reason="Skip if no CUDA device") + def test_vflip_cuda(self): + self._test_vflip("cuda") + + def _test_hflip(self, device): script_hflip = torch.jit.script(F_t.hflip) - img_tensor = torch.randn(3, 16, 16) + img_tensor = torch.randn(3, 16, 16, device=device) img_tensor_clone = img_tensor.clone() hflipped_img = F_t.hflip(img_tensor) hflipped_img_again = F_t.hflip(hflipped_img) @@ -62,10 +69,17 @@ def test_hflip(self): hflipped_img_script = script_hflip(img_tensor) self.assertTrue(torch.equal(hflipped_img, hflipped_img_script)) - def test_crop(self): - script_crop = torch.jit.script(F_t.crop) + def test_hflip_cpu(self): + self._test_hflip("cpu") + + @unittest.skipIf(not torch.cuda.is_available(), reason="Skip if no CUDA device") + def test_hflip_cuda(self): + self._test_hflip("cuda") + + def _test_crop(self, device): + script_crop = torch.jit.script(F.crop) - img_tensor, pil_img = self._create_data(16, 18) + img_tensor, pil_img = self._create_data(16, 18, device=device) test_configs = [ (1, 2, 4, 5), # crop inside top-left corner @@ -83,6 +97,13 @@ def test_crop(self): img_tensor_cropped = script_crop(img_tensor, top, left, height, width) self.compareTensorToPIL(img_tensor_cropped, pil_img_cropped) + def test_crop_cpu(self): + self._test_crop("cpu") + + @unittest.skipIf(not torch.cuda.is_available(), reason="Skip if no CUDA device") + def test_crop_cuda(self): + self._test_crop("cuda") + def test_hsv2rgb(self): shape = (3, 100, 150) for _ in range(20): @@ -128,7 +149,7 @@ def test_rgb2hsv(self): self.assertLess(max_diff, 1e-5) - def test_adjustments(self): + def _test_adjustments(self, device): script_adjust_brightness = torch.jit.script(F_t.adjust_brightness) script_adjust_contrast = torch.jit.script(F_t.adjust_contrast) script_adjust_saturation = torch.jit.script(F_t.adjust_saturation) @@ -143,16 +164,16 @@ def test_adjustments(self): shape = (channels, dims[0], dims[1]) if torch.randint(0, 2, (1,)) == 0: - img = torch.rand(*shape, dtype=torch.float) + img = torch.rand(*shape, dtype=torch.float, device=device) else: - img = torch.randint(0, 256, shape, dtype=torch.uint8) + img = torch.randint(0, 256, shape, dtype=torch.uint8, device=device) - factor = 3 * torch.rand(1) + factor = 3 * torch.rand(1).item() img_clone = img.clone() for f, ft, sft in fns: - ft_img = ft(img, factor) - sft_img = sft(img, factor) + ft_img = ft(img, factor).cpu() + sft_img = sft(img, factor).cpu() if not img.dtype.is_floating_point: ft_img = ft_img.to(torch.float) / 255 sft_img = sft_img.to(torch.float) / 255 @@ -170,15 +191,15 @@ def test_adjustments(self): self.assertTrue(torch.equal(img, img_clone)) # test for class interface - f = transforms.ColorJitter(brightness=factor.item()) + f = transforms.ColorJitter(brightness=factor) scripted_fn = torch.jit.script(f) scripted_fn(img) - f = transforms.ColorJitter(contrast=factor.item()) + f = transforms.ColorJitter(contrast=factor) scripted_fn = torch.jit.script(f) scripted_fn(img) - f = transforms.ColorJitter(saturation=factor.item()) + f = transforms.ColorJitter(saturation=factor) scripted_fn = torch.jit.script(f) scripted_fn(img) @@ -186,6 +207,13 @@ def test_adjustments(self): scripted_fn = torch.jit.script(f) scripted_fn(img) + def test_adjustments(self): + self._test_adjustments("cpu") + + @unittest.skipIf(not torch.cuda.is_available(), reason="Skip if no CUDA device") + def test_adjustments_cuda(self): + self._test_adjustments("cuda") + def test_rgb_to_grayscale(self): script_rgb_to_grayscale = torch.jit.script(F_t.rgb_to_grayscale) img_tensor = torch.randint(0, 255, (3, 16, 16), dtype=torch.uint8) @@ -199,10 +227,10 @@ def test_rgb_to_grayscale(self): grayscale_script = script_rgb_to_grayscale(img_tensor).to(int) self.assertTrue(torch.equal(grayscale_script, grayscale_tensor)) - def test_center_crop(self): + def _test_center_crop(self, device): script_center_crop = torch.jit.script(F.center_crop) - img_tensor, pil_img = self._create_data(32, 34) + img_tensor, pil_img = self._create_data(32, 34, device=device) cropped_pil_image = F.center_crop(pil_img, [10, 11]) @@ -212,10 +240,17 @@ def test_center_crop(self): cropped_tensor = script_center_crop(img_tensor, [10, 11]) self.compareTensorToPIL(cropped_tensor, cropped_pil_image) - def test_five_crop(self): + def test_center_crop(self): + self._test_center_crop("cpu") + + @unittest.skipIf(not torch.cuda.is_available(), reason="Skip if no CUDA device") + def test_center_crop_cuda(self): + self._test_center_crop("cuda") + + def _test_five_crop(self, device): script_five_crop = torch.jit.script(F.five_crop) - img_tensor, pil_img = self._create_data(32, 34) + img_tensor, pil_img = self._create_data(32, 34, device=device) cropped_pil_images = F.five_crop(pil_img, [10, 11]) @@ -227,10 +262,17 @@ def test_five_crop(self): for i in range(5): self.compareTensorToPIL(cropped_tensors[i], cropped_pil_images[i]) - def test_ten_crop(self): + def test_five_crop(self): + self._test_five_crop("cpu") + + @unittest.skipIf(not torch.cuda.is_available(), reason="Skip if no CUDA device") + def test_five_crop_cuda(self): + self._test_five_crop("cuda") + + def _test_ten_crop(self, device): script_ten_crop = torch.jit.script(F.ten_crop) - img_tensor, pil_img = self._create_data(32, 34) + img_tensor, pil_img = self._create_data(32, 34, device=device) cropped_pil_images = F.ten_crop(pil_img, [10, 11]) @@ -242,9 +284,16 @@ def test_ten_crop(self): for i in range(10): self.compareTensorToPIL(cropped_tensors[i], cropped_pil_images[i]) - def test_pad(self): + def test_ten_crop(self): + self._test_ten_crop("cpu") + + @unittest.skipIf(not torch.cuda.is_available(), reason="Skip if no CUDA device") + def test_ten_crop_cuda(self): + self._test_ten_crop("cuda") + + def _test_pad(self, device): script_fn = torch.jit.script(F_t.pad) - tensor, pil_img = self._create_data(7, 8) + tensor, pil_img = self._create_data(7, 8, device=device) for dt in [None, torch.float32, torch.float64]: if dt is not None: @@ -280,9 +329,16 @@ def test_pad(self): with self.assertRaises(ValueError, msg="Padding can not be negative for symmetric padding_mode"): F_t.pad(tensor, (-2, -3), padding_mode="symmetric") - def test_adjust_gamma(self): - script_fn = torch.jit.script(F_t.adjust_gamma) - tensor, pil_img = self._create_data(26, 36) + def test_pad_cpu(self): + self._test_pad("cpu") + + @unittest.skipIf(not torch.cuda.is_available(), reason="Skip if no CUDA device") + def test_pad_cuda(self): + self._test_pad("cuda") + + def _test_adjust_gamma(self, device): + script_fn = torch.jit.script(F.adjust_gamma) + tensor, pil_img = self._create_data(26, 36, device=device) for dt in [torch.float64, torch.float32, None]: @@ -293,8 +349,8 @@ def test_adjust_gamma(self): gains = [0.7, 1.0, 1.3] for gamma, gain in zip(gammas, gains): - adjusted_tensor = F_t.adjust_gamma(tensor, gamma, gain) - adjusted_pil = F_pil.adjust_gamma(pil_img, gamma, gain) + adjusted_tensor = F.adjust_gamma(tensor, gamma, gain) + adjusted_pil = F.adjust_gamma(pil_img, gamma, gain) scripted_result = script_fn(tensor, gamma, gain) self.assertEqual(adjusted_tensor.dtype, scripted_result.dtype) self.assertEqual(adjusted_tensor.size()[1:], adjusted_pil.size[::-1]) @@ -305,11 +361,18 @@ def test_adjust_gamma(self): self.compareTensorToPIL(rbg_tensor, adjusted_pil) - self.assertTrue(adjusted_tensor.equal(scripted_result)) + self.assertTrue(adjusted_tensor.allclose(scripted_result)) + + def test_adjust_gamma_cpu(self): + self._test_adjust_gamma("cpu") + + @unittest.skipIf(not torch.cuda.is_available(), reason="Skip if no CUDA device") + def test_adjust_gamma_cuda(self): + self._test_adjust_gamma("cuda") - def test_resize(self): + def _test_resize(self, device): script_fn = torch.jit.script(F_t.resize) - tensor, pil_img = self._create_data(26, 36) + tensor, pil_img = self._create_data(26, 36, device=device) for dt in [None, torch.float32, torch.float64]: if dt is not None: @@ -345,16 +408,23 @@ def test_resize(self): resize_result = script_fn(tensor, size=script_size, interpolation=interpolation) self.assertTrue(resized_tensor.equal(resize_result), msg="{}, {}".format(size, interpolation)) - def test_resized_crop(self): + def test_resize_cpu(self): + self._test_resize("cpu") + + @unittest.skipIf(not torch.cuda.is_available(), reason="Skip if no CUDA device") + def test_resize_cuda(self): + self._test_resize("cuda") + + def _test_resized_crop(self, device): # test values of F.resized_crop in several cases: # 1) resize to the same size, crop to the same size => should be identity - tensor, _ = self._create_data(26, 36) + tensor, _ = self._create_data(26, 36, device=device) for i in [0, 2, 3]: out_tensor = F.resized_crop(tensor, top=0, left=0, height=26, width=36, size=[26, 36], interpolation=i) self.assertTrue(tensor.equal(out_tensor), msg="{} vs {}".format(out_tensor[0, :5, :5], tensor[0, :5, :5])) # 2) resize by half and crop a TL corner - tensor, _ = self._create_data(26, 36) + tensor, _ = self._create_data(26, 36, device=device) out_tensor = F.resized_crop(tensor, top=0, left=0, height=20, width=30, size=[10, 15], interpolation=0) expected_out_tensor = tensor[:, :20:2, :30:2] self.assertTrue( @@ -362,11 +432,18 @@ def test_resized_crop(self): msg="{} vs {}".format(expected_out_tensor[0, :10, :10], out_tensor[0, :10, :10]) ) - def test_affine(self): + def test_resized_crop_cpu(self): + self._test_resized_crop("cpu") + + @unittest.skipIf(not torch.cuda.is_available(), reason="Skip if no CUDA device") + def test_resized_crop_cuda(self): + self._test_resized_crop("cuda") + + def _test_affine(self, device): # Tests on square and rectangular images scripted_affine = torch.jit.script(F.affine) - for tensor, pil_img in [self._create_data(26, 26), self._create_data(32, 26)]: + for tensor, pil_img in [self._create_data(26, 26, device=device), self._create_data(32, 26, device=device)]: # 1) identity map out_tensor = F.affine(tensor, angle=0, translate=[0, 0], scale=1.0, shear=[0.0, 0.0], resample=0) @@ -390,8 +467,16 @@ def test_affine(self): (180, torch.rot90(tensor, k=2, dims=(-1, -2))), ] for a, true_tensor in test_configs: + + out_pil_img = F.affine( + pil_img, angle=a, translate=[0, 0], scale=1.0, shear=[0.0, 0.0], resample=0 + ) + out_pil_tensor = torch.from_numpy(np.array(out_pil_img).transpose((2, 0, 1))).to(device) + for fn in [F.affine, scripted_affine]: - out_tensor = fn(tensor, angle=a, translate=[0, 0], scale=1.0, shear=[0.0, 0.0], resample=0) + out_tensor = fn( + tensor, angle=a, translate=[0, 0], scale=1.0, shear=[0.0, 0.0], resample=0 + ) if true_tensor is not None: self.assertTrue( true_tensor.equal(out_tensor), @@ -400,11 +485,6 @@ def test_affine(self): else: true_tensor = out_tensor - out_pil_img = F.affine( - pil_img, angle=a, translate=[0, 0], scale=1.0, shear=[0.0, 0.0], resample=0 - ) - out_pil_tensor = torch.from_numpy(np.array(out_pil_img).transpose((2, 0, 1))) - num_diff_pixels = (true_tensor != out_pil_tensor).sum().item() / 3.0 ratio_diff_pixels = num_diff_pixels / true_tensor.shape[-1] / true_tensor.shape[-2] # Tolerance : less than 6% of different pixels @@ -420,12 +500,16 @@ def test_affine(self): 90, 45, 15, -30, -60, -120 ] for a in test_configs: + + out_pil_img = F.affine( + pil_img, angle=a, translate=[0, 0], scale=1.0, shear=[0.0, 0.0], resample=0 + ) + out_pil_tensor = torch.from_numpy(np.array(out_pil_img).transpose((2, 0, 1))) + for fn in [F.affine, scripted_affine]: - out_tensor = fn(tensor, angle=a, translate=[0, 0], scale=1.0, shear=[0.0, 0.0], resample=0) - out_pil_img = F.affine( - pil_img, angle=a, translate=[0, 0], scale=1.0, shear=[0.0, 0.0], resample=0 - ) - out_pil_tensor = torch.from_numpy(np.array(out_pil_img).transpose((2, 0, 1))) + out_tensor = fn( + tensor, angle=a, translate=[0, 0], scale=1.0, shear=[0.0, 0.0], resample=0 + ).cpu() num_diff_pixels = (out_tensor != out_pil_tensor).sum().item() / 3.0 ratio_diff_pixels = num_diff_pixels / out_tensor.shape[-1] / out_tensor.shape[-2] @@ -443,9 +527,12 @@ def test_affine(self): [10, 12], (-12, -13) ] for t in test_configs: + + out_pil_img = F.affine(pil_img, angle=0, translate=t, scale=1.0, shear=[0.0, 0.0], resample=0) + for fn in [F.affine, scripted_affine]: out_tensor = fn(tensor, angle=0, translate=t, scale=1.0, shear=[0.0, 0.0], resample=0) - out_pil_img = F.affine(pil_img, angle=0, translate=t, scale=1.0, shear=[0.0, 0.0], resample=0) + self.compareTensorToPIL(out_tensor, out_pil_img) # 3) Test rotation + translation + scale + share @@ -467,23 +554,31 @@ def test_affine(self): out_pil_tensor = torch.from_numpy(np.array(out_pil_img).transpose((2, 0, 1))) for fn in [F.affine, scripted_affine]: - out_tensor = fn(tensor, angle=a, translate=t, scale=s, shear=sh, resample=r) + out_tensor = fn(tensor, angle=a, translate=t, scale=s, shear=sh, resample=r).cpu() num_diff_pixels = (out_tensor != out_pil_tensor).sum().item() / 3.0 ratio_diff_pixels = num_diff_pixels / out_tensor.shape[-1] / out_tensor.shape[-2] - # Tolerance : less than 5% of different pixels + # Tolerance : less than 5% (cpu), 6% (cuda) of different pixels + tol = 0.06 if device == "cuda" else 0.05 self.assertLess( ratio_diff_pixels, - 0.05, + tol, msg="{}: {}\n{} vs \n{}".format( (r, a, t, s, sh), ratio_diff_pixels, out_tensor[0, :7, :7], out_pil_tensor[0, :7, :7] ) ) - def test_rotate(self): + def test_affine_cpu(self): + self._test_affine("cpu") + + @unittest.skipIf(not torch.cuda.is_available(), reason="Skip if no CUDA device") + def test_affine_cuda(self): + self._test_affine("cuda") + + def _test_rotate(self, device): # Tests on square image scripted_rotate = torch.jit.script(F.rotate) - for tensor, pil_img in [self._create_data(26, 26), self._create_data(32, 26)]: + for tensor, pil_img in [self._create_data(26, 26, device=device), self._create_data(32, 26, device=device)]: img_size = pil_img.size centers = [ @@ -500,7 +595,7 @@ def test_rotate(self): out_pil_img = F.rotate(pil_img, angle=a, resample=r, expand=e, center=c) out_pil_tensor = torch.from_numpy(np.array(out_pil_img).transpose((2, 0, 1))) for fn in [F.rotate, scripted_rotate]: - out_tensor = fn(tensor, angle=a, resample=r, expand=e, center=c) + out_tensor = fn(tensor, angle=a, resample=r, expand=e, center=c).cpu() self.assertEqual( out_tensor.shape, @@ -523,11 +618,18 @@ def test_rotate(self): ) ) - def test_perspective(self): + def test_rotate_cpu(self): + self._test_rotate("cpu") + + @unittest.skipIf(not torch.cuda.is_available(), reason="Skip if no CUDA device") + def test_rotate_cuda(self): + self._test_rotate("cuda") + + def _test_perspective(self, device): from torchvision.transforms import RandomPerspective - for tensor, pil_img in [self._create_data(26, 34), self._create_data(26, 26)]: + for tensor, pil_img in [self._create_data(26, 34, device=device), self._create_data(26, 26, device=device)]: scripted_tranform = torch.jit.script(F.perspective) @@ -547,7 +649,7 @@ def test_perspective(self): out_pil_tensor = torch.from_numpy(np.array(out_pil_img).transpose((2, 0, 1))) for fn in [F.perspective, scripted_tranform]: - out_tensor = fn(tensor, startpoints=spoints, endpoints=epoints, interpolation=r) + out_tensor = fn(tensor, startpoints=spoints, endpoints=epoints, interpolation=r).cpu() num_diff_pixels = (out_tensor != out_pil_tensor).sum().item() / 3.0 ratio_diff_pixels = num_diff_pixels / out_tensor.shape[-1] / out_tensor.shape[-2] @@ -563,6 +665,13 @@ def test_perspective(self): ) ) + def test_perspective_cpu(self): + self._test_perspective("cpu") + + @unittest.skipIf(not torch.cuda.is_available(), reason="Skip if no CUDA device") + def test_perspective_cuda(self): + self._test_perspective("cuda") + if __name__ == '__main__': unittest.main() diff --git a/torchvision/transforms/functional.py b/torchvision/transforms/functional.py index 4d607400b48..06b2a0e1f80 100644 --- a/torchvision/transforms/functional.py +++ b/torchvision/transforms/functional.py @@ -223,10 +223,10 @@ def to_pil_image(pic, mode=None): pic = np.expand_dims(pic, 2) npimg = pic - if isinstance(pic, torch.FloatTensor) and mode != 'F': - pic = pic.mul(255).byte() if isinstance(pic, torch.Tensor): - npimg = np.transpose(pic.numpy(), (1, 2, 0)) + if pic.is_floating_point() and mode != 'F': + pic = pic.mul(255).byte() + npimg = np.transpose(pic.cpu().numpy(), (1, 2, 0)) if not isinstance(npimg, np.ndarray): raise TypeError('Input pic must be a torch.Tensor or NumPy ndarray, ' + diff --git a/torchvision/transforms/functional_tensor.py b/torchvision/transforms/functional_tensor.py index 0b14e9acab7..36a12280310 100644 --- a/torchvision/transforms/functional_tensor.py +++ b/torchvision/transforms/functional_tensor.py @@ -3,7 +3,7 @@ import torch from torch import Tensor -from torch.nn.functional import affine_grid, grid_sample +from torch.nn.functional import grid_sample from torch.jit.annotations import List, BroadcastingList2 @@ -714,12 +714,13 @@ def _gen_affine_grid( # 2) we can normalize by other image size, such that it covers "extend" option like in PIL.Image.rotate d = 0.5 - base_grid = torch.empty(1, oh, ow, 3) + base_grid = torch.empty(1, oh, ow, 3, dtype=theta.dtype, device=theta.device) base_grid[..., 0].copy_(torch.linspace(-ow * 0.5 + d, ow * 0.5 + d - 1, steps=ow)) base_grid[..., 1].copy_(torch.linspace(-oh * 0.5 + d, oh * 0.5 + d - 1, steps=oh).unsqueeze_(-1)) base_grid[..., 2].fill_(1) - output_grid = base_grid.view(1, oh * ow, 3).bmm(theta.transpose(1, 2) / torch.tensor([0.5 * w, 0.5 * h])) + rescaled_theta = theta.transpose(1, 2) / torch.tensor([0.5 * w, 0.5 * h], dtype=theta.dtype, device=theta.device) + output_grid = base_grid.view(1, oh * ow, 3).bmm(rescaled_theta) return output_grid.view(1, oh, ow, 2) @@ -746,14 +747,15 @@ def affine( _assert_grid_transform_inputs(img, matrix, resample, fillcolor, _interpolation_modes) - theta = torch.tensor(matrix, dtype=torch.float).reshape(1, 2, 3) + theta = torch.tensor(matrix, dtype=torch.float, device=img.device).reshape(1, 2, 3) shape = img.shape + # grid will be generated on the same device as theta and img grid = _gen_affine_grid(theta, w=shape[-1], h=shape[-2], ow=shape[-1], oh=shape[-2]) mode = _interpolation_modes[resample] return _apply_grid_transform(img, grid, mode) -def _compute_output_size(theta: Tensor, w: int, h: int) -> Tuple[int, int]: +def _compute_output_size(matrix: List[float], w: int, h: int) -> Tuple[int, int]: # Inspired of PIL implementation: # https://github.com/python-pillow/Pillow/blob/11de3318867e4398057373ee9f12dcb33db7335c/src/PIL/Image.py#L2054 @@ -765,6 +767,7 @@ def _compute_output_size(theta: Tensor, w: int, h: int) -> Tuple[int, int]: [0.5 * w, 0.5 * h, 1.0], [0.5 * w, -0.5 * h, 1.0], ]) + theta = torch.tensor(matrix, dtype=torch.float).reshape(1, 2, 3) new_pts = pts.view(1, 4, 3).bmm(theta.transpose(1, 2)).view(4, 2) min_vals, _ = new_pts.min(dim=0) max_vals, _ = new_pts.max(dim=0) @@ -807,16 +810,17 @@ def rotate( } _assert_grid_transform_inputs(img, matrix, resample, fill, _interpolation_modes) - theta = torch.tensor(matrix).reshape(1, 2, 3) w, h = img.shape[-1], img.shape[-2] - ow, oh = _compute_output_size(theta, w, h) if expand else (w, h) + ow, oh = _compute_output_size(matrix, w, h) if expand else (w, h) + theta = torch.tensor(matrix, dtype=torch.float, device=img.device).reshape(1, 2, 3) + # grid will be generated on the same device as theta and img grid = _gen_affine_grid(theta, w=w, h=h, ow=ow, oh=oh) mode = _interpolation_modes[resample] return _apply_grid_transform(img, grid, mode) -def _perspective_grid(coeffs: List[float], ow: int, oh: int): +def _perspective_grid(coeffs: List[float], ow: int, oh: int, device: torch.device): # https://github.com/python-pillow/Pillow/blob/4634eafe3c695a014267eefdce830b4a825beed7/ # src/libImaging/Geometry.c#L394 @@ -828,19 +832,20 @@ def _perspective_grid(coeffs: List[float], ow: int, oh: int): theta1 = torch.tensor([[ [coeffs[0], coeffs[1], coeffs[2]], [coeffs[3], coeffs[4], coeffs[5]] - ]]) + ]], dtype=torch.float, device=device) theta2 = torch.tensor([[ [coeffs[6], coeffs[7], 1.0], [coeffs[6], coeffs[7], 1.0] - ]]) + ]], dtype=torch.float, device=device) d = 0.5 - base_grid = torch.empty(1, oh, ow, 3) + base_grid = torch.empty(1, oh, ow, 3, dtype=torch.float, device=device) base_grid[..., 0].copy_(torch.linspace(d, ow * 1.0 + d - 1.0, steps=ow)) base_grid[..., 1].copy_(torch.linspace(d, oh * 1.0 + d - 1.0, steps=oh).unsqueeze_(-1)) base_grid[..., 2].fill_(1) - output_grid1 = base_grid.view(1, oh * ow, 3).bmm(theta1.transpose(1, 2) / torch.tensor([0.5 * ow, 0.5 * oh])) + rescaled_theta1 = theta1.transpose(1, 2) / torch.tensor([0.5 * ow, 0.5 * oh], dtype=torch.float, device=device) + output_grid1 = base_grid.view(1, oh * ow, 3).bmm(rescaled_theta1) output_grid2 = base_grid.view(1, oh * ow, 3).bmm(theta2.transpose(1, 2)) output_grid = output_grid1 / output_grid2 - 1.0 @@ -880,7 +885,7 @@ def perspective( ) ow, oh = img.shape[-1], img.shape[-2] - grid = _perspective_grid(perspective_coeffs, ow=ow, oh=oh) + grid = _perspective_grid(perspective_coeffs, ow=ow, oh=oh, device=img.device) mode = _interpolation_modes[interpolation] return _apply_grid_transform(img, grid, mode)