In [None]:
%%shell
git clone https://github.com/pytorch/vision.git
cd vision
cp references/detection/coco_eval.py ../
cp references/detection/coco_utils.py ../
cp references/detection/utils.py ../
cp references/detection/transforms.py ../
cp references/detection/engine.py ../

In [8]:
%%writefile data_setup.py
import torch
from torchvision.datasets import CocoDetection
from torchvision.transforms import transforms

class CocoDataset(CocoDetection):
  def __init__(self, root, annFile, transform=None, target_transform=None) -> None:
    super().__init__(root, annFile, transform, target_transform)

  def __getitem__(self, index: int):
    img, ori_target = super().__getitem__(index)

    num_objs = len(ori_target)
    boxes = []
    labels = []
    area = []
    iscrowd = []

    for i in range(num_objs):
      x_min = max(0, ori_target[i]['bbox'][0])
      y_min = max(0, ori_target[i]['bbox'][1])
      x_max = min(4908, x_min + ori_target[i]['bbox'][2])
      y_max = min(3264, y_min + ori_target[i]['bbox'][3])
      boxes.append([x_min, y_min, x_max, y_max])
      labels.append(ori_target[i]['category_id'])
      area.append(ori_target[i]['area'])
      iscrowd.append(ori_target[i]['iscrowd'])

    target = {}
    target['boxes'] = torch.as_tensor(boxes, dtype=torch.float32)
    target['labels'] = torch.as_tensor(labels, dtype=torch.int64)
    target['image_id'] = torch.tensor([ori_target[0]['image_id']])
    target['area'] = torch.as_tensor(area, dtype=torch.float32)
    target['iscrowd'] = torch.as_tensor(iscrowd, dtype=torch.int64)

    return img, target


def get_transform(train):
  if train:
    return transforms.Compose([
      transforms.ToTensor(),
      transforms.RandomHorizontalFlip(0.5),
    ])
  else:
    return transforms.Compose([
      transforms.ToTensor(),
    ])

def split_dataset(dataset_full, train_ratio):
  num_train = int(len(dataset_full) * train_ratio)
  num_test = len(dataset_full) - num_train

  dataset_train, dataset_test = torch.utils.data.random_split(dataset_full, [num_train, num_test])

  print(f'Size of the train set: {len(dataset_train)}')
  print(f'Size of the test set: {len(dataset_test)}')

  return dataset_train, dataset_test


def collate_fn(batch):
  return tuple(zip(*batch))


def create_datasets(
  root: str,
  annFile: str,
  train_ratio: float,
):
  dataset_full = CocoDataset(
    root = root,
    annFile = annFile,
    transform = get_transform(train=True),
  )

  dataset_train, dataset_test = split_dataset(dataset_full, train_ratio=train_ratio)

  return dataset_train, dataset_test


def create_dataloaders(
  dataset_train: torch.utils.data.Dataset,
  dataset_test: torch.utils.data.Dataset,
  batch_size: int,
  num_workers: int,
):
  train_dataloader = torch.utils.data.DataLoader(
    dataset_train,
    batch_size = batch_size,
    shuffle = True,
    num_workers = num_workers,
    collate_fn = collate_fn,
  )

  test_dataloader = torch.utils.data.DataLoader(
    dataset_test,
    batch_size = batch_size,
    shuffle = False,
    num_workers = num_workers,
    collate_fn = collate_fn,
  )

  return train_dataloader, test_dataloader

Overwriting data_setup.py


In [9]:
%%writefile model_builder.py
import torchvision
from torchvision.models.detection.faster_rcnn import FastRCNNPredictor

def get_FasterRCNN_model(num_classes, feature_extract=True):
  weights = torchvision.models.detection.FasterRCNN_ResNet50_FPN_V2_Weights.COCO_V1
  model = torchvision.models.detection.fasterrcnn_resnet50_fpn_v2(weights=weights)

  if feature_extract:
    for param in model.parameters():
      param.requires_grad = False

  in_features = model.roi_heads.box_predictor.cls_score.in_features
  model.roi_heads.box_predictor = FastRCNNPredictor(in_features, num_classes)

  return model

Overwriting model_builder.py


In [10]:
%%writefile custom_utils.py
import torch
import torchvision
import torch.nn as nn
from torchvision.transforms import transforms
import matplotlib.pyplot as plt
import matplotlib.patches as patches
import random
from pathlib import Path

def apply_nms(orig_prediction, iou_thresh=0.3):
  keep = torchvision.ops.nms(orig_prediction['boxes'].cpu(), orig_prediction['scores'].cpu(), iou_thresh)
  
  final_prediction = orig_prediction
  final_prediction['boxes'] = final_prediction['boxes'].cpu()[keep]
  final_prediction['scores'] = final_prediction['scores'].cpu()[keep]
  final_prediction['labels'] = final_prediction['labels'].cpu()[keep]
  
  return final_prediction


def torch_to_pil(img):
  return transforms.ToPILImage()(img).convert('RGB')


def plot_img_bbox(img, target, num_classes):
  fig, a = plt.subplots(1, 1)
  fig.set_size_inches(10, 10)
  a.imshow(img)

  for i in range(len(target['boxes'])):
    box = target['boxes'][i]
    label = int(target['labels'][i])

    cmap = plt.cm.get_cmap('hsv', num_classes+1)

    x, y, width, height  = box[0], box[1], box[2]-box[0], box[3]-box[1]

    rect = patches.Rectangle(
      (x, y),
      width, height,
      linewidth = 2,
      edgecolor = cmap(label),
      facecolor = 'none'
    )

    a.add_patch(rect)
  plt.show()


def inference_and_plot(
  dataset: torch.utils.data.Dataset,
  model: nn.Module,
  device: str,
  iou_thresh: float,
  num_classes: int,
):
  random_idx = random.randint(0, len(dataset)-1)
  img, target = dataset[random_idx]

  model.eval()
  with torch.inference_mode():
    prediction = model([img.to(device)])[0]

  nms_prediction = apply_nms(prediction, iou_thresh=iou_thresh)

  plot_img_bbox(torch_to_pil(img), nms_prediction, num_classes)


def save_model(
  model: torch.nn.Module,
  target_path: str,
  model_name: str,
):
  assert model_name.endswith('pth') or model_name.endswith('.pt'), "[Invalid model name]: model_name should end with '.pth' or '.pt'."

  target_path = Path(target_path)
  target_path.mkdir(parents=True, exist_ok=True)

  torch.save(
    obj = model.state_dict(),
    f = target_path / model_name,
  )

Overwriting custom_utils.py


In [11]:
%%writefile train.py
import argparse
import torch
from data_setup import create_datasets, create_dataloaders
from model_builder import get_FasterRCNN_model
from custom_utils import save_model
from engine import train_one_epoch, evaluate

parser = argparse.ArgumentParser()
parser.add_argument('--root', type=str, required=True)
parser.add_argument('--annfile', type=str, required=True)
parser.add_argument('--target_path', type=str, default='saved_models')
parser.add_argument('--model_name', type=str, default='faster_rcnn_v1.pth')
parser.add_argument('--num_classes', type=int, default=5)
parser.add_argument('--train_ratio', type=float, default=0.8)
parser.add_argument('--batch_size', type=int, default=2)
parser.add_argument('--num_workers', type=int, default=1)
parser.add_argument('--print_freq', type=int, default=1)
parser.add_argument('--lr', type=float, default=1e-3)
parser.add_argument('--epochs', type=int, default=5)
parser.add_argument('--fe', type=bool, default=True)
args = parser.parse_args()

DEVICE = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
ROOT = args.root
ANNFILE = args.annfile
TARGET_PATH = args.target_path
MODEL_NAME = args.model_name
NUM_CLASSES = args.num_classes
TRAIN_RATIO = args.train_ratio
BATCH_SIZE = args.batch_size
NUM_WORKERS = args.num_workers
PRINT_FREQ = args.print_freq
LEARNING_RATE = args.lr
EPOCHS = args.epochs
FEATURE_EXTRACT = args.fe

dataset_train, dataset_test = create_datasets(
  root = ROOT,
  annFile = ANNFILE,
  train_ratio = TRAIN_RATIO,
)

train_dataloader, test_dataloader = create_dataloaders(
  dataset_train = dataset_train,
  dataset_test = dataset_test,
  batch_size = BATCH_SIZE,
  num_workers = NUM_WORKERS,
)

model = get_FasterRCNN_model(num_classes=NUM_CLASSES, feature_extract=FEATURE_EXTRACT)
model.to(DEVICE)

optimizer = torch.optim.AdamW(model.parameters(), lr=LEARNING_RATE)

for epoch in range(EPOCHS):
  train_one_epoch(model, optimizer, train_dataloader, DEVICE, epoch, print_freq=PRINT_FREQ)
  evaluate(model, test_dataloader, device=DEVICE)

save_model(
  model=model,
  target_path=TARGET_PATH,
  model_name=MODEL_NAME,
)

Overwriting train.py


In [12]:
!python train.py --root '/content/drive/MyDrive/data/algae_toydata/algae_toydata_1' --annfile '/content/drive/MyDrive/data/algae_toydata/annotations/algae_toydata_1.json'

loading annotations into memory...
Done (t=1.39s)
creating index...
index created!
Size of the train set: 3
Size of the test set: 1
Downloading: "https://download.pytorch.org/models/fasterrcnn_resnet50_fpn_v2_coco-dd69338a.pth" to /root/.cache/torch/hub/checkpoints/fasterrcnn_resnet50_fpn_v2_coco-dd69338a.pth
100% 167M/167M [00:03<00:00, 44.7MB/s]
Epoch: [0]  [0/2]  eta: 0:00:21  lr: 0.001000  loss: 5.2356 (5.2356)  loss_classifier: 1.5061 (1.5061)  loss_box_reg: 0.1950 (0.1950)  loss_objectness: 3.1654 (3.1654)  loss_rpn_box_reg: 0.3691 (0.3691)  time: 10.9347  data: 4.2959  max mem: 1416
Epoch: [0]  [1/2]  eta: 0:00:05  lr: 0.001000  loss: 2.2653 (3.7504)  loss_classifier: 1.5061 (1.5085)  loss_box_reg: 0.1950 (0.2171)  loss_objectness: 0.4963 (1.8308)  loss_rpn_box_reg: 0.0188 (0.1940)  time: 5.5897  data: 2.1727  max mem: 1416
Epoch: [0] Total time: 0:00:11 (5.6006 s / it)
Test:  [0/1]  eta: 0:00:01  model_time: 0.1510 (0.1510)  evaluator_time: 0.0080 (0.0080)  time: 1.7537  data: 