In [1]:
#@ Downloading necessary libraries and dependencies:
import os

if not os.path.exists('open-images-bus-trucks'):
  !pip install -q torch_snippets
  !wget --quiet https://www.dropbox.com/s/agmzwk95v96ihic/open-images-bus-trucks.tar.xz
  !tar -xf open-images-bus-trucks.tar.xz
  !rm open-images-bus-trucks.tar.xz
  !git clone https://github.com/sizhky/ssd-utils/
%cd ssd-utils

[?25l     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m0.0/78.6 kB[0m [31m?[0m eta [36m-:--:--[0m[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m78.6/78.6 kB[0m [31m7.4 MB/s[0m eta [36m0:00:00[0m
[?25h  Preparing metadata (setup.py) ... [?25l[?25hdone
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m103.0/103.0 kB[0m [31m9.8 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m82.7/82.7 kB[0m [31m7.3 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m119.4/119.4 kB[0m [31m10.2 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m62.5/62.5 kB[0m [31m5.8 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m218.7/218.7 kB[0m [31m19.1 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m162.6/162.6 kB[0m [31m14.7 MB/s[0m eta [36m

In [3]:
#@ Data Processing:
from torch_snippets import *
DATA_ROOT = '../open-images-bus-trucks/'
IMAGE_ROOT=f'{DATA_ROOT}/images'
DF_RAW=pd.read_csv(f'{DATA_ROOT}/df.csv')
df=DF_RAW.copy()
df=df[df['ImageID'].isin(df['ImageID'].unique().tolist())]
label2target={l:t+1 for t,l in enumerate(DF_RAW['LabelName'].unique())}
label2target['background']=0
target2label={t:l for l, t in label2target.items()}
background_class=label2target['background']
num_classes=len(label2target)

In [5]:
import torch
device='cuda' if torch.cuda.is_available() else 'cpu'

In [7]:
#@ Preparing Data:
import collections
from PIL import Image
from torchvision import transforms
import glob

normalize=transforms.Normalize(
             mean=[0.485, 0.456, 0.406],
             std=[0.229, 0.224, 0.225]
)

denormalize=transforms.Normalize(
    mean=[-0.485/0.229, 0.456/0.224, 0.406/0.225],
    std=[1/0.229, 1/0.224, 1/0.255]
)

def preprocess_image(img):
  img=torch.tensor(img).permute(2, 0, 1)
  img=normalize(img)
  return img.to(device).float()

class OpenDataset(torch.utils.data.Dataset):
  w, h= 300, 300
  def __init__(self, df, image_dir=IMAGE_ROOT):
    self.image_dir=image_dir
    self.files=glob.glob(self.image_dir+'/')
    self.df=df
    self.image_infos=df.ImageID.unique()
    logger.info(f'{len(self)} items loaded')

  def __getitem___(self, ix):
    image_id=self.image_infos[ix]
    img_path=find(image_id, self.files)
    img=Image.open(img_path).convert("RGB")
    img=np.array(img.resize((self.w, self.h), resample=Image.BILINEAR))/255.
    data=df[df['ImageID']==image_id]
    labels=data['LabelName'].values.tolist()
    data=data[['XMin', 'YMin', 'XMax', 'YMax']].values
    data[:, [0, 2]] *= self.w
    data[:, [1, 3]] *= self.h
    boxes=data.astype(np.uint32).tolist()
    return img, boxes, labels
