From 8b0daac2bdd9ea2abb1631b83c40090a69755db5 Mon Sep 17 00:00:00 2001 From: Atharva Sehgal Date: Tue, 2 Jan 2024 16:39:41 -0800 Subject: [PATCH 1/2] Fixes #8184 --- torchvision/ops/boxes.py | 15 +++++++-------- 1 file changed, 7 insertions(+), 8 deletions(-) diff --git a/torchvision/ops/boxes.py b/torchvision/ops/boxes.py index a541f8d880a..11b2a06ba41 100644 --- a/torchvision/ops/boxes.py +++ b/torchvision/ops/boxes.py @@ -402,16 +402,15 @@ def masks_to_boxes(masks: torch.Tensor) -> torch.Tensor: if masks.numel() == 0: return torch.zeros((0, 4), device=masks.device, dtype=torch.float) - n = masks.shape[0] + non_zero_ys = torch.any(masks, axis=1).float() + non_zero_xs = torch.any(masks, axis=2).float() - bounding_boxes = torch.zeros((n, 4), device=masks.device, dtype=torch.float) + y1 = non_zero_ys.argmax(dim=1) + x1 = non_zero_xs.argmax(dim=1) - for index, mask in enumerate(masks): - y, x = torch.where(mask != 0) + y2 = (masks.shape[1] - 1) - non_zero_ys.flip(dims=[1]).argmax(dim=1) + x2 = (masks.shape[2] - 1) - non_zero_xs.flip(dims=[1]).argmax(dim=1) - bounding_boxes[index, 0] = torch.min(x) - bounding_boxes[index, 1] = torch.min(y) - bounding_boxes[index, 2] = torch.max(x) - bounding_boxes[index, 3] = torch.max(y) + bounding_boxes = torch.stack((x1, y1, x2, y2), dim=1).float() return bounding_boxes From a4e4278dd7df48f8a098e43dccaaed167b59ed0b Mon Sep 17 00:00:00 2001 From: Atharva Sehgal Date: Mon, 15 Jan 2024 16:52:45 +0000 Subject: [PATCH 2/2] Update to #8194 The X and Y dimensions were flipped. --- torchvision/ops/boxes.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/torchvision/ops/boxes.py b/torchvision/ops/boxes.py index 11b2a06ba41..132bce925cf 100644 --- a/torchvision/ops/boxes.py +++ b/torchvision/ops/boxes.py @@ -402,8 +402,8 @@ def masks_to_boxes(masks: torch.Tensor) -> torch.Tensor: if masks.numel() == 0: return torch.zeros((0, 4), device=masks.device, dtype=torch.float) - non_zero_ys = torch.any(masks, axis=1).float() - non_zero_xs = torch.any(masks, axis=2).float() + non_zero_xs = torch.any(masks, axis=1).float() + non_zero_ys = torch.any(masks, axis=2).float() y1 = non_zero_ys.argmax(dim=1) x1 = non_zero_xs.argmax(dim=1)