From a2c3180c5381754827051fba373041b4012b6a5a Mon Sep 17 00:00:00 2001 From: Will Price Date: Tue, 8 Jan 2019 14:21:39 +0000 Subject: [PATCH] Support normalising tensors with >=2 dimensions --- src/torchvideo/transforms/functional.py | 18 +++++++++++++----- tests/unit/test_transforms.py | 13 ++++++++++--- 2 files changed, 23 insertions(+), 8 deletions(-) diff --git a/src/torchvideo/transforms/functional.py b/src/torchvideo/transforms/functional.py index 2bac714..6b21a4c 100644 --- a/src/torchvideo/transforms/functional.py +++ b/src/torchvideo/transforms/functional.py @@ -24,20 +24,28 @@ def normalize( """ channel_count = tensor.shape[0] + if len(mean) != len(std): + raise ValueError( + "Expected mean and std to be of the same length, but were " + "{} and {} respectively".format(len(mean), len(std)) + ) if len(mean) != channel_count: raise ValueError( - "Expected mean to be the same length as the number of " "channels" + "Expected mean to be the same length, {}, as the number of channels" + "{}".format(len(mean), channel_count) ) if len(std) != channel_count: raise ValueError( - "Expected std to be the same length as the number of " "channels" + "Expected std to be the same length, {}, as the number of channels, " + "{}".format(len(std), channel_count) ) if not inplace: tensor = tensor.clone() - mean = torch.tensor(mean, dtype=torch.float32) - std = torch.tensor(std, dtype=torch.float32) - tensor.sub_(mean[:, None, None, None]).div_(std[:, None, None, None]) + statistic_shape = tuple([-1] + [1] * ((tensor.dim() - 1))) + mean = torch.tensor(mean, dtype=torch.float32).view(*statistic_shape) + std = torch.tensor(std, dtype=torch.float32).view(*statistic_shape) + tensor.sub_(mean).div_(std) return tensor diff --git a/tests/unit/test_transforms.py b/tests/unit/test_transforms.py index b464f4e..65ab18a 100644 --- a/tests/unit/test_transforms.py +++ b/tests/unit/test_transforms.py @@ -145,7 +145,8 @@ def test_raises_value_error_on_0_element_in_std_vector(self): NormalizeVideo([10, 10], [5, 0]) @pytest.mark.skipif(stats is None, reason="scipy.stats is not available") - def test_distribution_is_normal_after_transform(self): + @given(st.integers(2, 4)) + def test_distribution_is_normal_after_transform(self, ndim): """Basically a direct copy of https://github.com/pytorch/vision/blob/master/test/test_transforms.py#L753""" @@ -155,8 +156,14 @@ def kstest(tensor): p_value = 0.0001 for channel_count in [1, 3]: - # video is uniformly distributed in [0, 1] - video = torch.randn(channel_count, 5, 10, 10) * 10 + 5 + # video is normally distributed ~ N(5, 10) + if ndim == 2: + shape = [channel_count, 500] + elif ndim == 3: + shape = [channel_count, 10, 50] + else: + shape = [channel_count, 5, 10, 10] + video = torch.randn(*shape) * 10 + 5 # We want the video not to be sampled from N(0, 1) # i.e. we want to reject the null hypothesis that video is from this # distribution