From ea7c513ff69dedface1d5e4c1708ae5b6ebe9fdf Mon Sep 17 00:00:00 2001 From: vfdev-5 Date: Wed, 22 Jun 2022 15:36:11 +0000 Subject: [PATCH 1/4] Refactored and modified private api for resize functional op --- test/test_transforms_tensor.py | 20 +++------ torchvision/transforms/functional.py | 46 +++++++++++++++++++- torchvision/transforms/functional_pil.py | 34 +-------------- torchvision/transforms/functional_tensor.py | 47 +-------------------- 4 files changed, 52 insertions(+), 95 deletions(-) diff --git a/test/test_transforms_tensor.py b/test/test_transforms_tensor.py index ba2321ec455..f0cd3ba0021 100644 --- a/test/test_transforms_tensor.py +++ b/test/test_transforms_tensor.py @@ -394,9 +394,7 @@ def test_resize_int(self, size): @pytest.mark.parametrize( "size", [ - [ - 32, - ], + [32], [32, 32], (32, 32), [34, 35], @@ -412,7 +410,7 @@ def test_resize_scripted(self, dt, size, max_size, interpolation, device): # This is a trivial cast to float of uint8 data to test all cases tensor = tensor.to(dt) if max_size is not None and len(size) != 1: - pytest.xfail("with max_size, size must be a sequence with 2 elements") + pytest.skip("Size should be an int or a sequence of length 1 if max_size is specified") transform = T.Resize(size=size, interpolation=interpolation, max_size=max_size) s_transform = torch.jit.script(transform) @@ -420,11 +418,7 @@ def test_resize_scripted(self, dt, size, max_size, interpolation, device): _test_transform_vs_scripted_on_batch(transform, s_transform, batch_tensors) def test_resize_save(self, tmpdir): - transform = T.Resize( - size=[ - 32, - ] - ) + transform = T.Resize(size=[32]) s_transform = torch.jit.script(transform) s_transform.save(os.path.join(tmpdir, "t_resize.pt")) @@ -435,12 +429,8 @@ def test_resize_save(self, tmpdir): "size", [ (32,), - [ - 44, - ], - [ - 32, - ], + [44], + [32], [32, 32], (32, 32), [44, 55], diff --git a/torchvision/transforms/functional.py b/torchvision/transforms/functional.py index c40ae1eb92b..609c64ad4ff 100644 --- a/torchvision/transforms/functional.py +++ b/torchvision/transforms/functional.py @@ -360,6 +360,31 @@ def normalize(tensor: Tensor, mean: List[float], std: List[float], inplace: bool return F_t.normalize(tensor, mean=mean, std=std, inplace=inplace) +def _compute_output_size( + image_size: Tuple[int, int], size: List[int], max_size: Optional[int] = None +) -> Tuple[int, int]: + if isinstance(size, int) or len(size) == 1: # specified size only for the smallest edge + h, w = image_size + short, long = (w, h) if w <= h else (h, w) + requested_new_short = size if isinstance(size, int) else size[0] + + new_short, new_long = requested_new_short, int(requested_new_short * long / short) + + if max_size is not None: + if max_size <= requested_new_short: + raise ValueError( + f"max_size = {max_size} must be strictly greater than the requested " + f"size for the smaller edge size = {size}" + ) + if new_long > max_size: + new_short, new_long = int(max_size * new_short / new_long), max_size + + new_w, new_h = (new_short, new_long) if w <= h else (new_long, new_short) + else: # specified both h and w + new_w, new_h = size[1], size[0] + return new_h, new_w + + def resize( img: Tensor, size: List[int], @@ -423,13 +448,30 @@ def resize( if not isinstance(interpolation, InterpolationMode): raise TypeError("Argument interpolation should be a InterpolationMode") + if isinstance(size, (list, tuple)): + if len(size) not in [1, 2]: + raise ValueError( + f"Size must be an int or a 1 or 2 element tuple/list, not a {len(size)} element tuple/list" + ) + if max_size is not None and len(size) != 1: + raise ValueError( + "max_size should only be passed if size specifies the length of the smaller edge, " + "i.e. size should be an int or a sequence of length 1 in torchscript mode." + ) + + _, image_height, image_width = get_dimensions(img) + output_size = _compute_output_size((image_height, image_width), size, max_size) + + if (image_height, image_width) == output_size: + return img + if not isinstance(img, torch.Tensor): if antialias is not None and not antialias: warnings.warn("Anti-alias option is always applied for PIL Image input. Argument antialias is ignored.") pil_interpolation = pil_modes_mapping[interpolation] - return F_pil.resize(img, size=size, interpolation=pil_interpolation, max_size=max_size) + return F_pil.resize(img, size=output_size, interpolation=pil_interpolation) - return F_t.resize(img, size=size, interpolation=interpolation.value, max_size=max_size, antialias=antialias) + return F_t.resize(img, size=output_size, interpolation=interpolation.value, antialias=antialias) def pad(img: Tensor, padding: List[int], fill: int = 0, padding_mode: str = "constant") -> Tensor: diff --git a/torchvision/transforms/functional_pil.py b/torchvision/transforms/functional_pil.py index 0203ee4495b..3c1a911a5d4 100644 --- a/torchvision/transforms/functional_pil.py +++ b/torchvision/transforms/functional_pil.py @@ -242,44 +242,14 @@ def resize( img: Image.Image, size: Union[Sequence[int], int], interpolation: int = _pil_constants.BILINEAR, - max_size: Optional[int] = None, ) -> Image.Image: if not _is_pil_image(img): raise TypeError(f"img should be PIL Image. Got {type(img)}") - if not (isinstance(size, int) or (isinstance(size, Sequence) and len(size) in (1, 2))): + if not (isinstance(size, Sequence) and len(size) == 2): raise TypeError(f"Got inappropriate size arg: {size}") - if isinstance(size, Sequence) and len(size) == 1: - size = size[0] - if isinstance(size, int): - w, h = img.size - - short, long = (w, h) if w <= h else (h, w) - new_short, new_long = size, int(size * long / short) - - if max_size is not None: - if max_size <= size: - raise ValueError( - f"max_size = {max_size} must be strictly greater than the requested " - f"size for the smaller edge size = {size}" - ) - if new_long > max_size: - new_short, new_long = int(max_size * new_short / new_long), max_size - - new_w, new_h = (new_short, new_long) if w <= h else (new_long, new_short) - - if (w, h) == (new_w, new_h): - return img - else: - return img.resize((new_w, new_h), interpolation) - else: - if max_size is not None: - raise ValueError( - "max_size should only be passed if size specifies the length of the smaller edge, " - "i.e. size should be an int or a sequence of length 1 in torchscript mode." - ) - return img.resize(size[::-1], interpolation) + return img.resize(size[::-1], interpolation) @torch.jit.unused diff --git a/torchvision/transforms/functional_tensor.py b/torchvision/transforms/functional_tensor.py index 1899caebfc3..acc8d3ae3e1 100644 --- a/torchvision/transforms/functional_tensor.py +++ b/torchvision/transforms/functional_tensor.py @@ -430,70 +430,25 @@ def resize( img: Tensor, size: List[int], interpolation: str = "bilinear", - max_size: Optional[int] = None, antialias: Optional[bool] = None, ) -> Tensor: _assert_image_tensor(img) - if not isinstance(size, (int, tuple, list)): - raise TypeError("Got inappropriate size arg") - if not isinstance(interpolation, str): - raise TypeError("Got inappropriate interpolation arg") - - if interpolation not in ["nearest", "bilinear", "bicubic"]: - raise ValueError("This interpolation mode is unsupported with Tensor input") - if isinstance(size, tuple): size = list(size) - if isinstance(size, list): - if len(size) not in [1, 2]: - raise ValueError( - f"Size must be an int or a 1 or 2 element tuple/list, not a {len(size)} element tuple/list" - ) - if max_size is not None and len(size) != 1: - raise ValueError( - "max_size should only be passed if size specifies the length of the smaller edge, " - "i.e. size should be an int or a sequence of length 1 in torchscript mode." - ) - if antialias is None: antialias = False if antialias and interpolation not in ["bilinear", "bicubic"]: raise ValueError("Antialias option is supported for bilinear and bicubic interpolation modes only") - _, h, w = get_dimensions(img) - - if isinstance(size, int) or len(size) == 1: # specified size only for the smallest edge - short, long = (w, h) if w <= h else (h, w) - requested_new_short = size if isinstance(size, int) else size[0] - - new_short, new_long = requested_new_short, int(requested_new_short * long / short) - - if max_size is not None: - if max_size <= requested_new_short: - raise ValueError( - f"max_size = {max_size} must be strictly greater than the requested " - f"size for the smaller edge size = {size}" - ) - if new_long > max_size: - new_short, new_long = int(max_size * new_short / new_long), max_size - - new_w, new_h = (new_short, new_long) if w <= h else (new_long, new_short) - - if (w, h) == (new_w, new_h): - return img - - else: # specified both h and w - new_w, new_h = size[1], size[0] - img, need_cast, need_squeeze, out_dtype = _cast_squeeze_in(img, [torch.float32, torch.float64]) # Define align_corners to avoid warnings align_corners = False if interpolation in ["bilinear", "bicubic"] else None - img = interpolate(img, size=[new_h, new_w], mode=interpolation, align_corners=align_corners, antialias=antialias) + img = interpolate(img, size=size, mode=interpolation, align_corners=align_corners, antialias=antialias) if interpolation == "bicubic" and out_dtype == torch.uint8: img = img.clamp(min=0, max=255) From aade78f8a7bf36dbe70aaca7afcd2abf546d3ccb Mon Sep 17 00:00:00 2001 From: vfdev-5 Date: Wed, 22 Jun 2022 16:07:31 +0000 Subject: [PATCH 2/4] Fixed failures --- torchvision/prototype/transforms/functional/_geometry.py | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/torchvision/prototype/transforms/functional/_geometry.py b/torchvision/prototype/transforms/functional/_geometry.py index ac0e8e0eb13..f1d51fded82 100644 --- a/torchvision/prototype/transforms/functional/_geometry.py +++ b/torchvision/prototype/transforms/functional/_geometry.py @@ -42,6 +42,8 @@ def resize_image_tensor( max_size: Optional[int] = None, antialias: Optional[bool] = None, ) -> torch.Tensor: + # TODO: use _compute_output_size to enable max_size option + max_size # ununsed right now new_height, new_width = size num_channels, old_height, old_width = get_dimensions_image_tensor(image) batch_shape = image.shape[:-3] @@ -49,7 +51,6 @@ def resize_image_tensor( image.reshape((-1, num_channels, old_height, old_width)), size=size, interpolation=interpolation.value, - max_size=max_size, antialias=antialias, ).reshape(batch_shape + (num_channels, new_height, new_width)) @@ -60,7 +61,9 @@ def resize_image_pil( interpolation: InterpolationMode = InterpolationMode.BILINEAR, max_size: Optional[int] = None, ) -> PIL.Image.Image: - return _FP.resize(img, size, interpolation=pil_modes_mapping[interpolation], max_size=max_size) + # TODO: use _compute_output_size to enable max_size option + max_size # ununsed right now + return _FP.resize(img, size, interpolation=pil_modes_mapping[interpolation]) def resize_segmentation_mask( From a812a3bcdcca6ca8d7af79330220f1344ae89aa9 Mon Sep 17 00:00:00 2001 From: vfdev-5 Date: Wed, 22 Jun 2022 20:57:14 +0000 Subject: [PATCH 3/4] More updates --- torchvision/transforms/functional.py | 12 ++++++------ torchvision/transforms/functional_pil.py | 4 ++-- 2 files changed, 8 insertions(+), 8 deletions(-) diff --git a/torchvision/transforms/functional.py b/torchvision/transforms/functional.py index 609c64ad4ff..77feadc51f1 100644 --- a/torchvision/transforms/functional.py +++ b/torchvision/transforms/functional.py @@ -2,7 +2,7 @@ import numbers import warnings from enum import Enum -from typing import List, Tuple, Any, Optional +from typing import List, Tuple, Any, Optional, Union import numpy as np import torch @@ -360,10 +360,8 @@ def normalize(tensor: Tensor, mean: List[float], std: List[float], inplace: bool return F_t.normalize(tensor, mean=mean, std=std, inplace=inplace) -def _compute_output_size( - image_size: Tuple[int, int], size: List[int], max_size: Optional[int] = None -) -> Tuple[int, int]: - if isinstance(size, int) or len(size) == 1: # specified size only for the smallest edge +def _compute_output_size(image_size: Tuple[int, int], size: List[int], max_size: Optional[int] = None) -> List[int]: + if len(size) == 1: # specified size only for the smallest edge h, w = image_size short, long = (w, h) if w <= h else (h, w) requested_new_short = size if isinstance(size, int) else size[0] @@ -382,7 +380,7 @@ def _compute_output_size( new_w, new_h = (new_short, new_long) if w <= h else (new_long, new_short) else: # specified both h and w new_w, new_h = size[1], size[0] - return new_h, new_w + return [new_h, new_w] def resize( @@ -460,6 +458,8 @@ def resize( ) _, image_height, image_width = get_dimensions(img) + if isinstance(size, int): + size = [size] output_size = _compute_output_size((image_height, image_width), size, max_size) if (image_height, image_width) == output_size: diff --git a/torchvision/transforms/functional_pil.py b/torchvision/transforms/functional_pil.py index 3c1a911a5d4..7ebd9f71588 100644 --- a/torchvision/transforms/functional_pil.py +++ b/torchvision/transforms/functional_pil.py @@ -240,13 +240,13 @@ def crop( @torch.jit.unused def resize( img: Image.Image, - size: Union[Sequence[int], int], + size: Union[List[int], int], interpolation: int = _pil_constants.BILINEAR, ) -> Image.Image: if not _is_pil_image(img): raise TypeError(f"img should be PIL Image. Got {type(img)}") - if not (isinstance(size, Sequence) and len(size) == 2): + if not (isinstance(size, list) and len(size) == 2): raise TypeError(f"Got inappropriate size arg: {size}") return img.resize(size[::-1], interpolation) From 09728221bbc2a1e59d143e620d5df86033ac677d Mon Sep 17 00:00:00 2001 From: vfdev-5 Date: Wed, 22 Jun 2022 21:03:02 +0000 Subject: [PATCH 4/4] Fixed flake8 --- torchvision/transforms/functional.py | 2 +- torchvision/transforms/functional_pil.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/torchvision/transforms/functional.py b/torchvision/transforms/functional.py index 77feadc51f1..80444c31204 100644 --- a/torchvision/transforms/functional.py +++ b/torchvision/transforms/functional.py @@ -2,7 +2,7 @@ import numbers import warnings from enum import Enum -from typing import List, Tuple, Any, Optional, Union +from typing import List, Tuple, Any, Optional import numpy as np import torch diff --git a/torchvision/transforms/functional_pil.py b/torchvision/transforms/functional_pil.py index 7ebd9f71588..93bdeb8f308 100644 --- a/torchvision/transforms/functional_pil.py +++ b/torchvision/transforms/functional_pil.py @@ -1,5 +1,5 @@ import numbers -from typing import Any, Dict, List, Optional, Sequence, Tuple, Union +from typing import Any, Dict, List, Optional, Tuple, Union import numpy as np import torch