In [1]:
import torch
import os
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
import xml.etree.ElementTree as ET
import pandas as pd

I'll get anchor boxes' width and height relatively to the size of the image to make them more universal.
I need 5 anchor boxes, not 15, cause the best anchor for each object each the same no matter what the grid size is

In [2]:
class VOCDataset(Dataset):
    def __init__(self, devkit_path, 
                 subsets = [('VOC2007', 'trainval'), ('VOC2012', 'trainval') ]):
        super().__init__()
        self.devkit_path = devkit_path
        self.subsets = subsets

        self.all_labels = []
        for subset in self.subsets:
            subset_path = os.path.join(self.devkit_path, subset[0], 'ImageSets', 'Main', '{}.txt'.format(subset[1]))
            print(os.path.exists(subset_path), subset_path)
            with open(subset_path, 'r') as file:
                subset_labels = file.read().splitlines()
            self.all_labels.append(subset_labels)

    def __getitem__(self, idx):
        # get paths
        subset_idx = 0
        for subset_labels in self.all_labels:
            if idx < len(subset_labels):
                break
            else:
                subset_idx += 1
                idx -= len(subset_labels)

        if idx < 0 or subset_idx >= len(self.subsets):
            raise Exception("Index out of range.")

        # print(subset_idx, idx)
        annotation_path = os.path.join(self.devkit_path, self.subsets[subset_idx][0], 'Annotations', '{}.xml'.format(self.all_labels[subset_idx][idx]))

        # print(os.path.exists(annotation_path), annotation_path)
        
        # parse annotations
        tree = ET.parse(annotation_path)
        root = tree.getroot()

        img_w = int(root.find("./size/width").text)
        img_h = int(root.find("./size/height").text)
        img_d = int(root.find("./size/depth").text)

        boxes = []
        for i, item in enumerate(root.findall('./object')):
            label = item.find("name").text
            bndbox = item.find("bndbox")
            xmin = int(bndbox.find("xmin").text)
            ymin = int(bndbox.find("ymin").text)
            xmax = int(bndbox.find("xmax").text)
            ymax = int(bndbox.find("ymax").text)
        
            obj_w = (xmax - xmin) / img_w
            obj_h = (ymax - ymin) / img_h

            boxes.append((obj_w, obj_h))

        return boxes
        
    def __len__(self):
        summed_len = 0
        for _subset in self.all_labels:
            summed_len += len(_subset)
        return summed_len

In [3]:
train_set = VOCDataset(devkit_path = '../../datasets/VOCdevkit/')

True ../../datasets/VOCdevkit/VOC2007\ImageSets\Main\trainval.txt
True ../../datasets/VOCdevkit/VOC2012\ImageSets\Main\trainval.txt


In [4]:
boxes = []
for idx in range(len(train_set)):
    boxes += train_set[idx]

In [5]:
df = pd.DataFrame(boxes, columns = ["box_w", "box_h"])

In [6]:
df

Unnamed: 0,box_w,box_h
0,0.122000,0.341333
1,0.176000,0.288000
2,0.124000,0.346667
3,0.108000,0.280000
4,0.070000,0.090667
...,...,...
47218,0.868263,0.228000
47219,0.802000,0.808000
47220,0.766000,0.272000
47221,0.812000,0.575682


In [26]:
def distance_f(box, centroid):
    iou = IoU([0, 0] + list(box), [0, 0] + list(centroid), midpoint=True)
    return 1 - iou

In [27]:
def IoU(box1, box2, midpoint=True):
    if midpoint:
        x1 = box1[0]
        y1 = box1[1]
        w1 = box1[2]
        h1 = box1[3]
    
        x2 = box2[0]
        y2 = box2[1]
        w2 = box2[2]
        h2 = box2[3]
    
        xmin1 = x1 - w1/2
        xmin2 = x2 - w2/2
        ymin1 = y1 - h1/2
        ymin2 = y2 - h2/2
    
        xmax1 = x1 + w1/2
        xmax2 = x2 + w2/2
        ymax1 = y1 + h1/2
        ymax2 = y2 + h2/2
    else:
        xmin1, ymin1, xmax1, ymax1 = box1
        xmin2, ymin2, xmax2, ymax2 = box2
    
    xmin_i = max(xmin1, xmin2)
    xmax_i = min(xmax1, xmax2)
    ymin_i = max(ymin1, ymin2)
    ymax_i = min(ymax1, ymax2)

    intersection = max(xmax_i-xmin_i, 0) * max(ymax_i-ymin_i, 0)

    area1 = (xmax1 - xmin1) * (ymax1 - ymin1)
    area2 = (xmax2 - xmin2) * (ymax2 - ymin2)

    return intersection / (area1 + area2 - intersection + 1e-6)

In [44]:
from pyclustering.cluster.kmeans import kmeans
from pyclustering.utils.metric import type_metric, distance_metric

metric = distance_metric(type_metric.USER_DEFINED, func=distance_f)

# create K-Means algorithm with specific distance metric
start_centers = torch.rand(5, 2).tolist()
kmeans_instance = kmeans(df, start_centers, metric=metric, tolerance=0.00001)

# run cluster analysis and obtain results
kmeans_instance.process()
centers = kmeans_instance.get_centers()

In [43]:
start_centers, centers

([[0.38347065448760986, 0.9537832736968994],
  [0.8005502820014954, 0.17438530921936035],
  [0.9535838961601257, 0.8640784025192261],
  [0.32188349962234497, 0.49408531188964844],
  [0.568794846534729, 0.9568654298782349]],
 [[0.32882816561531697, 0.6554346360225188],
  [0.07522471620262354, 0.1241032298491792],
  [0.7862514755140577, 0.8175396064741282],
  [0.1944210942223786, 0.3173615394924991],
  [0.6152915161928911, 0.3970012983608503]])

In [46]:
import pickle
with open("anchors_VOC0712trainval.pickle", "wb") as file:
    pickle.dump(centers, file, protocol=pickle.HIGHEST_PROTOCOL)