In [1]:
%config Completer.use_jedi = False

In [2]:
import os
import torch
from torch.utils.data import Dataset
from torch.utils.data import DataLoader
from PIL import Image
from torchvision import transforms
import torchvision
from pycocotools.coco import COCO

In [3]:
class simpleDataset(Dataset):
    def __init__(self, root, annotations, transforms=None):
        self.root = root
        self.transforms = transforms
        self.coco = COCO(annotations)
        self.ids = list(sorted(self.coco.imgs.keys()))

    def __getitem__(self, index):
        coco = self.coco
        
        img_id = self.ids[index]
        ann_ids = coco.getAnnIds(imgIds=img_id)
        coco_annotation = coco.loadAnns(ann_ids)

        path = coco.loadImgs(img_id)[0]['file_name']
        img = Image.open(os.path.join(self.root, path)).convert('RGB')
        num_objects = len(coco_annotation)

        boxes = []
        labels = []
        areas = []
        for i in range(num_objects):
            xmin = coco_annotation[i]['bbox'][0]
            ymin = coco_annotation[i]['bbox'][1]
            xmax = xmin + coco_annotation[i]['bbox'][2]
            ymax = ymin + coco_annotation[i]['bbox'][3]
            boxes.append([xmin, ymin, xmax, ymax])
            labels.append(coco_annotation[i]['category_id'])
            areas.append(coco_annotation[i]['area'])

        # Collect
        boxes = torch.as_tensor(boxes, dtype=torch.float32)
        labels = torch.as_tensor(labels, dtype=torch.int64)
        area = torch.as_tensor(areas, dtype=torch.float32)
        iscrowd = torch.zeros((num_objects,), dtype=torch.int64)
        image_id = torch.tensor([img_id])

        target = {
            'boxes' : boxes,
            'labels' : labels,
            'area' : area,
            'iscrowd' : iscrowd,
            'image_id' : image_id
        }

        if self.transforms is not None:
            img = self.transforms(img)

        return img, target

    def __len__(self):
        return len(self.ids)

In [4]:
def get_transform():
    custom_transforms = []
    custom_transforms.append(torchvision.transforms.ToTensor())
    return torchvision.transforms.Compose(custom_transforms)

In [5]:
train_dir = 'images'
train_coco = 'coco_annotations.json'

In [6]:
dataset = simpleDataset(
    root=train_dir,
    annotations=train_coco,
    transforms=get_transform()
)

loading annotations into memory...
Done (t=0.01s)
creating index...
index created!


In [7]:
def collate_fn(batch):
    return tuple(zip(*batch))

In [8]:
data_loader = DataLoader(dataset, batch_size=1, collate_fn=collate_fn, shuffle=True)

In [9]:
device = torch.device('cpu')
for imgs, annotations in data_loader:
    imgs = list(img.to(device) for img in imgs)
    annotations = [{k: v.to(device) for k, v in t.items()} for t in annotations]
    print(annotations)
    break

[{'boxes': tensor([[280., 781., 369., 897.],
        [411., 747., 481., 920.],
        [177., 829., 253., 992.],
        [ 28., 386., 123., 453.],
        [ 56., 459., 168., 538.]]), 'labels': tensor([2, 2, 2, 2, 2]), 'area': tensor([8270., 9702., 9918., 5127., 7099.]), 'iscrowd': tensor([0, 0, 0, 0, 0]), 'image_id': tensor([46177])}]


In [10]:
from torchvision.models.detection.faster_rcnn import FastRCNNPredictor

def get_model_instance_segmentation(num_classes):
    # load an instance segmentation model pre-trained pre-trained on COCO
    model = torchvision.models.detection.fasterrcnn_resnet50_fpn(pretrained=False)
    # get number of input features for the classifier
    in_features = model.roi_heads.box_predictor.cls_score.in_features
    # replace the pre-trained head with a new one
    model.roi_heads.box_predictor = FastRCNNPredictor(in_features, num_classes)

    return model

# 2 classes; Only target class or background
num_classes = 6
num_epochs = 10
model = get_model_instance_segmentation(num_classes)

# move model to the right device
model.to(device)

# parameters
params = [p for p in model.parameters() if p.requires_grad]
optimizer = torch.optim.SGD(params, lr=0.005, momentum=0.9, weight_decay=0.0005)

len_dataloader = len(data_loader)

for epoch in range(num_epochs):
    model.train()
    i = 0    
    for imgs, annotations in data_loader:
        i += 1
        imgs = list(img.to(device) for img in imgs)
        annotations = [{k: v.to(device) for k, v in t.items()} for t in annotations]
        loss_dict = model(imgs, annotations)
        losses = sum(loss for loss in loss_dict.values())

        optimizer.zero_grad()
        losses.backward()
        optimizer.step()

        print(f'Iteration: {i}/{len_dataloader}, Loss: {losses}')

Iteration: 1/42, Loss: 2.6182169914245605
Iteration: 2/42, Loss: 2.309131145477295
Iteration: 3/42, Loss: 1.7498996257781982
Iteration: 4/42, Loss: 1.1089553833007812
Iteration: 5/42, Loss: 1.7091799974441528
Iteration: 6/42, Loss: 0.7679508328437805
Iteration: 7/42, Loss: 1.3034652471542358
Iteration: 8/42, Loss: 1.0414786338806152
Iteration: 9/42, Loss: 0.39679285883903503
Iteration: 10/42, Loss: 0.31005871295928955
Iteration: 11/42, Loss: 0.37618982791900635
Iteration: 12/42, Loss: 0.11696583777666092
Iteration: 13/42, Loss: 0.12494420260190964
Iteration: 14/42, Loss: 0.5675291419029236
Iteration: 15/42, Loss: 0.3730897307395935
Iteration: 16/42, Loss: 0.5509160161018372
Iteration: 17/42, Loss: 0.5689661502838135
Iteration: 18/42, Loss: 0.19600819051265717
Iteration: 19/42, Loss: 0.3067000210285187
Iteration: 20/42, Loss: 0.30699560046195984


ValueError: All bounding boxes should have positive height and width. Found invalid box [282.8125, 149.21875, 342.96875, 107.8125] for target at index 0.