Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
122 changes: 13 additions & 109 deletions test/test_functional_tensor.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
import itertools
import os
import unittest
import colorsys
import math

Expand Down Expand Up @@ -31,113 +30,18 @@
NEAREST, BILINEAR, BICUBIC = InterpolationMode.NEAREST, InterpolationMode.BILINEAR, InterpolationMode.BICUBIC


class Tester(unittest.TestCase):

def setUp(self):
self.device = "cpu"

def _test_rotate_all_options(self, tensor, pil_img, scripted_rotate, centers):
img_size = pil_img.size
dt = tensor.dtype
for r in [NEAREST, ]:
for a in range(-180, 180, 17):
for e in [True, False]:
for c in centers:
for f in [None, [0, 0, 0], (1, 2, 3), [255, 255, 255], [1, ], (2.0, )]:
f_pil = int(f[0]) if f is not None and len(f) == 1 else f
out_pil_img = F.rotate(pil_img, angle=a, interpolation=r, expand=e, center=c, fill=f_pil)
out_pil_tensor = torch.from_numpy(np.array(out_pil_img).transpose((2, 0, 1)))
for fn in [F.rotate, scripted_rotate]:
out_tensor = fn(tensor, angle=a, interpolation=r, expand=e, center=c, fill=f).cpu()

if out_tensor.dtype != torch.uint8:
out_tensor = out_tensor.to(torch.uint8)

self.assertEqual(
out_tensor.shape,
out_pil_tensor.shape,
msg="{}: {} vs {}".format(
(img_size, r, dt, a, e, c), out_tensor.shape, out_pil_tensor.shape
))

num_diff_pixels = (out_tensor != out_pil_tensor).sum().item() / 3.0
ratio_diff_pixels = num_diff_pixels / out_tensor.shape[-1] / out_tensor.shape[-2]
# Tolerance : less than 3% of different pixels
self.assertLess(
ratio_diff_pixels,
0.03,
msg="{}: {}\n{} vs \n{}".format(
(img_size, r, dt, a, e, c, f),
ratio_diff_pixels,
out_tensor[0, :7, :7],
out_pil_tensor[0, :7, :7]
)
)

def test_rotate(self):
# Tests on square image
scripted_rotate = torch.jit.script(F.rotate)

data = [_create_data(26, 26, device=self.device), _create_data(32, 26, device=self.device)]
for tensor, pil_img in data:

img_size = pil_img.size
centers = [
None,
(int(img_size[0] * 0.3), int(img_size[0] * 0.4)),
[int(img_size[0] * 0.5), int(img_size[0] * 0.6)]
]

for dt in [None, torch.float32, torch.float64, torch.float16]:

if dt == torch.float16 and torch.device(self.device).type == "cpu":
# skip float16 on CPU case
continue

if dt is not None:
tensor = tensor.to(dtype=dt)

self._test_rotate_all_options(tensor, pil_img, scripted_rotate, centers)

batch_tensors = _create_data_batch(26, 36, num_samples=4, device=self.device)
if dt is not None:
batch_tensors = batch_tensors.to(dtype=dt)

center = (20, 22)
_test_fn_on_batch(
batch_tensors, F.rotate, angle=32, interpolation=NEAREST, expand=True, center=center
)
tensor, pil_img = data[0]
# assert deprecation warning and non-BC
with self.assertWarnsRegex(UserWarning, r"Argument resample is deprecated and will be removed"):
res1 = F.rotate(tensor, 45, resample=2)
res2 = F.rotate(tensor, 45, interpolation=BILINEAR)
assert_equal(res1, res2)

# assert changed type warning
with self.assertWarnsRegex(UserWarning, r"Argument interpolation should be of type InterpolationMode"):
res1 = F.rotate(tensor, 45, interpolation=2)
res2 = F.rotate(tensor, 45, interpolation=BILINEAR)
assert_equal(res1, res2)


@unittest.skipIf(not torch.cuda.is_available(), reason="Skip if no CUDA device")
class CUDATester(Tester):

def setUp(self):
self.device = "cuda"

def test_scale_channel(self):
"""Make sure that _scale_channel gives the same results on CPU and GPU as
histc or bincount are used depending on the device.
"""
# TODO: when # https://github.com/pytorch/pytorch/issues/53194 is fixed,
# only use bincount and remove that test.
size = (1_000,)
img_chan = torch.randint(0, 256, size=size).to('cpu')
scaled_cpu = F_t._scale_channel(img_chan)
scaled_cuda = F_t._scale_channel(img_chan.to('cuda'))
assert_equal(scaled_cpu, scaled_cuda.to('cpu'))
@needs_cuda
def test_scale_channel():
"""Make sure that _scale_channel gives the same results on CPU and GPU as
histc or bincount are used depending on the device.
"""
# TODO: when # https://github.com/pytorch/pytorch/issues/53194 is fixed,
# only use bincount and remove that test.
size = (1_000,)
img_chan = torch.randint(0, 256, size=size).to('cpu')
scaled_cpu = F_t._scale_channel(img_chan)
scaled_cuda = F_t._scale_channel(img_chan.to('cuda'))
assert_equal(scaled_cpu, scaled_cuda.to('cpu'))


class TestRotate:
Expand Down Expand Up @@ -1271,4 +1175,4 @@ def test_ten_crop(device):


if __name__ == '__main__':
unittest.main()
pytest.main([__file__])