Skip to content

Commit

Permalink
Implement crop option in RotationTransform
Browse files Browse the repository at this point in the history
The `crop` option in the `RotationTransform` and `RandomRotation`
transformations can be used to crop the rotated image to the largest
possible axis-aligned rectangle with a maximal area within the rotated
image.

It's disabled by default to ensure backwards compatibility.
  • Loading branch information
timofurrer committed Apr 21, 2020
1 parent bb9f5d8 commit 8b39590
Show file tree
Hide file tree
Showing 3 changed files with 86 additions and 6 deletions.
68 changes: 64 additions & 4 deletions detectron2/data/transforms/transform.py
Expand Up @@ -3,7 +3,7 @@
# File: transform.py

import numpy as np
from fvcore.transforms.transform import HFlipTransform, NoOpTransform, Transform
from fvcore.transforms.transform import CropTransform, HFlipTransform, NoOpTransform, Transform
from PIL import Image

try:
Expand Down Expand Up @@ -106,7 +106,7 @@ class RotationTransform(Transform):
number of degrees counter clockwise around its center.
"""

def __init__(self, h, w, angle, expand=True, center=None, interp=None):
def __init__(self, h, w, angle, expand=True, center=None, interp=None, crop=False):
"""
Args:
h, w (int): original image size
Expand All @@ -117,6 +117,9 @@ def __init__(self, h, w, angle, expand=True, center=None, interp=None):
if left to None, the center will be fit to the center of each image
center has no effect if expand=True because it only affects shifting
interp: cv2 interpolation method, default cv2.INTER_LINEAR
crop (bool): crop rotated image to the largest possible
axis-aligned rectangle with maximal area within the rotated image.
Enabling this will remove empty borders after the rotation.
"""
super().__init__()
image_center = np.array((w / 2, h / 2))
Expand All @@ -138,6 +141,11 @@ def __init__(self, h, w, angle, expand=True, center=None, interp=None):
# Needed because of this problem https://github.com/opencv/opencv/issues/11784
self.rm_image = self.create_rotation_matrix(offset=-0.5)

if self.crop:
self.crop_transform = self.create_crop_transform()
else:
self.crop_transform = NoOpTransform()

def apply_image(self, img, interp=None):
"""
img should be a numpy array, formatted as Height * Width * Nchannels
Expand All @@ -146,7 +154,10 @@ def apply_image(self, img, interp=None):
return img
assert img.shape[:2] == (self.h, self.w)
interp = interp if interp is not None else self.interp
return cv2.warpAffine(img, self.rm_image, (self.bound_w, self.bound_h), flags=interp)
rotated_image = cv2.warpAffine(
img, self.rm_image, (self.bound_w, self.bound_h), flags=interp
)
return self.crop_transform.apply_image(rotated_image)

def apply_coords(self, coords):
"""
Expand All @@ -155,7 +166,8 @@ def apply_coords(self, coords):
if len(coords) == 0:
return coords
coords = np.asarray(coords, dtype=float)
return cv2.transform(coords[:, np.newaxis, :], self.rm_coords)[:, 0, :]
rotated_coords = cv2.transform(coords[:, np.newaxis, :], self.rm_coords)[:, 0, :]
return self.crop_transform.apply_coords(rotated_coords)

def apply_segmentation(self, segmentation):
segmentation = self.apply_image(segmentation, interp=cv2.INTER_NEAREST)
Expand All @@ -173,6 +185,54 @@ def create_rotation_matrix(self, offset=0):
rm[:, 2] += new_center
return rm

def create_crop_transform(self):
"""
Create a CropTransform for the largest possible
axis-aligned rectangle (maximal area) within the rotated rectangle.
"""

def _find_max_area_rect(width, height, angle):
"""
Given a rectangle of size `width` x `height` that has been rotated by `angle` (in
radians), computes the width and height of the largest possible
axis-aligned rectangle (maximal area) within the rotated rectangle.
See: https://stackoverflow.com/a/16770343/1336014
"""
quadrant = int(np.floor(angle / (np.pi / 2))) & 3
sign_alpha = angle if ((quadrant & 1) == 0) else np.pi - angle
alpha = (sign_alpha % np.pi + np.pi) % np.pi

bb_w = width * np.cos(alpha) + height * np.sin(alpha)
bb_h = width * np.sin(alpha) + height * np.cos(alpha)

gamma = np.arctan2(bb_w, bb_w) if (width < height) else np.arctan2(bb_w, bb_w)

delta = np.pi - alpha - gamma

length = height if (width < height) else width

d = length * np.cos(alpha)
a = d * np.sin(alpha) / np.sin(delta)

y = a * np.cos(gamma)
x = y * np.tan(gamma)

wr, hr = bb_w - 2 * x, bb_h - 2 * y
return wr, hr

cropped_w, cropped_h = _find_max_area_rect(self.w, self.h, np.radians(self.angle))

# clip the coordinates to the image size
cropped_w = self.w if cropped_w > self.w else cropped_w
cropped_h = self.h if cropped_h > self.h else cropped_h

# create crop transformation from center
x0 = int(self.image_center[0] - cropped_w * 0.5)
y0 = int(self.image_center[1] - cropped_h * 0.5)
crop_transform = CropTransform(x0, y0, int(cropped_w), int(cropped_h))
return crop_transform


def HFlip_rotated_box(transform, rotated_boxes):
"""
Expand Down
11 changes: 9 additions & 2 deletions detectron2/data/transforms/transform_gen.py
Expand Up @@ -220,7 +220,9 @@ class RandomRotation(TransformGen):
number of degrees counter clockwise around the given center.
"""

def __init__(self, angle, expand=True, center=None, sample_style="range", interp=None):
def __init__(
self, angle, expand=True, center=None, sample_style="range", interp=None, crop=False
):
"""
Args:
angle (list[float]): If ``sample_style=="range"``,
Expand All @@ -234,6 +236,9 @@ def __init__(self, angle, expand=True, center=None, sample_style="range", interp
If ``sample_style=="choice"``, a list of centers to sample from
Default: None, which means that the center of rotation is the center of the image
center has no effect if expand=True because it only affects shifting
crop (bool): crop rotated image to the largest possible
axis-aligned rectangle with maximal area within the rotated image.
Enabling this will remove empty borders after the rotation.
"""
super().__init__()
assert sample_style in ["range", "choice"], sample_style
Expand Down Expand Up @@ -262,7 +267,9 @@ def get_transform(self, img):
if center is not None:
center = (w * center[0], h * center[1]) # Convert to absolute coordinates

return RotationTransform(h, w, angle, expand=self.expand, center=center, interp=self.interp)
return RotationTransform(
h, w, angle, expand=self.expand, center=center, interp=self.interp, crop=self.crop
)


class RandomCrop(TransformGen):
Expand Down
13 changes: 13 additions & 0 deletions tests/data/test_rotation_transform.py
Expand Up @@ -56,6 +56,19 @@ def test_center_expand(self):
self.assertEqualsArrays(r1.apply_image(image), r2.apply_image(image))
self.assertEqualsArrays(r1.apply_coords(coords), r2.apply_coords(coords))

def test_crop90(self):
image = np.array(
[[1, 1, 1, 1, 1, 1], [2, 2, 2, 2, 2, 2], [3, 3, 3, 3, 3, 3], [4, 4, 4, 4, 4, 4]],
dtype=np.float,
)
# there is no center pixel with 6 -> therefore 4 is completly cut
expected_cropped_rotation_image = np.array([[1, 2, 3], [1, 2, 3], [1, 2, 3], [1, 2, 3]])
rot = RotationTransform(
image.shape[0], image.shape[1], 90, expand=False, center=None, crop=True
)
rotated_image = rot.apply_image(image)
self.assertEqualsArrays(rotated_image, expected_cropped_rotation_image)


if __name__ == "__main__":
unittest.main()

0 comments on commit 8b39590

Please sign in to comment.