In [14]:
import torch
import os
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
import xml.etree.ElementTree as ET
import pandas as pd
from sklearn.cluster import KMeans

I'll get anchor boxes' width and height relatively to the size of the image to make them more universal.
I need 5 anchor boxes, not 15, cause the best anchor for each object each the same no matter what the grid size is

In [25]:
class VOCDataset(Dataset):
    def __init__(self, devkit_path, 
                 subsets = [('VOC2007', 'trainval'), ('VOC2012', 'trainval') ]):
        super().__init__()
        self.devkit_path = devkit_path
        self.subsets = subsets

        self.all_labels = []
        for subset in self.subsets:
            subset_path = os.path.join(self.devkit_path, subset[0], 'ImageSets', 'Main', '{}.txt'.format(subset[1]))
            print(os.path.exists(subset_path), subset_path)
            with open(subset_path, 'r') as file:
                subset_labels = file.read().splitlines()
            self.all_labels.append(subset_labels)

    def __getitem__(self, idx):
        # get paths
        subset_idx = 0
        for subset_labels in self.all_labels:
            if idx < len(subset_labels):
                break
            else:
                subset_idx += 1
                idx -= len(subset_labels)

        if idx < 0 or subset_idx >= len(self.subsets):
            raise Exception("Index out of range.")

        # print(subset_idx, idx)
        annotation_path = os.path.join(self.devkit_path, self.subsets[subset_idx][0], 'Annotations', '{}.xml'.format(self.all_labels[subset_idx][idx]))

        # print(os.path.exists(annotation_path), annotation_path)
        
        # parse annotations
        tree = ET.parse(annotation_path)
        root = tree.getroot()

        img_w = int(root.find("./size/width").text)
        img_h = int(root.find("./size/height").text)
        img_d = int(root.find("./size/depth").text)

        boxes = []
        for i, item in enumerate(root.findall('./object')):
            label = item.find("name").text
            bndbox = item.find("bndbox")
            xmin = int(bndbox.find("xmin").text)
            ymin = int(bndbox.find("ymin").text)
            xmax = int(bndbox.find("xmax").text)
            ymax = int(bndbox.find("ymax").text)
        
            obj_w = (xmax - xmin) / img_w
            obj_h = (ymax - ymin) / img_h

            boxes.append((obj_w, obj_h))

        return boxes
        
    def __len__(self):
        summed_len = 0
        for _subset in self.all_labels:
            summed_len += len(_subset)
        return summed_len

In [27]:
train_set = VOCDataset(devkit_path = '../../datasets/VOCdevkit/')

True ../../datasets/VOCdevkit/VOC2007\ImageSets\Main\trainval.txt
True ../../datasets/VOCdevkit/VOC2012\ImageSets\Main\trainval.txt


In [59]:
boxes = []
for idx in range(len(train_set)):
    boxes += train_set[idx]

In [29]:
df = pd.DataFrame(boxes, columns = ["box_w", "box_h"])

In [30]:
df

Unnamed: 0,box_w,box_h
0,0.122000,0.341333
1,0.176000,0.288000
2,0.124000,0.346667
3,0.108000,0.280000
4,0.070000,0.090667
...,...,...
47218,0.868263,0.228000
47219,0.802000,0.808000
47220,0.766000,0.272000
47221,0.812000,0.575682


In [56]:
kmeans = KMeans(n_clusters=5, max_iter=300, tol=0.0000000001, verbose=0, random_state=None)

In [57]:
kmeans.fit(df)

In [58]:
kmeans.cluster_centers_, kmeans.n_iter_

(array([[0.72248613, 0.45577772],
        [0.3748161 , 0.74453631],
        [0.84177217, 0.86174343],
        [0.09752187, 0.14036605],
        [0.24777899, 0.38633005]]),
 40)

In [63]:
def IoU_for_anchors(box1, box2):
    """
    Considers only width and height of boxes, (x,y) are discarded
    """
    w1 = box1[0]
    h1 = box1[1]
    
    w2 = box2[0]
    h2 = box2[1]
    intersection_w = min(w1, w2)
    intersection_h = min(h1, h2)
    intersection = intersection_w * intersection_h
    area_1 = w1 * h1
    area_2 = w2 * h2
    return intersection / (area_1+area_2-intersection)

In [69]:
IoU_for_anchors([3.0, 3.5], [3.0, 3.5])

1.0

In [70]:
def IoU(box1, box2):
    x1 = box1[0]
    y1 = box1[1]
    w1 = box1[2]
    h1 = box1[3]

    x2 = box2[0]
    y2 = box2[1]
    w2 = box2[2]
    h2 = box2[3]