##Training SSD on a custom dataset

##1.Download the image dataset and clone the GitHub repository hosting the code for the model and other utilities for processing the data.

In [26]:
import os
if not os.path.exists('open-images-bus-trucks'):
  %pip install -q torch_snippets
  !wget --quiet https://www.dropbox.com/s/agmzwk95v96ihic/open-images-bus-trucks.tar.xz
  !tar -xf open-images-bus-trucks.tar.xz
  !rm open-images-bus-trucks.tar.xz
  !git clone https://github.com/sizhky/ssd-utils/
%cd ssd-utils

Cloning into 'ssd-utils'...
remote: Enumerating objects: 9, done.[K
remote: Counting objects:  11% (1/9)[Kremote: Counting objects:  22% (2/9)[Kremote: Counting objects:  33% (3/9)[Kremote: Counting objects:  44% (4/9)[Kremote: Counting objects:  55% (5/9)[Kremote: Counting objects:  66% (6/9)[Kremote: Counting objects:  77% (7/9)[Kremote: Counting objects:  88% (8/9)[Kremote: Counting objects: 100% (9/9)[Kremote: Counting objects: 100% (9/9), done.[K
remote: Compressing objects:  12% (1/8)[Kremote: Compressing objects:  25% (2/8)[Kremote: Compressing objects:  37% (3/8)[Kremote: Compressing objects:  50% (4/8)[Kremote: Compressing objects:  62% (5/8)[Kremote: Compressing objects:  75% (6/8)[Kremote: Compressing objects:  87% (7/8)[Kremote: Compressing objects: 100% (8/8)[Kremote: Compressing objects: 100% (8/8), done.[K
Receiving objects:  11% (1/9)Receiving objects:  22% (2/9)Receiving objects:  33% (3/9)Receiving objects:  44% (4/9)Receiving 

##2.Pre-process the data

In [27]:
from torch_snippets import *
import glob
import torch


DATA_ROOT = "../open-images-bus-trucks/"
IMAGE_ROOT = f"{DATA_ROOT}/images"
DF_RAW = df = pd.read_csv(f"{DATA_ROOT}/df.csv")

df = df[df['ImageID'].isin(df['ImageID'].unique().tolist())]

label2target = {l:t+1 for t,l in enumerate(DF_RAW['LabelName'].unique())}
label2target['background'] = 0
target2label = {t:l for l,t in label2target.items()}
background_class = len(label2target)
num_classes = len(label2target)

device = "cuda" if torch.cuda.is_available() else "cpu"

##3.Prepare the dataset class

In [34]:
import collections, os, torch
from PIL import Image
from torchvision import transforms

normalize = transforms.Normalize(
  mean=[0.485, 0.456, 0.406],
  std=[0.229, 0.224, 0.225]
)

denormalize = transforms.Normalize(
  mean=[-0.485/0.229, -0.456/0.224, -0.406/0.255],
  std=[1/0.229, 1/0.224, 1/0.225]
)


def preprocess_image(img):
  img = torch.tensor(img).permute(2,0,1)
  img = normalize(img)
  return img.to(device).float()



class OpenDataset(torch.utils.data.Dataset):
  w, h = 300, 300
  def __init__(self, df, image_dir=IMAGE_ROOT):
    self.image_dir = image_dir
    self.files = glob.glob(self.image_dir+'/*')
    self.df = df
    self.image_infos = df.ImageID.unique()
    logger.info(f'{len(self)} items loaded')

  def __getitem__(self, ix):
    # load images and masks
    image_id = self.image_infos[ix]
    img_path = find(image_id, self.files)
    img = Image.open(img_path).convert('RGB')
    img = np.array(img.resize((self.w, self.h), resample=Image.BILINEAR))/255.
    data = df[df['ImageID'] == image_id]
    labels = data['LabelName'].values.tolist()
    data = data[['XMin', 'YMin', 'XMax','YMax']].values
    data[:, [0,2]]*= self.w
    data[:, [1,3]]*= self.h
    boxes  = data.astype(np.uint32).tolist()  # convert to absolute coordinates
    return img, boxes, labels

  def collate_fn(self, batch):
    images, boxes, labels = [], [], []
    for item in batch:
      img, image_boxes, image_labels = item
      img = preprocess_image(img)[None]
      images.append(img)
      boxes.append(torch.tensor(image_boxes).float().to(device)/300.)
      labels.append(torch.tensor([label2target[c] for c in image_labels]).long().to(device))

    images = torch.cat(images).to(device)
    return images, boxes, labels


  def __len__(self):
    return len(self.image_infos)


##4.Prepare the training and test datasets and the dataloaders

In [35]:
from sklearn.model_selection import train_test_split
from torch.utils.data import DataLoader


trn_ids, val_ids = train_test_split(df.ImageID.unique(), test_size=0.1, random_state=99)
trn_df, val_df = df[df['ImageID'].isin(trn_ids)], df[df['ImageID'].isin(val_ids)]
len(trn_df), len(val_df)

train_ds = OpenDataset(trn_df)
test_ds = OpenDataset(val_df)

train_loader = DataLoader(train_ds, batch_size=4, collate_fn=train_ds.collate_fn, drop_last=True)
test_loader = DataLoader(test_ds, batch_size=4, collate_fn=test_ds.collate_fn, drop_last=True)

##5.Define functions to train on a batch of data and calculate the accuracy and loss values on the validation data

In [36]:
def train_batch(inputs, model, criterion, optimizer):
  model.train()
  N = len(train_loader)
  images, boxes, labels = inputs
  _regr, _clss = model(images)
  loss = criterion(_regr, _clss, boxes, labels)
  optimizer.zero_grad()
  loss.backward()
  optimizer.step()
  return loss


@torch.no_grad()
def validate_batch(inputs, model, criterion):
  model.eval()
  images, boxes, labels = inputs
  _regr, _clss = model(images)
  loss = criterion(_regr, _clss, boxes, labels)
  return loss

##6.Import the model

In [31]:
from model import SSD300, MultiBoxLoss
from detect import *

##7.Initialize the model, optimizer, and loss function.

In [37]:
from torch_snippets.torch_loader import Report

n_epochs = 3

model = SSD300(num_classes, device)
optimizer = torch.optim.AdamW(model.parameters(), lr=1e-4, weight_decay=1e-5)
criterion = MultiBoxLoss(priors_cxcy=model.priors_cxcy, device=device)

log = Report(n_epochs=n_epochs)
logs_to_print = 5


Loaded base model.



##8.Train the model over increasing epochs

In [None]:
for epoch in range(n_epochs):
  _n = len(train_loader)
  for ix, inputs in enumerate(train_loader):
    loss = train_batch(inputs, model, criterion, optimizer)
    pos = (epoch + (ix+1)/_n)
    log.record(pos, trn_loss=loss.item(), end='\r')

  _n = len(test_loader)
  for ix, inputs in enumerate(test_loader):
    loss = validate_batch(inputs, model, criterion)
    pos = (epoch + (ix+1)/_n)
    log.record(pos, val_loss=loss.item(), end='\r')

EPOCH: 0.107  trn_loss: 3.401  (5388.98s - 145487.71s remaining)