Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fixes F.affine and F.rotate to support rectangular tensor images #2553

Merged
merged 14 commits into from Aug 6, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
263 changes: 147 additions & 116 deletions test/test_functional_tensor.py
Expand Up @@ -385,134 +385,165 @@ def test_resized_crop(self):
)

def test_affine(self):
# Tests on square image
tensor, pil_img = self._create_data(26, 26)

# Tests on square and rectangular images
scripted_affine = torch.jit.script(F.affine)
# 1) identity map
out_tensor = F.affine(tensor, angle=0, translate=[0, 0], scale=1.0, shear=[0.0, 0.0], resample=0)
self.assertTrue(
tensor.equal(out_tensor), msg="{} vs {}".format(out_tensor[0, :5, :5], tensor[0, :5, :5])
)
out_tensor = scripted_affine(tensor, angle=0, translate=[0, 0], scale=1.0, shear=[0.0, 0.0], resample=0)
self.assertTrue(
tensor.equal(out_tensor), msg="{} vs {}".format(out_tensor[0, :5, :5], tensor[0, :5, :5])
)

# 2) Test rotation
test_configs = [
(90, torch.rot90(tensor, k=1, dims=(-1, -2))),
(45, None),
(30, None),
(-30, None),
(-45, None),
(-90, torch.rot90(tensor, k=-1, dims=(-1, -2))),
(180, torch.rot90(tensor, k=2, dims=(-1, -2))),
]
for a, true_tensor in test_configs:
for fn in [F.affine, scripted_affine]:
out_tensor = fn(tensor, angle=a, translate=[0, 0], scale=1.0, shear=[0.0, 0.0], resample=0)
if true_tensor is not None:
self.assertTrue(
true_tensor.equal(out_tensor),
msg="{}\n{} vs \n{}".format(a, out_tensor[0, :5, :5], true_tensor[0, :5, :5])
)
else:
true_tensor = out_tensor

out_pil_img = F.affine(pil_img, angle=a, translate=[0, 0], scale=1.0, shear=[0.0, 0.0], resample=0)
out_pil_tensor = torch.from_numpy(np.array(out_pil_img).transpose((2, 0, 1)))

num_diff_pixels = (true_tensor != out_pil_tensor).sum().item() / 3.0
ratio_diff_pixels = num_diff_pixels / true_tensor.shape[-1] / true_tensor.shape[-2]
# Tolerance : less than 6% of different pixels
self.assertLess(
ratio_diff_pixels,
0.06,
msg="{}\n{} vs \n{}".format(
ratio_diff_pixels, true_tensor[0, :7, :7], out_pil_tensor[0, :7, :7]
)
)
# 3) Test translation
test_configs = [
[10, 12], (-12, -13)
]
for t in test_configs:
for fn in [F.affine, scripted_affine]:
out_tensor = fn(tensor, angle=0, translate=t, scale=1.0, shear=[0.0, 0.0], resample=0)
out_pil_img = F.affine(pil_img, angle=0, translate=t, scale=1.0, shear=[0.0, 0.0], resample=0)
self.compareTensorToPIL(out_tensor, out_pil_img)

# 3) Test rotation + translation + scale + share
test_configs = [
(45, [5, 6], 1.0, [0.0, 0.0]),
(33, (5, -4), 1.0, [0.0, 0.0]),
(45, [-5, 4], 1.2, [0.0, 0.0]),
(33, (-4, -8), 2.0, [0.0, 0.0]),
(85, (10, -10), 0.7, [0.0, 0.0]),
(0, [0, 0], 1.0, [35.0, ]),
(25, [0, 0], 1.2, [0.0, 15.0]),
(45, [-10, 0], 0.7, [2.0, 5.0]),
(45, [-10, -10], 1.2, [4.0, 5.0]),
]
for r in [0, ]:
for a, t, s, sh in test_configs:
out_pil_img = F.affine(pil_img, angle=a, translate=t, scale=s, shear=sh, resample=r)
out_pil_tensor = torch.from_numpy(np.array(out_pil_img).transpose((2, 0, 1)))
for tensor, pil_img in [self._create_data(26, 26), self._create_data(32, 26)]:

# 1) identity map
out_tensor = F.affine(tensor, angle=0, translate=[0, 0], scale=1.0, shear=[0.0, 0.0], resample=0)
self.assertTrue(
tensor.equal(out_tensor), msg="{} vs {}".format(out_tensor[0, :5, :5], tensor[0, :5, :5])
)
out_tensor = scripted_affine(tensor, angle=0, translate=[0, 0], scale=1.0, shear=[0.0, 0.0], resample=0)
self.assertTrue(
tensor.equal(out_tensor), msg="{} vs {}".format(out_tensor[0, :5, :5], tensor[0, :5, :5])
)

if pil_img.size[0] == pil_img.size[1]:
# 2) Test rotation
test_configs = [
(90, torch.rot90(tensor, k=1, dims=(-1, -2))),
(45, None),
(30, None),
(-30, None),
(-45, None),
(-90, torch.rot90(tensor, k=-1, dims=(-1, -2))),
(180, torch.rot90(tensor, k=2, dims=(-1, -2))),
]
for a, true_tensor in test_configs:
for fn in [F.affine, scripted_affine]:
out_tensor = fn(tensor, angle=a, translate=[0, 0], scale=1.0, shear=[0.0, 0.0], resample=0)
if true_tensor is not None:
self.assertTrue(
true_tensor.equal(out_tensor),
msg="{}\n{} vs \n{}".format(a, out_tensor[0, :5, :5], true_tensor[0, :5, :5])
)
else:
true_tensor = out_tensor

out_pil_img = F.affine(
pil_img, angle=a, translate=[0, 0], scale=1.0, shear=[0.0, 0.0], resample=0
)
out_pil_tensor = torch.from_numpy(np.array(out_pil_img).transpose((2, 0, 1)))

num_diff_pixels = (true_tensor != out_pil_tensor).sum().item() / 3.0
ratio_diff_pixels = num_diff_pixels / true_tensor.shape[-1] / true_tensor.shape[-2]
# Tolerance : less than 6% of different pixels
self.assertLess(
ratio_diff_pixels,
0.06,
msg="{}\n{} vs \n{}".format(
ratio_diff_pixels, true_tensor[0, :7, :7], out_pil_tensor[0, :7, :7]
)
)
else:
test_configs = [
90, 45, 15, -30, -60, -120
]
for a in test_configs:
for fn in [F.affine, scripted_affine]:
out_tensor = fn(tensor, angle=a, translate=[0, 0], scale=1.0, shear=[0.0, 0.0], resample=0)
out_pil_img = F.affine(
pil_img, angle=a, translate=[0, 0], scale=1.0, shear=[0.0, 0.0], resample=0
)
out_pil_tensor = torch.from_numpy(np.array(out_pil_img).transpose((2, 0, 1)))

num_diff_pixels = (out_tensor != out_pil_tensor).sum().item() / 3.0
ratio_diff_pixels = num_diff_pixels / out_tensor.shape[-1] / out_tensor.shape[-2]
# Tolerance : less than 3% of different pixels
self.assertLess(
ratio_diff_pixels,
0.03,
msg="{}: {}\n{} vs \n{}".format(
a, ratio_diff_pixels, out_tensor[0, :7, :7], out_pil_tensor[0, :7, :7]
)
)

# 3) Test translation
test_configs = [
[10, 12], (-12, -13)
]
for t in test_configs:
for fn in [F.affine, scripted_affine]:
out_tensor = fn(tensor, angle=a, translate=t, scale=s, shear=sh, resample=r)
num_diff_pixels = (out_tensor != out_pil_tensor).sum().item() / 3.0
ratio_diff_pixels = num_diff_pixels / out_tensor.shape[-1] / out_tensor.shape[-2]
# Tolerance : less than 5% of different pixels
self.assertLess(
ratio_diff_pixels,
0.05,
msg="{}: {}\n{} vs \n{}".format(
(r, a, t, s, sh), ratio_diff_pixels, out_tensor[0, :7, :7], out_pil_tensor[0, :7, :7]
out_tensor = fn(tensor, angle=0, translate=t, scale=1.0, shear=[0.0, 0.0], resample=0)
out_pil_img = F.affine(pil_img, angle=0, translate=t, scale=1.0, shear=[0.0, 0.0], resample=0)
self.compareTensorToPIL(out_tensor, out_pil_img)

# 3) Test rotation + translation + scale + share
test_configs = [
(45, [5, 6], 1.0, [0.0, 0.0]),
(33, (5, -4), 1.0, [0.0, 0.0]),
(45, [-5, 4], 1.2, [0.0, 0.0]),
(33, (-4, -8), 2.0, [0.0, 0.0]),
(85, (10, -10), 0.7, [0.0, 0.0]),
(0, [0, 0], 1.0, [35.0, ]),
(-25, [0, 0], 1.2, [0.0, 15.0]),
(-45, [-10, 0], 0.7, [2.0, 5.0]),
(-45, [-10, -10], 1.2, [4.0, 5.0]),
(-90, [0, 0], 1.0, [0.0, 0.0]),
]
for r in [0, ]:
for a, t, s, sh in test_configs:
out_pil_img = F.affine(pil_img, angle=a, translate=t, scale=s, shear=sh, resample=r)
out_pil_tensor = torch.from_numpy(np.array(out_pil_img).transpose((2, 0, 1)))

for fn in [F.affine, scripted_affine]:
out_tensor = fn(tensor, angle=a, translate=t, scale=s, shear=sh, resample=r)
num_diff_pixels = (out_tensor != out_pil_tensor).sum().item() / 3.0
ratio_diff_pixels = num_diff_pixels / out_tensor.shape[-1] / out_tensor.shape[-2]
# Tolerance : less than 5% of different pixels
self.assertLess(
ratio_diff_pixels,
0.05,
msg="{}: {}\n{} vs \n{}".format(
(r, a, t, s, sh), ratio_diff_pixels, out_tensor[0, :7, :7], out_pil_tensor[0, :7, :7]
)
)
)

def test_rotate(self):
# Tests on square image
tensor, pil_img = self._create_data(26, 26)
scripted_rotate = torch.jit.script(F.rotate)

img_size = pil_img.size

centers = [
None,
(int(img_size[0] * 0.3), int(img_size[0] * 0.4)),
[int(img_size[0] * 0.5), int(img_size[0] * 0.6)]
]

for r in [0, ]:
for a in range(-120, 120, 23):
for e in [True, False]:
for c in centers:

out_pil_img = F.rotate(pil_img, angle=a, resample=r, expand=e, center=c)
out_pil_tensor = torch.from_numpy(np.array(out_pil_img).transpose((2, 0, 1)))
for fn in [F.rotate, scripted_rotate]:
out_tensor = fn(tensor, angle=a, resample=r, expand=e, center=c)

self.assertEqual(
out_tensor.shape,
out_pil_tensor.shape,
msg="{}: {} vs {}".format(
(r, a, e, c), out_tensor.shape, out_pil_tensor.shape
for tensor, pil_img in [self._create_data(26, 26), self._create_data(32, 26)]:

img_size = pil_img.size
centers = [
None,
(int(img_size[0] * 0.3), int(img_size[0] * 0.4)),
[int(img_size[0] * 0.5), int(img_size[0] * 0.6)]
]

for r in [0, ]:
for a in range(-180, 180, 17):
for e in [True, False]:
for c in centers:

out_pil_img = F.rotate(pil_img, angle=a, resample=r, expand=e, center=c)
out_pil_tensor = torch.from_numpy(np.array(out_pil_img).transpose((2, 0, 1)))
for fn in [F.rotate, scripted_rotate]:
out_tensor = fn(tensor, angle=a, resample=r, expand=e, center=c)

self.assertEqual(
out_tensor.shape,
out_pil_tensor.shape,
msg="{}: {} vs {}".format(
(img_size, r, a, e, c), out_tensor.shape, out_pil_tensor.shape
)
)
)
num_diff_pixels = (out_tensor != out_pil_tensor).sum().item() / 3.0
ratio_diff_pixels = num_diff_pixels / out_tensor.shape[-1] / out_tensor.shape[-2]
# Tolerance : less than 2% of different pixels
self.assertLess(
ratio_diff_pixels,
0.02,
msg="{}: {}\n{} vs \n{}".format(
(r, a, e, c), ratio_diff_pixels, out_tensor[0, :7, :7], out_pil_tensor[0, :7, :7]
num_diff_pixels = (out_tensor != out_pil_tensor).sum().item() / 3.0
ratio_diff_pixels = num_diff_pixels / out_tensor.shape[-1] / out_tensor.shape[-2]
# Tolerance : less than 2% of different pixels
self.assertLess(
ratio_diff_pixels,
0.02,
msg="{}: {}\n{} vs \n{}".format(
(img_size, r, a, e, c),
ratio_diff_pixels,
out_tensor[0, :7, :7],
out_pil_tensor[0, :7, :7]
)
)
)


if __name__ == '__main__':
Expand Down
11 changes: 5 additions & 6 deletions torchvision/transforms/functional.py
Expand Up @@ -848,8 +848,9 @@ def rotate(
center_f = [0.0, 0.0]
if center is not None:
img_size = _get_image_size(img)
# Center is normalized to [-1, +1]
center_f = [2.0 * t / s - 1.0 for s, t in zip(img_size, center)]
# Center values should be in pixel coordinates but translated such that (0, 0) corresponds to image center.
center_f = [1.0 * (c - s * 0.5) for c, s in zip(center, img_size)]

# due to current incoherence of rotation angle direction between affine and rotate implementations
# we need to set -angle.
matrix = _get_inverse_affine_matrix(center_f, -angle, [0.0, 0.0], 1.0, [0.0, 0.0])
Expand Down Expand Up @@ -926,10 +927,8 @@ def affine(

return F_pil.affine(img, matrix=matrix, resample=resample, fillcolor=fillcolor)

# we need to rescale translate by image size / 2 as its values can be between -1 and 1
translate = [2.0 * t / s for s, t in zip(img_size, translate)]

matrix = _get_inverse_affine_matrix([0.0, 0.0], angle, translate, scale, shear)
translate_f = [1.0 * t for t in translate]
matrix = _get_inverse_affine_matrix([0.0, 0.0], angle, translate_f, scale, shear)
return F_t.affine(img, matrix=matrix, resample=resample, fillcolor=fillcolor)


Expand Down