-
-
Notifications
You must be signed in to change notification settings - Fork 16.4k
/
autoanchor.py
175 lines (148 loc) Β· 7.75 KB
/
autoanchor.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
# Ultralytics YOLOv5 π, AGPL-3.0 license
"""AutoAnchor utils."""
import random
import numpy as np
import torch
import yaml
from tqdm import tqdm
from utils import TryExcept
from utils.general import LOGGER, TQDM_BAR_FORMAT, colorstr
PREFIX = colorstr("AutoAnchor: ")
def check_anchor_order(m):
"""Checks and corrects anchor order against stride in YOLOv5 Detect() module if necessary."""
a = m.anchors.prod(-1).mean(-1).view(-1) # mean anchor area per output layer
da = a[-1] - a[0] # delta a
ds = m.stride[-1] - m.stride[0] # delta s
if da and (da.sign() != ds.sign()): # same order
LOGGER.info(f"{PREFIX}Reversing anchor order")
m.anchors[:] = m.anchors.flip(0)
@TryExcept(f"{PREFIX}ERROR")
def check_anchors(dataset, model, thr=4.0, imgsz=640):
"""Evaluates anchor fit to dataset and adjusts if necessary, supporting customizable threshold and image size."""
m = model.module.model[-1] if hasattr(model, "module") else model.model[-1] # Detect()
shapes = imgsz * dataset.shapes / dataset.shapes.max(1, keepdims=True)
scale = np.random.uniform(0.9, 1.1, size=(shapes.shape[0], 1)) # augment scale
wh = torch.tensor(np.concatenate([l[:, 3:5] * s for s, l in zip(shapes * scale, dataset.labels)])).float() # wh
def metric(k): # compute metric
"""Computes ratio metric, anchors above threshold, and best possible recall for YOLOv5 anchor evaluation."""
r = wh[:, None] / k[None]
x = torch.min(r, 1 / r).min(2)[0] # ratio metric
best = x.max(1)[0] # best_x
aat = (x > 1 / thr).float().sum(1).mean() # anchors above threshold
bpr = (best > 1 / thr).float().mean() # best possible recall
return bpr, aat
stride = m.stride.to(m.anchors.device).view(-1, 1, 1) # model strides
anchors = m.anchors.clone() * stride # current anchors
bpr, aat = metric(anchors.cpu().view(-1, 2))
s = f"\n{PREFIX}{aat:.2f} anchors/target, {bpr:.3f} Best Possible Recall (BPR). "
if bpr > 0.98: # threshold to recompute
LOGGER.info(f"{s}Current anchors are a good fit to dataset β
")
else:
LOGGER.info(f"{s}Anchors are a poor fit to dataset β οΈ, attempting to improve...")
na = m.anchors.numel() // 2 # number of anchors
anchors = kmean_anchors(dataset, n=na, img_size=imgsz, thr=thr, gen=1000, verbose=False)
new_bpr = metric(anchors)[0]
if new_bpr > bpr: # replace anchors
anchors = torch.tensor(anchors, device=m.anchors.device).type_as(m.anchors)
m.anchors[:] = anchors.clone().view_as(m.anchors)
check_anchor_order(m) # must be in pixel-space (not grid-space)
m.anchors /= stride
s = f"{PREFIX}Done β
(optional: update model *.yaml to use these anchors in the future)"
else:
s = f"{PREFIX}Done β οΈ (original anchors better than new anchors, proceeding with original anchors)"
LOGGER.info(s)
def kmean_anchors(dataset="./data/coco128.yaml", n=9, img_size=640, thr=4.0, gen=1000, verbose=True):
"""
Creates kmeans-evolved anchors from training dataset.
Arguments:
dataset: path to data.yaml, or a loaded dataset
n: number of anchors
img_size: image size used for training
thr: anchor-label wh ratio threshold hyperparameter hyp['anchor_t'] used for training, default=4.0
gen: generations to evolve anchors using genetic algorithm
verbose: print all results
Return:
k: kmeans evolved anchors
Usage:
from utils.autoanchor import *; _ = kmean_anchors()
"""
from scipy.cluster.vq import kmeans
npr = np.random
thr = 1 / thr
def metric(k, wh): # compute metrics
"""Computes ratio metric, anchors above threshold, and best possible recall for YOLOv5 anchor evaluation."""
r = wh[:, None] / k[None]
x = torch.min(r, 1 / r).min(2)[0] # ratio metric
# x = wh_iou(wh, torch.tensor(k)) # iou metric
return x, x.max(1)[0] # x, best_x
def anchor_fitness(k): # mutation fitness
"""Evaluates fitness of YOLOv5 anchors by computing recall and ratio metrics for an anchor evolution process."""
_, best = metric(torch.tensor(k, dtype=torch.float32), wh)
return (best * (best > thr).float()).mean() # fitness
def print_results(k, verbose=True):
"""Sorts and logs kmeans-evolved anchor metrics and best possible recall values for YOLOv5 anchor evaluation."""
k = k[np.argsort(k.prod(1))] # sort small to large
x, best = metric(k, wh0)
bpr, aat = (best > thr).float().mean(), (x > thr).float().mean() * n # best possible recall, anch > thr
s = (
f"{PREFIX}thr={thr:.2f}: {bpr:.4f} best possible recall, {aat:.2f} anchors past thr\n"
f"{PREFIX}n={n}, img_size={img_size}, metric_all={x.mean():.3f}/{best.mean():.3f}-mean/best, "
f"past_thr={x[x > thr].mean():.3f}-mean: "
)
for x in k:
s += "%i,%i, " % (round(x[0]), round(x[1]))
if verbose:
LOGGER.info(s[:-2])
return k
if isinstance(dataset, str): # *.yaml file
with open(dataset, errors="ignore") as f:
data_dict = yaml.safe_load(f) # model dict
from utils.dataloaders import LoadImagesAndLabels
dataset = LoadImagesAndLabels(data_dict["train"], augment=True, rect=True)
# Get label wh
shapes = img_size * dataset.shapes / dataset.shapes.max(1, keepdims=True)
wh0 = np.concatenate([l[:, 3:5] * s for s, l in zip(shapes, dataset.labels)]) # wh
# Filter
i = (wh0 < 3.0).any(1).sum()
if i:
LOGGER.info(f"{PREFIX}WARNING β οΈ Extremely small objects found: {i} of {len(wh0)} labels are <3 pixels in size")
wh = wh0[(wh0 >= 2.0).any(1)].astype(np.float32) # filter > 2 pixels
# wh = wh * (npr.rand(wh.shape[0], 1) * 0.9 + 0.1) # multiply by random scale 0-1
# Kmeans init
try:
LOGGER.info(f"{PREFIX}Running kmeans for {n} anchors on {len(wh)} points...")
assert n <= len(wh) # apply overdetermined constraint
s = wh.std(0) # sigmas for whitening
k = kmeans(wh / s, n, iter=30)[0] * s # points
assert n == len(k) # kmeans may return fewer points than requested if wh is insufficient or too similar
except Exception:
LOGGER.warning(f"{PREFIX}WARNING β οΈ switching strategies from kmeans to random init")
k = np.sort(npr.rand(n * 2)).reshape(n, 2) * img_size # random init
wh, wh0 = (torch.tensor(x, dtype=torch.float32) for x in (wh, wh0))
k = print_results(k, verbose=False)
# Plot
# k, d = [None] * 20, [None] * 20
# for i in tqdm(range(1, 21)):
# k[i-1], d[i-1] = kmeans(wh / s, i) # points, mean distance
# fig, ax = plt.subplots(1, 2, figsize=(14, 7), tight_layout=True)
# ax = ax.ravel()
# ax[0].plot(np.arange(1, 21), np.array(d) ** 2, marker='.')
# fig, ax = plt.subplots(1, 2, figsize=(14, 7)) # plot wh
# ax[0].hist(wh[wh[:, 0]<100, 0],400)
# ax[1].hist(wh[wh[:, 1]<100, 1],400)
# fig.savefig('wh.png', dpi=200)
# Evolve
f, sh, mp, s = anchor_fitness(k), k.shape, 0.9, 0.1 # fitness, generations, mutation prob, sigma
pbar = tqdm(range(gen), bar_format=TQDM_BAR_FORMAT) # progress bar
for _ in pbar:
v = np.ones(sh)
while (v == 1).all(): # mutate until a change occurs (prevent duplicates)
v = ((npr.random(sh) < mp) * random.random() * npr.randn(*sh) * s + 1).clip(0.3, 3.0)
kg = (k.copy() * v).clip(min=2.0)
fg = anchor_fitness(kg)
if fg > f:
f, k = fg, kg.copy()
pbar.desc = f"{PREFIX}Evolving anchors with Genetic Algorithm: fitness = {f:.4f}"
if verbose:
print_results(k, verbose)
return print_results(k).astype(np.float32)