diff --git a/test/test_prototype_transforms.py b/test/test_prototype_transforms.py index 784394d2955..9734a5dc30a 100644 --- a/test/test_prototype_transforms.py +++ b/test/test_prototype_transforms.py @@ -715,30 +715,38 @@ def test__get_params(self, padding, pad_if_needed, size, mocker): if padding is not None: if isinstance(padding, int): - h += 2 * padding - w += 2 * padding + pad_top = pad_bottom = pad_left = pad_right = padding elif isinstance(padding, list) and len(padding) == 2: - w += 2 * padding[0] - h += 2 * padding[1] + pad_left = pad_right = padding[0] + pad_top = pad_bottom = padding[1] elif isinstance(padding, list) and len(padding) == 4: - w += padding[0] + padding[2] - h += padding[1] + padding[3] + pad_left, pad_top, pad_right, pad_bottom = padding - expected_input_width = w - expected_input_height = h + h += pad_top + pad_bottom + w += pad_left + pad_right + else: + pad_left = pad_right = pad_top = pad_bottom = 0 if pad_if_needed: if w < size[1]: - w += 2 * (size[1] - w) + diff = size[1] - w + pad_left += diff + pad_right += diff + w += 2 * diff if h < size[0]: - h += 2 * (size[0] - h) + diff = size[0] - h + pad_top += diff + pad_bottom += diff + h += 2 * diff + + padding = [pad_left, pad_top, pad_right, pad_bottom] assert 0 <= params["top"] <= h - size[0] + 1 assert 0 <= params["left"] <= w - size[1] + 1 assert params["height"] == size[0] assert params["width"] == size[1] - assert params["input_width"] == expected_input_width - assert params["input_height"] == expected_input_height + assert params["needs_pad"] is any(padding) + assert params["padding"] == padding @pytest.mark.parametrize("padding", [None, 1, [2, 3], [1, 2, 3, 4]]) @pytest.mark.parametrize("pad_if_needed", [False, True]) diff --git a/test/test_prototype_transforms_consistency.py b/test/test_prototype_transforms_consistency.py index 9e2e3051189..c8debe1e293 100644 --- a/test/test_prototype_transforms_consistency.py +++ b/test/test_prototype_transforms_consistency.py @@ -966,7 +966,7 @@ def _transform(self, inpt, params): class TestRefSegTransforms: def make_datapoints(self, supports_pil=True, image_dtype=torch.uint8): - size = (256, 640) + size = (256, 460) num_categories = 21 conv_fns = [] diff --git a/torchvision/prototype/transforms/_geometry.py b/torchvision/prototype/transforms/_geometry.py index df1f09fc192..008d4d195cb 100644 --- a/torchvision/prototype/transforms/_geometry.py +++ b/torchvision/prototype/transforms/_geometry.py @@ -414,78 +414,80 @@ def __init__( _check_padding_arg(padding) _check_padding_mode_arg(padding_mode) - self.padding = padding + self.padding = F._geometry._parse_pad_padding(padding) if padding else None # type: ignore[arg-type] self.pad_if_needed = pad_if_needed self.fill = _setup_fill_arg(fill) self.padding_mode = padding_mode def _get_params(self, sample: Any) -> Dict[str, Any]: - _, height, width = query_chw(sample) + _, padded_height, padded_width = query_chw(sample) if self.padding is not None: - # update height, width with static padding data - padding = self.padding - if isinstance(padding, Sequence): - padding = list(padding) - pad_left, pad_right, pad_top, pad_bottom = F._geometry._parse_pad_padding(padding) - height += pad_top + pad_bottom - width += pad_left + pad_right - - output_height, output_width = self.size - # We have to store maybe padded image size for pad_if_needed branch in _transform - input_height, input_width = height, width + pad_left, pad_right, pad_top, pad_bottom = self.padding + padded_height += pad_top + pad_bottom + padded_width += pad_left + pad_right + else: + pad_left = pad_right = pad_top = pad_bottom = 0 + + cropped_height, cropped_width = self.size if self.pad_if_needed: - # pad width if needed - if width < output_width: - width += 2 * (output_width - width) - # pad height if needed - if height < output_height: - height += 2 * (output_height - height) - - if height < output_height or width < output_width: + if padded_height < cropped_height: + diff = cropped_height - padded_height + + pad_top += diff + pad_bottom += diff + padded_height += 2 * diff + + if padded_width < cropped_width: + diff = cropped_width - padded_width + + pad_left += diff + pad_right += diff + padded_width += 2 * diff + + if padded_height < cropped_height or padded_width < cropped_width: raise ValueError( - f"Required crop size {(output_height, output_width)} is larger then input image size {(height, width)}" + f"Required crop size {(cropped_height, cropped_width)} is larger than " + f"{'padded ' if self.padding is not None else ''}input image size {(padded_height, padded_width)}." ) - if width == output_width and height == output_height: - return dict(top=0, left=0, height=height, width=width, input_width=input_width, input_height=input_height) + # We need a different order here than we have in self.padding since this padding will be parsed again in `F.pad` + padding = [pad_left, pad_top, pad_right, pad_bottom] + needs_pad = any(padding) - top = torch.randint(0, height - output_height + 1, size=(1,)).item() - left = torch.randint(0, width - output_width + 1, size=(1,)).item() + needs_vert_crop, top = ( + (True, int(torch.randint(0, padded_height - cropped_height + 1, size=()))) + if padded_height > cropped_height + else (False, 0) + ) + needs_horz_crop, left = ( + (True, int(torch.randint(0, padded_width - cropped_width + 1, size=()))) + if padded_width > cropped_width + else (False, 0) + ) return dict( + needs_crop=needs_vert_crop or needs_horz_crop, top=top, left=left, - height=output_height, - width=output_width, - input_width=input_width, - input_height=input_height, + height=cropped_height, + width=cropped_width, + needs_pad=needs_pad, + padding=padding, ) def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any: - # TODO: (PERF) check for speed optimization if we avoid repeated pad calls - fill = self.fill[type(inpt)] - fill = F._geometry._convert_fill_arg(fill) + if params["needs_pad"]: + fill = self.fill[type(inpt)] + fill = F._geometry._convert_fill_arg(fill) - if self.padding is not None: - # This cast does Sequence[int] -> List[int] and is required to make mypy happy - padding = self.padding - if not isinstance(padding, int): - padding = list(padding) + inpt = F.pad(inpt, padding=params["padding"], fill=fill, padding_mode=self.padding_mode) - inpt = F.pad(inpt, padding=padding, fill=fill, padding_mode=self.padding_mode) + if params["needs_crop"]: + inpt = F.crop(inpt, top=params["top"], left=params["left"], height=params["height"], width=params["width"]) - if self.pad_if_needed: - input_width, input_height = params["input_width"], params["input_height"] - if input_width < self.size[1]: - padding = [self.size[1] - input_width, 0] - inpt = F.pad(inpt, padding=padding, fill=fill, padding_mode=self.padding_mode) - if input_height < self.size[0]: - padding = [0, self.size[0] - input_height] - inpt = F.pad(inpt, padding=padding, fill=fill, padding_mode=self.padding_mode) - - return F.crop(inpt, top=params["top"], left=params["left"], height=params["height"], width=params["width"]) + return inpt class RandomPerspective(_RandomApplyTransform):