In [2]:
import os
import torch
import torch.utils.data
import torchvision
from PIL import Image, ImageDraw
from pycocotools.coco import COCO
import numpy as np

class CocoDataset(torch.utils.data.Dataset):
    def __init__(self, root, annotation, transforms=None):
        self.root = root
        self.transforms = transforms
        self.coco = COCO(annotation)
        self.ids = list(sorted(self.coco.imgs.keys()))

    def __getitem__(self, index):
        coco = self.coco
        # image id
        img_id = self.ids[index]
        # fetch all annotations belonging to an image id
        ann_ids = coco.getAnnIds(imgIds=img_id)
        # coco annotations
        annotations = coco.loadAnns(ann_ids)
        # path of input image
        path = coco.loadImgs(img_id)[0]['file_name']
        # open the input image
        img = Image.open(os.path.join(self.root, path))
        img_width, img_height = img.size

        # number of objects in the image
        num_objs = len(annotations)
        instance_masks = []
        for annotation in annotations:
            # this is our canvas
            mask = Image.new('1', (img_width, img_height))
            mask_draw = ImageDraw.Draw(mask, '1')
            for segmentation in annotation['segmentation']:
                mask_draw.polygon(segmentation, fill=1)
                bool_array = np.array(mask) > 0
                instance_masks.append(bool_array)
        # a set of binary masks (one per object)
        masks = torch.as_tensor(instance_masks, dtype=torch.uint8)
        
        # bounding boxes
        boxes = []
        for i in range(num_objs):
            xmin = annotations[i]['bbox'][0]
            ymin = annotations[i]['bbox'][1]
            xmax = xmin + annotations[i]['bbox'][2]
            ymax = ymin + annotations[i]['bbox'][3]
            boxes.append([xmin, ymin, xmax, ymax])
        boxes = torch.as_tensor(boxes, dtype=torch.float32)
        labels = torch.ones((num_objs,), dtype=torch.int64)
        img_id = torch.tensor([img_id])
        areas = []
        for i in range(num_objs):
            areas.append(annotations[i]['area'])
        areas = torch.as_tensor(areas, dtype=torch.float32)
        iscrowd = torch.zeros((num_objs,), dtype=torch.int64)
        
        
        # target in dictionary format
        target = {}
        target["boxes"] = boxes
        target["labels"] = labels
        target["masks"] = masks
        target["image_id"] = img_id
        target["area"] = areas
        target["iscrowd"] = iscrowd

        if self.transforms is not None:
            img, target = self.transforms(img, target)

        return img, target

    def __len__(self):
        return len(self.ids)
    

In [None]:
# Download TorchVision repo to use some files from
# references/detection
#git clone https://github.com/pytorch/vision.git
#cd vision
#git checkout v0.3.0

#cp references/detection/utils.py ../
#cp references/detection/transforms.py ../
#cp references/detection/coco_eval.py ../
#cp references/detection/engine.py ../
#cp references/detection/coco_utils.py ../

#rm vision

In [3]:
from engine import train_one_epoch, evaluate
import utils
import transforms as T

# In my case, just added ToTensor
def get_transform():
    transforms = []
    # converts the image, a PIL image, into a PyTorch Tensor
    transforms.append(T.ToTensor())
    return T.Compose(transforms)

### Load data 

In [4]:
train_data_dir = "<your custom images go here>"
train_coco = "<your custom coco annotations go here>"

In [5]:
# create own Dataset
full_dataset = CocoDataset(root=train_data_dir,
                          annotation=train_coco,
                          transforms=get_transform())

train_size = int(0.8 * len(full_dataset))
test_size = len(full_dataset) - train_size
dataset, dataset_test = torch.utils.data.random_split(full_dataset, [train_size, test_size])


# define training and validation data loaders
data_loader = torch.utils.data.DataLoader(
    dataset, batch_size=2, shuffle=True, num_workers=4,
    collate_fn=utils.collate_fn)

data_loader_test = torch.utils.data.DataLoader(
    dataset_test, batch_size=1, shuffle=False, num_workers=4,
    collate_fn=utils.collate_fn)

loading annotations into memory...
Done (t=0.00s)
creating index...
index created!


### Fine Tuning

In [6]:
import torchvision
from torchvision.models.detection.faster_rcnn import FastRCNNPredictor
from torchvision.models.detection.mask_rcnn import MaskRCNNPredictor

      
def get_instance_segmentation_model(num_classes):
    # load an instance segmentation model pre-trained on COCO
    model = torchvision.models.detection.maskrcnn_resnet50_fpn(pretrained=True)

    # get the number of input features for the classifier
    in_features = model.roi_heads.box_predictor.cls_score.in_features
    # replace the pre-trained head with a new one
    model.roi_heads.box_predictor = FastRCNNPredictor(in_features, num_classes)

    # now get the number of input features for the mask classifier
    in_features_mask = model.roi_heads.mask_predictor.conv5_mask.in_channels
    hidden_layer = 256
    # and replace the mask predictor with a new one
    model.roi_heads.mask_predictor = MaskRCNNPredictor(in_features_mask,
                                                       hidden_layer,
                                                       num_classes)

    return model

In [7]:
# Now let's instantiate the model and the optimizer
device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')


# our dataset has two classes only - background and person
num_classes = 2

# get the model using our helper function
model = get_instance_segmentation_model(num_classes)
# move model to the right device
model.to(device)

# construct an optimizer
params = [p for p in model.parameters() if p.requires_grad]
optimizer = torch.optim.SGD(params, lr=0.005,
                            momentum=0.9, weight_decay=0.0005)

# and a learning rate scheduler which decreases the learning rate by
# 10x every 3 epochs
lr_scheduler = torch.optim.lr_scheduler.StepLR(optimizer,
                                               step_size=3,
                                               gamma=0.1)

In [8]:
# let's train it for 10 epochs
num_epochs = 10

for epoch in range(num_epochs):
    # train for one epoch, printing every 10 iterations
    train_one_epoch(model, optimizer, data_loader, device, epoch, print_freq=10)
    # update the learning rate
    lr_scheduler.step()
    # evaluate on the test dataset
    evaluate(model, data_loader_test, device=device)

Epoch: [0]  [0/6]  eta: 0:00:05  lr: 0.001004  loss: 3.4752 (3.4752)  loss_classifier: 0.4713 (0.4713)  loss_box_reg: 0.0909 (0.0909)  loss_mask: 2.9056 (2.9056)  loss_objectness: 0.0051 (0.0051)  loss_rpn_box_reg: 0.0023 (0.0023)  time: 0.9062  data: 0.2868  max mem: 2138
Epoch: [0]  [5/6]  eta: 0:00:00  lr: 0.005000  loss: 1.1839 (2.0219)  loss_classifier: 0.1406 (0.2677)  loss_box_reg: 0.0868 (0.0901)  loss_mask: 0.9293 (1.6493)  loss_objectness: 0.0079 (0.0098)  loss_rpn_box_reg: 0.0030 (0.0051)  time: 0.5695  data: 0.0499  max mem: 2413
Epoch: [0] Total time: 0:00:03 (0.5741 s / it)
creating index...
index created!
Test:  [0/4]  eta: 0:00:01  model_time: 0.1123 (0.1123)  evaluator_time: 0.0043 (0.0043)  time: 0.2685  data: 0.1508  max mem: 2413
Test:  [3/4]  eta: 0:00:00  model_time: 0.1107 (0.1113)  evaluator_time: 0.0025 (0.0034)  time: 0.1543  data: 0.0389  max mem: 2413
Test: Total time: 0:00:00 (0.1633 s / it)
Averaged stats: model_time: 0.1107 (0.1113)  evaluator_time: 0.002

Epoch: [3]  [0/6]  eta: 0:00:04  lr: 0.000500  loss: 0.3636 (0.3636)  loss_classifier: 0.0416 (0.0416)  loss_box_reg: 0.0684 (0.0684)  loss_mask: 0.2462 (0.2462)  loss_objectness: 0.0024 (0.0024)  loss_rpn_box_reg: 0.0050 (0.0050)  time: 0.7916  data: 0.2846  max mem: 2413
Epoch: [3]  [5/6]  eta: 0:00:00  lr: 0.000500  loss: 0.2628 (0.2894)  loss_classifier: 0.0416 (0.0407)  loss_box_reg: 0.0589 (0.0623)  loss_mask: 0.1589 (0.1777)  loss_objectness: 0.0024 (0.0028)  loss_rpn_box_reg: 0.0050 (0.0059)  time: 0.5476  data: 0.0490  max mem: 2413
Epoch: [3] Total time: 0:00:03 (0.5522 s / it)
creating index...
index created!
Test:  [0/4]  eta: 0:00:01  model_time: 0.1274 (0.1274)  evaluator_time: 0.0132 (0.0132)  time: 0.2882  data: 0.1465  max mem: 2413
Test:  [3/4]  eta: 0:00:00  model_time: 0.1191 (0.1232)  evaluator_time: 0.0100 (0.0128)  time: 0.1739  data: 0.0369  max mem: 2413
Test: Total time: 0:00:00 (0.1808 s / it)
Averaged stats: model_time: 0.1191 (0.1232)  evaluator_time: 0.010

Epoch: [6]  [0/6]  eta: 0:00:04  lr: 0.000050  loss: 0.2379 (0.2379)  loss_classifier: 0.0365 (0.0365)  loss_box_reg: 0.0507 (0.0507)  loss_mask: 0.1434 (0.1434)  loss_objectness: 0.0042 (0.0042)  loss_rpn_box_reg: 0.0030 (0.0030)  time: 0.7387  data: 0.2417  max mem: 2413
Epoch: [6]  [5/6]  eta: 0:00:00  lr: 0.000050  loss: 0.2379 (0.2676)  loss_classifier: 0.0348 (0.0371)  loss_box_reg: 0.0507 (0.0642)  loss_mask: 0.1434 (0.1561)  loss_objectness: 0.0042 (0.0048)  loss_rpn_box_reg: 0.0037 (0.0054)  time: 0.5604  data: 0.0419  max mem: 2413
Epoch: [6] Total time: 0:00:03 (0.5691 s / it)
creating index...
index created!
Test:  [0/4]  eta: 0:00:01  model_time: 0.1324 (0.1324)  evaluator_time: 0.0152 (0.0152)  time: 0.2957  data: 0.1470  max mem: 2413
Test:  [3/4]  eta: 0:00:00  model_time: 0.1223 (0.1278)  evaluator_time: 0.0123 (0.0139)  time: 0.1808  data: 0.0379  max mem: 2413
Test: Total time: 0:00:00 (0.1898 s / it)
Averaged stats: model_time: 0.1223 (0.1278)  evaluator_time: 0.012

Epoch: [9]  [0/6]  eta: 0:00:04  lr: 0.000005  loss: 0.2393 (0.2393)  loss_classifier: 0.0313 (0.0313)  loss_box_reg: 0.0656 (0.0656)  loss_mask: 0.1335 (0.1335)  loss_objectness: 0.0023 (0.0023)  loss_rpn_box_reg: 0.0066 (0.0066)  time: 0.7595  data: 0.2422  max mem: 2413
Epoch: [9]  [5/6]  eta: 0:00:00  lr: 0.000005  loss: 0.2374 (0.2618)  loss_classifier: 0.0321 (0.0366)  loss_box_reg: 0.0527 (0.0630)  loss_mask: 0.1365 (0.1540)  loss_objectness: 0.0025 (0.0029)  loss_rpn_box_reg: 0.0047 (0.0053)  time: 0.5471  data: 0.0426  max mem: 2413
Epoch: [9] Total time: 0:00:03 (0.5524 s / it)
creating index...
index created!
Test:  [0/4]  eta: 0:00:01  model_time: 0.1325 (0.1325)  evaluator_time: 0.0160 (0.0160)  time: 0.3006  data: 0.1510  max mem: 2413
Test:  [3/4]  eta: 0:00:00  model_time: 0.1224 (0.1269)  evaluator_time: 0.0132 (0.0141)  time: 0.1800  data: 0.0381  max mem: 2413
Test: Total time: 0:00:00 (0.1871 s / it)
Averaged stats: model_time: 0.1224 (0.1269)  evaluator_time: 0.013

### Evaluation

In [13]:
# pick one image from the test set
img, _ = dataset_test[1]
# put model into evaluation mode
model.eval()
with torch.no_grad():
    prediction = model([img.to(device)])

In [None]:
#Let's inspect the image and the predicted segmentation masks.
Image.fromarray(img.mul(255).permute(1, 2, 0).byte().numpy())

In [None]:
Image.fromarray(prediction[0]['masks'][0, 0].mul(255).byte().cpu().numpy())