In [1]:
#@ Downloading necessary libraries and dependencies:
import os

if not os.path.exists('open-images-bus-trucks'):
  !pip install -q torch_snippets
  !wget --quiet https://www.dropbox.com/s/agmzwk95v96ihic/open-images-bus-trucks.tar.xz
  !tar -xf open-images-bus-trucks.tar.xz
  !rm open-images-bus-trucks.tar.xz
  !git clone https://github.com/sizhky/ssd-utils/
%cd ssd-utils

[?25l     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m0.0/78.6 kB[0m [31m?[0m eta [36m-:--:--[0m[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m78.6/78.6 kB[0m [31m3.6 MB/s[0m eta [36m0:00:00[0m
[?25h  Preparing metadata (setup.py) ... [?25l[?25hdone
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m103.0/103.0 kB[0m [31m2.3 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m82.7/82.7 kB[0m [31m3.9 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m119.4/119.4 kB[0m [31m10.1 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m62.5/62.5 kB[0m [31m4.6 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m218.7/218.7 kB[0m [31m18.5 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m162.6/162.6 kB[0m [31m15.3 MB/s[0m eta [36m

In [2]:
#@ Data Processing:
from torch_snippets import *
DATA_ROOT = '../open-images-bus-trucks/'
IMAGE_ROOT=f'{DATA_ROOT}/images'
DF_RAW=pd.read_csv(f'{DATA_ROOT}/df.csv')
df=DF_RAW.copy()
df=df[df['ImageID'].isin(df['ImageID'].unique().tolist())]
label2target={l:t+1 for t,l in enumerate(DF_RAW['LabelName'].unique())}
label2target['background']=0
target2label={t:l for l, t in label2target.items()}
background_class=label2target['background']
num_classes=len(label2target)

In [3]:
import torch
device='cuda' if torch.cuda.is_available() else 'cpu'

In [5]:
#@ Preparing Data:
import collections
from PIL import Image
from torchvision import transforms
import glob

normalize=transforms.Normalize(
             mean=[0.485, 0.456, 0.406],
             std=[0.229, 0.224, 0.225]
)

denormalize=transforms.Normalize(
    mean=[-0.485/0.229, 0.456/0.224, 0.406/0.225],
    std=[1/0.229, 1/0.224, 1/0.255]
)

def preprocess_image(img):
  img=torch.tensor(img).permute(2, 0, 1)
  img=normalize(img)
  return img.to(device).float()

class OpenDataset(torch.utils.data.Dataset):
  w, h= 300, 300
  def __init__(self, df, image_dir=IMAGE_ROOT):
    self.image_dir=image_dir
    self.files=glob.glob(self.image_dir+'/')
    self.df=df
    self.image_infos=df.ImageID.unique()
    logger.info(f'{len(self)} items loaded')

  def __getitem___(self, ix):
    image_id=self.image_infos[ix]
    img_path=find(image_id, self.files)
    img=Image.open(img_path).convert("RGB")
    img=np.array(img.resize((self.w, self.h), resample=Image.BILINEAR))/255.
    data=df[df['ImageID']==image_id]
    labels=data['LabelName'].values.tolist()
    data=data[['XMin', 'YMin', 'XMax', 'YMax']].values
    data[:, [0, 2]] *= self.w
    data[:, [1, 3]] *= self.h
    boxes=data.astype(np.uint32).tolist()
    return img, boxes, labels

  def collate_fn(self, batch):
    images, boxes, labels=[], [], []
    for item in batch:
      img, image_boxes, image_labels=item
      img=preprocess_image(img)[None]
      images.append(img)

      boxes.append(torch.tensor(image_boxes).float().to(device)/300)
      labels.append(torch.tensor([label2target[c] for c in image_labels]).long().to(device))
      images=torch.cat(images).to(device)

      return images, boxes, labels


  def __len__(self):
    return len(self.image_infos)


In [7]:
from torch.utils.data import DataLoader

In [9]:
from sklearn.model_selection import train_test_split

train_ids, val_ids=train_test_split(df.ImageID.unique(), test_size=0.1, random_state=99)
train_df, val_df=df[df['ImageID'].isin(train_ids)], df[df['ImageID'].isin(val_ids)]

train_ds=OpenDataset(train_df)
test_ds=OpenDataset(val_df)

train_loader=DataLoader(train_ds, batch_size=4, collate_fn=train_ds.collate_fn, drop_last=True)
test_loader=DataLoader(train_ds, batch_size=4, collate_fn=test_ds.collate_fn, drop_last=True)

In [10]:
def train_batch(inputs, model, criterion, optimizer):
  model.train()
  N= len(train_loader)
  images, boxes, labels=inputs
  _regr, _clss=model(images)
  loss=criterion(_regr, _clss, boxes, labels)
  optimizer.zero_grad()
  loss.backward()
  optimizer.step()
  return loss

@torch.no_grad()
def validate_batch(inputs, model, criterion):
  model.eval()
  images, boxes, labels=inputs
  _regr, _clss=model(images)
  loss=criterion(_regr, _clss, boxes, labels)
  return loss


In [23]:
#@ Import Model:
from model import SSD300, MultiBoxLoss
from detect import *

In [25]:
#@ training:
n_epochs=5

model=SSD300(num_classes, device)
optimizer=torch.optim.Adam(model.parameters(), lr=1e-4, weight_decay=1e-5)
criterion=MultiBoxLoss(priors_cxcy=model.priors_cxcy, device=device)

# log=Report(n_epochs=n_epochs)
# logs_to_print=5

for epoch in range(n_epochs):
  _n= len(train_loader)
  for ix, inputs in enumerate(train_loader):
    loss=train_batch(inputs, model, criterion, optimizer)
    pos=(epoch + (ix+1)/_n)
    log.record(pos, train_loss=loss.item(), end='\r')

for epoch in range(n_epochs):
  _n= len(test_loader)
  for ix, inputs in enumerate(test_loader):
    loss=validate_batch(inputs, model, criterion)
    pos=(epoch + (ix+1)/_n)
    log.record(pos, val_loss=loss.item(), end='\r')


Loaded base model.



NotImplementedError: Subclasses of Dataset should implement __getitem__.