In [13]:
%config Completer.use_jedi = False

In [14]:
import os
import torch
from torch.utils.data import Dataset
from torch.utils.data import DataLoader
from PIL import Image
from torchvision import transforms
import torchvision
from pycocotools.coco import COCO
from tqdm import tqdm

In [15]:
class simpleDataset(Dataset):
    def __init__(self, root, annotations, transforms=None):
        self.root = root
        self.transforms = transforms
        self.coco = COCO(annotations)
        self.ids = list(sorted(self.coco.imgs.keys()))

    def __getitem__(self, index):
        coco = self.coco
        
        img_id = self.ids[index]
        ann_ids = coco.getAnnIds(imgIds=img_id)
        coco_annotation = coco.loadAnns(ann_ids)

        path = coco.loadImgs(img_id)[0]['file_name']
        img = Image.open(os.path.join(self.root, path)).convert('RGB')
        num_objects = len(coco_annotation)

        boxes = []
        labels = []
        areas = []
        for i in range(num_objects):
            xmin = coco_annotation[i]['bbox'][0]
            ymin = coco_annotation[i]['bbox'][1]
            xmax = xmin + coco_annotation[i]['bbox'][2]
            ymax = ymin + coco_annotation[i]['bbox'][3]
            boxes.append([xmin, ymin, xmax, ymax])
            labels.append(coco_annotation[i]['category_id'])
            areas.append(coco_annotation[i]['area'])

        # Collect
        boxes = torch.as_tensor(boxes, dtype=torch.float32)
        labels = torch.as_tensor(labels, dtype=torch.int64)
        area = torch.as_tensor(areas, dtype=torch.float32)
        iscrowd = torch.zeros((num_objects,), dtype=torch.int64)
        image_id = torch.tensor([img_id])

        target = {
            'boxes' : boxes,
            'labels' : labels,
            'area' : area,
            'iscrowd' : iscrowd,
            'image_id' : image_id
        }

        if self.transforms is not None:
            img = self.transforms(img)

        return img, target

    def __len__(self):
        return len(self.ids)

In [16]:
def get_transform():
    custom_transforms = []
    custom_transforms.append(torchvision.transforms.ToTensor())
    return torchvision.transforms.Compose(custom_transforms)

In [17]:
train_dir = 'coco/train2017'
train_coco = 'coco/annotations/instances_train2017.json'

In [18]:
dataset = simpleDataset(
    root=train_dir,
    annotations=train_coco,
    transforms=get_transform()
)

loading annotations into memory...
Done (t=0.00s)
creating index...
index created!


In [19]:
def collate_fn(batch):
    return tuple(zip(*batch))

In [20]:
data_loader = DataLoader(dataset, batch_size=1, collate_fn=collate_fn, shuffle=True)

In [21]:
device = torch.device('cpu')
for imgs, annotations in data_loader:
    imgs = list(img.to(device) for img in imgs)
    annotations = [{k: v.to(device) for k, v in t.items()} for t in annotations]
    print(annotations)
    break

[{'boxes': tensor([[437., 395., 561., 520.]]), 'labels': tensor([1]), 'area': tensor([12370.]), 'iscrowd': tensor([0]), 'image_id': tensor([49836])}]


In [22]:
from torchvision.models.detection.faster_rcnn import FastRCNNPredictor

def get_model_instance_segmentation(num_classes):
    # load an instance segmentation model pre-trained pre-trained on COCO
    model = torchvision.models.detection.fasterrcnn_resnet50_fpn(pretrained=False)
    # get number of input features for the classifier
    in_features = model.roi_heads.box_predictor.cls_score.in_features
    # replace the pre-trained head with a new one
    model.roi_heads.box_predictor = FastRCNNPredictor(in_features, num_classes)

    return model

# 2 classes; Only target class or background
num_classes = 6
num_epochs = 2
model = get_model_instance_segmentation(num_classes)

# move model to the right device
model.to(device)

# parameters
params = [p for p in model.parameters() if p.requires_grad]
optimizer = torch.optim.SGD(params, lr=0.005, momentum=0.9, weight_decay=0.0005)

len_dataloader = len(data_loader)

for epoch in range(num_epochs):
    model.train()
    i = 0    
    for imgs, annotations in data_loader:
        i += 1
        imgs = list(img.to(device) for img in imgs)
        annotations = [{k: v.to(device) for k, v in t.items()} for t in annotations]
        loss_dict = model(imgs, annotations)
        losses = sum(loss for loss in loss_dict.values())

        optimizer.zero_grad()
        losses.backward()
        optimizer.step()

        print(f'Iteration: {i}/{len_dataloader}, Loss: {losses}')

Iteration: 1/33, Loss: 2.432880163192749
Iteration: 2/33, Loss: 2.1718130111694336
Iteration: 3/33, Loss: 1.4464969635009766
Iteration: 4/33, Loss: 1.0309903621673584
Iteration: 5/33, Loss: 0.6299889087677002
Iteration: 6/33, Loss: 0.4648781418800354
Iteration: 7/33, Loss: 0.2493247389793396
Iteration: 8/33, Loss: 0.6392616033554077
Iteration: 9/33, Loss: 0.4481787383556366
Iteration: 10/33, Loss: 0.5074942708015442
Iteration: 11/33, Loss: 0.3465309739112854
Iteration: 12/33, Loss: 0.13202057778835297
Iteration: 13/33, Loss: 0.32141566276550293
Iteration: 14/33, Loss: 0.43179285526275635
Iteration: 15/33, Loss: 0.11503937840461731
Iteration: 16/33, Loss: 0.241445854306221
Iteration: 17/33, Loss: 0.42009228467941284
Iteration: 18/33, Loss: 0.2091321051120758
Iteration: 19/33, Loss: 1.0022833347320557
Iteration: 20/33, Loss: 0.27452021837234497
Iteration: 21/33, Loss: 0.43044283986091614
Iteration: 22/33, Loss: 0.9497078657150269
Iteration: 23/33, Loss: 0.7936545610427856
Iteration: 24/3

In [24]:
#torch.save(model, 'models/one_image_fast_rcnn_train.pth')

In [25]:
model.eval()

FasterRCNN(
  (transform): GeneralizedRCNNTransform(
      Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
      Resize(min_size=(800,), max_size=1333, mode='bilinear')
  )
  (backbone): BackboneWithFPN(
    (body): IntermediateLayerGetter(
      (conv1): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
      (bn1): FrozenBatchNorm2d(64, eps=1e-05)
      (relu): ReLU(inplace=True)
      (maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
      (layer1): Sequential(
        (0): Bottleneck(
          (conv1): Conv2d(64, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)
          (bn1): FrozenBatchNorm2d(64, eps=1e-05)
          (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
          (bn2): FrozenBatchNorm2d(64, eps=1e-05)
          (conv3): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
          (bn3): FrozenBatchNorm2d(256, eps=1e-05)
          (relu