Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
75 changes: 75 additions & 0 deletions references/detection/transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -326,3 +326,78 @@ def forward(
)

return image, target


class FixedSizeCrop(nn.Module):
def __init__(self, size, fill=0, padding_mode="constant"):
super().__init__()
size = tuple(T._setup_size(size, error_msg="Please provide only two dimensions (h, w) for size."))
self.crop_height = size[0]
self.crop_width = size[1]
self.fill = fill # TODO: Fill is currently respected only on PIL. Apply tensor patch.
self.padding_mode = padding_mode

def _pad(self, img, target, padding):
# Taken from the functional_tensor.py pad
if isinstance(padding, int):
pad_left = pad_right = pad_top = pad_bottom = padding
elif len(padding) == 1:
pad_left = pad_right = pad_top = pad_bottom = padding[0]
elif len(padding) == 2:
pad_left = pad_right = padding[0]
pad_top = pad_bottom = padding[1]
else:
pad_left = padding[0]
pad_top = padding[1]
pad_right = padding[2]
pad_bottom = padding[3]

padding = [pad_left, pad_top, pad_right, pad_bottom]
img = F.pad(img, padding, self.fill, self.padding_mode)
if target is not None:
target["boxes"][:, 0::2] += pad_left
target["boxes"][:, 1::2] += pad_top
if "masks" in target:
target["masks"] = F.pad(target["masks"], padding, 0, "constant")

return img, target

def _crop(self, img, target, top, left, height, width):
img = F.crop(img, top, left, height, width)
if target is not None:
boxes = target["boxes"]
boxes[:, 0::2] -= left
boxes[:, 1::2] -= top
boxes[:, 0::2].clamp_(min=0, max=width)
boxes[:, 1::2].clamp_(min=0, max=height)

is_valid = (boxes[:, 0] < boxes[:, 2]) & (boxes[:, 1] < boxes[:, 3])

target["boxes"] = boxes[is_valid]
target["labels"] = target["labels"][is_valid]
if "masks" in target:
target["masks"] = F.crop(target["masks"][is_valid], top, left, height, width)

return img, target

def forward(self, img, target=None):
_, height, width = F.get_dimensions(img)
new_height = min(height, self.crop_height)
new_width = min(width, self.crop_width)

if new_height != height or new_width != width:
offset_height = max(height - self.crop_height, 0)
offset_width = max(width - self.crop_width, 0)

r = torch.rand(1)
top = int(offset_height * r)
left = int(offset_width * r)

img, target = self._crop(img, target, top, left, new_height, new_width)

pad_bottom = max(self.crop_height - new_height, 0)
pad_right = max(self.crop_width - new_width, 0)
if pad_bottom != 0 or pad_right != 0:
img, target = self._pad(img, target, [0, 0, pad_right, pad_bottom])

return img, target