# Mask R-CNN for Bin Picking

This notebook is adopted from the [TorchVision 0.3 Object Detection finetuning tutorial](https://pytorch.org/tutorials/intermediate/torchvision_tutorial.html).  We will be finetuning a pre-trained [Mask R-CNN](https://arxiv.org/abs/1703.06870) model on a dataset generated from our "clutter generator" script.


In [None]:
# Imports
import fnmatch
import json
import multiprocessing
import os

import matplotlib.pyplot as plt
import numpy as np
import torch
import torch.utils.data
from IPython.display import display
from PIL import Image

ycb = [
    "003_cracker_box.sdf",
    "004_sugar_box.sdf",
    "005_tomato_soup_can.sdf",
    "006_mustard_bottle.sdf",
    "009_gelatin_box.sdf",
    "010_potted_meat_can.sdf",
]

# Download our bin-picking model

And a small set of images for testing.

In [None]:
dataset_path = "clutter_maskrcnn_data"
if not os.path.exists(dataset_path):
    !wget https://groups.csail.mit.edu/locomotion/clutter_maskrcnn_test.zip .
    !unzip -q clutter_maskrcnn_test.zip

num_images = len(fnmatch.filter(os.listdir(dataset_path), "*.png"))


def open_image(idx):
    filename = os.path.join(dataset_path, f"{idx:05d}.png")
    return Image.open(filename).convert("RGB")


model_file = "clutter_maskrcnn_model.pt"
if not os.path.exists(model_file):
    !wget https://groups.csail.mit.edu/locomotion/clutter_maskrcnn_model.pt .

# Load the model

In [None]:
import torchvision
import torchvision.transforms.functional as Tf
from torchvision.models.detection import MaskRCNN_ResNet50_FPN_Weights
from torchvision.models.detection.faster_rcnn import FastRCNNPredictor
from torchvision.models.detection.mask_rcnn import MaskRCNNPredictor


def get_instance_segmentation_model(num_classes):
    # load an instance segmentation model pre-trained on COCO
    model = torchvision.models.detection.maskrcnn_resnet50_fpn(
        weights=MaskRCNN_ResNet50_FPN_Weights.DEFAULT
    )

    # get the number of input features for the classifier
    in_features = model.roi_heads.box_predictor.cls_score.in_features
    # replace the pre-trained head with a new one
    model.roi_heads.box_predictor = FastRCNNPredictor(in_features, num_classes)

    # now get the number of input features for the mask classifier
    in_features_mask = model.roi_heads.mask_predictor.conv5_mask.in_channels
    hidden_layer = 256
    # and replace the mask predictor with a new one
    model.roi_heads.mask_predictor = MaskRCNNPredictor(
        in_features_mask, hidden_layer, num_classes
    )

    return model


num_classes = len(ycb) + 1
model = get_instance_segmentation_model(num_classes)
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
model.load_state_dict(torch.load("clutter_maskrcnn_model.pt", map_location=device))
model.eval()

model.to(device)

# Evaluate the network

In [None]:
# pick one image from the test set (choose between 9950 and 9999)
img = open_image(9952)

with torch.no_grad():
    prediction = model([Tf.to_tensor(img).to(device)])

Printing the prediction shows that we have a list of dictionaries. Each element
of the list corresponds to a different image; since we have a single image,
there is a single dictionary in the list. The dictionary contains the
predictions for the image we passed. In this case, we can see that it contains
`boxes`, `labels`, `masks` and `scores` as fields.

In [None]:
prediction

Let's inspect the image and the predicted segmentation masks.

For that, we need to convert the image, which has been rescaled to 0-1 and had the channels flipped so that we have it in `[C, H, W]` format.

In [None]:
img

And let's now visualize the top predicted segmentation mask. The masks are predicted as `[N, 1, H, W]`, where `N` is the number of predictions, and are probability maps between 0-1.

In [None]:
N = prediction[0]["masks"].shape[0]
fig, ax = plt.subplots(N, 1, figsize=(15, 15))
for n in range(prediction[0]["masks"].shape[0]):
    ax[n].imshow(
        np.asarray(
            Image.fromarray(prediction[0]["masks"][n, 0].mul(255).byte().cpu().numpy())
        )
    )

# Plot the object detections

In [None]:
import random

import matplotlib.patches as patches


def plot_prediction():
    img_np = np.array(img)
    fig, ax = plt.subplots(1, figsize=(12, 9))
    ax.imshow(img_np)

    cmap = plt.get_cmap("tab20b")
    colors = [cmap(i) for i in np.linspace(0, 1, 20)]

    num_instances = prediction[0]["boxes"].shape[0]
    bbox_colors = random.sample(colors, num_instances)
    boxes = prediction[0]["boxes"].cpu().numpy()
    labels = prediction[0]["labels"].cpu().numpy()

    for i in range(num_instances):
        color = bbox_colors[i]
        bb = boxes[i, :]
        bbox = patches.Rectangle(
            (bb[0], bb[1]),
            bb[2] - bb[0],
            bb[3] - bb[1],
            linewidth=2,
            edgecolor=color,
            facecolor="none",
        )
        ax.add_patch(bbox)
        plt.text(
            bb[0],
            bb[0],
            s=ycb[labels[i]],
            color="white",
            verticalalignment="top",
            bbox={"color": color, "pad": 0},
        )
    plt.axis("off")


plot_prediction()

# Visualize the region proposals 

Let's visualize some of the intermediate results of the networks.

TODO: would be very cool to put a slider on this so that we could slide through ALL of the boxes.  But my matplotlib non-interactive backend makes it too tricky!

In [None]:
class Inspector:
    """A helper class from Kuni to be used for torch.nn.Module.register_forward_hook."""

    def __init__(self):
        self.x = None

    def hook(self, module, input, output):
        self.x = output


inspector = Inspector()
model.rpn.register_forward_hook(inspector.hook)

with torch.no_grad():
    prediction = model([Tf.to_tensor(img).to(device)])

rpn_values = inspector.x


img_np = np.array(img)
plt.figure()
fig, ax = plt.subplots(1, figsize=(12, 9))
ax.imshow(img_np)

cmap = plt.get_cmap("tab20b")
colors = [cmap(i) for i in np.linspace(0, 1, 20)]

num_to_draw = 20
bbox_colors = random.sample(colors, num_to_draw)
boxes = rpn_values[0][0].cpu().numpy()
print(f"Region proposals (drawing first {num_to_draw} out of {boxes.shape[0]})")

for i in range(num_to_draw):
    color = bbox_colors[i]
    bb = boxes[i, :]
    bbox = patches.Rectangle(
        (bb[0], bb[1]),
        bb[2] - bb[0],
        bb[3] - bb[1],
        linewidth=2,
        edgecolor=color,
        facecolor="none",
    )
    ax.add_patch(bbox)
plt.axis("off");

# Try a few more images

In [None]:
# pick one image from the test set (choose between 9950 and 9999)
img = open_image(9985)

with torch.no_grad():
    prediction = model([Tf.to_tensor(img).to(device)])

plot_prediction()