In [10]:
# k-means ++ for YOLOv2 anchors
# 通过k-means ++ 算法获取YOLOv2需要的anchors的尺寸
import numpy as np
from collections import Counter
import xml.etree.ElementTree as ET
import math
import os
os.environ["CUDA_DEVICE_ORDER"]="PCI_BUS_ID"
os.environ["CUDA_VISIBLE_DEVICES"]="1"

VOC_LABELS = {
    'none': (0, 'Background'),
    'aeroplane': (1, 'Vehicle'),
    'bicycle': (2, 'Vehicle'),
    'bird': (3, 'Animal'),
    'boat': (4, 'Vehicle'),
    'bottle': (5, 'Indoor'),
    'bus': (6, 'Vehicle'),
    'car': (7, 'Vehicle'),
    'cat': (8, 'Animal'),
    'chair': (9, 'Indoor'),
    'cow': (10, 'Animal'),
    'diningtable': (11, 'Indoor'),
    'dog': (12, 'Animal'),
    'horse': (13, 'Animal'),
    'motorbike': (14, 'Vehicle'),
    'person': (15, 'Person'),
    'pottedplant': (16, 'Indoor'),
    'sheep': (17, 'Animal'),
    'sofa': (18, 'Indoor'),
    'train': (19, 'Vehicle'),
    'tvmonitor': (20, 'Indoor'),
}

# 定义Box类，描述bounding box的坐标
class Box():
    def __init__(self, x, y, w, h):
        self.x = x
        self.y = y
        self.w = w
        self.h = h


# 计算两个box在某个轴上的重叠部分
# x1是box1的中心在该轴上的坐标
# len1是box1在该轴上的长度
# x2是box2的中心在该轴上的坐标
# len2是box2在该轴上的长度
# 返回值是该轴上重叠的长度
def overlap(x1, len1, x2, len2):
    len1_half = len1 / 2
    len2_half = len2 / 2

    left = max(x1 - len1_half, x2 - len2_half)
    right = min(x1 + len1_half, x2 + len2_half)

    return right - left


# 计算box a 和box b 的交集面积
# a和b都是Box类型实例
# 返回值area是box a 和box b 的交集面积
def box_intersection(a, b):
    w = overlap(a.x, a.w, b.x, b.w)
    h = overlap(a.y, a.h, b.y, b.h)
    if w < 0 or h < 0:
        return 0

    area = w * h
    return area


# 计算 box a 和 box b 的并集面积
# a和b都是Box类型实例
# 返回值u是box a 和box b 的并集面积
def box_union(a, b):
    i = box_intersection(a, b)
    u = a.w * a.h + b.w * b.h - i
    return u


# 计算 box a 和 box b 的 iou
# a和b都是Box类型实例
# 返回值是box a 和box b 的iou
def box_iou(a, b):
    return box_intersection(a, b) / box_union(a, b)


# 使用k-means ++ 初始化 centroids，减少随机初始化的centroids对最终结果的影响
# boxes是所有bounding boxes的Box对象列表
# n_anchors是k-means的k值
# 返回值centroids 是初始化的n_anchors个centroid
def init_centroids(boxes,n_anchors):
    centroids = []
    boxes_num = len(boxes)

    centroid_index = np.random.choice(boxes_num, 1)
    centroids.append(boxes[centroid_index])

    print(centroids[0].w,centroids[0].h)

    for centroid_index in range(0,n_anchors-1):

        sum_distance = 0
        distance_thresh = 0
        distance_list = []
        cur_sum = 0

        for box in boxes:
            min_distance = 1
            for centroid_i, centroid in enumerate(centroids):
                distance = (1 - box_iou(box, centroid))
                if distance < min_distance:
                    min_distance = distance
            sum_distance += min_distance
            distance_list.append(min_distance)

        distance_thresh = sum_distance*np.random.random()

        for i in range(0,boxes_num):
            cur_sum += distance_list[i]
            if cur_sum > distance_thresh:
                centroids.append(boxes[i])
                print(boxes[i].w, boxes[i].h)
                break

    return centroids

def read_xml(filename):
    tree = ET.parse(filename)
    root = tree.getroot()
    size = root.find('size')
    shape = [int(size.find('height').text),
             int(size.find('width').text),
             int(size.find('depth').text)]

    labels = []
    labels_text = []
    gt_boxes = []
    for obj in root.findall('object'):
        label = obj.find('name').text
        labels.append(int(VOC_LABELS[label][0]))
        labels_text.append(label.encode('ascii'))
        bbox = obj.find('bndbox')
        gt_boxes.append((float(bbox.find('ymin').text) / shape[0],
                       float(bbox.find('xmin').text) / shape[1],
                       float(bbox.find('ymax').text) / shape[0],
                       float(bbox.find('xmax').text) / shape[1]
                       ))

    return labels, gt_boxes
# 进行 k-means 计算新的centroids
# boxes是所有bounding boxes的Box对象列表
# n_anchors是k-means的k值
# centroids是所有簇的中心
# 返回值new_centroids 是计算出的新簇中心
# 返回值groups是n_anchors个簇包含的boxes的列表
# 返回值loss是所有box距离所属的最近的centroid的距离的和
def do_kmeans(n_anchors, boxes, centroids):
    loss = 0
    groups = []
    new_centroids = []
    for i in range(n_anchors):
        groups.append([])
        new_centroids.append(Box(0, 0, 0, 0))

    for box in boxes:
        min_distance = 1
        group_index = 0
        for centroid_index, centroid in enumerate(centroids):
            distance = (1 - box_iou(box, centroid))
            if distance < min_distance:
                min_distance = distance
                group_index = centroid_index
        groups[group_index].append(box)
        loss += min_distance
        new_centroids[group_index].w += box.w
        new_centroids[group_index].h += box.h

    for i in range(n_anchors):
        new_centroids[i].w /= len(groups[i])
        new_centroids[i].h /= len(groups[i])

    return new_centroids, groups, loss


# 计算给定bounding boxes的n_anchors数量的centroids
# label_path是训练集列表文件地址
# n_anchors 是anchors的数量
# loss_convergence是允许的loss的最小变化值
# grid_size * grid_size 是栅格数量
# iterations_num是最大迭代次数
# plus = 1时启用k means ++ 初始化centroids

def compute_centroids(label_path,n_anchors,loss_convergence,grid_size,iterations_num,plus):

    boxes = []
    label_files = []
    files = os.listdir(label_path)
    for i in range(len(files)):
        file_path = os.path.join(label_path, files[i])
        _, get_box = read_xml(file_path)
        if len(get_box) > 0:
            for j in range(len(get_box)):
                boxes.append(Box(0, 0, float(get_box[j][2]), float(get_box[j][3])))      
    
    if plus:
        centroids = init_centroids(boxes, n_anchors)
    else:
        centroid_indices = np.random.choice(len(boxes), n_anchors)
        centroids = []
        for centroid_index in centroid_indices:
            centroids.append(boxes[centroid_index])

    # iterate k-means
    centroids, groups, old_loss = do_kmeans(n_anchors, boxes, centroids)
    iterations = 1
    total=[]
    for i in range(10):
        print('=========={}======='.format(i+1))
        while (True):
            centroids, groups, loss = do_kmeans(n_anchors, boxes, centroids)
            iterations = iterations + 1
            print("loss = %f" % loss)
            if abs(old_loss - loss) < loss_convergence or iterations > iterations_num:
                break
            old_loss = loss
        total.append(centroids)

        #for centroid in centroids:
            #print(centroid.w * grid_size, centroid.h * grid_size)

    # print result
    #for centroid in centroids:
    #    print("k-means result: \n")
    #    print(centroid.w * grid_size, centroid.h * grid_size)
    print("=======over=========")
    return centroids, total
    
path = '/nfshome/xueqin/udalearn/data/VOCdevkit/VOC2007/Annotations/000005.xml'
label_path = "/nfshome/xueqin/udalearn/data/VOCdevkit/VOC2007/Annotations/"
n_anchors = 6
loss_convergence = 1e-4
grid_size = 1
iterations_num = 1000
plus = 0
the_anchors, total_ = compute_centroids(label_path,n_anchors,loss_convergence,grid_size,iterations_num,plus)

loss = 4158.306930
loss = 4021.184227
loss = 3958.991334
loss = 3927.354150
loss = 3907.209933
loss = 3900.909612
loss = 3899.721260
loss = 3899.741615
loss = 3901.462522
loss = 3903.844605
loss = 3906.048160
loss = 3908.072426
loss = 3909.544024
loss = 3911.541824
loss = 3913.401712
loss = 3914.846055
loss = 3915.761729
loss = 3916.499804
loss = 3917.264625
loss = 3918.173495
loss = 3919.013751
loss = 3919.605457
loss = 3920.089293
loss = 3920.647802
loss = 3921.322590
loss = 3921.800744
loss = 3922.654736
loss = 3923.197516
loss = 3923.735468
loss = 3924.028980
loss = 3924.211074
loss = 3924.567076
loss = 3924.841397
loss = 3925.219423
loss = 3925.491810
loss = 3925.617980
loss = 3925.576462
loss = 3925.454677
loss = 3925.370672
loss = 3925.394629
loss = 3925.461390
loss = 3925.407152
loss = 3925.334300
loss = 3925.294471
loss = 3925.293184
loss = 3925.306431
loss = 3925.321484
loss = 3925.332383
loss = 3925.344306
loss = 3925.344306
loss = 3925.344306
loss = 3925.344306
loss = 3925.

In [11]:
arr=[]
index=1
for the_anchor in the_anchors:
    val=the_anchor.w * the_anchor.h
    add_part = (str(index), val,the_anchor.w , the_anchor.h, the_anchor.w/the_anchor.h, the_anchor.h/the_anchor.w)
    arr.append(add_part)
    index+=1

arr = sorted(arr, key=lambda x:x[1])
for va in arr:
    print(va)

('4', 0.08778434045679562, 0.6071455279124731, 0.14458533649851857, 4.199219247372945, 0.23813950667748668)
('1', 0.16385631379275817, 0.26400014809241884, 0.620667507108355, 0.42534874964274577, 2.3510119656867663)
('2', 0.2481028803324041, 0.7214684486827769, 0.34388597420355477, 2.0979874225859576, 0.4766472807388952)
('5', 0.4299693294597133, 0.5549865799615662, 0.774738245904053, 0.7163536625379123, 1.3959585220199469)
('6', 0.49720328373985917, 0.8565614195626663, 0.5804642520482837, 1.4756488733632067, 0.6776679859625837)
('3', 0.829547318163543, 0.9089480067199897, 0.9126455111079782, 0.9959485864522583, 1.0040678942696968)


In [None]:

('5', 0.08785858876957825, 0.6069205891184852, 0.1447612592896007, 4.192562237278665, 0.23851762798137646)
('2', 0.16393243378912314, 0.2641232585929535, 0.6206664065195532, 0.42554785601181505, 2.3499119684725556)
('1', 0.24843886141344285, 0.7219757514396125, 0.34410970301711413, 2.0980976273246945, 0.4766222443495683)
('3', 0.430052173806052, 0.5551009626284619, 0.774727847290535, 0.7165109200215574, 1.3956521415890122)
('6', 0.497337624091014, 0.8565951265820061, 0.5805982414066436, 1.4753663815906355, 0.677797740600454)
('4', 0.829594303550739, 0.9089977637031678, 0.9126472436753345, 0.9960012151491634, 1.004014839329526)

('1', 0.0989036053342131, 0.6444302403814556, 0.15347449442420547, 4.1989403047013285)
('3', 0.16036208808236854, 0.24611916716734425, 0.6515627772026925, 0.3777366905825866)
('5', 0.2214060979833464, 0.5622134505326127, 0.3938114567938519, 1.4276208597631381)
('6', 0.42921240305244773, 0.8761708338443062, 0.48987296366534605, 1.788567442645921)
('4', 0.46158703120076555, 0.5813498215368733, 0.7939918687521064, 0.7321861147653359)
('2', 0.8175683709617362, 0.9176844917580006, 0.8909035494274586, 1.030060428368203)

In [None]:
('1', 0.07438879940989732, 0.6151127941892652, 0.12093521726847464, 5.086299988395627, 0.19660657104014626)
('4', 0.12639558664414777, 0.3919999850238866, 0.32243773334952, 1.215738558113984, 0.8225452695613551)
('7', 0.14827410746550607, 0.21591303384861948, 0.6867306934766324, 0.31440714081897436, 3.1805893383820063)
('5', 0.2423000438643753, 0.8284194039409104, 0.29248475194052564, 2.832350741174236, 0.35306361795623514)
('2', 0.31614826511065275, 0.6037036649091843, 0.5236812089888692, 1.1528075755760332, 0.8674474571355253)
('9', 0.3696494284741175, 0.4411455437226479, 0.8379307775724, 0.5264701518670872, 1.8994429151464225)
('3', 0.5255499557270419, 0.9091088208214282, 0.5780935611780551, 1.5725980738633742, 0.635890388408856)
('8', 0.593623103597078, 0.7003314178940292, 0.8476316904104712, 0.826221371637119, 1.2103293794235157)
('6', 0.876065723118035, 0.9521819848454999, 0.9200612247039988, 1.0349115464047896, 0.9662661543142795)