In [None]:
# Much code repurposed from the Pytorch tutorial
# "Torchvision Object Detection Finetuning Tutorial"
import sys

import torch
import torchvision
import nucleus

from PIL import Image
from torchvision.models.detection import FasterRCNN
from torchvision.models.detection.faster_rcnn import FastRCNNPredictor

import detection.utils as utils
import detection.engine as engine

In [None]:
API_KEY = 'your_nucleus_api_key'
DSET_SLICE = 'your_nucleus_dataset_slice'

In [None]:
class CustomDataset(torch.utils.data.Dataset):
    def __init__(self, transforms, data):
        self.transforms = transforms
        self.imgs = data
        # self.load_labels()

    def __getitem__(self, idx):
        # load images
        item = self.imgs[idx]["item"]
        annotations = self.imgs[idx]["annotations"]
        img_path = "~/path/to/image/" + item.metadata.get(
            "filename", "Img-3039.jpg"
        )
        with open(img_path, "rb") as file:
            img = Image.open(file).convert("RGB")

        boxes = []
        for anno in annotations["box"]:
            xmin = anno.x
            ymin = anno.y
            xmax = anno.x + anno.width
            ymax = anno.y + anno.height
            boxes.append([xmin, ymin, xmax, ymax])
        num_objs = len(boxes)
        boxes = torch.as_tensor(boxes, dtype=torch.float32)
        area = (boxes[:, 3] - boxes[:, 1]) * (boxes[:, 2] - boxes[:, 0])
        # labels = torch.as_tensor(label_ids, dtype=torch.int64)
        labels = torch.ones((num_objs,), dtype=torch.int64)  # horse (1) or background
        iscrowd = torch.zeros((num_objs,), dtype=torch.int64)
        image_id = torch.tensor([idx])
        target = {}
        target["boxes"] = boxes
        target["labels"] = labels
        target["image_id"] = image_id
        target["area"] = area
        target["iscrowd"] = iscrowd

        if self.transforms is not None:
            img, target = self.transforms(img, target)

        # must return:
        # @image: a PIL Image of size (H, W)
        # @target: a dict containing the following fields
        #   boxes (FloatTensor[N, 4]): coordinates of N bounding boxes in [x0,y0,x1,y1] format ranging from 0-W and 0-H
        #   labels (Int64Tensor[N]): the label for each bounding box, 0 represents background class
        #   image_id (Int64Tensor[1]): an image identifier
        #   area (Tensor[N]): the area of the bounding box
        #   iscrowd (UInt8Tensor[N]): instances with iscrowd=True will be ignored during eval
        return img, target

    def __len__(self):
        return len(self.imgs)


In [None]:
from typing import Optional, Dict, Tuple
from torch import Tensor, nn
from torchvision.transforms import functional
from torchvision.transforms import transforms as T

def _flip_coco_person_keypoints(kps, width):
    flip_inds = [0, 2, 1, 4, 3, 6, 5, 8, 7, 10, 9, 12, 11, 14, 13, 16, 15]
    flipped_data = kps[:, flip_inds]
    flipped_data[..., 0] = width - flipped_data[..., 0]
    # Maintain COCO convention that if visibility == 0, then x, y = 0
    inds = flipped_data[..., 2] == 0
    flipped_data[inds] = 0
    return flipped_data


class ToTensor(nn.Module):
    def forward(
        self, image: Tensor, target: Optional[Dict[str, Tensor]] = None
    ) -> Tuple[Tensor, Optional[Dict[str, Tensor]]]:
        image = functional.pil_to_tensor(image)
        image = functional.convert_image_dtype(image)
        return image, target


class RandomHorizontalFlip(T.RandomHorizontalFlip):
    def forward(
        self, image: Tensor, target: Optional[Dict[str, Tensor]] = None
    ) -> Tuple[Tensor, Optional[Dict[str, Tensor]]]:
        if torch.rand(1) < self.p:
            image = functional.hflip(image)
            if target is not None:
                width, _ = functional.get_image_size(image)
                target["boxes"][:, [0, 2]] = width - target["boxes"][:, [2, 0]]
                if "masks" in target:
                    target["masks"] = target["masks"].flip(-1)
                if "keypoints" in target:
                    keypoints = target["keypoints"]
                    keypoints = _flip_coco_person_keypoints(keypoints, width)
                    target["keypoints"] = keypoints
        return image, target


class Compose:
    def __init__(self, transforms):
        self.transforms = transforms

    def __call__(self, image, target):
        for t in self.transforms:
            image, target = t(image, target)
        return image, target

In [None]:
def get_raw_horse_data():
    client = nucleus.NucleusClient(API_KEY)
    data = client.get_slice(DSET_SLICE)
    ia = data.items_and_annotations()
    filtered = list(filter(lambda row: row["item"].metadata.get("filename"), ia))
    print(len(filtered))
    return filtered


def get_transform(train):
    transforms = []
    transforms.append(T.ToTensor())
    if train:
        transforms.append(T.RandomHorizontalFlip(0.5))
    return T.Compose(transforms)


def get_model():
    # load model pre-trained on COCO
    model = torchvision.models.detection.fasterrcnn_resnet50_fpn(pretrained=True)
    num_classes = 2  # 1 class (horse) + background
    input_features = model.roi_heads.box_predictor.cls_score.in_features
    model.roi_heads.box_predictor = FastRCNNPredictor(input_features, num_classes)
    return model

In [None]:
def train(model, num_epochs):
    raw_data = get_raw_horse_data()
    lengths = []
    cur_length = len(raw_data)
    while cur_length >= 50:
        lengths.append(cur_length)
        cur_length -= 50
    device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
    print(lengths)
    dataset = CustomDataset(get_transform(train=True), raw_data)
    dataset_test = CustomDataset(get_transform(train=False), raw_data)
    indices = list(range(len(dataset)))
    dataset = torch.utils.data.Subset(dataset, indices[:-50])
    dataset_test = torch.utils.data.Subset(dataset_test, indices[-50:])
    # define data loaders
    data_loader = torch.utils.data.DataLoader(
        dataset, batch_size=2, shuffle=True, num_workers=4, collate_fn=utils.collate_fn
    )
    data_loader_test = torch.utils.data.DataLoader(
        dataset_test, batch_size=1, shuffle=False, num_workers=4, collate_fn=utils.collate_fn
    )
    model.to(device)

    # construct optimizer
    params = [p for p in model.parameters() if p.requires_grad]
    optimizer = torch.optim.SGD(params, lr=0.005, momentum=0.9, weight_decay=0.0005)
    lr_scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=3, gamma=0.1)

    # train for num_epochs
    for epoch in range(num_epochs):
        engine.train_one_epoch(model, optimizer, data_loader, device, epoch, print_freq=10)
        lr_scheduler.step()
        if epoch == 4 or epoch == 9 or epoch == 14:
            path = f"./model_checkpoints/model_epoch_{epoch}.pkl"
            with open(path, "wb+") as file:
                torch.save(model.state_dict(), file)

    with open("./model_checkpoints/eval_output.log", "a+") as f:
        orig_target = sys.stdout
        sys.stdout = f
        engine.evaluate(model, data_loader_test, device=device)
        print("=======\n")
        sys.stdout = orig_target

In [None]:
model = get_model()

In [None]:
train(model)