https://medium.com/codex/implementing-r-cnn-object-detection-on-voc2012-with-pytorch-b05d3c623afe

In [29]:
import urllib
import tarfile
from pathlib import Path
import os
import random
import xml.etree.ElementTree as ET
import cv2
import matplotlib.pyplot as plt
import numpy as np

Download data

In [4]:
class VOCDataset:
    def __init__(self):
        self.root = Path("data")
        self.root.mkdir(parents=True, exist_ok=True)
        self.train_dir = None
        self.test_dir = None
        self.train_data_link = None
        self.test_data_link = None


    def common_init(self):
        # init for shared subclasses
        self.label_type = ['none','aeroplane',"Bicycle",'bird',"Boat","Bottle","Bus","Car","Cat","Chair",'cow',"Diningtable","Dog","Horse","Motorbike",'person', "Pottedplant",'sheep',"Sofa","Train","TVmonitor"]
        self.convert_id = ['background','Aeroplane',"Bicycle",'Bird',"Boat","Bottle","Bus","Car","Cat","Chair",'Cow',"Dining table","Dog","Horse","Motorbike",'Person', "Potted plant",'Sheep',"Sofa","Train","TV/monitor"]
        self.convert_labels = {}
        for i, x in enumerate(self.label_type):
            self.convert_labels[x.lower()] = i

        self.num_classes = len(self.label_type)     # 20 class + 1 background
    

    def download_dataset(self, validation_size=5000):
        # download voc train dataset
        if os.path.exists(self.root / "voctrain.tar"):
            print("[*] Dataset already exists!")
        else:
            print("[*] Downloading dataset...")
            print(self.train_data_link)
            urllib.request.urlretrieve(self.train_data_link, self.root / "voctrain.tar")

        if os.path.exists(self.root / "VOCtrain"):
            print("[*] Dataset is already extracted!")
        else:
            print("[*] Extracting dataset...")
            tar = tarfile.open(self.root / "voctrain.tar")
            tar.extractall(self.root / "VOCtrain")
            tar.close()

        # download test dataset
        if os.path.exists(self.root / "VOCtest"):
            print("[*] Test dataset already exist!")
        else:
            if self.test_data_link is None:
                # move 5k images to validation set
                print("[*] Moving validation data...")
                test_annotation_dir = self.test_dir / "Annotations"
                test_img_dir = self.test_dir / "JPEGImages"
                test_annotation_dir.mkdir(parents=True, exist_ok=True)
                test_img_dir.mkdir(parents=True, exist_ok=True)

                random.seed(731)
                val_img_paths = random.sample(sorted(os.listdir(self.train_dir / "JPEGImages")), validation_size)

                for path in val_img_paths:
                    img_name = str(path).split("/")[-1].split(".")[0]
                    # move image
                    os.rename(self.train_dir / "JPEGImages" / f"{img_name}.jpg", test_img_dir / f"{img_name}.jpg")
                    # move annotation
                    os.rename(self.train_dir / "Annotations" / f"{img_name}.xml", test_annotation_dir / f"{img_name}.xml")
            else:
                # load from val data
                print("[*] Downloading validation dataset...")
                urllib.request.urlretrieve(self.test_data_link, "voctset.tar")

                print("[*] Extracting validation dataset...")
                tar = tarfile.open("voctest.tar", "r:")
                tar.extractall("/content/VOCtest")
                tar.close()
                # os.remove("/content/voctset.tar")


    def read_xml(self, xml_path):
        object_list = []

        tree = ET.parse(open(xml_path, 'r'))
        root = tree.getroot()

        objects = root.findall("object")
        for _object in objects:
            name = _object.find("name").text
            bndbox = _object.find("bndbox")
            xmin = int(bndbox.find("xmin").text)
            ymin = int(bndbox.find("ymin").text)
            xmax = int(bndbox.find("xmax").text)
            ymax = int(bndbox.find("ymax").text)
            class_name = _object.find("name").text
            object_list.append({'x1':xmin, 'x2':xmax, 'y1':ymin, 'y2':ymax, 'class':self.convert_labels[class_name]})
        
        return object_list


In [5]:
class VOC2007(VOCDataset):
    def __init__(self):
        super().__init__()
        self.train_dir = self.root / 'VOCtrain/VOCdevkit/VOC2007'
        self.test_dir = self.root / 'VOCtest/VOCdevkit/VOC2007'
        self.train_data_link = 'http://host.robots.ox.ac.uk/pascal/VOC/voc2007/VOCtrainval_06-Nov-2007.tar'
        self.test_data_link = 'http://host.robots.ox.ac.uk/pascal/VOC/voc2007/VOCtest_06-Nov-2007.tar'
        self.common_init()  #mandatory
    
class VOC2012(VOCDataset):
    def __init__(self):
        super().__init__()
        self.train_dir = self.root / 'VOCtrain/VOCdevkit/VOC2012'
        self.test_dir = self.root / 'VOCtest/VOCdevkit/VOC2012'
        # original site goes down frequently, so we use a link to the clone alternatively
        # self.train_data_link = 'http://host.robots.ox.ac.uk/pascal/VOC/voc2012/VOCtrainval_11-May-2012.tar' 
        self.train_data_link = 'http://pjreddie.com/media/files/VOCtrainval_11-May-2012.tar'
        self.test_data_link = None
        self.common_init()  #mandatory

In [6]:
voc_dataset = VOC2012()
voc_dataset.download_dataset()

[*] Dataset already exists!
[*] Dataset is already extracted!
[*] Test dataset already exist!


In [7]:
train_data_num = len(os.listdir(voc_dataset.train_dir / "Annotations"))
valid_data_num = len(os.listdir(voc_dataset.test_dir / "Annotations"))
print("train data num:", train_data_num)
print("valid data num:", valid_data_num)

train data num: 12125
valid data num: 5000


Show some data

In [36]:

def plot_img(img_path):
    print("plotting:", img_path)
    img_bgr = cv2.imread(img_path)
    img_rgb = cv2.cvtColor(img_bgr, cv2.COLOR_BGR2RGB)
    print("img shape:", img_rgb.shape)
    plt.imshow(img_rgb)
    plt.show()
    
    
def plot_bounding_box(img_path):
    # img
    print("plotting:", img_path)
    img_bgr = cv2.imread(img_path)
    img_rgb = cv2.cvtColor(img_bgr, cv2.COLOR_BGR2RGB)
    
    xml_path = img_path.replace("JPEGImages", "Annotations").replace(".jpg", ".xml")
    tree = ET.parse(open(xml_path, 'r'))
    root = tree.getroot()
    
    # print image shape
    w, h = root.find("size").find("width").text, root.find("size").find("height").text
    print("width, height:", w, h)

    # bounding box settings
    box_img = img_bgr.copy()
    bbox_color = (0, 69, 255)    # bgr
    bbox_thickness = 2
    font = cv2.FONT_HERSHEY_SIMPLEX
    font_scale = 0.5
    font_color = bbox_color
    line_type = 1

    # plot box
    objects = root.findall("object")
    for _object in objects:
        name = _object.find("name").text
        bndbox = _object.find("bndbox")
        xmin = int(bndbox.find("xmin").text)
        ymin = int(bndbox.find("ymin").text)
        xmax = int(bndbox.find("xmax").text)
        ymax = int(bndbox.find("ymax").text)
        class_name = _object.find('name').text

        cv2.rectangle(box_img, (xmin, ymin), (xmax, ymax), bbox_color, bbox_thickness)
        cv2.putText(box_img, class_name, (xmin, ymin-5), font, font_scale, font_color, line_type)

    box_img_rgb = cv2.cvtColor(box_img, cv2.COLOR_BGR2RGB)
    result = np.hstack((img_rgb, box_img_rgb))
    # plt.imshow(result)
    plt.imshow(img_rgb)
    plt.show()
    plt.imshow(box_img_rgb)
    plt.show()

In [None]:
img_name = random.choice(os.listdir(voc_dataset.train_dir / "JPEGImages"))
img_path = str(voc_dataset.train_dir / "JPEGImages" / img_name)
plot_bounding_box(img_path)