Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactor adjust ops tests #2595

Merged
merged 11 commits into from Sep 1, 2020
133 changes: 59 additions & 74 deletions test/test_functional_tensor.py
Expand Up @@ -26,8 +26,7 @@ def compareTensorToPIL(self, tensor, pil_image, msg=None):
if np_pil_image.ndim == 2:
np_pil_image = np_pil_image[:, :, None]
pil_tensor = torch.as_tensor(np_pil_image.transpose((2, 0, 1)))
if msg is None:
msg = "tensor:\n{} \ndid not equal PIL tensor:\n{}".format(tensor, pil_tensor)
msg = "{}: tensor:\n{} \ndid not equal PIL tensor:\n{}".format(msg, tensor, pil_tensor)
self.assertTrue(tensor.equal(pil_tensor), msg)

def approxEqualTensorToPIL(self, tensor, pil_image, tol=1e-5, msg=None):
Expand Down Expand Up @@ -130,64 +129,6 @@ def test_rgb2hsv(self):

self.assertLess(max_diff, 1e-5)

def test_adjustments(self):
script_adjust_brightness = torch.jit.script(F_t.adjust_brightness)
script_adjust_contrast = torch.jit.script(F_t.adjust_contrast)
script_adjust_saturation = torch.jit.script(F_t.adjust_saturation)

fns = ((F.adjust_brightness, F_t.adjust_brightness, script_adjust_brightness),
(F.adjust_contrast, F_t.adjust_contrast, script_adjust_contrast),
(F.adjust_saturation, F_t.adjust_saturation, script_adjust_saturation))

for _ in range(20):
channels = 3
dims = torch.randint(1, 50, (2,))
shape = (channels, dims[0], dims[1])

if torch.randint(0, 2, (1,)) == 0:
img = torch.rand(*shape, dtype=torch.float)
else:
img = torch.randint(0, 256, shape, dtype=torch.uint8)

factor = 3 * torch.rand(1)
img_clone = img.clone()
for f, ft, sft in fns:

ft_img = ft(img, factor)
sft_img = sft(img, factor)
if not img.dtype.is_floating_point:
ft_img = ft_img.to(torch.float) / 255
sft_img = sft_img.to(torch.float) / 255

img_pil = transforms.ToPILImage()(img)
f_img_pil = f(img_pil, factor)
f_img = transforms.ToTensor()(f_img_pil)

# F uses uint8 and F_t uses float, so there is a small
# difference in values caused by (at most 5) truncations.
max_diff = (ft_img - f_img).abs().max()
max_diff_scripted = (sft_img - f_img).abs().max()
self.assertLess(max_diff, 5 / 255 + 1e-5)
self.assertLess(max_diff_scripted, 5 / 255 + 1e-5)
self.assertTrue(torch.equal(img, img_clone))

# test for class interface
f = transforms.ColorJitter(brightness=factor.item())
scripted_fn = torch.jit.script(f)
scripted_fn(img)

f = transforms.ColorJitter(contrast=factor.item())
scripted_fn = torch.jit.script(f)
scripted_fn(img)

f = transforms.ColorJitter(saturation=factor.item())
scripted_fn = torch.jit.script(f)
scripted_fn(img)
vfdev-5 marked this conversation as resolved.
Show resolved Hide resolved

f = transforms.ColorJitter(brightness=1)
scripted_fn = torch.jit.script(f)
scripted_fn(img)

def test_rgb_to_grayscale(self):
script_rgb_to_grayscale = torch.jit.script(F.rgb_to_grayscale)

Expand Down Expand Up @@ -286,32 +227,76 @@ def test_pad(self):
with self.assertRaises(ValueError, msg="Padding can not be negative for symmetric padding_mode"):
F_t.pad(tensor, (-2, -3), padding_mode="symmetric")

def test_adjust_gamma(self):
script_fn = torch.jit.script(F_t.adjust_gamma)
tensor, pil_img = self._create_data(26, 36)
def _test_adjust_fn(self, fn, fn_pil, fn_t, configs):
script_fn = torch.jit.script(fn)

for dt in [torch.float64, torch.float32, None]:
torch.manual_seed(15)

tensor, pil_img = self._create_data(26, 34)

for dt in [None, torch.float32, torch.float64]:

if dt is not None:
tensor = F.convert_image_dtype(tensor, dt)

gammas = [0.8, 1.0, 1.2]
gains = [0.7, 1.0, 1.3]
for gamma, gain in zip(gammas, gains):
for config in configs:

adjusted_tensor = F_t.adjust_gamma(tensor, gamma, gain)
adjusted_pil = F_pil.adjust_gamma(pil_img, gamma, gain)
scripted_result = script_fn(tensor, gamma, gain)
self.assertEqual(adjusted_tensor.dtype, scripted_result.dtype)
self.assertEqual(adjusted_tensor.size()[1:], adjusted_pil.size[::-1])
adjusted_tensor = fn_t(tensor, **config)
adjusted_pil = fn_pil(pil_img, **config)
scripted_result = script_fn(tensor, **config)
msg = "{}, {}".format(dt, config)
self.assertEqual(adjusted_tensor.dtype, scripted_result.dtype, msg=msg)
self.assertEqual(adjusted_tensor.size()[1:], adjusted_pil.size[::-1], msg=msg)

rbg_tensor = adjusted_tensor

if adjusted_tensor.dtype != torch.uint8:
rbg_tensor = F.convert_image_dtype(adjusted_tensor, torch.uint8)

self.compareTensorToPIL(rbg_tensor, adjusted_pil)
# Check that max difference does not exceed 1 in [0, 255] range
# Exact matching is not possible due to incompatibility convert_image_dtype and PIL results
rbg_tensor = rbg_tensor.float()
adjusted_pil_tensor = torch.as_tensor(np.array(adjusted_pil).transpose((2, 0, 1))).to(rbg_tensor)
max_diff = torch.abs(rbg_tensor - adjusted_pil_tensor).max().item()
self.assertLessEqual(
max_diff,
1.0,
msg="{}: tensor:\n{} \ndid not equal PIL tensor:\n{}".format(msg, rbg_tensor, adjusted_pil_tensor)
)

self.assertTrue(adjusted_tensor.equal(scripted_result), msg=msg)

def test_adjust_brightness(self):
self._test_adjust_fn(
F.adjust_brightness,
F_pil.adjust_brightness,
F_t.adjust_brightness,
[{"brightness_factor": f} for f in [0.1, 0.5, 1.0, 1.34, 2.5]]
)

def test_adjust_contrast(self):
self._test_adjust_fn(
F.adjust_contrast,
F_pil.adjust_contrast,
F_t.adjust_contrast,
[{"contrast_factor": f} for f in [0.2, 0.5, 1.0, 1.5, 2.0]]
)

self.assertTrue(adjusted_tensor.equal(scripted_result))
def test_adjust_saturation(self):
self._test_adjust_fn(
F.adjust_saturation,
F_pil.adjust_saturation,
F_t.adjust_saturation,
[{"saturation_factor": f} for f in [0.5, 0.75, 1.0, 1.25, 1.5]]
)

def test_adjust_gamma(self):
self._test_adjust_fn(
F.adjust_gamma,
F_pil.adjust_gamma,
F_t.adjust_gamma,
[{"gamma": g1, "gain": g2} for g1, g2 in zip([0.8, 1.0, 1.2], [0.7, 1.0, 1.3])]
)

def test_resize(self):
script_fn = torch.jit.script(F_t.resize)
Expand Down