diff --git a/references/detection/transforms.py b/references/detection/transforms.py index 16b9a8826a8..5924fa0c560 100644 --- a/references/detection/transforms.py +++ b/references/detection/transforms.py @@ -4,7 +4,7 @@ import torchvision from torch import nn, Tensor from torchvision.transforms import functional as F -from torchvision.transforms import transforms as T +from torchvision.transforms import transforms as T, InterpolationMode def _flip_coco_person_keypoints(kps, width): @@ -282,3 +282,52 @@ def forward( image = F.to_pil_image(image) return image, target + + +class ScaleJitter(nn.Module): + """Randomly resizes the image and its bounding boxes within the specified scale range. + The class implements the Scale Jitter augmentation as described in the paper + `"Simple Copy-Paste is a Strong Data Augmentation Method for Instance Segmentation" `_. + + Args: + target_size (tuple of ints): The target size for the transform provided in (height, weight) format. + scale_range (tuple of ints): scaling factor interval, e.g (a, b), then scale is randomly sampled from the + range a <= scale <= b. + interpolation (InterpolationMode): Desired interpolation enum defined by + :class:`torchvision.transforms.InterpolationMode`. Default is ``InterpolationMode.BILINEAR``. + """ + + def __init__( + self, + target_size: Tuple[int, int], + scale_range: Tuple[float, float] = (0.1, 2.0), + interpolation: InterpolationMode = InterpolationMode.BILINEAR, + ): + super().__init__() + self.target_size = target_size + self.scale_range = scale_range + self.interpolation = interpolation + + def forward( + self, image: Tensor, target: Optional[Dict[str, Tensor]] = None + ) -> Tuple[Tensor, Optional[Dict[str, Tensor]]]: + if isinstance(image, torch.Tensor): + if image.ndimension() not in {2, 3}: + raise ValueError(f"image should be 2/3 dimensional. Got {image.ndimension()} dimensions.") + elif image.ndimension() == 2: + image = image.unsqueeze(0) + + r = self.scale_range[0] + torch.rand(1) * (self.scale_range[1] - self.scale_range[0]) + new_width = int(self.target_size[1] * r) + new_height = int(self.target_size[0] * r) + + image = F.resize(image, [new_height, new_width], interpolation=self.interpolation) + + if target is not None: + target["boxes"] *= r + if "masks" in target: + target["masks"] = F.resize( + target["masks"], [new_height, new_width], interpolation=InterpolationMode.NEAREST + ) + + return image, target