Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Set masks to zero where masks overlap #8213

Merged
merged 9 commits into from
Jan 19, 2024
2 changes: 1 addition & 1 deletion CONTRIBUTING.md
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,7 @@ We don't officially support building from source using `pip`, but _if_ you do, y
#### Other development dependencies (some of these are needed to run tests):

```
pip install expecttest flake8 typing mypy pytest pytest-mock scipy
pip install expecttest flake8 typing mypy pytest pytest-mock scipy pillow
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That shouldn't be needed here as it's already a hard-dependency of torchvision, so it should get installed when you install torchvision from source.

Curious if you encounter any issue here?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

hmmm odd. I remember having to install it manually after building from source, as I got a package not installed error from the tests. I'm not familiar with the internals of python setup.py develop, but I can say that I've used pip install -e . instead before. There is some discussion on SO about dependencies not being handled correctly with the former, but I'm not sure how much weight to give that.

Happy to remove this though, I may have just done something odd / could be a personal configuration issue

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the info - let me just remove this so we can move forward with this PR and if we're getting more reports about this I'll investigate. Thank you!

NicolasHug marked this conversation as resolved.
Show resolved Hide resolved
```

## Development Process
Expand Down
20 changes: 10 additions & 10 deletions test/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -210,14 +210,11 @@ def test_draw_segmentation_masks(colors, alpha, device):
num_masks, h, w = 2, 100, 100
dtype = torch.uint8
img = torch.randint(0, 256, size=(3, h, w), dtype=dtype, device=device)
masks = torch.randint(0, 2, (num_masks, h, w), dtype=torch.bool, device=device)
masks = torch.zeros((num_masks, h, w), dtype=torch.bool, device=device)
masks[0, 10:20, 10:20] = True
masks[1, 15:25, 15:25] = True

# For testing we enforce that there's no overlap between the masks. The
# current behaviour is that the last mask's color will take priority when
# masks overlap, but this makes testing slightly harder, so we don't really
# care
overlap = masks[0] & masks[1]
masks[:, overlap] = False

out = utils.draw_segmentation_masks(img, masks, colors=colors, alpha=alpha)
assert out.dtype == dtype
Expand All @@ -239,12 +236,15 @@ def test_draw_segmentation_masks(colors, alpha, device):
color = torch.tensor(color, dtype=dtype, device=device)

if alpha == 1:
assert (out[:, mask] == color[:, None]).all()
assert (out[:, mask & ~overlap] == color[:, None]).all()
elif alpha == 0:
assert (out[:, mask] == img[:, mask]).all()
assert (out[:, mask & ~overlap] == img[:, mask & ~overlap]).all()

interpolated_color = (img[:, mask] * (1 - alpha) + color[:, None] * alpha).to(dtype)
torch.testing.assert_close(out[:, mask], interpolated_color, rtol=0.0, atol=1.0)
interpolated_color = (img[:, mask & ~overlap] * (1 - alpha) + color[:, None] * alpha).to(dtype)
torch.testing.assert_close(out[:, mask & ~overlap], interpolated_color, rtol=0.0, atol=1.0)

interpolated_overlap = (img[:, overlap] * (1 - alpha)).to(dtype)
torch.testing.assert_close(out[:, overlap], interpolated_overlap, rtol=0.0, atol=1.0)


def test_draw_segmentation_masks_dtypes():
Expand Down
3 changes: 3 additions & 0 deletions torchvision/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -299,6 +299,7 @@ def draw_segmentation_masks(
raise ValueError("The image and the masks must have the same height and width")

num_masks = masks.size()[0]
overlapping_masks = masks.sum(dim=0) > 1

if num_masks == 0:
warnings.warn("masks doesn't contain any mask. No mask was drawn")
Expand All @@ -315,6 +316,8 @@ def draw_segmentation_masks(
for mask, color in zip(masks, colors):
img_to_draw[:, mask] = color[:, None]

img_to_draw[:, overlapping_masks] = 0

out = image * (1 - alpha) + img_to_draw * alpha
# Note: at this point, out is a float tensor in [0, 1] or [0, 255] depending on original_dtype
return out.to(original_dtype)
Expand Down