Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Added utility to draw segmentation masks #3330

Merged
merged 38 commits into from
Mar 22, 2021
Merged

Conversation

oke-aditya
Copy link
Contributor

@oke-aditya oke-aditya commented Jan 30, 2021

Closes #3272.

Initial Implementation

  • Adds Code
  • Adds Tests
  • Adds Docs

@oke-aditya oke-aditya marked this pull request as ready for review February 2, 2021 17:53
torchvision/utils.py Outdated Show resolved Hide resolved
img_to_draw = torch.from_numpy(np.array(img_to_draw))

# Project the drawn image to orignal one
image[: 1] = img_to_draw
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I need small help here as this projects to a black background.

My guess is we need an alpha channel which will make masks transparent ?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I supposed there are many ways to do it. One that I had in mind originally (which might not be the optimal) is to convert the img_to_draw from palette to RGBA, replace the background colour with transparent and then combine it with image to achieve the "projection". Worth experimenting with the approach because it's likely there is a better way.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yep, DETR has very nice visualization, but they use matplotlib. Unsure how to reproduce them though.
As you pointed out before Mask RCNN utils have nice way to apply mask too.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I just sent it for reference not necessarily for reproduction. :)

Copy link
Contributor

@datumbox datumbox left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@oke-aditya It's going towards the right direction. I left a few comments to get your thoughts.

BTW @fmassa brought to my attention a nice guide they have to DETR with some visualization utils we might want to look into for inspiration.

torchvision/utils.py Outdated Show resolved Hide resolved
torchvision/utils.py Outdated Show resolved Hide resolved
torchvision/utils.py Outdated Show resolved Hide resolved
img_to_draw = torch.from_numpy(np.array(img_to_draw))

# Project the drawn image to orignal one
image[: 1] = img_to_draw
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I supposed there are many ways to do it. One that I had in mind originally (which might not be the optimal) is to convert the img_to_draw from palette to RGBA, replace the background colour with transparent and then combine it with image to achieve the "projection". Worth experimenting with the approach because it's likely there is a better way.

@oke-aditya
Copy link
Contributor Author

Hi @datumbox I just resolved the num_classes hardcoding by changing params to probabilities.

@oke-aditya
Copy link
Contributor Author

oke-aditya commented Feb 4, 2021

I'm bit unsure now how to proceed, one way I think is to create another util to apply weighted mask just like MaskRCNN matterpott benchmark

We can probably use this and apply the mask here.

Unsure of tests, it might be too lengthy and calculative to write 20 x 20 x 3 tensor with probabilities. I thought to just run FCN on torch.full and use its outputs.

Let me know how to proceed ! This part seems tricky

This does double check, but couples tests with models.

@datumbox
Copy link
Contributor

datumbox commented Feb 4, 2021

@oke-aditya Great changes, I think we are almost there!

As I commented above, there are multiple ways to do this. Here is a hacky, quick and dirty approach to get an idea. I'm sure you can do it in a much better way:

@torch.no_grad()
def draw_segmentation_masks(
    image: torch.Tensor,
    masks: torch.Tensor,
    colors: Optional[List[Union[str, Tuple[int, int, int]]]] = None,
) -> torch.Tensor:
    if not isinstance(image, torch.Tensor):
        raise TypeError(f"Tensor expected, got {type(image)}")
    elif image.dtype != torch.uint8:
        raise ValueError(f"Tensor uint8 expected, got {image.dtype}")
    elif image.dim() != 3:
        raise ValueError("Pass individual images, not batches")

    classifications = masks.argmax(0).byte()
    img_to_draw = Image.fromarray(classifications.cpu().numpy())

    if colors is None:
        num_classes = masks.size(0)
        palette = torch.tensor([2 ** 25 - 1, 2 ** 15 - 1, 2 ** 21 - 1, 0])
        colors_t = torch.as_tensor([i for i in range(num_classes)])[:, None] * palette
        color_arr = (colors_t % 255).numpy().astype("uint8")
        color_arr[1:, 3] = 255
    else:
        color_list = []
        for color in colors:
            if isinstance(color, str):
                fill_color = ImageColor.getrgb(color) 
                color_list.append(fill_color)
            elif isinstance(color, tuple):
                color_list.append(color)

        color_arr = np.array(color_list).astype("uint8")

    img_to_draw.putpalette(color_arr)

    img_to_draw = torch.from_numpy(np.array(img_to_draw.convert('RGBA')))
    img_to_draw = img_to_draw.permute((2, 0, 1))

    alpha = 0.6
    return (torch.cat([image, torch.full(image.shape[1:], 255).unsqueeze(0)]).float()*alpha+img_to_draw.float()*(1.0-alpha)).to(dtype=torch.uint8)

Ping me when you have a demo you are comfortable with to discuss the last details of the API. :)

@oke-aditya
Copy link
Contributor Author

Sorry for the delay @datumbox . Here are few ouputs for different values of alpha

Alpha = 0.2

draw_masks_util2

Alpha = 0.3

draw_masks_util3

Alpha = 0.6

draw_masks_util4

Alpha = 0.7

draw_masks_util5

Also I made alpha as a paramter as it is really useful, to remove or keep background.

Another thought I had was to make util apply_mask which can be used to project mask.
This might be useful in some other frequently use alpha blending, alpha masking cases.
So let me know !

def apply_mask(image, mask, alpha):

torch.cat([image, torch.full(image.shape[1:], 255).unsqueeze(0)]).float()
            * alpha + mask.float() * (1.0 - alpha)).to(dtype=torch.uint8)

Copy link
Member

@fmassa fmassa left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hi,

The PR looks very nice, thanks a lot!

I have one suggestion which I think would make the function more generic, and wouldn't involve too many changes. Let me know what you think

Comment on lines 244 to 245
num_classes = masks.size()[0]
masks = masks.argmax(0)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Instead of being specific for semantic segmentation and not supporting instance segmentation nor panoptic segmentation, I think we could make this slightly more generic while supporting all the use-cases I mentioned. The idea would be to accept a mask as a [num_masks, H, W] boolean Tensor.
This way, the user can get the semantic segmentation masks to pass to this function as follows

out.argmax(0) == torch.arange(out.shape[0])[:, None, None]

masks (Tensor): Tensor of shape (num_classes, H, W). Each containing probability of predicted class.
alpha (float): Float number between 0 and 1 denoting factor of transpaerency of masks.
colors (List[Union[str, Tuple[int, int, int]]]): List containing the colors of masks. The colors can
be represented as `str` or `Tuple[int, int, int]`.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If we add support for instance segmentation and panoptic segmentation, I think it would be a good idea to add an example from using the output of a semantic segmentation model and an instance segmentation model (for example from those from torchvision)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sure, this is a TODO. I will add GitHub gist and other minor documentation improvements for both the utilities.

Can I do this in a follow-up PR which will address all the issues as mentioned in #3364 ?

Copy link
Member

@fmassa fmassa Feb 10, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sure, adding documentation improvements on a follow-up PR is ok with me. What do you think about the other comment as well? Because it would be a breaking change in functionality if we support it, so better do it once (specially that the branch cut is happening very soon so if we merge it now it can get integrated in the release, in which case breaking backwards-compatibility is more annoying)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Definetely, I will refactor the other comment ASAP in this PR 😄 I understand how bad it would be with BC change.

@oke-aditya
Copy link
Contributor Author

oke-aditya commented Feb 10, 2021

Slight problem while implementing.
Seems that argmax is not supported for bool tensor

I am pushing my latest changes, can someone please have look ?

A simple code to reproduce bug

masks =    [
        [False, False, False, False, False],
        [True, True, True, True, True],
        [False, False, False, False, False],
        [False, False, False, False, False],
        [False, False, False, False, False]
    ],
    [
        [True, True, True, True, True],
        [False, False, False, False, False],
        [True, True, True, True, True],
        [True, True, True, True, True],
        [False, False, False, False, False]
    ],
    [
        [False, False, False, False, False],
        [False, False, False, False, False],
        [False, False, False, False, False],
        [False, False, False, False, False],
        [True, True, True, True, True],
    ]
], dtype=torch.bool)

masks = masks.argmax(0)
print(masks)

RuntimeError: "argmax_cpu" not implemented for 'Bool'

Maybe there is some workaround? I guess this was the only change, to support bool Tensor instead of float Tensor

The error would occur when we try to get [H, W] from [num_masks, H, W] tensor.

@fmassa
Copy link
Member

fmassa commented Feb 11, 2021

@oke-aditya maybe I'm missing something, but if we pass a bool tensor we don't need to compute the argmax anymore, because the independent masks have already been computed?

EDIT: oh I see, you don't perform a for loop over each one of the masks as of now. Computing the argmax could be done after casting the mask to float for example, but note that in instance segmentation each pixel can be covered by multiple masks, so the argmax wouldn't be enough to handle those. But it can be a first approximation

@oke-aditya
Copy link
Contributor Author

oke-aditya commented Feb 11, 2021

Computing the argmax could be done after casting the mask to float for example, but note that in instance segmentation each pixel can be covered by multiple masks, so the argmax wouldn't be enough to handle those. But it can be a first approximation

That's possible. But may I know why we decided to pass a Bool Tensor and not a Float mask Tensor ?
Sorry if the question is slightly dumb.

Edit: My idea was to make initial implementation compatible with single channel masks. That would make it compatible with Mask RCNN. I thought in the same previous function, we could handle single channel case differently.

@oke-aditya
Copy link
Contributor Author

oke-aditya commented Feb 16, 2021

I think with release cut coming quite soon, we could wait for and add in next release ? (I mean after 0.9.0)

@oke-aditya
Copy link
Contributor Author

Hey @fmassa and @datumbox. Any thoughts on how to proceed further with this. I'm willing to incorporate any changes requested 😃

Copy link
Member

@fmassa fmassa left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sorry for the delay in reviewing.

I've made a comment to unblock you, but I think in the future we might need to change the implementation to let it handle overlapping masks (which is the whole purpose of the function accepting a tensor with a per-object map).

The input mask doesn't need to be a boolean by the way, but it can be a floating point representing probabilities for the given instance.

But to unblock for now let's just make this small change I proposed, and we can improve this in a follow-up PR

raise ValueError("Pass an RGB image. Other Image formats are not supported")

num_masks = masks.size()[0]
masks = masks.argmax(0)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In order to unblock you can do masks.to(torch.int64).argmax(0) or cast it to float if you want.
This won't handle overlapping instances very well thought, and we will need to remove this in the future and probably replace it with a for loop so that overlapping masks are taken into account.

Plus, by using for loops and letting the mask be a floating point if the user wants, we can allow the user to have heatmaps being passed (instead of only binary maps), which would be very nice

@oke-aditya oke-aditya requested a review from fmassa March 19, 2021 16:49
@oke-aditya
Copy link
Contributor Author

Extremely sorry for the delay (my health let me down 😢)

I think that boolean tensor leads to some limitations and again we re-cast it to int/float.

I refactored to use floating-point masks. Each point represents the probability of class.
Floating-point masks make more sense as we can either take argmax() and plot the best masks.
Or we could take topk to plot top masks. Or as you said we can cover up the overlapping case.

Currently, this code accepts a floating-point tensor of (num_masks, H, W)

Let me know what we need to do in further PRs / this PR.

Copy link
Member

@fmassa fmassa left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks!

@fmassa fmassa merged commit 19ad0bb into pytorch:master Mar 22, 2021
@oke-aditya oke-aditya deleted the add_msks branch March 22, 2021 18:13
facebook-github-bot pushed a commit that referenced this pull request Apr 1, 2021
Summary:
* add draw segm masks

* rewrites with new api

* fix flaky colors

* fix resize bug

* resize for sanity

* cleanup

* project the image

* Minor refactor to adopt num classes

* add uint8 in docstring

* adds alpha and docstring

* move code a bit down

* Minor fix

* fix type check

* Fixing resize bug.

* Fix type of alpha.

* Remove unnecessary RGBA conversions.

* update docs to supported only rgb

* minor edits

* adds tests

* shifts masks up

* change tests and impelementation for bool

* change mode to L

* convert to float

* fixes docs

Reviewed By: fmassa

Differential Revision: D27433933

fbshipit-source-id: 26e72b4f8471218631b26cc555422890b0f6b81d

Co-authored-by: Vasilis Vryniotis <datumbox@users.noreply.github.com>
Co-authored-by: Vasilis Vryniotis <vvryniotis@fb.com>
Co-authored-by: Francisco Massa <fvsmassa@gmail.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Utility to draw Semantic Segmentation Masks
5 participants