Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
59 changes: 30 additions & 29 deletions test/test_transforms_video.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import torch
from torchvision.transforms import Compose
import unittest
import pytest
import random
import numpy as np
import warnings
Expand All @@ -17,7 +17,7 @@
import torchvision.transforms._transforms_video as transforms


class TestVideoTransforms(unittest.TestCase):
class TestVideoTransforms():

def test_random_crop_video(self):
numFrames = random.randint(4, 128)
Expand All @@ -30,8 +30,8 @@ def test_random_crop_video(self):
transforms.ToTensorVideo(),
transforms.RandomCropVideo((oheight, owidth)),
])(clip)
self.assertEqual(result.size(2), oheight)
self.assertEqual(result.size(3), owidth)
assert result.size(2) == oheight
assert result.size(3) == owidth

transforms.RandomCropVideo((oheight, owidth)).__repr__()

Expand All @@ -46,8 +46,8 @@ def test_random_resized_crop_video(self):
transforms.ToTensorVideo(),
transforms.RandomResizedCropVideo((oheight, owidth)),
])(clip)
self.assertEqual(result.size(2), oheight)
self.assertEqual(result.size(3), owidth)
assert result.size(2) == oheight
assert result.size(3) == owidth

transforms.RandomResizedCropVideo((oheight, owidth)).__repr__()

Expand All @@ -70,7 +70,7 @@ def test_center_crop_video(self):

msg = "height: " + str(height) + " width: " \
+ str(width) + " oheight: " + str(oheight) + " owidth: " + str(owidth)
self.assertEqual(result.sum().item(), 0, msg)
assert result.sum().item() == 0, msg

oheight += 1
owidth += 1
Expand All @@ -82,7 +82,7 @@ def test_center_crop_video(self):

msg = "height: " + str(height) + " width: " \
+ str(width) + " oheight: " + str(oheight) + " owidth: " + str(owidth)
self.assertEqual(sum1.item() > 1, True, msg)
assert sum1.item() > 1, msg

oheight += 1
owidth += 1
Expand All @@ -94,28 +94,29 @@ def test_center_crop_video(self):

msg = "height: " + str(height) + " width: " \
+ str(width) + " oheight: " + str(oheight) + " owidth: " + str(owidth)
self.assertTrue(sum2.item() > 1, msg)
self.assertTrue(sum2.item() > sum1.item(), msg)
assert sum2.item() > 1, msg
assert sum2.item() > sum1.item(), msg

@unittest.skipIf(stats is None, 'scipy.stats is not available')
def test_normalize_video(self):
@pytest.mark.skipif(stats is None, reason='scipy.stats is not available')
@pytest.mark.parametrize('channels', [1, 3])
def test_normalize_video(self, channels):
def samples_from_standard_normal(tensor):
p_value = stats.kstest(list(tensor.view(-1)), 'norm', args=(0, 1)).pvalue
return p_value > 0.0001

random_state = random.getstate()
random.seed(42)
for channels in [1, 3]:
numFrames = random.randint(4, 128)
height = random.randint(32, 256)
width = random.randint(32, 256)
mean = random.random()
std = random.random()
clip = torch.normal(mean, std, size=(channels, numFrames, height, width))
mean = [clip[c].mean().item() for c in range(channels)]
std = [clip[c].std().item() for c in range(channels)]
normalized = transforms.NormalizeVideo(mean, std)(clip)
self.assertTrue(samples_from_standard_normal(normalized))

numFrames = random.randint(4, 128)
height = random.randint(32, 256)
width = random.randint(32, 256)
mean = random.random()
std = random.random()
clip = torch.normal(mean, std, size=(channels, numFrames, height, width))
mean = [clip[c].mean().item() for c in range(channels)]
std = [clip[c].std().item() for c in range(channels)]
normalized = transforms.NormalizeVideo(mean, std)(clip)
assert samples_from_standard_normal(normalized)
random.setstate(random_state)

# Checking the optional in-place behaviour
Expand All @@ -129,19 +130,19 @@ def test_to_tensor_video(self):
numFrames, height, width = 64, 4, 4
trans = transforms.ToTensorVideo()

with self.assertRaises(TypeError):
with pytest.raises(TypeError):
trans(np.random.rand(numFrames, height, width, 1).tolist())
trans(torch.rand((numFrames, height, width, 1), dtype=torch.float))

with self.assertRaises(ValueError):
with pytest.raises(ValueError):
trans(torch.ones((3, numFrames, height, width, 3), dtype=torch.uint8))
trans(torch.ones((height, width, 3), dtype=torch.uint8))
trans(torch.ones((width, 3), dtype=torch.uint8))
trans(torch.ones((3), dtype=torch.uint8))

trans.__repr__()

@unittest.skipIf(stats is None, 'scipy.stats not available')
@pytest.mark.skipif(stats is None, reason='scipy.stats not available')
def test_random_horizontal_flip_video(self):
random_state = random.getstate()
random.seed(42)
Expand All @@ -157,7 +158,7 @@ def test_random_horizontal_flip_video(self):

p_value = stats.binom_test(num_horizontal, num_samples, p=0.5)
random.setstate(random_state)
self.assertGreater(p_value, 0.0001)
assert p_value > 0.0001

num_samples = 250
num_horizontal = 0
Expand All @@ -168,10 +169,10 @@ def test_random_horizontal_flip_video(self):

p_value = stats.binom_test(num_horizontal, num_samples, p=0.7)
random.setstate(random_state)
self.assertGreater(p_value, 0.0001)
assert p_value > 0.0001

transforms.RandomHorizontalFlipVideo().__repr__()


if __name__ == '__main__':
unittest.main()
pytest.main([__file__])