In [1]:
import os
import numpy as np
import torch
from PIL import Image


class PennFudanDataset(object):
    def __init__(self, root, transforms):
        self.root = root
        self.transforms = transforms
        # load all image files, sorting them to
        # ensure that they are aligned
        self.imgs = list(sorted(os.listdir(os.path.join(root, "PNGImages"))))
        self.masks = list(sorted(os.listdir(os.path.join(root, "PedMasks"))))

    def __getitem__(self, idx):
        # load images ad masks
        img_path = os.path.join(self.root, "PNGImages", self.imgs[idx])
        mask_path = os.path.join(self.root, "PedMasks", self.masks[idx])
        img = Image.open(img_path).convert("RGB")
        # note that we haven't converted the mask to RGB,
        # because each color corresponds to a different instance
        # with 0 being background
        mask = Image.open(mask_path)
        # convert the PIL Image into a numpy array
        mask = np.array(mask)
        # instances are encoded as different colors
        obj_ids = np.unique(mask)
        # first id is the background, so remove it
        obj_ids = obj_ids[1:]

        # split the color-encoded mask into a set
        # of binary masks
        masks = mask == obj_ids[:, None, None]

        # get bounding box coordinates for each mask
        num_objs = len(obj_ids)
        boxes = []
        for i in range(num_objs):
            pos = np.where(masks[i])
            xmin = np.min(pos[1])
            xmax = np.max(pos[1])
            ymin = np.min(pos[0])
            ymax = np.max(pos[0])
            boxes.append([xmin, ymin, xmax, ymax])

        # convert everything into a torch.Tensor
        boxes = torch.as_tensor(boxes, dtype=torch.float32)
        # there is only one class
        labels = torch.ones((num_objs,), dtype=torch.int64)
        masks = torch.as_tensor(masks, dtype=torch.uint8)

        image_id = torch.tensor([idx])
        area = (boxes[:, 3] - boxes[:, 1]) * (boxes[:, 2] - boxes[:, 0])
        # suppose all instances are not crowd
        iscrowd = torch.zeros((num_objs,), dtype=torch.int64)

        target = {}
        target["boxes"] = boxes
        target["labels"] = labels
        target["masks"] = masks
        target["image_id"] = image_id
        target["area"] = area
        target["iscrowd"] = iscrowd

        if self.transforms is not None:
            img, target = self.transforms(img, target)

        return img, target

    def __len__(self):
        return len(self.imgs)

In [2]:
import torchvision
from torchvision.models.detection.faster_rcnn import FastRCNNPredictor

# load a model pre-trained pre-trained on COCO
model = torchvision.models.detection.fasterrcnn_resnet50_fpn(pretrained=True)

# replace the classifier with a new one, that has
# num_classes which is user-defined
num_classes = 2  # 1 class (person) + background
# get number of input features for the classifier
in_features = model.roi_heads.box_predictor.cls_score.in_features
# replace the pre-trained head with a new one
model.roi_heads.box_predictor = FastRCNNPredictor(in_features, num_classes)

In [3]:
import torchvision
from torchvision.models.detection import FasterRCNN
from torchvision.models.detection.rpn import AnchorGenerator

# load a pre-trained model for classification and return
# only the features
backbone = torchvision.models.mobilenet_v2(pretrained=True).features
# FasterRCNN needs to know the number of
# output channels in a backbone. For mobilenet_v2, it's 1280
# so we need to add it here
backbone.out_channels = 1280

# let's make the RPN generate 5 x 3 anchors per spatial
# location, with 5 different sizes and 3 different aspect
# ratios. We have a Tuple[Tuple[int]] because each feature
# map could potentially have different sizes and
# aspect ratios
anchor_generator = AnchorGenerator(sizes=((32, 64, 128, 256, 512),),
                                   aspect_ratios=((0.5, 1.0, 2.0),))

# let's define what are the feature maps that we will
# use to perform the region of interest cropping, as well as
# the size of the crop after rescaling.
# if your backbone returns a Tensor, featmap_names is expected to
# be [0]. More generally, the backbone should return an
# OrderedDict[Tensor], and in featmap_names you can choose which
# feature maps to use.
roi_pooler = torchvision.ops.MultiScaleRoIAlign(featmap_names=[0],
                                                output_size=7,
                                                sampling_ratio=2)

# put the pieces together inside a FasterRCNN model
model = FasterRCNN(backbone,
                   num_classes=2,
                   rpn_anchor_generator=anchor_generator,
                   box_roi_pool=roi_pooler)

In [4]:
import torchvision
from torchvision.models.detection.faster_rcnn import FastRCNNPredictor
from torchvision.models.detection.mask_rcnn import MaskRCNNPredictor


def get_model_instance_segmentation(num_classes):
    # load an instance segmentation model pre-trained pre-trained on COCO
    model = torchvision.models.detection.maskrcnn_resnet50_fpn(pretrained=True)

    # get number of input features for the classifier
    in_features = model.roi_heads.box_predictor.cls_score.in_features
    # replace the pre-trained head with a new one
    model.roi_heads.box_predictor = FastRCNNPredictor(in_features, num_classes)

    # now get the number of input features for the mask classifier
    in_features_mask = model.roi_heads.mask_predictor.conv5_mask.in_channels
    hidden_layer = 256
    # and replace the mask predictor with a new one
    model.roi_heads.mask_predictor = MaskRCNNPredictor(in_features_mask,
                                                       hidden_layer,
                                                       num_classes)

    return model


In [5]:
import transforms as T

def get_transform(train):
    transforms = []
    transforms.append(T.ToTensor())
    if train:
        transforms.append(T.RandomHorizontalFlip(0.5))
    return T.Compose(transforms)

In [6]:
from engine import train_one_epoch, evaluate
import utils

dataset = PennFudanDataset('data/PennFudanPed', get_transform(train=True))
data_loader = torch.utils.data.DataLoader(
        dataset, batch_size=2, shuffle=True, num_workers=4,
        collate_fn=utils.collate_fn)


In [7]:
num_classes = 2
model = get_model_instance_segmentation(num_classes)


In [8]:
inputes, labels = next(iter(data_loader))
print(inputes, labels)

(tensor([[[0.2000, 0.2000, 0.1882,  ..., 0.3255, 0.2941, 0.2588],
         [0.1922, 0.1922, 0.1843,  ..., 0.5294, 0.5176, 0.5020],
         [0.1922, 0.1961, 0.1922,  ..., 0.3647, 0.3843, 0.4039],
         ...,
         [0.4627, 0.4667, 0.4706,  ..., 0.6039, 0.5961, 0.5961],
         [0.4588, 0.4706, 0.4784,  ..., 0.6078, 0.6000, 0.6000],
         [0.4588, 0.4784, 0.4863,  ..., 0.6078, 0.6039, 0.6039]],

        [[0.2745, 0.2745, 0.2745,  ..., 0.3333, 0.3020, 0.2667],
         [0.2667, 0.2667, 0.2706,  ..., 0.5373, 0.5255, 0.5098],
         [0.2706, 0.2745, 0.2784,  ..., 0.3725, 0.3922, 0.4118],
         ...,
         [0.4863, 0.4902, 0.4941,  ..., 0.6118, 0.6039, 0.6039],
         [0.4824, 0.4941, 0.5020,  ..., 0.6157, 0.6078, 0.6078],
         [0.4824, 0.5020, 0.5098,  ..., 0.6157, 0.6118, 0.6118]],

        [[0.3294, 0.3294, 0.3255,  ..., 0.3137, 0.2824, 0.2471],
         [0.3216, 0.3216, 0.3216,  ..., 0.5176, 0.5059, 0.4902],
         [0.3137, 0.3176, 0.3294,  ..., 0.3529, 0.3725, 0

In [9]:
model.eval()
outputs = model(inputes)
print(outputs)

[{'boxes': tensor([[275.9577, 316.6336, 318.5106, 383.9306],
        [271.9060, 323.2332, 351.9729, 379.2775],
        [281.3891, 372.3487, 316.6173, 398.9269],
        [ 41.5713, 286.3236,  56.1606, 349.3644],
        [ 44.2435, 287.9150,  67.2770, 302.6497],
        [ 39.0848, 289.4987,  76.4029, 317.0681],
        [ 13.5428, 337.5920,  40.9716, 347.9232],
        [  4.4524, 318.2136,  26.1295, 345.2112],
        [  6.8647, 290.4635,  74.1044, 327.2486],
        [ 41.7460, 281.9311,  74.6048, 305.9201],
        [ 25.0568, 283.6998,  46.8219, 312.0411],
        [ 25.9458, 290.0471,  33.2346, 306.3122],
        [ 27.6217, 291.9919,  49.2796, 333.6548],
        [ 43.4185, 285.3246,  65.9154, 294.5755],
        [ 27.3065, 291.2525,  40.2406, 307.3762],
        [ 37.6051, 284.9370,  73.6713, 358.9913],
        [  7.1955, 334.7949,  37.5208, 343.9569],
        [ 27.6335, 285.6236, 100.0955, 315.7202],
        [ 45.4745, 287.9875,  64.7737, 331.1579],
        [ 26.7403, 282.2476,  64.5446, 