diff --git a/docs/source/transforms.rst b/docs/source/transforms.rst index 2e0c6cefb8d..63b6399b1fe 100644 --- a/docs/source/transforms.rst +++ b/docs/source/transforms.rst @@ -9,6 +9,11 @@ Functional transforms give fine-grained control over the transformations. This is useful if you have to build a more complex transformation pipeline (e.g. in the case of segmentation tasks). +All transformations accept PIL Image, Tensor Image or batch of Tensor Images as input. Tensor Image is a tensor with +``(C, H, W)`` shape, where ``C`` is a number of channels, ``H`` and ``W`` are image height and width. Batch of +Tensor Images is a tensor of ``(B, C, H, W)`` shape, where ``B`` is a number of images in the batch. Deterministic or +random transformations applied on the batch of Tensor Images identically transform all the images of the batch. + .. autoclass:: Compose Transforms on PIL Image diff --git a/test/common_utils.py b/test/common_utils.py index 13e3561f19b..385bc670a2b 100644 --- a/test/common_utils.py +++ b/test/common_utils.py @@ -341,6 +341,15 @@ def _create_data(self, height=3, width=3, channels=3, device="cpu"): pil_img = Image.fromarray(tensor.permute(1, 2, 0).contiguous().cpu().numpy()) return tensor, pil_img + def _create_data_batch(self, height=3, width=3, channels=3, num_samples=4, device="cpu"): + batch_tensor = torch.randint( + 0, 255, + (num_samples, channels, height, width), + dtype=torch.uint8, + device=device + ) + return batch_tensor + def compareTensorToPIL(self, tensor, pil_image, msg=None): np_pil_image = np.array(pil_image) if np_pil_image.ndim == 2: diff --git a/test/test_functional_tensor.py b/test/test_functional_tensor.py index c75c3cf93e2..87373359e83 100644 --- a/test/test_functional_tensor.py +++ b/test/test_functional_tensor.py @@ -6,7 +6,6 @@ from PIL.Image import NEAREST, BILINEAR, BICUBIC import torch -import torchvision.transforms as transforms import torchvision.transforms.functional_tensor as F_t import torchvision.transforms.functional_pil as F_pil import torchvision.transforms.functional as F @@ -19,31 +18,47 @@ class Tester(TransformsTester): def setUp(self): self.device = "cpu" + def _test_fn_on_batch(self, batch_tensors, fn, **fn_kwargs): + transformed_batch = fn(batch_tensors, **fn_kwargs) + for i in range(len(batch_tensors)): + img_tensor = batch_tensors[i, ...] + transformed_img = fn(img_tensor, **fn_kwargs) + self.assertTrue(transformed_img.equal(transformed_batch[i, ...])) + + scripted_fn = torch.jit.script(fn) + # scriptable function test + s_transformed_batch = scripted_fn(batch_tensors, **fn_kwargs) + self.assertTrue(transformed_batch.allclose(s_transformed_batch)) + def test_vflip(self): - script_vflip = torch.jit.script(F_t.vflip) - img_tensor = torch.randn(3, 16, 16, device=self.device) - img_tensor_clone = img_tensor.clone() - vflipped_img = F_t.vflip(img_tensor) - vflipped_img_again = F_t.vflip(vflipped_img) - self.assertEqual(vflipped_img.shape, img_tensor.shape) - self.assertTrue(torch.equal(img_tensor, vflipped_img_again)) - self.assertTrue(torch.equal(img_tensor, img_tensor_clone)) + script_vflip = torch.jit.script(F.vflip) + + img_tensor, pil_img = self._create_data(16, 18, device=self.device) + vflipped_img = F.vflip(img_tensor) + vflipped_pil_img = F.vflip(pil_img) + self.compareTensorToPIL(vflipped_img, vflipped_pil_img) + # scriptable function test vflipped_img_script = script_vflip(img_tensor) - self.assertTrue(torch.equal(vflipped_img, vflipped_img_script)) + self.assertTrue(vflipped_img.equal(vflipped_img_script)) + + batch_tensors = self._create_data_batch(16, 18, num_samples=4, device=self.device) + self._test_fn_on_batch(batch_tensors, F.vflip) def test_hflip(self): - script_hflip = torch.jit.script(F_t.hflip) - img_tensor = torch.randn(3, 16, 16, device=self.device) - img_tensor_clone = img_tensor.clone() - hflipped_img = F_t.hflip(img_tensor) - hflipped_img_again = F_t.hflip(hflipped_img) - self.assertEqual(hflipped_img.shape, img_tensor.shape) - self.assertTrue(torch.equal(img_tensor, hflipped_img_again)) - self.assertTrue(torch.equal(img_tensor, img_tensor_clone)) + script_hflip = torch.jit.script(F.hflip) + + img_tensor, pil_img = self._create_data(16, 18, device=self.device) + hflipped_img = F.hflip(img_tensor) + hflipped_pil_img = F.hflip(pil_img) + self.compareTensorToPIL(hflipped_img, hflipped_pil_img) + # scriptable function test hflipped_img_script = script_hflip(img_tensor) - self.assertTrue(torch.equal(hflipped_img, hflipped_img_script)) + self.assertTrue(hflipped_img.equal(hflipped_img_script)) + + batch_tensors = self._create_data_batch(16, 18, num_samples=4, device=self.device) + self._test_fn_on_batch(batch_tensors, F.hflip) def test_crop(self): script_crop = torch.jit.script(F.crop) @@ -66,6 +81,9 @@ def test_crop(self): img_tensor_cropped = script_crop(img_tensor, top, left, height, width) self.compareTensorToPIL(img_tensor_cropped, pil_img_cropped) + batch_tensors = self._create_data_batch(16, 18, num_samples=4, device=self.device) + self._test_fn_on_batch(batch_tensors, F.crop, top=top, left=left, height=height, width=width) + def test_hsv2rgb(self): scripted_fn = torch.jit.script(F_t._hsv2rgb) shape = (3, 100, 150) @@ -89,6 +107,9 @@ def test_hsv2rgb(self): s_rgb_img = scripted_fn(hsv_img) self.assertTrue(rgb_img.allclose(s_rgb_img)) + batch_tensors = self._create_data_batch(120, 100, num_samples=4, device=self.device).float() + self._test_fn_on_batch(batch_tensors, F_t._hsv2rgb) + def test_rgb2hsv(self): scripted_fn = torch.jit.script(F_t._rgb2hsv) shape = (3, 150, 100) @@ -97,7 +118,7 @@ def test_rgb2hsv(self): hsv_img = F_t._rgb2hsv(rgb_img) ft_hsv_img = hsv_img.permute(1, 2, 0).flatten(0, 1) - r, g, b, = rgb_img.unbind(0) + r, g, b, = rgb_img.unbind(dim=-3) r = r.flatten().cpu().numpy() g = g.flatten().cpu().numpy() b = b.flatten().cpu().numpy() @@ -119,6 +140,9 @@ def test_rgb2hsv(self): s_hsv_img = scripted_fn(rgb_img) self.assertTrue(hsv_img.allclose(s_hsv_img)) + batch_tensors = self._create_data_batch(120, 100, num_samples=4, device=self.device).float() + self._test_fn_on_batch(batch_tensors, F_t._rgb2hsv) + def test_rgb_to_grayscale(self): script_rgb_to_grayscale = torch.jit.script(F.rgb_to_grayscale) @@ -128,14 +152,14 @@ def test_rgb_to_grayscale(self): gray_pil_image = F.rgb_to_grayscale(pil_img, num_output_channels=num_output_channels) gray_tensor = F.rgb_to_grayscale(img_tensor, num_output_channels=num_output_channels) - if num_output_channels == 1: - print(gray_tensor.shape) - self.approxEqualTensorToPIL(gray_tensor.float(), gray_pil_image, tol=1.0 + 1e-10, agg_method="max") s_gray_tensor = script_rgb_to_grayscale(img_tensor, num_output_channels=num_output_channels) self.assertTrue(s_gray_tensor.equal(gray_tensor)) + batch_tensors = self._create_data_batch(16, 18, num_samples=4, device=self.device) + self._test_fn_on_batch(batch_tensors, F.rgb_to_grayscale, num_output_channels=num_output_channels) + def test_center_crop(self): script_center_crop = torch.jit.script(F.center_crop) @@ -149,6 +173,9 @@ def test_center_crop(self): cropped_tensor = script_center_crop(img_tensor, [10, 11]) self.compareTensorToPIL(cropped_tensor, cropped_pil_image) + batch_tensors = self._create_data_batch(16, 18, num_samples=4, device=self.device) + self._test_fn_on_batch(batch_tensors, F.center_crop, output_size=[10, 11]) + def test_five_crop(self): script_five_crop = torch.jit.script(F.five_crop) @@ -164,6 +191,23 @@ def test_five_crop(self): for i in range(5): self.compareTensorToPIL(cropped_tensors[i], cropped_pil_images[i]) + batch_tensors = self._create_data_batch(16, 18, num_samples=4, device=self.device) + tuple_transformed_batches = F.five_crop(batch_tensors, [10, 11]) + for i in range(len(batch_tensors)): + img_tensor = batch_tensors[i, ...] + tuple_transformed_imgs = F.five_crop(img_tensor, [10, 11]) + self.assertEqual(len(tuple_transformed_imgs), len(tuple_transformed_batches)) + + for j in range(len(tuple_transformed_imgs)): + true_transformed_img = tuple_transformed_imgs[j] + transformed_img = tuple_transformed_batches[j][i, ...] + self.assertTrue(true_transformed_img.equal(transformed_img)) + + # scriptable function test + s_tuple_transformed_batches = script_five_crop(batch_tensors, [10, 11]) + for transformed_batch, s_transformed_batch in zip(tuple_transformed_batches, s_tuple_transformed_batches): + self.assertTrue(transformed_batch.equal(s_transformed_batch)) + def test_ten_crop(self): script_ten_crop = torch.jit.script(F.ten_crop) @@ -179,9 +223,27 @@ def test_ten_crop(self): for i in range(10): self.compareTensorToPIL(cropped_tensors[i], cropped_pil_images[i]) + batch_tensors = self._create_data_batch(16, 18, num_samples=4, device=self.device) + tuple_transformed_batches = F.ten_crop(batch_tensors, [10, 11]) + for i in range(len(batch_tensors)): + img_tensor = batch_tensors[i, ...] + tuple_transformed_imgs = F.ten_crop(img_tensor, [10, 11]) + self.assertEqual(len(tuple_transformed_imgs), len(tuple_transformed_batches)) + + for j in range(len(tuple_transformed_imgs)): + true_transformed_img = tuple_transformed_imgs[j] + transformed_img = tuple_transformed_batches[j][i, ...] + self.assertTrue(true_transformed_img.equal(transformed_img)) + + # scriptable function test + s_tuple_transformed_batches = script_ten_crop(batch_tensors, [10, 11]) + for transformed_batch, s_transformed_batch in zip(tuple_transformed_batches, s_tuple_transformed_batches): + self.assertTrue(transformed_batch.equal(s_transformed_batch)) + def test_pad(self): - script_fn = torch.jit.script(F_t.pad) + script_fn = torch.jit.script(F.pad) tensor, pil_img = self._create_data(7, 8, device=self.device) + batch_tensors = self._create_data_batch(16, 18, num_samples=4, device=self.device) for dt in [None, torch.float32, torch.float64, torch.float16]: @@ -192,6 +254,8 @@ def test_pad(self): if dt is not None: # This is a trivial cast to float of uint8 data to test all cases tensor = tensor.to(dt) + batch_tensors = batch_tensors.to(dt) + for pad in [2, [3, ], [0, 3], (3, 3), [4, 2, 4, 3]]: configs = [ {"padding_mode": "constant", "fill": 0}, @@ -219,6 +283,8 @@ def test_pad(self): pad_tensor_script = script_fn(tensor, script_pad, **kwargs) self.assertTrue(pad_tensor.equal(pad_tensor_script), msg="{}, {}".format(pad, kwargs)) + self._test_fn_on_batch(batch_tensors, F.pad, padding=script_pad, **kwargs) + with self.assertRaises(ValueError, msg="Padding can not be negative for symmetric padding_mode"): F_t.pad(tensor, (-2, -3), padding_mode="symmetric") @@ -226,11 +292,13 @@ def _test_adjust_fn(self, fn, fn_pil, fn_t, configs, tol=2.0 + 1e-10, agg_method script_fn = torch.jit.script(fn) torch.manual_seed(15) tensor, pil_img = self._create_data(26, 34, device=self.device) + batch_tensors = self._create_data_batch(16, 18, num_samples=4, device=self.device) for dt in [None, torch.float32, torch.float64]: if dt is not None: tensor = F.convert_image_dtype(tensor, dt) + batch_tensors = F.convert_image_dtype(batch_tensors, dt) for config in configs: adjusted_tensor = fn_t(tensor, **config) @@ -254,6 +322,8 @@ def _test_adjust_fn(self, fn, fn_pil, fn_t, configs, tol=2.0 + 1e-10, agg_method atol = 1.0 self.assertTrue(adjusted_tensor.allclose(scripted_result, atol=atol), msg=msg) + self._test_fn_on_batch(batch_tensors, fn, **config) + def test_adjust_brightness(self): self._test_adjust_fn( F.adjust_brightness, @@ -299,6 +369,7 @@ def test_adjust_gamma(self): def test_resize(self): script_fn = torch.jit.script(F_t.resize) tensor, pil_img = self._create_data(26, 36, device=self.device) + batch_tensors = self._create_data_batch(16, 18, num_samples=4, device=self.device) for dt in [None, torch.float32, torch.float64, torch.float16]: @@ -309,6 +380,8 @@ def test_resize(self): if dt is not None: # This is a trivial cast to float of uint8 data to test all cases tensor = tensor.to(dt) + batch_tensors = batch_tensors.to(dt) + for size in [32, 26, [32, ], [32, 32], (32, 32), [26, 35]]: for interpolation in [BILINEAR, BICUBIC, NEAREST]: resized_tensor = F_t.resize(tensor, size=size, interpolation=interpolation) @@ -339,6 +412,10 @@ def test_resize(self): resize_result = script_fn(tensor, size=script_size, interpolation=interpolation) self.assertTrue(resized_tensor.equal(resize_result), msg="{}, {}".format(size, interpolation)) + self._test_fn_on_batch( + batch_tensors, F.resize, size=script_size, interpolation=interpolation + ) + def test_resized_crop(self): # test values of F.resized_crop in several cases: # 1) resize to the same size, crop to the same size => should be identity @@ -356,6 +433,11 @@ def test_resized_crop(self): msg="{} vs {}".format(expected_out_tensor[0, :10, :10], out_tensor[0, :10, :10]) ) + batch_tensors = self._create_data_batch(26, 36, num_samples=4, device=self.device) + self._test_fn_on_batch( + batch_tensors, F.resized_crop, top=1, left=2, height=20, width=30, size=[10, 15], interpolation=0 + ) + def _test_affine_identity_map(self, tensor, scripted_affine): # 1) identity map out_tensor = F.affine(tensor, angle=0, translate=[0, 0], scale=1.0, shear=[0.0, 0.0], resample=0) @@ -515,7 +597,52 @@ def test_affine(self): else: self._test_affine_rect_rotations(tensor, pil_img, scripted_affine) self._test_affine_translations(tensor, pil_img, scripted_affine) - # self._test_affine_all_ops(tensor, pil_img, scripted_affine) + self._test_affine_all_ops(tensor, pil_img, scripted_affine) + + batch_tensors = self._create_data_batch(26, 36, num_samples=4, device=self.device) + if dt is not None: + batch_tensors = batch_tensors.to(dtype=dt) + + self._test_fn_on_batch( + batch_tensors, F.affine, angle=-43, translate=[-3, 4], scale=1.2, shear=[4.0, 5.0] + ) + + def _test_rotate_all_options(self, tensor, pil_img, scripted_rotate, centers): + img_size = pil_img.size + dt = tensor.dtype + for r in [0, ]: + for a in range(-180, 180, 17): + for e in [True, False]: + for c in centers: + + out_pil_img = F.rotate(pil_img, angle=a, resample=r, expand=e, center=c) + out_pil_tensor = torch.from_numpy(np.array(out_pil_img).transpose((2, 0, 1))) + for fn in [F.rotate, scripted_rotate]: + out_tensor = fn(tensor, angle=a, resample=r, expand=e, center=c).cpu() + + if out_tensor.dtype != torch.uint8: + out_tensor = out_tensor.to(torch.uint8) + + self.assertEqual( + out_tensor.shape, + out_pil_tensor.shape, + msg="{}: {} vs {}".format( + (img_size, r, dt, a, e, c), out_tensor.shape, out_pil_tensor.shape + ) + ) + num_diff_pixels = (out_tensor != out_pil_tensor).sum().item() / 3.0 + ratio_diff_pixels = num_diff_pixels / out_tensor.shape[-1] / out_tensor.shape[-2] + # Tolerance : less than 3% of different pixels + self.assertLess( + ratio_diff_pixels, + 0.03, + msg="{}: {}\n{} vs \n{}".format( + (img_size, r, dt, a, e, c), + ratio_diff_pixels, + out_tensor[0, :7, :7], + out_pil_tensor[0, :7, :7] + ) + ) def test_rotate(self): # Tests on square image @@ -540,39 +667,43 @@ def test_rotate(self): if dt is not None: tensor = tensor.to(dtype=dt) - for r in [0, ]: - for a in range(-180, 180, 17): - for e in [True, False]: - for c in centers: - - out_pil_img = F.rotate(pil_img, angle=a, resample=r, expand=e, center=c) - out_pil_tensor = torch.from_numpy(np.array(out_pil_img).transpose((2, 0, 1))) - for fn in [F.rotate, scripted_rotate]: - out_tensor = fn(tensor, angle=a, resample=r, expand=e, center=c).cpu() - - if out_tensor.dtype != torch.uint8: - out_tensor = out_tensor.to(torch.uint8) - - self.assertEqual( - out_tensor.shape, - out_pil_tensor.shape, - msg="{}: {} vs {}".format( - (img_size, r, dt, a, e, c), out_tensor.shape, out_pil_tensor.shape - ) - ) - num_diff_pixels = (out_tensor != out_pil_tensor).sum().item() / 3.0 - ratio_diff_pixels = num_diff_pixels / out_tensor.shape[-1] / out_tensor.shape[-2] - # Tolerance : less than 3% of different pixels - self.assertLess( - ratio_diff_pixels, - 0.03, - msg="{}: {}\n{} vs \n{}".format( - (img_size, r, dt, a, e, c), - ratio_diff_pixels, - out_tensor[0, :7, :7], - out_pil_tensor[0, :7, :7] - ) - ) + self._test_rotate_all_options(tensor, pil_img, scripted_rotate, centers) + + batch_tensors = self._create_data_batch(26, 36, num_samples=4, device=self.device) + if dt is not None: + batch_tensors = batch_tensors.to(dtype=dt) + + center = (20, 22) + self._test_fn_on_batch( + batch_tensors, F.rotate, angle=32, resample=0, expand=True, center=center + ) + + def _test_perspective(self, tensor, pil_img, scripted_tranform, test_configs): + dt = tensor.dtype + for r in [0, ]: + for spoints, epoints in test_configs: + out_pil_img = F.perspective(pil_img, startpoints=spoints, endpoints=epoints, interpolation=r) + out_pil_tensor = torch.from_numpy(np.array(out_pil_img).transpose((2, 0, 1))) + + for fn in [F.perspective, scripted_tranform]: + out_tensor = fn(tensor, startpoints=spoints, endpoints=epoints, interpolation=r).cpu() + + if out_tensor.dtype != torch.uint8: + out_tensor = out_tensor.to(torch.uint8) + + num_diff_pixels = (out_tensor != out_pil_tensor).sum().item() / 3.0 + ratio_diff_pixels = num_diff_pixels / out_tensor.shape[-1] / out_tensor.shape[-2] + # Tolerance : less than 5% of different pixels + self.assertLess( + ratio_diff_pixels, + 0.05, + msg="{}: {}\n{} vs \n{}".format( + (r, dt, spoints, epoints), + ratio_diff_pixels, + out_tensor[0, :7, :7], + out_pil_tensor[0, :7, :7] + ) + ) def test_perspective(self): @@ -602,30 +733,16 @@ def test_perspective(self): if dt is not None: tensor = tensor.to(dtype=dt) - for r in [0, ]: - for spoints, epoints in test_configs: - out_pil_img = F.perspective(pil_img, startpoints=spoints, endpoints=epoints, interpolation=r) - out_pil_tensor = torch.from_numpy(np.array(out_pil_img).transpose((2, 0, 1))) - - for fn in [F.perspective, scripted_tranform]: - out_tensor = fn(tensor, startpoints=spoints, endpoints=epoints, interpolation=r).cpu() + self._test_perspective(tensor, pil_img, scripted_tranform, test_configs) - if out_tensor.dtype != torch.uint8: - out_tensor = out_tensor.to(torch.uint8) + batch_tensors = self._create_data_batch(26, 36, num_samples=4, device=self.device) + if dt is not None: + batch_tensors = batch_tensors.to(dtype=dt) - num_diff_pixels = (out_tensor != out_pil_tensor).sum().item() / 3.0 - ratio_diff_pixels = num_diff_pixels / out_tensor.shape[-1] / out_tensor.shape[-2] - # Tolerance : less than 5% of different pixels - self.assertLess( - ratio_diff_pixels, - 0.05, - msg="{}: {}\n{} vs \n{}".format( - (r, dt, spoints, epoints), - ratio_diff_pixels, - out_tensor[0, :7, :7], - out_pil_tensor[0, :7, :7] - ) - ) + for spoints, epoints in test_configs: + self._test_fn_on_batch( + batch_tensors, F.perspective, startpoints=spoints, endpoints=epoints, interpolation=0 + ) @unittest.skipIf(not torch.cuda.is_available(), reason="Skip if no CUDA device") diff --git a/test/test_transforms_tensor.py b/test/test_transforms_tensor.py index d82c4f6b309..182c70712fe 100644 --- a/test/test_transforms_tensor.py +++ b/test/test_transforms_tensor.py @@ -19,20 +19,43 @@ def setUp(self): def _test_functional_op(self, func, fn_kwargs): if fn_kwargs is None: fn_kwargs = {} + + f = getattr(F, func) tensor, pil_img = self._create_data(height=10, width=10, device=self.device) - transformed_tensor = getattr(F, func)(tensor, **fn_kwargs) - transformed_pil_img = getattr(F, func)(pil_img, **fn_kwargs) + transformed_tensor = f(tensor, **fn_kwargs) + transformed_pil_img = f(pil_img, **fn_kwargs) self.compareTensorToPIL(transformed_tensor, transformed_pil_img) + def _test_transform_vs_scripted(self, transform, s_transform, tensor): + torch.manual_seed(12) + out1 = transform(tensor) + torch.manual_seed(12) + out2 = s_transform(tensor) + self.assertTrue(out1.equal(out2)) + + def _test_transform_vs_scripted_on_batch(self, transform, s_transform, batch_tensors): + torch.manual_seed(12) + transformed_batch = transform(batch_tensors) + + for i in range(len(batch_tensors)): + img_tensor = batch_tensors[i, ...] + torch.manual_seed(12) + transformed_img = transform(img_tensor) + self.assertTrue(transformed_img.equal(transformed_batch[i, ...])) + + torch.manual_seed(12) + s_transformed_batch = s_transform(batch_tensors) + self.assertTrue(transformed_batch.equal(s_transformed_batch)) + def _test_class_op(self, method, meth_kwargs=None, test_exact_match=True, **match_kwargs): if meth_kwargs is None: meth_kwargs = {} - tensor, pil_img = self._create_data(26, 34, device=self.device) # test for class interface f = getattr(T, method)(**meth_kwargs) scripted_fn = torch.jit.script(f) + tensor, pil_img = self._create_data(26, 34, device=self.device) # set seed to reproduce the same transformation for tensor and PIL image torch.manual_seed(12) transformed_tensor = f(tensor) @@ -47,6 +70,9 @@ def _test_class_op(self, method, meth_kwargs=None, test_exact_match=True, **matc transformed_tensor_script = scripted_fn(tensor) self.assertTrue(transformed_tensor.equal(transformed_tensor_script)) + batch_tensors = self._create_data_batch(height=23, width=34, channels=3, num_samples=4, device=self.device) + self._test_transform_vs_scripted_on_batch(f, scripted_fn, batch_tensors) + def _test_op(self, func, method, fn_kwargs=None, meth_kwargs=None): self._test_functional_op(func, fn_kwargs) self._test_class_op(method, meth_kwargs) @@ -167,15 +193,18 @@ def _test_op_list_output(self, func, method, out_length, fn_kwargs=None, meth_kw fn_kwargs = {} if meth_kwargs is None: meth_kwargs = {} + + fn = getattr(F, func) + scripted_fn = torch.jit.script(fn) + tensor, pil_img = self._create_data(height=20, width=20, device=self.device) - transformed_t_list = getattr(F, func)(tensor, **fn_kwargs) - transformed_p_list = getattr(F, func)(pil_img, **fn_kwargs) + transformed_t_list = fn(tensor, **fn_kwargs) + transformed_p_list = fn(pil_img, **fn_kwargs) self.assertEqual(len(transformed_t_list), len(transformed_p_list)) self.assertEqual(len(transformed_t_list), out_length) for transformed_tensor, transformed_pil_img in zip(transformed_t_list, transformed_p_list): self.compareTensorToPIL(transformed_tensor, transformed_pil_img) - scripted_fn = torch.jit.script(getattr(F, func)) transformed_t_list_script = scripted_fn(tensor.detach().clone(), **fn_kwargs) self.assertEqual(len(transformed_t_list), len(transformed_t_list_script)) self.assertEqual(len(transformed_t_list_script), out_length) @@ -184,11 +213,24 @@ def _test_op_list_output(self, func, method, out_length, fn_kwargs=None, meth_kw msg="{} vs {}".format(transformed_tensor, transformed_tensor_script)) # test for class interface - f = getattr(T, method)(**meth_kwargs) - scripted_fn = torch.jit.script(f) + fn = getattr(T, method)(**meth_kwargs) + scripted_fn = torch.jit.script(fn) output = scripted_fn(tensor) self.assertEqual(len(output), len(transformed_t_list_script)) + # test on batch of tensors + batch_tensors = self._create_data_batch(height=23, width=34, channels=3, num_samples=4, device=self.device) + torch.manual_seed(12) + transformed_batch_list = fn(batch_tensors) + + for i in range(len(batch_tensors)): + img_tensor = batch_tensors[i, ...] + torch.manual_seed(12) + transformed_img_list = fn(img_tensor) + for transformed_img, transformed_batch in zip(transformed_img_list, transformed_batch_list): + self.assertTrue(transformed_img.equal(transformed_batch[i, ...]), + msg="{} vs {}".format(transformed_img, transformed_batch[i, ...])) + def test_five_crop(self): fn_kwargs = meth_kwargs = {"size": (5,)} self._test_op_list_output( @@ -227,6 +269,7 @@ def test_ten_crop(self): def test_resize(self): tensor, _ = self._create_data(height=34, width=36, device=self.device) + batch_tensors = torch.randint(0, 255, size=(4, 3, 44, 56), dtype=torch.uint8, device=self.device) script_fn = torch.jit.script(F.resize) for dt in [None, torch.float32, torch.float64]: @@ -247,13 +290,13 @@ def test_resize(self): self.assertTrue(s_resized_tensor.equal(resized_tensor)) transform = T.Resize(size=script_size, interpolation=interpolation) - resized_tensor = transform(tensor) - script_transform = torch.jit.script(transform) - s_resized_tensor = script_transform(tensor) - self.assertTrue(s_resized_tensor.equal(resized_tensor)) + s_transform = torch.jit.script(transform) + self._test_transform_vs_scripted(transform, s_transform, tensor) + self._test_transform_vs_scripted_on_batch(transform, s_transform, batch_tensors) def test_resized_crop(self): tensor = torch.randint(0, 255, size=(3, 44, 56), dtype=torch.uint8, device=self.device) + batch_tensors = torch.randint(0, 255, size=(4, 3, 44, 56), dtype=torch.uint8, device=self.device) for scale in [(0.7, 1.2), [0.7, 1.2]]: for ratio in [(0.75, 1.333), [0.75, 1.333]]: @@ -263,15 +306,12 @@ def test_resized_crop(self): size=size, scale=scale, ratio=ratio, interpolation=interpolation ) s_transform = torch.jit.script(transform) - - torch.manual_seed(12) - out1 = transform(tensor) - torch.manual_seed(12) - out2 = s_transform(tensor) - self.assertTrue(out1.equal(out2)) + self._test_transform_vs_scripted(transform, s_transform, tensor) + self._test_transform_vs_scripted_on_batch(transform, s_transform, batch_tensors) def test_random_affine(self): tensor = torch.randint(0, 255, size=(3, 44, 56), dtype=torch.uint8, device=self.device) + batch_tensors = torch.randint(0, 255, size=(4, 3, 44, 56), dtype=torch.uint8, device=self.device) for shear in [15, 10.0, (5.0, 10.0), [-15, 15], [-10.0, 10.0, -11.0, 11.0]]: for scale in [(0.7, 1.2), [0.7, 1.2]]: @@ -284,14 +324,12 @@ def test_random_affine(self): ) s_transform = torch.jit.script(transform) - torch.manual_seed(12) - out1 = transform(tensor) - torch.manual_seed(12) - out2 = s_transform(tensor) - self.assertTrue(out1.equal(out2)) + self._test_transform_vs_scripted(transform, s_transform, tensor) + self._test_transform_vs_scripted_on_batch(transform, s_transform, batch_tensors) def test_random_rotate(self): tensor = torch.randint(0, 255, size=(3, 44, 56), dtype=torch.uint8, device=self.device) + batch_tensors = torch.randint(0, 255, size=(4, 3, 44, 56), dtype=torch.uint8, device=self.device) for center in [(0, 0), [10, 10], None, (56, 44)]: for expand in [True, False]: @@ -302,14 +340,12 @@ def test_random_rotate(self): ) s_transform = torch.jit.script(transform) - torch.manual_seed(12) - out1 = transform(tensor) - torch.manual_seed(12) - out2 = s_transform(tensor) - self.assertTrue(out1.equal(out2)) + self._test_transform_vs_scripted(transform, s_transform, tensor) + self._test_transform_vs_scripted_on_batch(transform, s_transform, batch_tensors) def test_random_perspective(self): tensor = torch.randint(0, 255, size=(3, 44, 56), dtype=torch.uint8, device=self.device) + batch_tensors = torch.randint(0, 255, size=(4, 3, 44, 56), dtype=torch.uint8, device=self.device) for distortion_scale in np.linspace(0.1, 1.0, num=20): for interpolation in [NEAREST, BILINEAR]: @@ -319,11 +355,8 @@ def test_random_perspective(self): ) s_transform = torch.jit.script(transform) - torch.manual_seed(12) - out1 = transform(tensor) - torch.manual_seed(12) - out2 = s_transform(tensor) - self.assertTrue(out1.equal(out2)) + self._test_transform_vs_scripted(transform, s_transform, tensor) + self._test_transform_vs_scripted_on_batch(transform, s_transform, batch_tensors) def test_to_grayscale(self): diff --git a/torchvision/transforms/functional_tensor.py b/torchvision/transforms/functional_tensor.py index 4518a916542..0ef607775e7 100644 --- a/torchvision/transforms/functional_tensor.py +++ b/torchvision/transforms/functional_tensor.py @@ -36,7 +36,7 @@ def vflip(img: Tensor) -> Tensor: Please, consider instead using methods from `transforms.functional` module. Args: - img (Tensor): Image Tensor to be flipped in the form [C, H, W]. + img (Tensor): Image Tensor to be flipped in the form [..., C, H, W]. Returns: Tensor: Vertically flipped image Tensor. @@ -56,7 +56,7 @@ def hflip(img: Tensor) -> Tensor: Please, consider instead using methods from `transforms.functional` module. Args: - img (Tensor): Image Tensor to be flipped in the form [C, H, W]. + img (Tensor): Image Tensor to be flipped in the form [..., C, H, W]. Returns: Tensor: Horizontally flipped image Tensor. @@ -183,7 +183,8 @@ def adjust_contrast(img: Tensor, contrast_factor: float) -> Tensor: if not _is_tensor_a_torch_image(img): raise TypeError('tensor is not a torch image.') - mean = torch.mean(rgb_to_grayscale(img).to(torch.float)) + dtype = img.dtype if torch.is_floating_point(img) else torch.float32 + mean = torch.mean(rgb_to_grayscale(img).to(dtype), dim=(-3, -2, -1), keepdim=True) return _blend(img, mean, contrast_factor) @@ -229,9 +230,9 @@ def adjust_hue(img: Tensor, hue_factor: float) -> Tensor: img = img.to(dtype=torch.float32) / 255.0 img = _rgb2hsv(img) - h, s, v = img.unbind(0) + h, s, v = img.unbind(dim=-3) h = (h + hue_factor) % 1.0 - img = torch.stack((h, s, v)) + img = torch.stack((h, s, v), dim=-3) img_hue_adj = _hsv2rgb(img) if orig_dtype == torch.uint8: @@ -466,12 +467,12 @@ def _blend(img1: Tensor, img2: Tensor, ratio: float) -> Tensor: def _rgb2hsv(img): - r, g, b = img.unbind(0) + r, g, b = img.unbind(dim=-3) # Implementation is based on https://github.com/python-pillow/Pillow/blob/4174d4267616897df3746d315d5a2d0f82c656ee/ # src/libImaging/Convert.c#L330 - maxc = torch.max(img, dim=0).values - minc = torch.min(img, dim=0).values + maxc = torch.max(img, dim=-3).values + minc = torch.min(img, dim=-3).values # The algorithm erases S and H channel where `maxc = minc`. This avoids NaN # from happening in the results, because @@ -501,11 +502,11 @@ def _rgb2hsv(img): hb = ((maxc != g) & (maxc != r)) * (4.0 + gc - rc) h = (hr + hg + hb) h = torch.fmod((h / 6.0 + 1.0), 1.0) - return torch.stack((h, s, maxc)) + return torch.stack((h, s, maxc), dim=-3) def _hsv2rgb(img): - h, s, v = img.unbind(0) + h, s, v = img.unbind(dim=-3) i = torch.floor(h * 6.0) f = (h * 6.0) - i i = i.to(dtype=torch.int32) @@ -515,14 +516,14 @@ def _hsv2rgb(img): t = torch.clamp((v * (1.0 - s * (1.0 - f))), 0.0, 1.0) i = i % 6 - mask = i == torch.arange(6, device=i.device)[:, None, None] + mask = i.unsqueeze(dim=-3) == torch.arange(6, device=i.device).view(-1, 1, 1) - a1 = torch.stack((v, q, p, p, t, v)) - a2 = torch.stack((t, v, v, q, p, p)) - a3 = torch.stack((p, p, t, v, v, q)) - a4 = torch.stack((a1, a2, a3)) + a1 = torch.stack((v, q, p, p, t, v), dim=-3) + a2 = torch.stack((t, v, v, q, p, p), dim=-3) + a3 = torch.stack((p, p, t, v, v, q), dim=-3) + a4 = torch.stack((a1, a2, a3), dim=-4) - return torch.einsum("ijk, xijk -> xjk", mask.to(dtype=img.dtype), a4) + return torch.einsum("...ijk, ...xijk -> ...xjk", mask.to(dtype=img.dtype), a4) def _pad_symmetric(img: Tensor, padding: List[int]) -> Tensor: @@ -793,6 +794,9 @@ def _apply_grid_transform(img: Tensor, grid: Tensor, mode: str) -> Tensor: need_cast = True img = img.to(grid) + if img.shape[0] > 1: + # Apply same grid to a batch of images + grid = grid.expand(img.shape[0], grid.shape[1], grid.shape[2], grid.shape[3]) img = grid_sample(img, grid, mode=mode, padding_mode="zeros", align_corners=False) if need_squeeze: