-
Notifications
You must be signed in to change notification settings - Fork 22
/
regnet.py
40 lines (36 loc) · 2.39 KB
/
regnet.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
"""
@author: Signatrix GmbH
Implementation of paradigm described in paper: Designing Network Design Spaces published by Facebook AI Research (FAIR)
"""
import numpy as np
from src.anynet import AnyNetXe
class RegNetX(AnyNetXe):
def __init__(self, initial_width, slope, quantized_param, network_depth, bottleneck_ratio, group_width, stride,
se_ratio=None):
# We need to derive block width and number of blocks from initial parameters.
parameterized_width = initial_width + slope * np.arange(network_depth) # From equation 2
parameterized_block = np.log(parameterized_width / initial_width) / np.log(quantized_param) # From equation 3
parameterized_block = np.round(parameterized_block)
quantized_width = initial_width * np.power(quantized_param, parameterized_block)
# We need to convert quantized_width to make sure that it is divisible by 8
quantized_width = 8 * np.round(quantized_width / 8)
ls_block_width, ls_num_blocks = np.unique(quantized_width.astype(np.int), return_counts=True)
# At this points, for each stage, the above-calculated block width could be incompatible to group width
# due to bottleneck ratio. Hence, we need to adjust the formers.
# Group width could be swapped to number of groups, since their multiplication is block width
ls_group_width = np.array([min(group_width, block_width // bottleneck_ratio) for block_width in ls_block_width])
ls_block_width = np.round(ls_block_width // bottleneck_ratio / group_width) * group_width
ls_group_width = ls_group_width.astype(np.int) * bottleneck_ratio
ls_bottleneck_ratio = [bottleneck_ratio for _ in range(len(ls_block_width))]
# print (ls_num_blocks)
# print (ls_block_width)
# print (ls_bottleneck_ratio)
# print (ls_group_width)
super(RegNetX, self).__init__(ls_num_blocks, ls_block_width.astype(np.int).tolist(), ls_bottleneck_ratio,
ls_group_width.tolist(), stride, se_ratio)
class RegNetY(RegNetX):
# RegNetY = RegNetX + SE
def __init__(self, initial_width, slope, quantized_param, network_depth, bottleneck_ratio, group_width, stride,
se_ratio):
super(RegNetY, self).__init__(initial_width, slope, quantized_param, network_depth, bottleneck_ratio,
group_width, stride, se_ratio)