From 22e6ff3bfa4fef63e9cc194b6bd10b4e46e0ded5 Mon Sep 17 00:00:00 2001 From: Alykhan Tejani Date: Thu, 12 Oct 2017 17:49:47 +0100 Subject: [PATCH 1/2] added tests for transforms.Normalize --- test/test_transforms.py | 16 ++++++++++++++++ 1 file changed, 16 insertions(+) diff --git a/test/test_transforms.py b/test/test_transforms.py index 8ede5dd7bea..8f9167f6deb 100644 --- a/test/test_transforms.py +++ b/test/test_transforms.py @@ -430,6 +430,22 @@ def test_random_horizontal_flip(self): random.setstate(random_state) assert p_value > 0.0001 + @unittest.skipIf(stats is None, 'scipt.stats is not available') + def test_normalize(self): + def samples_from_standard_normal(tensor): + p_value = stats.kstest(list(tensor.view(-1)), 'norm', args=(0, 1)).pvalue + return p_value > 0.0001 + + random_state = random.getstate() + random.seed(42) + for channels in [1, 3]: + img = torch.rand(channels, 10, 10) + mean = [img[c].mean() for c in range(channels)] + std = [img[c].std() for c in range(channels)] + normalized = transforms.Normalize(mean, std)(img) + assert samples_from_standard_normal(normalized) + random.setstate(random_state) + def test_adjust_brightness(self): x_shape = [2, 2, 3] x_data = [0, 5, 13, 54, 135, 226, 37, 8, 234, 90, 255, 1] From 38d404156211d5d088cfdb81584de49d53094e13 Mon Sep 17 00:00:00 2001 From: Alykhan Tejani Date: Thu, 12 Oct 2017 17:57:07 +0100 Subject: [PATCH 2/2] add docs for the Normalize transform --- torchvision/transforms.py | 21 +++++++++------------ 1 file changed, 9 insertions(+), 12 deletions(-) diff --git a/torchvision/transforms.py b/torchvision/transforms.py index 24a85e89afe..50b243d8275 100644 --- a/torchvision/transforms.py +++ b/torchvision/transforms.py @@ -127,12 +127,11 @@ def normalize(tensor, mean, std): Args: tensor (Tensor): Tensor image of size (C, H, W) to be normalized. - mean (sequence): Sequence of means for R, G, B channels respecitvely. - std (sequence): Sequence of standard deviations for R, G, B channels - respecitvely. + mean (sequence): Sequence of means for each channel. + std (sequence): Sequence of standard deviations for each channely. Returns: - Tensor: Normalized image. + Tensor: Normalized Tensor image. """ if not _is_tensor_image(tensor): raise TypeError('tensor is not a torch image.') @@ -557,15 +556,13 @@ def __call__(self, pic): class Normalize(object): """Normalize an tensor image with mean and standard deviation. - - Given mean: (R, G, B) and std: (R, G, B), - will normalize each channel of the torch.*Tensor, i.e. - channel = (channel - mean) / std + Given mean: ``(M1,...,Mn)`` and std: ``(M1,..,Mn)`` for ``n`` channels, this transform + will normalize each channel of the input ``torch.*Tensor`` i.e. + ``input[channel] = (input[channel] - mean[channel]) / std[channel]`` Args: - mean (sequence): Sequence of means for R, G, B channels respecitvely. - std (sequence): Sequence of standard deviations for R, G, B channels - respecitvely. + mean (sequence): Sequence of means for each channel. + std (sequence): Sequence of standard deviations for each channel. """ def __init__(self, mean, std): @@ -578,7 +575,7 @@ def __call__(self, tensor): tensor (Tensor): Tensor image of size (C, H, W) to be normalized. Returns: - Tensor: Normalized image. + Tensor: Normalized Tensor image. """ return normalize(tensor, self.mean, self.std)