Skip to content
6 changes: 4 additions & 2 deletions torchvision/transforms/functional.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import numbers
import warnings
from enum import Enum
from typing import List, Tuple, Any, Optional
from typing import List, Tuple, Any, Optional, Union

import numpy as np
import torch
Expand Down Expand Up @@ -334,7 +334,9 @@ def to_pil_image(pic, mode=None):
return Image.fromarray(npimg, mode=mode)


def normalize(tensor: Tensor, mean: List[float], std: List[float], inplace: bool = False) -> Tensor:
def normalize(
tensor: Tensor, mean: Union[float, List[float]], std: Union[float, List[float]], inplace: bool = False
) -> Tensor:
"""Normalize a float tensor image with mean and standard deviation.
This transform does not support PIL Image.

Expand Down
13 changes: 11 additions & 2 deletions torchvision/transforms/functional_tensor.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import warnings
from typing import Optional, Tuple, List
from typing import Optional, Tuple, List, Union

import torch
from torch import Tensor
Expand Down Expand Up @@ -925,7 +925,9 @@ def equalize(img: Tensor) -> Tensor:
return torch.stack([_equalize_single_image(x) for x in img])


def normalize(tensor: Tensor, mean: List[float], std: List[float], inplace: bool = False) -> Tensor:
def normalize(
tensor: Tensor, mean: Union[float, List[float]], std: Union[float, List[float]], inplace: bool = False
) -> Tensor:
_assert_image_tensor(tensor)

if not tensor.is_floating_point():
Expand All @@ -939,6 +941,13 @@ def normalize(tensor: Tensor, mean: List[float], std: List[float], inplace: bool
if not inplace:
tensor = tensor.clone()

# Make sure the type of mean and std are List[float]
# Otherwise it will error on the torch.as_tensor call
if isinstance(mean, float):
mean = [mean]
if isinstance(std, float):
std = [std]

dtype = tensor.dtype
mean = torch.as_tensor(mean, dtype=dtype, device=tensor.device)
std = torch.as_tensor(std, dtype=dtype, device=tensor.device)
Expand Down
6 changes: 6 additions & 0 deletions torchvision/transforms/transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -257,6 +257,12 @@ def __init__(self, mean, std, inplace=False):
_log_api_usage_once(self)
self.mean = mean
self.std = std

if isinstance(mean, Sequence):
self.mean = list(mean)
if isinstance(std, Sequence):
self.std = list(std)

self.inplace = inplace

def forward(self, tensor: Tensor) -> Tensor:
Expand Down