diff --git a/torchvision/transforms/functional.py b/torchvision/transforms/functional.py index 5b762ff2975..4b93f406152 100644 --- a/torchvision/transforms/functional.py +++ b/torchvision/transforms/functional.py @@ -2,7 +2,7 @@ import numbers import warnings from enum import Enum -from typing import List, Tuple, Any, Optional +from typing import List, Tuple, Any, Optional, Union import numpy as np import torch @@ -334,7 +334,9 @@ def to_pil_image(pic, mode=None): return Image.fromarray(npimg, mode=mode) -def normalize(tensor: Tensor, mean: List[float], std: List[float], inplace: bool = False) -> Tensor: +def normalize( + tensor: Tensor, mean: Union[float, List[float]], std: Union[float, List[float]], inplace: bool = False +) -> Tensor: """Normalize a float tensor image with mean and standard deviation. This transform does not support PIL Image. diff --git a/torchvision/transforms/functional_tensor.py b/torchvision/transforms/functional_tensor.py index d8671405b96..e495fd5d7f9 100644 --- a/torchvision/transforms/functional_tensor.py +++ b/torchvision/transforms/functional_tensor.py @@ -1,5 +1,5 @@ import warnings -from typing import Optional, Tuple, List +from typing import Optional, Tuple, List, Union import torch from torch import Tensor @@ -925,7 +925,9 @@ def equalize(img: Tensor) -> Tensor: return torch.stack([_equalize_single_image(x) for x in img]) -def normalize(tensor: Tensor, mean: List[float], std: List[float], inplace: bool = False) -> Tensor: +def normalize( + tensor: Tensor, mean: Union[float, List[float]], std: Union[float, List[float]], inplace: bool = False +) -> Tensor: _assert_image_tensor(tensor) if not tensor.is_floating_point(): @@ -939,6 +941,13 @@ def normalize(tensor: Tensor, mean: List[float], std: List[float], inplace: bool if not inplace: tensor = tensor.clone() + # Make sure the type of mean and std are List[float] + # Otherwise it will error on the torch.as_tensor call + if isinstance(mean, float): + mean = [mean] + if isinstance(std, float): + std = [std] + dtype = tensor.dtype mean = torch.as_tensor(mean, dtype=dtype, device=tensor.device) std = torch.as_tensor(std, dtype=dtype, device=tensor.device) diff --git a/torchvision/transforms/transforms.py b/torchvision/transforms/transforms.py index 2324acdd592..7ca0c37dd06 100644 --- a/torchvision/transforms/transforms.py +++ b/torchvision/transforms/transforms.py @@ -257,6 +257,12 @@ def __init__(self, mean, std, inplace=False): _log_api_usage_once(self) self.mean = mean self.std = std + + if isinstance(mean, Sequence): + self.mean = list(mean) + if isinstance(std, Sequence): + self.std = list(std) + self.inplace = inplace def forward(self, tensor: Tensor) -> Tensor: