Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Set masks to zero where masks overlap #8213

Merged
merged 9 commits into from
Jan 19, 2024
20 changes: 10 additions & 10 deletions test/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -210,14 +210,11 @@ def test_draw_segmentation_masks(colors, alpha, device):
num_masks, h, w = 2, 100, 100
dtype = torch.uint8
img = torch.randint(0, 256, size=(3, h, w), dtype=dtype, device=device)
masks = torch.randint(0, 2, (num_masks, h, w), dtype=torch.bool, device=device)
masks = torch.zeros((num_masks, h, w), dtype=torch.bool, device=device)
masks[0, 10:20, 10:20] = True
masks[1, 15:25, 15:25] = True

# For testing we enforce that there's no overlap between the masks. The
# current behaviour is that the last mask's color will take priority when
# masks overlap, but this makes testing slightly harder, so we don't really
# care
overlap = masks[0] & masks[1]
masks[:, overlap] = False

out = utils.draw_segmentation_masks(img, masks, colors=colors, alpha=alpha)
assert out.dtype == dtype
Expand All @@ -239,12 +236,15 @@ def test_draw_segmentation_masks(colors, alpha, device):
color = torch.tensor(color, dtype=dtype, device=device)

if alpha == 1:
assert (out[:, mask] == color[:, None]).all()
assert (out[:, mask & ~overlap] == color[:, None]).all()
elif alpha == 0:
assert (out[:, mask] == img[:, mask]).all()
assert (out[:, mask & ~overlap] == img[:, mask & ~overlap]).all()

interpolated_color = (img[:, mask] * (1 - alpha) + color[:, None] * alpha).to(dtype)
torch.testing.assert_close(out[:, mask], interpolated_color, rtol=0.0, atol=1.0)
interpolated_color = (img[:, mask & ~overlap] * (1 - alpha) + color[:, None] * alpha).to(dtype)
torch.testing.assert_close(out[:, mask & ~overlap], interpolated_color, rtol=0.0, atol=1.0)

interpolated_overlap = (img[:, overlap] * (1 - alpha)).to(dtype)
torch.testing.assert_close(out[:, overlap], interpolated_overlap, rtol=0.0, atol=1.0)


def test_draw_segmentation_masks_dtypes():
Expand Down
3 changes: 3 additions & 0 deletions torchvision/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -299,6 +299,7 @@ def draw_segmentation_masks(
raise ValueError("The image and the masks must have the same height and width")

num_masks = masks.size()[0]
overlapping_masks = masks.sum(dim=0) > 1

if num_masks == 0:
warnings.warn("masks doesn't contain any mask. No mask was drawn")
Expand All @@ -315,6 +316,8 @@ def draw_segmentation_masks(
for mask, color in zip(masks, colors):
img_to_draw[:, mask] = color[:, None]

img_to_draw[:, overlapping_masks] = 0

out = image * (1 - alpha) + img_to_draw * alpha
# Note: at this point, out is a float tensor in [0, 1] or [0, 255] depending on original_dtype
return out.to(original_dtype)
Expand Down