Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
51 changes: 50 additions & 1 deletion references/detection/transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import torchvision
from torch import nn, Tensor
from torchvision.transforms import functional as F
from torchvision.transforms import transforms as T
from torchvision.transforms import transforms as T, InterpolationMode


def _flip_coco_person_keypoints(kps, width):
Expand Down Expand Up @@ -282,3 +282,52 @@ def forward(
image = F.to_pil_image(image)

return image, target


class ScaleJitter(nn.Module):
"""Randomly resizes the image and its bounding boxes within the specified scale range.
The class implements the Scale Jitter augmentation as described in the paper
`"Simple Copy-Paste is a Strong Data Augmentation Method for Instance Segmentation" <https://arxiv.org/abs/2012.07177>`_.

Args:
target_size (tuple of ints): The target size for the transform provided in (height, weight) format.
scale_range (tuple of ints): scaling factor interval, e.g (a, b), then scale is randomly sampled from the
range a <= scale <= b.
interpolation (InterpolationMode): Desired interpolation enum defined by
:class:`torchvision.transforms.InterpolationMode`. Default is ``InterpolationMode.BILINEAR``.
"""

def __init__(
self,
target_size: Tuple[int, int],
scale_range: Tuple[float, float] = (0.1, 2.0),
interpolation: InterpolationMode = InterpolationMode.BILINEAR,
):
super().__init__()
self.target_size = target_size
self.scale_range = scale_range
self.interpolation = interpolation

def forward(
self, image: Tensor, target: Optional[Dict[str, Tensor]] = None
) -> Tuple[Tensor, Optional[Dict[str, Tensor]]]:
if isinstance(image, torch.Tensor):
if image.ndimension() not in {2, 3}:
raise ValueError(f"image should be 2/3 dimensional. Got {image.ndimension()} dimensions.")
elif image.ndimension() == 2:
image = image.unsqueeze(0)

r = self.scale_range[0] + torch.rand(1) * (self.scale_range[1] - self.scale_range[0])
new_width = int(self.target_size[1] * r)
new_height = int(self.target_size[0] * r)

image = F.resize(image, [new_height, new_width], interpolation=self.interpolation)

if target is not None:
target["boxes"] *= r
if "masks" in target:
target["masks"] = F.resize(
target["masks"], [new_height, new_width], interpolation=InterpolationMode.NEAREST
)

return image, target