Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 5 additions & 12 deletions test/test_transforms_v2.py
Original file line number Diff line number Diff line change
Expand Up @@ -631,7 +631,7 @@ def affine_rotated_bounding_boxes(bounding_boxes):
)


def reference_affine_keypoints_helper(keypoints, *, affine_matrix, new_canvas_size=None, clamp=True):
def reference_affine_keypoints_helper(keypoints, *, affine_matrix, new_canvas_size=None, cast=True):
canvas_size = new_canvas_size or keypoints.canvas_size

def affine_keypoints(keypoints):
Expand All @@ -650,10 +650,7 @@ def affine_keypoints(keypoints):
float(transformed_points[0, 1]),
]
)

if clamp:
output = F.clamp_keypoints(output, canvas_size=canvas_size)
else:
if not cast:
dtype = output.dtype

return output.to(dtype=dtype, device=device)
Expand Down Expand Up @@ -2293,10 +2290,10 @@ def _reference_rotate_keypoints(self, keypoints, *, angle, expand, center):
keypoints,
affine_matrix=affine_matrix,
new_canvas_size=new_canvas_size,
clamp=False,
cast=False,
)

return F.clamp_keypoints(self._recenter_keypoints_after_expand(output, recenter_xy=recenter_xy)).to(keypoints)
return self._recenter_keypoints_after_expand(output, recenter_xy=recenter_xy).to(keypoints)

@pytest.mark.parametrize("angle", _CORRECTNESS_AFFINE_KWARGS["angle"])
@pytest.mark.parametrize("expand", [False, True])
Expand Down Expand Up @@ -5360,11 +5357,7 @@ def perspective_keypoints(keypoints):
]
)

# It is important to clamp before casting, especially for CXCYWH format, dtype=int64
return F.clamp_keypoints(
output,
canvas_size=canvas_size,
).to(dtype=dtype, device=device)
return output.to(dtype=dtype, device=device)

return tv_tensors.KeyPoints(
torch.cat([perspective_keypoints(k) for k in keypoints.reshape(-1, 2).unbind()], dim=0).reshape(
Expand Down
16 changes: 8 additions & 8 deletions torchvision/transforms/v2/functional/_geometry.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@

from torchvision.utils import _log_api_usage_once

from ._meta import _get_size_image_pil, clamp_bounding_boxes, clamp_keypoints, convert_bounding_box_format
from ._meta import _get_size_image_pil, clamp_bounding_boxes, convert_bounding_box_format

from ._utils import _FillTypeJIT, _get_kernel, _register_five_ten_crop_kernel_internal, _register_kernel_internal

Expand Down Expand Up @@ -71,7 +71,7 @@ def horizontal_flip_keypoints(keypoints: torch.Tensor, canvas_size: tuple[int, i
shape = keypoints.shape
keypoints = keypoints.clone().reshape(-1, 2)
keypoints[..., 0] = keypoints[..., 0].sub_(canvas_size[1] - 1).neg_()
return clamp_keypoints(keypoints.reshape(shape), canvas_size=canvas_size)
return keypoints.reshape(shape)


@_register_kernel_internal(horizontal_flip, tv_tensors.KeyPoints, tv_tensor_wrapper=False)
Expand Down Expand Up @@ -159,7 +159,7 @@ def vertical_flip_keypoints(keypoints: torch.Tensor, canvas_size: tuple[int, int
shape = keypoints.shape
keypoints = keypoints.clone().reshape(-1, 2)
keypoints[..., 1] = keypoints[..., 1].sub_(canvas_size[0] - 1).neg_()
return clamp_keypoints(keypoints.reshape(shape), canvas_size=canvas_size)
return keypoints.reshape(shape)


def vertical_flip_bounding_boxes(
Expand Down Expand Up @@ -1026,7 +1026,7 @@ def _affine_keypoints_with_expand(
new_width, new_height = _compute_affine_output_size(affine_vector, width, height)
canvas_size = (new_height, new_width)

out_keypoints = clamp_keypoints(transformed_points, canvas_size=canvas_size).reshape(original_shape)
out_keypoints = transformed_points.reshape(original_shape)
out_keypoints = out_keypoints.to(original_dtype)

return out_keypoints, canvas_size
Expand Down Expand Up @@ -1695,7 +1695,7 @@ def pad_keypoints(
left, right, top, bottom = _parse_pad_padding(padding)
pad = torch.tensor([left, top], dtype=keypoints.dtype, device=keypoints.device)
canvas_size = (canvas_size[0] + top + bottom, canvas_size[1] + left + right)
return clamp_keypoints(keypoints + pad, canvas_size), canvas_size
return keypoints + pad, canvas_size


@_register_kernel_internal(pad, tv_tensors.KeyPoints, tv_tensor_wrapper=False)
Expand Down Expand Up @@ -1817,7 +1817,7 @@ def crop_keypoints(
keypoints = keypoints - torch.tensor([left, top], dtype=keypoints.dtype, device=keypoints.device)
canvas_size = (height, width)

return clamp_keypoints(keypoints, canvas_size=canvas_size), canvas_size
return keypoints, canvas_size


@_register_kernel_internal(crop, tv_tensors.KeyPoints, tv_tensor_wrapper=False)
Expand Down Expand Up @@ -2047,7 +2047,7 @@ def perspective_keypoints(
numer_points = torch.matmul(points, theta1.T)
denom_points = torch.matmul(points, theta2.T)
transformed_points = numer_points.div_(denom_points)
return clamp_keypoints(transformed_points.to(keypoints.dtype), canvas_size).reshape(original_shape)
return transformed_points.to(keypoints.dtype).reshape(original_shape)


@_register_kernel_internal(perspective, tv_tensors.KeyPoints, tv_tensor_wrapper=False)
Expand Down Expand Up @@ -2376,7 +2376,7 @@ def elastic_keypoints(
t_size = torch.tensor(canvas_size[::-1], device=displacement.device, dtype=displacement.dtype)
transformed_points = inv_grid[0, index_y, index_x, :].add_(1).mul_(0.5 * t_size).sub_(0.5)

return clamp_keypoints(transformed_points.to(keypoints.dtype), canvas_size=canvas_size).reshape(original_shape)
return transformed_points.to(keypoints.dtype).reshape(original_shape)


@_register_kernel_internal(elastic, tv_tensors.KeyPoints, tv_tensor_wrapper=False)
Expand Down