From 77e9a5741a29a462e836bb96884eb9f533cf394f Mon Sep 17 00:00:00 2001 From: Vasilis Vryniotis Date: Wed, 16 Feb 2022 19:10:42 +0000 Subject: [PATCH 1/3] Adding Scale Jitter in references. --- references/detection/transforms.py | 46 +++++++++++++++++++++++++++++- 1 file changed, 45 insertions(+), 1 deletion(-) diff --git a/references/detection/transforms.py b/references/detection/transforms.py index 16b9a8826a8..2fd925169cd 100644 --- a/references/detection/transforms.py +++ b/references/detection/transforms.py @@ -4,7 +4,7 @@ import torchvision from torch import nn, Tensor from torchvision.transforms import functional as F -from torchvision.transforms import transforms as T +from torchvision.transforms import transforms as T, InterpolationMode def _flip_coco_person_keypoints(kps, width): @@ -282,3 +282,47 @@ def forward( image = F.to_pil_image(image) return image, target + + +class ScaleJitter(nn.Module): + """Randomly resizes the image and its bounding boxes within a specified ratio range. + The class implements the Scale Jitter augmentation as described in the paper + `"Simple Copy-Paste is a Strong Data Augmentation Method for Instance Segmentation" `_. + + Args: + scale_range (tuple of ints): scaling factor interval, e.g (a, b), then scale is randomly sampled from the + range a <= scale <= b. + interpolation (InterpolationMode): Desired interpolation enum defined by + :class:`torchvision.transforms.InterpolationMode`. Default is ``InterpolationMode.BILINEAR``. + """ + + def __init__( + self, + scale_range: Tuple[float, float] = (0.1, 2.0), + interpolation: InterpolationMode = InterpolationMode.BILINEAR, + ): + super().__init__() + self.scale_range = scale_range + self.interpolation = interpolation + + def forward( + self, image: Tensor, target: Optional[Dict[str, Tensor]] = None + ) -> Tuple[Tensor, Optional[Dict[str, Tensor]]]: + if isinstance(image, torch.Tensor): + if image.ndimension() not in {2, 3}: + raise ValueError(f"image should be 2/3 dimensional. Got {image.ndimension()} dimensions.") + elif image.ndimension() == 2: + image = image.unsqueeze(0) + + old_width, old_height = F.get_image_size(image) + + r = self.scale_range[0] + torch.rand(1) * (self.scale_range[1] - self.scale_range[0]) + new_width = int(old_width * r) + new_height = int(old_height * r) + + image = F.resize(image, [new_height, new_width], interpolation=self.interpolation) + + if target is not None: + target["boxes"] *= r + + return image, target From 558339bfedd6ea85f25f61019dc5473763bff202 Mon Sep 17 00:00:00 2001 From: Vasilis Vryniotis Date: Wed, 16 Feb 2022 19:17:02 +0000 Subject: [PATCH 2/3] Update documentation. --- references/detection/transforms.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/references/detection/transforms.py b/references/detection/transforms.py index 2fd925169cd..ff2744da517 100644 --- a/references/detection/transforms.py +++ b/references/detection/transforms.py @@ -285,7 +285,7 @@ def forward( class ScaleJitter(nn.Module): - """Randomly resizes the image and its bounding boxes within a specified ratio range. + """Randomly resizes the image and its bounding boxes within the specified scale range. The class implements the Scale Jitter augmentation as described in the paper `"Simple Copy-Paste is a Strong Data Augmentation Method for Instance Segmentation" `_. From b6a16948d0301c4ddd28825b4f4c4b74e85d51b1 Mon Sep 17 00:00:00 2001 From: Vasilis Vryniotis Date: Fri, 18 Feb 2022 15:49:27 +0000 Subject: [PATCH 3/3] Address review comments. --- references/detection/transforms.py | 13 +++++++++---- 1 file changed, 9 insertions(+), 4 deletions(-) diff --git a/references/detection/transforms.py b/references/detection/transforms.py index ff2744da517..5924fa0c560 100644 --- a/references/detection/transforms.py +++ b/references/detection/transforms.py @@ -290,6 +290,7 @@ class ScaleJitter(nn.Module): `"Simple Copy-Paste is a Strong Data Augmentation Method for Instance Segmentation" `_. Args: + target_size (tuple of ints): The target size for the transform provided in (height, weight) format. scale_range (tuple of ints): scaling factor interval, e.g (a, b), then scale is randomly sampled from the range a <= scale <= b. interpolation (InterpolationMode): Desired interpolation enum defined by @@ -298,10 +299,12 @@ class ScaleJitter(nn.Module): def __init__( self, + target_size: Tuple[int, int], scale_range: Tuple[float, float] = (0.1, 2.0), interpolation: InterpolationMode = InterpolationMode.BILINEAR, ): super().__init__() + self.target_size = target_size self.scale_range = scale_range self.interpolation = interpolation @@ -314,15 +317,17 @@ def forward( elif image.ndimension() == 2: image = image.unsqueeze(0) - old_width, old_height = F.get_image_size(image) - r = self.scale_range[0] + torch.rand(1) * (self.scale_range[1] - self.scale_range[0]) - new_width = int(old_width * r) - new_height = int(old_height * r) + new_width = int(self.target_size[1] * r) + new_height = int(self.target_size[0] * r) image = F.resize(image, [new_height, new_width], interpolation=self.interpolation) if target is not None: target["boxes"] *= r + if "masks" in target: + target["masks"] = F.resize( + target["masks"], [new_height, new_width], interpolation=InterpolationMode.NEAREST + ) return image, target