Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
279 changes: 141 additions & 138 deletions test/test_functional_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,82 +36,6 @@ class Tester(unittest.TestCase):
def setUp(self):
self.device = "cpu"

def test_assert_image_tensor(self):
shape = (100,)
tensor = torch.rand(*shape, dtype=torch.float, device=self.device)

list_of_methods = [(F_t._get_image_size, (tensor, )), (F_t.vflip, (tensor, )),
(F_t.hflip, (tensor, )), (F_t.crop, (tensor, 1, 2, 4, 5)),
(F_t.adjust_brightness, (tensor, 0.)), (F_t.adjust_contrast, (tensor, 1.)),
(F_t.adjust_hue, (tensor, -0.5)), (F_t.adjust_saturation, (tensor, 2.)),
(F_t.center_crop, (tensor, [10, 11])), (F_t.five_crop, (tensor, [10, 11])),
(F_t.ten_crop, (tensor, [10, 11])), (F_t.pad, (tensor, [2, ], 2, "constant")),
(F_t.resize, (tensor, [10, 11])), (F_t.perspective, (tensor, [0.2, ])),
(F_t.gaussian_blur, (tensor, (2, 2), (0.7, 0.5))),
(F_t.invert, (tensor, )), (F_t.posterize, (tensor, 0)),
(F_t.solarize, (tensor, 0.3)), (F_t.adjust_sharpness, (tensor, 0.3)),
(F_t.autocontrast, (tensor, )), (F_t.equalize, (tensor, ))]

for func, args in list_of_methods:
with self.assertRaises(Exception) as context:
func(*args)

self.assertTrue('Tensor is not a torch image.' in str(context.exception))

def test_vflip(self):
script_vflip = torch.jit.script(F.vflip)

img_tensor, pil_img = _create_data(16, 18, device=self.device)
vflipped_img = F.vflip(img_tensor)
vflipped_pil_img = F.vflip(pil_img)
_assert_equal_tensor_to_pil(vflipped_img, vflipped_pil_img)

# scriptable function test
vflipped_img_script = script_vflip(img_tensor)
assert_equal(vflipped_img, vflipped_img_script)

batch_tensors = _create_data_batch(16, 18, num_samples=4, device=self.device)
_test_fn_on_batch(batch_tensors, F.vflip)

def test_hflip(self):
script_hflip = torch.jit.script(F.hflip)

img_tensor, pil_img = _create_data(16, 18, device=self.device)
hflipped_img = F.hflip(img_tensor)
hflipped_pil_img = F.hflip(pil_img)
_assert_equal_tensor_to_pil(hflipped_img, hflipped_pil_img)

# scriptable function test
hflipped_img_script = script_hflip(img_tensor)
assert_equal(hflipped_img, hflipped_img_script)

batch_tensors = _create_data_batch(16, 18, num_samples=4, device=self.device)
_test_fn_on_batch(batch_tensors, F.hflip)

def test_crop(self):
script_crop = torch.jit.script(F.crop)

img_tensor, pil_img = _create_data(16, 18, device=self.device)

test_configs = [
(1, 2, 4, 5), # crop inside top-left corner
(2, 12, 3, 4), # crop inside top-right corner
(8, 3, 5, 6), # crop inside bottom-left corner
(8, 11, 4, 3), # crop inside bottom-right corner
]

for top, left, height, width in test_configs:
pil_img_cropped = F.crop(pil_img, top, left, height, width)

img_tensor_cropped = F.crop(img_tensor, top, left, height, width)
_assert_equal_tensor_to_pil(img_tensor_cropped, pil_img_cropped)

img_tensor_cropped = script_crop(img_tensor, top, left, height, width)
_assert_equal_tensor_to_pil(img_tensor_cropped, pil_img_cropped)

batch_tensors = _create_data_batch(16, 18, num_samples=4, device=self.device)
_test_fn_on_batch(batch_tensors, F.crop, top=top, left=left, height=height, width=width)

def test_hsv2rgb(self):
scripted_fn = torch.jit.script(F_t._hsv2rgb)
shape = (3, 100, 150)
Expand Down Expand Up @@ -610,68 +534,6 @@ def test_rotate(self):
res2 = F.rotate(tensor, 45, interpolation=BILINEAR)
assert_equal(res1, res2)

def test_gaussian_blur(self):
small_image_tensor = torch.from_numpy(
np.arange(3 * 10 * 12, dtype="uint8").reshape((10, 12, 3))
).permute(2, 0, 1).to(self.device)

large_image_tensor = torch.from_numpy(
np.arange(26 * 28, dtype="uint8").reshape((1, 26, 28))
).to(self.device)

scripted_transform = torch.jit.script(F.gaussian_blur)

# true_cv2_results = {
# # np_img = np.arange(3 * 10 * 12, dtype="uint8").reshape((10, 12, 3))
# # cv2.GaussianBlur(np_img, ksize=(3, 3), sigmaX=0.8)
# "3_3_0.8": ...
# # cv2.GaussianBlur(np_img, ksize=(3, 3), sigmaX=0.5)
# "3_3_0.5": ...
# # cv2.GaussianBlur(np_img, ksize=(3, 5), sigmaX=0.8)
# "3_5_0.8": ...
# # cv2.GaussianBlur(np_img, ksize=(3, 5), sigmaX=0.5)
# "3_5_0.5": ...
# # np_img2 = np.arange(26 * 28, dtype="uint8").reshape((26, 28))
# # cv2.GaussianBlur(np_img2, ksize=(23, 23), sigmaX=1.7)
# "23_23_1.7": ...
# }
p = os.path.join(os.path.dirname(os.path.abspath(__file__)), 'assets', 'gaussian_blur_opencv_results.pt')
true_cv2_results = torch.load(p)

for tensor in [small_image_tensor, large_image_tensor]:

for dt in [None, torch.float32, torch.float64, torch.float16]:
if dt == torch.float16 and torch.device(self.device).type == "cpu":
# skip float16 on CPU case
continue

if dt is not None:
tensor = tensor.to(dtype=dt)

for ksize in [(3, 3), [3, 5], (23, 23)]:
for sigma in [[0.5, 0.5], (0.5, 0.5), (0.8, 0.8), (1.7, 1.7)]:

_ksize = (ksize, ksize) if isinstance(ksize, int) else ksize
_sigma = sigma[0] if sigma is not None else None
shape = tensor.shape
gt_key = "{}_{}_{}__{}_{}_{}".format(
shape[-2], shape[-1], shape[-3],
_ksize[0], _ksize[1], _sigma
)
if gt_key not in true_cv2_results:
continue

true_out = torch.tensor(
true_cv2_results[gt_key]
).reshape(shape[-2], shape[-1], shape[-3]).permute(2, 0, 1).to(tensor)

for fn in [F.gaussian_blur, scripted_transform]:
out = fn(tensor, kernel_size=ksize, sigma=sigma)
torch.testing.assert_close(
out, true_out, rtol=0.0, atol=1.0, check_stride=False,
msg="{}, {}".format(ksize, sigma)
)


@unittest.skipIf(not torch.cuda.is_available(), reason="Skip if no CUDA device")
class CUDATester(Tester):
Expand Down Expand Up @@ -1141,5 +1003,146 @@ def test_adjust_gamma(device, dtype, config):
)


@pytest.mark.parametrize('device', cpu_and_gpu())
@pytest.mark.parametrize('func, args', [
(F_t._get_image_size, ()), (F_t.vflip, ()),
(F_t.hflip, ()), (F_t.crop, (1, 2, 4, 5)),
(F_t.adjust_brightness, (0., )), (F_t.adjust_contrast, (1., )),
(F_t.adjust_hue, (-0.5, )), (F_t.adjust_saturation, (2., )),
(F_t.center_crop, ([10, 11], )), (F_t.five_crop, ([10, 11], )),
(F_t.ten_crop, ([10, 11], )), (F_t.pad, ([2, ], 2, "constant")),
(F_t.resize, ([10, 11], )), (F_t.perspective, ([0.2, ])),
(F_t.gaussian_blur, ((2, 2), (0.7, 0.5))),
(F_t.invert, ()), (F_t.posterize, (0, )),
(F_t.solarize, (0.3, )), (F_t.adjust_sharpness, (0.3, )),
(F_t.autocontrast, ()), (F_t.equalize, ())
])
def test_assert_image_tensor(device, func, args):
shape = (100,)
tensor = torch.rand(*shape, dtype=torch.float, device=device)
with pytest.raises(Exception, match=r"Tensor is not a torch image."):
func(tensor, *args)


@pytest.mark.parametrize('device', cpu_and_gpu())
def test_vflip(device):
script_vflip = torch.jit.script(F.vflip)

img_tensor, pil_img = _create_data(16, 18, device=device)
vflipped_img = F.vflip(img_tensor)
vflipped_pil_img = F.vflip(pil_img)
_assert_equal_tensor_to_pil(vflipped_img, vflipped_pil_img)

# scriptable function test
vflipped_img_script = script_vflip(img_tensor)
assert_equal(vflipped_img, vflipped_img_script)

batch_tensors = _create_data_batch(16, 18, num_samples=4, device=device)
_test_fn_on_batch(batch_tensors, F.vflip)


@pytest.mark.parametrize('device', cpu_and_gpu())
def test_hflip(device):
script_hflip = torch.jit.script(F.hflip)

img_tensor, pil_img = _create_data(16, 18, device=device)
hflipped_img = F.hflip(img_tensor)
hflipped_pil_img = F.hflip(pil_img)
_assert_equal_tensor_to_pil(hflipped_img, hflipped_pil_img)

# scriptable function test
hflipped_img_script = script_hflip(img_tensor)
assert_equal(hflipped_img, hflipped_img_script)

batch_tensors = _create_data_batch(16, 18, num_samples=4, device=device)
_test_fn_on_batch(batch_tensors, F.hflip)


@pytest.mark.parametrize('device', cpu_and_gpu())
@pytest.mark.parametrize('top, left, height, width', [
(1, 2, 4, 5), # crop inside top-left corner
(2, 12, 3, 4), # crop inside top-right corner
(8, 3, 5, 6), # crop inside bottom-left corner
(8, 11, 4, 3), # crop inside bottom-right corner
])
def test_crop(device, top, left, height, width):
script_crop = torch.jit.script(F.crop)

img_tensor, pil_img = _create_data(16, 18, device=device)

pil_img_cropped = F.crop(pil_img, top, left, height, width)

img_tensor_cropped = F.crop(img_tensor, top, left, height, width)
_assert_equal_tensor_to_pil(img_tensor_cropped, pil_img_cropped)

img_tensor_cropped = script_crop(img_tensor, top, left, height, width)
_assert_equal_tensor_to_pil(img_tensor_cropped, pil_img_cropped)

batch_tensors = _create_data_batch(16, 18, num_samples=4, device=device)
_test_fn_on_batch(batch_tensors, F.crop, top=top, left=left, height=height, width=width)


@pytest.mark.parametrize('device', cpu_and_gpu())
@pytest.mark.parametrize('image_size', ('small', 'large'))
@pytest.mark.parametrize('dt', [None, torch.float32, torch.float64, torch.float16])
@pytest.mark.parametrize('ksize', [(3, 3), [3, 5], (23, 23)])
@pytest.mark.parametrize('sigma', [[0.5, 0.5], (0.5, 0.5), (0.8, 0.8), (1.7, 1.7)])
@pytest.mark.parametrize('fn', [F.gaussian_blur, torch.jit.script(F.gaussian_blur)])
def test_gaussian_blur(device, image_size, dt, ksize, sigma, fn):

# true_cv2_results = {
# # np_img = np.arange(3 * 10 * 12, dtype="uint8").reshape((10, 12, 3))
# # cv2.GaussianBlur(np_img, ksize=(3, 3), sigmaX=0.8)
# "3_3_0.8": ...
# # cv2.GaussianBlur(np_img, ksize=(3, 3), sigmaX=0.5)
# "3_3_0.5": ...
# # cv2.GaussianBlur(np_img, ksize=(3, 5), sigmaX=0.8)
# "3_5_0.8": ...
# # cv2.GaussianBlur(np_img, ksize=(3, 5), sigmaX=0.5)
# "3_5_0.5": ...
# # np_img2 = np.arange(26 * 28, dtype="uint8").reshape((26, 28))
# # cv2.GaussianBlur(np_img2, ksize=(23, 23), sigmaX=1.7)
# "23_23_1.7": ...
# }
p = os.path.join(os.path.dirname(os.path.abspath(__file__)), 'assets', 'gaussian_blur_opencv_results.pt')
true_cv2_results = torch.load(p)

if image_size == 'small':
tensor = torch.from_numpy(
np.arange(3 * 10 * 12, dtype="uint8").reshape((10, 12, 3))
).permute(2, 0, 1).to(device)
else:
tensor = torch.from_numpy(
np.arange(26 * 28, dtype="uint8").reshape((1, 26, 28))
).to(device)

if dt == torch.float16 and device == "cpu":
# skip float16 on CPU case
return

if dt is not None:
tensor = tensor.to(dtype=dt)

_ksize = (ksize, ksize) if isinstance(ksize, int) else ksize
_sigma = sigma[0] if sigma is not None else None
shape = tensor.shape
gt_key = "{}_{}_{}__{}_{}_{}".format(
shape[-2], shape[-1], shape[-3],
_ksize[0], _ksize[1], _sigma
)
if gt_key not in true_cv2_results:
return

true_out = torch.tensor(
true_cv2_results[gt_key]
).reshape(shape[-2], shape[-1], shape[-3]).permute(2, 0, 1).to(tensor)

out = fn(tensor, kernel_size=ksize, sigma=sigma)
torch.testing.assert_close(
out, true_out, rtol=0.0, atol=1.0, check_stride=False,
msg="{}, {}".format(ksize, sigma)
)


if __name__ == '__main__':
unittest.main()