From 7a138ad70df6162d1d3218a575949a0c15506530 Mon Sep 17 00:00:00 2001 From: Vasilis Vryniotis Date: Fri, 11 Mar 2022 13:24:30 +0000 Subject: [PATCH] Better handling for Pad's fill argument --- test/test_transforms.py | 6 +++--- torchvision/transforms/functional_pil.py | 8 +++++++- torchvision/transforms/transforms.py | 2 +- 3 files changed, 11 insertions(+), 5 deletions(-) diff --git a/test/test_transforms.py b/test/test_transforms.py index c2e93ec497e..c0cc8ea71fe 100644 --- a/test/test_transforms.py +++ b/test/test_transforms.py @@ -452,12 +452,12 @@ def test_resize_size_equals_small_edge_size(height, width): class TestPad: - def test_pad(self): + @pytest.mark.parametrize("fill", [85, 85.0]) + def test_pad(self, fill): height = random.randint(10, 32) * 2 width = random.randint(10, 32) * 2 img = torch.ones(3, height, width, dtype=torch.uint8) padding = random.randint(1, 20) - fill = random.randint(1, 50) result = transforms.Compose( [ transforms.ToPILImage(), @@ -484,7 +484,7 @@ def test_pad_with_tuple_of_pad_values(self): output = transforms.Pad(padding)(img) assert output.size == (width + padding[0] * 2, height + padding[1] * 2) - padding = tuple(random.randint(1, 20) for _ in range(4)) + padding = [random.randint(1, 20) for _ in range(4)] output = transforms.Pad(padding)(img) assert output.size[0] == width + padding[0] + padding[2] assert output.size[1] == height + padding[1] + padding[3] diff --git a/torchvision/transforms/functional_pil.py b/torchvision/transforms/functional_pil.py index 261c4000bac..ba0159a1123 100644 --- a/torchvision/transforms/functional_pil.py +++ b/torchvision/transforms/functional_pil.py @@ -154,7 +154,7 @@ def pad( if not isinstance(padding, (numbers.Number, tuple, list)): raise TypeError("Got inappropriate padding arg") - if not isinstance(fill, (numbers.Number, str, tuple)): + if not isinstance(fill, (numbers.Number, str, tuple, list)): raise TypeError("Got inappropriate fill arg") if not isinstance(padding_mode, str): raise TypeError("Got inappropriate padding_mode arg") @@ -301,6 +301,12 @@ def _parse_fill( fill = tuple(fill) + if img.mode != "F": + if isinstance(fill, (list, tuple)): + fill = tuple(int(x) for x in fill) + else: + fill = int(fill) + return {name: fill} diff --git a/torchvision/transforms/transforms.py b/torchvision/transforms/transforms.py index 2324acdd592..a95977ad704 100644 --- a/torchvision/transforms/transforms.py +++ b/torchvision/transforms/transforms.py @@ -428,7 +428,7 @@ def __init__(self, padding, fill=0, padding_mode="constant"): if not isinstance(padding, (numbers.Number, tuple, list)): raise TypeError("Got inappropriate padding arg") - if not isinstance(fill, (numbers.Number, str, tuple)): + if not isinstance(fill, (numbers.Number, str, tuple, list)): raise TypeError("Got inappropriate fill arg") if padding_mode not in ["constant", "edge", "reflect", "symmetric"]: