diff --git a/test/test_transforms_video.py b/test/test_transforms_video.py index 942bb010f71..81b65ef0a6d 100644 --- a/test/test_transforms_video.py +++ b/test/test_transforms_video.py @@ -1,6 +1,6 @@ import torch from torchvision.transforms import Compose -import unittest +import pytest import random import numpy as np import warnings @@ -17,7 +17,7 @@ import torchvision.transforms._transforms_video as transforms -class TestVideoTransforms(unittest.TestCase): +class TestVideoTransforms(): def test_random_crop_video(self): numFrames = random.randint(4, 128) @@ -30,8 +30,8 @@ def test_random_crop_video(self): transforms.ToTensorVideo(), transforms.RandomCropVideo((oheight, owidth)), ])(clip) - self.assertEqual(result.size(2), oheight) - self.assertEqual(result.size(3), owidth) + assert result.size(2) == oheight + assert result.size(3) == owidth transforms.RandomCropVideo((oheight, owidth)).__repr__() @@ -46,8 +46,8 @@ def test_random_resized_crop_video(self): transforms.ToTensorVideo(), transforms.RandomResizedCropVideo((oheight, owidth)), ])(clip) - self.assertEqual(result.size(2), oheight) - self.assertEqual(result.size(3), owidth) + assert result.size(2) == oheight + assert result.size(3) == owidth transforms.RandomResizedCropVideo((oheight, owidth)).__repr__() @@ -70,7 +70,7 @@ def test_center_crop_video(self): msg = "height: " + str(height) + " width: " \ + str(width) + " oheight: " + str(oheight) + " owidth: " + str(owidth) - self.assertEqual(result.sum().item(), 0, msg) + assert result.sum().item() == 0, msg oheight += 1 owidth += 1 @@ -82,7 +82,7 @@ def test_center_crop_video(self): msg = "height: " + str(height) + " width: " \ + str(width) + " oheight: " + str(oheight) + " owidth: " + str(owidth) - self.assertEqual(sum1.item() > 1, True, msg) + assert sum1.item() > 1, msg oheight += 1 owidth += 1 @@ -94,28 +94,29 @@ def test_center_crop_video(self): msg = "height: " + str(height) + " width: " \ + str(width) + " oheight: " + str(oheight) + " owidth: " + str(owidth) - self.assertTrue(sum2.item() > 1, msg) - self.assertTrue(sum2.item() > sum1.item(), msg) + assert sum2.item() > 1, msg + assert sum2.item() > sum1.item(), msg - @unittest.skipIf(stats is None, 'scipy.stats is not available') - def test_normalize_video(self): + @pytest.mark.skipif(stats is None, reason='scipy.stats is not available') + @pytest.mark.parametrize('channels', [1, 3]) + def test_normalize_video(self, channels): def samples_from_standard_normal(tensor): p_value = stats.kstest(list(tensor.view(-1)), 'norm', args=(0, 1)).pvalue return p_value > 0.0001 random_state = random.getstate() random.seed(42) - for channels in [1, 3]: - numFrames = random.randint(4, 128) - height = random.randint(32, 256) - width = random.randint(32, 256) - mean = random.random() - std = random.random() - clip = torch.normal(mean, std, size=(channels, numFrames, height, width)) - mean = [clip[c].mean().item() for c in range(channels)] - std = [clip[c].std().item() for c in range(channels)] - normalized = transforms.NormalizeVideo(mean, std)(clip) - self.assertTrue(samples_from_standard_normal(normalized)) + + numFrames = random.randint(4, 128) + height = random.randint(32, 256) + width = random.randint(32, 256) + mean = random.random() + std = random.random() + clip = torch.normal(mean, std, size=(channels, numFrames, height, width)) + mean = [clip[c].mean().item() for c in range(channels)] + std = [clip[c].std().item() for c in range(channels)] + normalized = transforms.NormalizeVideo(mean, std)(clip) + assert samples_from_standard_normal(normalized) random.setstate(random_state) # Checking the optional in-place behaviour @@ -129,11 +130,11 @@ def test_to_tensor_video(self): numFrames, height, width = 64, 4, 4 trans = transforms.ToTensorVideo() - with self.assertRaises(TypeError): + with pytest.raises(TypeError): trans(np.random.rand(numFrames, height, width, 1).tolist()) trans(torch.rand((numFrames, height, width, 1), dtype=torch.float)) - with self.assertRaises(ValueError): + with pytest.raises(ValueError): trans(torch.ones((3, numFrames, height, width, 3), dtype=torch.uint8)) trans(torch.ones((height, width, 3), dtype=torch.uint8)) trans(torch.ones((width, 3), dtype=torch.uint8)) @@ -141,7 +142,7 @@ def test_to_tensor_video(self): trans.__repr__() - @unittest.skipIf(stats is None, 'scipy.stats not available') + @pytest.mark.skipif(stats is None, reason='scipy.stats not available') def test_random_horizontal_flip_video(self): random_state = random.getstate() random.seed(42) @@ -157,7 +158,7 @@ def test_random_horizontal_flip_video(self): p_value = stats.binom_test(num_horizontal, num_samples, p=0.5) random.setstate(random_state) - self.assertGreater(p_value, 0.0001) + assert p_value > 0.0001 num_samples = 250 num_horizontal = 0 @@ -168,10 +169,10 @@ def test_random_horizontal_flip_video(self): p_value = stats.binom_test(num_horizontal, num_samples, p=0.7) random.setstate(random_state) - self.assertGreater(p_value, 0.0001) + assert p_value > 0.0001 transforms.RandomHorizontalFlipVideo().__repr__() if __name__ == '__main__': - unittest.main() + pytest.main([__file__])