In [1]:
import os
import numpy as np
import torch
import torch.utils.data
from PIL import Image
import json

class CreateDataset(torch.utils.data.Dataset):
    def __init__(self, info_dir, img_dir, transforms=None):
        with open(info_dir) as f:
            info = json.load(f)
            self.data_info = info["annotations"]

        self.img_dir = img_dir
        self.transforms = transforms

    def __getitem__(self, idx):
        img_path = os.path.join(self.img_dir, str(self.data_info[idx]["id"])+".png")
        img = Image.open(img_path)
        
        info = self.data_info[idx]

        image_id = info["id"]
        num_objs = len(info["category_id"])
        iscrowd = info["iscrowd"]

        boxes = []
        labels = []
        for i in range(num_objs):
            box = info["bbox"][i]
            xmin = np.min(box[0])
            xmax = np.max(box[0]+box[2])
            ymin = np.min(box[1])
            ymax = np.max(box[1]+box[3])
            boxes.append([xmin, ymin, xmax, ymax])

            labels.append(info["category_id"][i]-1)                 

        image_id = torch.tensor([image_id])
        boxes = torch.as_tensor(boxes, dtype=torch.float32)
        labels = torch.as_tensor(labels, dtype=torch.int64)
        iscrowd = torch.zeros((num_objs,), dtype=torch.int64)

        area = (boxes[:, 3] - boxes[:, 1]) * (boxes[:, 2] - boxes[:, 0])
        target = {}
        target["boxes"] = boxes
        target["labels"] = labels
        target["image_id"] = image_id
        target["area"] = area
        target["iscrowd"] = iscrowd

        if self.transforms is not None:
            img, target = self.transforms(img, target)

        return img, target

    def __len__(self):
        return len(self.data_info)

In [2]:
dataset = CreateDataset('datasets/info_all.json', 'datasets/images')
dataset[150]

(<PIL.PngImagePlugin.PngImageFile image mode=RGB size=1080x1920 at 0x7FD7CC6796D0>,
 {'boxes': tensor([[ 417.5385, 1134.7693,  686.7692, 1320.9231],
          [ 322.1538, 1325.5385,  746.7692, 1917.8462]]),
  'labels': tensor([3, 0]),
  'image_id': tensor([20201111123018241]),
  'area': tensor([ 50118.3320, 251502.9844]),
  'iscrowd': tensor([0, 0])})

In [3]:
from engine import train_one_epoch, evaluate
import utils
import transforms as T

#TODO
def get_transform(train):
    transforms = []
    # converts the image, a PIL image, into a PyTorch Tensor
    transforms.append(T.ToTensor())
    if train:
        transforms.append(T.RandomHorizontalFlip(0.5))
    return T.Compose(transforms)

In [4]:
train_dataset = CreateDataset('datasets/info_all.json', 'datasets/images', get_transform(train=True))
val_dataset = CreateDataset('datasets/info_all.json', 'datasets/images', get_transform(train=False))

torch.manual_seed(1)

indices = torch.randperm(len(train_dataset)).tolist()
split_idx = int(0.1*len(train_dataset))
train_dataset = torch.utils.data.Subset(train_dataset, indices[:-split_idx])
val_dataset = torch.utils.data.Subset(val_dataset, indices[-split_idx:])

# define training and validation data loaders
BATCH_SIZE = 16
train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True, num_workers=4, collate_fn=utils.collate_fn, drop_last=True) 
val_loader = torch.utils.data.DataLoader(val_dataset, batch_size=BATCH_SIZE, shuffle=False, num_workers=4, collate_fn=utils.collate_fn, drop_last=False)


In [5]:
## test code

# data_loader = torch.utils.data.DataLoader(train_dataset, batch_size=2, shuffle=True, num_workers=4, collate_fn=utils.collate_fn)
# images, targets = next(iter(data_loader))
# images = list(image for image in images)
# targets = [{k: v for k, v in t.items()} for t in targets]
# output = model(images,targets)   # Returns losses and detections

# model.eval()
# x = [torch.rand(3, 300, 400), torch.rand(3, 500, 400)]
# predictions = model(x)

In [6]:
import torchvision
from torchvision.models.detection.faster_rcnn import FastRCNNPredictor

num_classes = 4
device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')

model = torchvision.models.detection.fasterrcnn_resnet50_fpn(pretrained=True)
in_features = model.roi_heads.box_predictor.cls_score.in_features
model.roi_heads.box_predictor = FastRCNNPredictor(in_features, num_classes) 
model.to(device)

# construct an optimizer
params = [p for p in model.parameters() if p.requires_grad]
optimizer = torch.optim.SGD(params, lr=0.005, momentum=0.9, weight_decay=0.0005)
lr_scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, factor=0.1, patience=2, verbose=True)

In [7]:
LOAD_MODEL = False
LOAD_MODEL_FILE = "model/default.pth"

if LOAD_MODEL:
    print("=> Loading checkpoint")
    checkpoint = torch.load(LOAD_MODEL_FILE)
    model.load_state_dict(checkpoint["state_dict"])
    optimizer.load_state_dict(checkpoint["optimizer"])

num_epochs = 10
for epoch in range(num_epochs):
    train_one_epoch(model, optimizer, train_loader, device, epoch, print_freq=10)
    lr_scheduler.step()
    evaluate(model, val_loader, device=device)



In [None]:
print("saving model")
save_model = {"state_dict": model.state_dict(), "optimizer": optimizer.state_dict()}
torch.save(save_model, LOAD_MODEL_FILE)

In [None]:
img, _ = val_dataset[2]

model.eval()
with torch.no_grad():
    prediction = model([img.to(device)])

print(prediction)