diff --git a/test/test_functional_tensor.py b/test/test_functional_tensor.py index b7e27a808bf..649fc5ed1b0 100644 --- a/test/test_functional_tensor.py +++ b/test/test_functional_tensor.py @@ -972,7 +972,7 @@ def test_adjust_gamma(device, dtype, config, channels): [ {"padding_mode": "constant", "fill": 0}, {"padding_mode": "constant", "fill": 10}, - {"padding_mode": "constant", "fill": 20}, + {"padding_mode": "constant", "fill": 20.2}, {"padding_mode": "edge"}, {"padding_mode": "reflect"}, {"padding_mode": "symmetric"}, diff --git a/torchvision/transforms/functional.py b/torchvision/transforms/functional.py index 2a4a7f1b6dd..eea53a228a9 100644 --- a/torchvision/transforms/functional.py +++ b/torchvision/transforms/functional.py @@ -2,7 +2,7 @@ import numbers import warnings from enum import Enum -from typing import List, Tuple, Any, Optional +from typing import List, Tuple, Any, Optional, Union import numpy as np import torch @@ -474,7 +474,7 @@ def resize( return F_t.resize(img, size=output_size, interpolation=interpolation.value, antialias=antialias) -def pad(img: Tensor, padding: List[int], fill: int = 0, padding_mode: str = "constant") -> Tensor: +def pad(img: Tensor, padding: List[int], fill: Union[int, float] = 0, padding_mode: str = "constant") -> Tensor: r"""Pad the given image on all sides with the given "pad" value. If the image is torch Tensor, it is expected to have [..., H, W] shape, where ... means at most 2 leading dimensions for mode reflect and symmetric, diff --git a/torchvision/transforms/functional_tensor.py b/torchvision/transforms/functional_tensor.py index acc8d3ae3e1..50e41647af1 100644 --- a/torchvision/transforms/functional_tensor.py +++ b/torchvision/transforms/functional_tensor.py @@ -1,5 +1,5 @@ import warnings -from typing import Optional, Tuple, List +from typing import Optional, Tuple, List, Union import torch from torch import Tensor @@ -370,7 +370,7 @@ def _parse_pad_padding(padding: List[int]) -> List[int]: return [pad_left, pad_right, pad_top, pad_bottom] -def pad(img: Tensor, padding: List[int], fill: int = 0, padding_mode: str = "constant") -> Tensor: +def pad(img: Tensor, padding: List[int], fill: Union[int, float] = 0, padding_mode: str = "constant") -> Tensor: _assert_image_tensor(img) if not isinstance(padding, (int, tuple, list)):