In [93]:
import torch
import os
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
import xml.etree.ElementTree as ET
from PIL import Image

In [4]:
import pickle
with open('anchors_VOC0712trainval.pickle', 'rb') as handle:
    anchors = pickle.load(handle)

In [9]:
def IoU(box1, box2, midpoint=True):
    if midpoint:
        x1 = box1[0]
        y1 = box1[1]
        w1 = box1[2]
        h1 = box1[3]
    
        x2 = box2[0]
        y2 = box2[1]
        w2 = box2[2]
        h2 = box2[3]
    
        xmin1 = x1 - w1/2
        xmin2 = x2 - w2/2
        ymin1 = y1 - h1/2
        ymin2 = y2 - h2/2
    
        xmax1 = x1 + w1/2
        xmax2 = x2 + w2/2
        ymax1 = y1 + h1/2
        ymax2 = y2 + h2/2
    else:
        xmin1, ymin1, xmax1, ymax1 = box1
        xmin2, ymin2, xmax2, ymax2 = box2
    
    xmin_i = max(xmin1, xmin2)
    xmax_i = min(xmax1, xmax2)
    ymin_i = max(ymin1, ymin2)
    ymax_i = min(ymax1, ymax2)

    intersection = max(xmax_i-xmin_i, 0) * max(ymax_i-ymin_i, 0)

    area1 = (xmax1 - xmin1) * (ymax1 - ymin1)
    area2 = (xmax2 - xmin2) * (ymax2 - ymin2)

    return intersection / (area1 + area2 - intersection + 1e-6)

In [155]:
class VOCDataset(Dataset):
    def __init__(self, devkit_path, 
                 subsets = [('VOC2007', 'trainval'), ('VOC2012', 'trainval')], 
                 anchors = [], scales = [13], 
                 threshold_ignore_prediction = 0.5,
                 transforms = None,
                 dtype=None, device=None):
        super().__init__()
        self.devkit_path = devkit_path
        self.subsets = subsets
        self.anchors = anchors
        self.scales = scales
        self.threshold_ignore_prediction = threshold_ignore_prediction
        self.transforms = transforms
        self.dtype = dtype
        self.device = device

        self.object_placed = 0
        self.object_not_placed = 0

        self.all_labels = []
        for subset in self.subsets:
            subset_path = os.path.join(self.devkit_path, subset[0], 'ImageSets', 'Main', '{}.txt'.format(subset[1]))
            print(os.path.exists(subset_path), subset_path)
            with open(subset_path, 'r') as file:
                subset_labels = file.read().splitlines()
            self.all_labels.append(subset_labels)

        self.classes = ['aeroplane', 'bicycle', 'bird', 'boat', 'bottle', 'bus', 'car', 'cat', 'chair', 'cow',
                        'diningtable', 'dog', 'horse', 'motorbike', 'person', 'pottedplant', 'sheep', 'sofa', 'train',
                        'tvmonitor']

    def __getitem__(self, idx):
        # get paths
        subset_idx = 0
        for subset_labels in self.all_labels:
            if idx < len(subset_labels):
                break
            else:
                subset_idx += 1
                idx -= len(subset_labels)

        if idx < 0 or subset_idx >= len(self.subsets):
            raise Exception("Index out of range.")

        # print(subset_idx, idx)
        image_path = os.path.join(self.devkit_path, self.subsets[subset_idx][0], 'JPEGImages', '{}.jpg'.format(self.all_labels[subset_idx][idx]))
        annotation_path = os.path.join(self.devkit_path, self.subsets[subset_idx][0], 'Annotations', '{}.xml'.format(self.all_labels[subset_idx][idx]))

        # print(os.path.exists(image_path), image_path)
        # print(os.path.exists(annotation_path), annotation_path)

        # get PIL image
        PIL_img = Image.open(image_path)

        # initialize tensors
        gt_out = [torch.zeros(len(self.anchors)*(5+len(self.classes)), scale, scale) for scale in self.scales]
        
        # parse annotations
        tree = ET.parse(annotation_path)
        root = tree.getroot()

        img_w = int(root.find("./size/width").text)
        img_h = int(root.find("./size/height").text)
        img_d = int(root.find("./size/depth").text)

        bboxes = []
        for item in root.findall('./object'):
            label = item.find("name").text
            bndbox = item.find("bndbox")
            xmin = int(bndbox.find("xmin").text)
            ymin = int(bndbox.find("ymin").text)
            xmax = int(bndbox.find("xmax").text)
            ymax = int(bndbox.find("ymax").text)

            bboxes.append([x_min, y_min, x_max, y_max, label])

        if self.transform:
            transformed = self.transform(image=PIL_img, bboxes=bboxes)
            image = transformed['image']
            bboxes = transformed['bboxes']
        else:
            return PIL_img, bboxes

        for box in bboxes:
            x_min, y_min, x_max, y_max, label = box
        
            obj_w = xmax - xmin
            obj_h = ymax - ymin

            obj_xc = xmax - obj_w / 2
            obj_yc = ymax - obj_h / 2

            for scale_idx, scale in enumerate(self.scales):
                cell_w = img_w / scale
                cell_h = img_h / scale

                cell_x = int(obj_xc / cell_w)
                cell_y = int(obj_yc / cell_h)
                obj_xc = (obj_xc % cell_w) / cell_w
                obj_yc = (obj_yc % cell_h) / cell_h

                obj_w = obj_w / cell_w
                obj_h = obj_h / cell_h

                bndbox = torch.tensor([obj_xc, obj_yc, obj_w, obj_h])
                
                IoUs = torch.empty(len(self.anchors))
                for i, anchor in enumerate(self.anchors):
                    cell_aw = anchor[0] * scale
                    cell_ah = anchor[1] * scale
                    _anchor = torch.tensor([obj_xc, obj_yc, cell_aw, cell_ah])
                    IoUs[i] = IoU(bndbox, _anchor)

                anchors_argsort = torch.argsort(IoUs, descending=True)
                best_anchor = anchors_argsort[0]
                
                placement_0 = best_anchor*(5+len(self.classes))
                _objectness = (placement_0, cell_x, cell_y)
                taken = gt_out[scale_idx][_objectness] == 1
                if taken:
                    # =========== TEST ==========
                    self.object_not_placed += 1
                    # =========== TEST ==========
                    continue
                else:
                    gt_out[scale_idx][_objectness] = 1
                    gt_out[scale_idx][placement_0+1:placement_0+5, cell_x, cell_y] = bndbox
                    
                    label_placement = placement_0 + 1 + 4 + self.classes.index(label)
                    gt_out[scale_idx][label_placement] = 1

                    # =========== TEST ==========
                    self.object_placed += 1
                    # =========== TEST ==========

                # not the best anchors
                for anchor_idx in anchors_argsort[1:]:
                    if IoUs[anchor_idx] > self.threshold_ignore_prediction:
                        placement_0 = anchor_idx*(5+len(self.classes))
                        _objectness = (placement_0, cell_x, cell_y)
                        gt_out[scale_idx][_objectness] = -1
                        
                    
        return image, gt_out if len(self.scales) > 1 else image, gt_out[0] # that all basically is *just* the label. I also need to write transforms for the input image
        
    def __len__(self):
        summed_len = 0
        for _subset in self.all_labels:
            summed_len += len(_subset)
        return summed_len

In [157]:
train_set = VOCDataset(devkit_path = '../../datasets/VOCdevkit/', scales=[13], anchors=anchors)

True ../../datasets/VOCdevkit/VOC2007\ImageSets\Main\trainval.txt
True ../../datasets/VOCdevkit/VOC2012\ImageSets\Main\trainval.txt


In [145]:
train_set.object_placed

0

In [146]:
train_loader = DataLoader(train_set, batch_size=64, shuffle=False)

In [147]:
for img in train_loader:
    del img

0 0
True ../../datasets/VOCdevkit/VOC2007\JPEGImages\000005.jpg
True ../../datasets/VOCdevkit/VOC2007\Annotations\000005.xml
0 1
True ../../datasets/VOCdevkit/VOC2007\JPEGImages\000007.jpg
True ../../datasets/VOCdevkit/VOC2007\Annotations\000007.xml
0 2
True ../../datasets/VOCdevkit/VOC2007\JPEGImages\000009.jpg
True ../../datasets/VOCdevkit/VOC2007\Annotations\000009.xml
0 3
True ../../datasets/VOCdevkit/VOC2007\JPEGImages\000012.jpg
True ../../datasets/VOCdevkit/VOC2007\Annotations\000012.xml
0 4
True ../../datasets/VOCdevkit/VOC2007\JPEGImages\000016.jpg
True ../../datasets/VOCdevkit/VOC2007\Annotations\000016.xml
0 5
True ../../datasets/VOCdevkit/VOC2007\JPEGImages\000017.jpg
True ../../datasets/VOCdevkit/VOC2007\Annotations\000017.xml
0 6
True ../../datasets/VOCdevkit/VOC2007\JPEGImages\000019.jpg
True ../../datasets/VOCdevkit/VOC2007\Annotations\000019.xml
0 7
True ../../datasets/VOCdevkit/VOC2007\JPEGImages\000020.jpg
True ../../datasets/VOCdevkit/VOC2007\Annotations\000020.xml


In [148]:
train_set.object_placed, train_set.object_not_placed, train_set.object_placed + train_set.object_not_placed

(45653, 1570, 47223)