diff --git a/torchvision/ops/boxes.py b/torchvision/ops/boxes.py index 309990ea03a..dab6c9c6b75 100644 --- a/torchvision/ops/boxes.py +++ b/torchvision/ops/boxes.py @@ -418,16 +418,15 @@ def masks_to_boxes(masks: torch.Tensor) -> torch.Tensor: if masks.numel() == 0: return torch.zeros((0, 4), device=masks.device, dtype=torch.float) - n = masks.shape[0] + non_zero_xs = torch.any(masks, axis=1).float() + non_zero_ys = torch.any(masks, axis=2).float() - bounding_boxes = torch.zeros((n, 4), device=masks.device, dtype=torch.float) + y1 = non_zero_ys.argmax(dim=1) + x1 = non_zero_xs.argmax(dim=1) - for index, mask in enumerate(masks): - y, x = torch.where(mask != 0) + y2 = (masks.shape[1] - 1) - non_zero_ys.flip(dims=[1]).argmax(dim=1) + x2 = (masks.shape[2] - 1) - non_zero_xs.flip(dims=[1]).argmax(dim=1) - bounding_boxes[index, 0] = torch.min(x) - bounding_boxes[index, 1] = torch.min(y) - bounding_boxes[index, 2] = torch.max(x) - bounding_boxes[index, 3] = torch.max(y) + bounding_boxes = torch.stack((x1, y1, x2, y2), dim=1).float() return bounding_boxes