Skip to content

Commit

Permalink
Support normalising tensors with >=2 dimensions
Browse files Browse the repository at this point in the history
  • Loading branch information
willprice committed Jan 8, 2019
1 parent 5a3824b commit a2c3180
Show file tree
Hide file tree
Showing 2 changed files with 23 additions and 8 deletions.
18 changes: 13 additions & 5 deletions src/torchvideo/transforms/functional.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,20 +24,28 @@ def normalize(
"""
channel_count = tensor.shape[0]
if len(mean) != len(std):
raise ValueError(
"Expected mean and std to be of the same length, but were "
"{} and {} respectively".format(len(mean), len(std))
)
if len(mean) != channel_count:
raise ValueError(
"Expected mean to be the same length as the number of " "channels"
"Expected mean to be the same length, {}, as the number of channels"
"{}".format(len(mean), channel_count)
)
if len(std) != channel_count:
raise ValueError(
"Expected std to be the same length as the number of " "channels"
"Expected std to be the same length, {}, as the number of channels, "
"{}".format(len(std), channel_count)
)
if not inplace:
tensor = tensor.clone()

mean = torch.tensor(mean, dtype=torch.float32)
std = torch.tensor(std, dtype=torch.float32)
tensor.sub_(mean[:, None, None, None]).div_(std[:, None, None, None])
statistic_shape = tuple([-1] + [1] * ((tensor.dim() - 1)))
mean = torch.tensor(mean, dtype=torch.float32).view(*statistic_shape)
std = torch.tensor(std, dtype=torch.float32).view(*statistic_shape)
tensor.sub_(mean).div_(std)
return tensor


Expand Down
13 changes: 10 additions & 3 deletions tests/unit/test_transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -145,7 +145,8 @@ def test_raises_value_error_on_0_element_in_std_vector(self):
NormalizeVideo([10, 10], [5, 0])

@pytest.mark.skipif(stats is None, reason="scipy.stats is not available")
def test_distribution_is_normal_after_transform(self):
@given(st.integers(2, 4))
def test_distribution_is_normal_after_transform(self, ndim):
"""Basically a direct copy of
https://github.com/pytorch/vision/blob/master/test/test_transforms.py#L753"""

Expand All @@ -155,8 +156,14 @@ def kstest(tensor):

p_value = 0.0001
for channel_count in [1, 3]:
# video is uniformly distributed in [0, 1]
video = torch.randn(channel_count, 5, 10, 10) * 10 + 5
# video is normally distributed ~ N(5, 10)
if ndim == 2:
shape = [channel_count, 500]
elif ndim == 3:
shape = [channel_count, 10, 50]
else:
shape = [channel_count, 5, 10, 10]
video = torch.randn(*shape) * 10 + 5
# We want the video not to be sampled from N(0, 1)
# i.e. we want to reject the null hypothesis that video is from this
# distribution
Expand Down

0 comments on commit a2c3180

Please sign in to comment.