-
Notifications
You must be signed in to change notification settings - Fork 20
/
Copy pathsettings.py
115 lines (94 loc) · 2.45 KB
/
settings.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
default_feature_settings = dict(
feature='mels',
samplerate=16000,
n_mels=32,
fmin=0,
fmax=8000,
n_fft=512,
hop_length=256,
augmentations=5,
)
default_training_settings = dict(
epochs=50,
batch=50,
train_samples=36000,
val_samples=3000,
augment=0,
learning_rate=0.01,
nesterov_momentum=0.9,
)
default_model_settings = dict(
model='sbcnn',
frames=72,
conv_block='conv',
conv_size='5x5',
downsample_size='4x2',
filters=24,
n_stages=3,
n_blocks_per_stage=1,
dropout=0.5,
voting='mean',
voting_overlap=0.0,
normalize='meanstd',
fully_connected=64,
)
names = set().union(*[
default_feature_settings.keys(),
default_training_settings.keys(),
default_model_settings.keys(),
])
def populate_defaults():
s = {}
for n in names:
v = default_model_settings.get(n, None)
if v is None:
v = default_training_settings.get(n, None)
if v is None:
v = default_feature_settings.get(n, None)
s[n] = v
return s
defaults = populate_defaults()
def test_no_overlapping_settings():
f = default_feature_settings.keys()
t = default_training_settings.keys()
m = default_model_settings.keys()
assert len(names) == len(f) + len(t) + len(m)
test_no_overlapping_settings()
def parse_dimensions(s):
pieces = s.split('x')
return tuple( int(d) for d in pieces )
# Functions that convert string representation to actual setting data
parsers = {
'pool': parse_dimensions,
'kernel': parse_dimensions,
'conv_size': parse_dimensions,
'downsample_size': parse_dimensions,
}
def test_parse_dimensions():
valid_examples = [
('3x3', (3,3)),
('4x2', (4,2))
]
for inp, expect in valid_examples:
out = parse_dimensions(inp)
assert out == expect, (out, '!=', expect)
test_parse_dimensions()
def load_settings(args):
settings = {}
for key in names:
string = args.get(key, defaults[key])
parser = parsers.get(key, lambda x: x)
value = parser(string)
settings[key] = value
return settings
def test_settings_empty():
load_settings({})
test_settings_empty()
def add_arguments(parser):
a = parser.add_argument
for name in names:
data_type = type(defaults[name])
default = None
a('--{}'.format(name), default=default, type=data_type,
help='%(default)s'
)