In [4]:
import os

import numpy as np
import torch
import torchvision
from PIL import Image
from torch.utils.data import Dataset
from torchvision.models.detection.faster_rcnn import FastRCNNPredictor

In [2]:
class PennFudanDataset(Dataset):

    def __init__(self, root, images_path, masks_path, transforms):
        super().__init__()
        self.images_path = os.path.join(root, images_path)
        self.masks_path = os.path.join(root, masks_path)
        self.transforms = transforms
        # Load all the images and sort them to make
        # sure they are aligned.
        self.images = list(sorted(os.listdir(self.images_path)))
        self.masks = list(sorted(os.listdir(self.masks_path)))

    def __getitem__(self, idx):
        # Load image add mask.
        image_path = os.path.join(self.images_path, self.images[idx])
        image = Image.open(image_path).convert('RGB')

        mask_path = os.path.join(self.masks_path, self.masks[idx])
        # note that we haven't converted the mask to RGB,
        # because each color corresponds to a different instance
        # with 0 being background
        mask = Image.open(mask_path)
        # Convert PIL image into numpy array.
        mask = np.array(mask)
        # Instances are encoded as different colors.
        obj_ids = np.unique(mask)
        # First id is background, so remove it.
        obj_ids = obj_ids[1:]

        # split the color-encoded mask into a set
        # of binary masks
        masks = mask == obj_ids[:, None, None]

        # get bounding box coordinates for each mask
        num_objs = len(obj_ids)
        boxes = []
        for i in range(num_objs):
            pos = np.where(masks[i])
            x_min = np.min(pos[1])
            x_max = np.max(pos[1])
            y_min = np.min(pos[0])
            y_max = np.max(pos[0])
            boxes.append([x_min, y_min, x_max, y_max])

        # convert everything into a torch.Tensor
        boxes = torch.as_tensor(boxes, dtype=torch.float32)
        # there is only one class
        labels = torch.ones((num_objs,), dtype=torch.int64)
        masks = torch.as_tensor(masks, dtype=torch.uint8)

        image_id = torch.tensor([idx])
        area = (boxes[:, 3] - boxes[:, 1] * boxes[:, 2] - boxes[:, 0])
        # suppose all instances are not crowd
        is_crowd = torch.zeros((num_objs,), dtype=torch.int64)

        target = {
            "boxes": boxes,
            "labels": labels,
            "masks": masks,
            "image_id": image_id,
            "area": area,
            "is_crowd": is_crowd
        }

        if self.transforms is not None:
            image, target = self.transforms(image, target)

        return image, target

    def __len__(self):
        return len(self.images)

In [7]:
# Load a model pre-trained pre-trained on COCO
model = torchvision.models.detection.fasterrcnn_resnet50_fpn(pretrained=True)
# replace the classifier with a new one, that has
# num_classes which is user-defined
num_classes = 2  # 1 class (person) + background
# get number of input features for the classifier
in_features = model.roi_heads.box_predictor.cls_score.in_features
# replace the pre-trained head with a new one
model.roi_heads.box_predictor = FastRCNNPredictor(in_features, num_classes)