diff --git a/torchvision/transforms/transforms.py b/torchvision/transforms/transforms.py index 903e2b6cf5b..310baf14938 100644 --- a/torchvision/transforms/transforms.py +++ b/torchvision/transforms/transforms.py @@ -94,7 +94,7 @@ def __call__(self, pic): class Normalize(object): """Normalize an tensor image with mean and standard deviation. - Given mean: ``(M1,...,Mn)`` and std: ``(M1,..,Mn)`` for ``n`` channels, this transform + Given mean: ``(M1,...,Mn)`` and std: ``(S1,..,Sn)`` for ``n`` channels, this transform will normalize each channel of the input ``torch.*Tensor`` i.e. ``input[channel] = (input[channel] - mean[channel]) / std[channel]``