-
Notifications
You must be signed in to change notification settings - Fork 43
/
exp3.ctc.config
238 lines (211 loc) · 11.4 KB
/
exp3.ctc.config
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
#!crnn/rnn.py
# kate: syntax python;
import os
import numpy
from subprocess import check_output, CalledProcessError
from Pretrain import WrapEpochValue
# task
use_tensorflow = True
task = "train"
device = "gpu"
multiprocessing = True
update_on_device = True
debug_mode = False
if int(os.environ.get("DEBUG", "0")):
print("** DEBUG MODE")
debug_mode = True
if config.has("beam_size"):
beam_size = config.int("beam_size", 0)
print("** beam_size %i" % beam_size)
else:
beam_size = 12
# data
num_inputs = 40
num_outputs = {"classes": (10025, 1), "data": (num_inputs, 2)} # see vocab
EpochSplit = 20
def get_dataset(key, subset=None, train_partition_epoch=None):
d = {
'class': 'LibriSpeechCorpus',
'path': 'base/dataset/tars',
"use_zip": True,
"use_cache_manager": True,
"prefix": key,
"bpe": {
'bpe_file': 'base/dataset/trans.bpe.codes',
'vocab_file': 'base/dataset/trans.bpe.vocab',
'seq_postfix': [0],
'unknown_label': '<unk>'},
"audio": {
"norm_mean": "base/dataset/stats.mean.txt",
"norm_std_dev": "base/dataset/stats.std_dev.txt"},
}
if key.startswith("train"):
d["partition_epoch"] = train_partition_epoch
if key == "train":
d["epoch_wise_filter"] = {
(1, 5): {
'max_mean_len': 75, # chars, should be around 14 bpe labels
'subdirs': ['train-clean-100', 'train-clean-360']}}
#d["audio"]["random_permute"] = True
num_seqs = 281241 # total
d["seq_ordering"] = "laplace:%i" % (num_seqs // 1000)
else:
d["fixed_random_seed"] = 1
d["seq_ordering"] = "sorted_reverse"
if subset:
d["fixed_random_subset"] = subset # faster
return d
train = get_dataset("train", train_partition_epoch=EpochSplit)
dev = get_dataset("dev", subset=3000)
cache_size = "0"
window = 1
# network
# (also defined by num_inputs & num_outputs)
target = "classes"
EncKeyTotalDim = 1024
AttNumHeads = 1
EncKeyPerHeadDim = EncKeyTotalDim // AttNumHeads
EncValueTotalDim = 2048
EncValuePerHeadDim = EncValueTotalDim // AttNumHeads
LstmDim = EncValueTotalDim // 2
network = {
"source": {"class": "eval", "eval": "tf.clip_by_value(source(0), -3.0, 3.0)"},
"lstm0_fw" : { "class": "rec", "unit": "nativelstm2", "n_out" : LstmDim, "direction": 1, "from": ["source"] },
"lstm0_bw" : { "class": "rec", "unit": "nativelstm2", "n_out" : LstmDim, "direction": -1, "from": ["source"] },
"lstm0_pool": {"class": "pool", "mode": "max", "padding": "same", "pool_size": (2,), "from": ["lstm0_fw", "lstm0_bw"], "trainable": False},
"lstm1_fw" : { "class": "rec", "unit": "nativelstm2", "n_out" : LstmDim, "direction": 1, "from": ["lstm0_pool"], "dropout": 0.3 },
"lstm1_bw" : { "class": "rec", "unit": "nativelstm2", "n_out" : LstmDim, "direction": -1, "from": ["lstm0_pool"], "dropout": 0.3 },
"lstm1_pool": {"class": "pool", "mode": "max", "padding": "same", "pool_size": (2,), "from": ["lstm1_fw", "lstm1_bw"], "trainable": False},
"lstm2_fw" : { "class": "rec", "unit": "nativelstm2", "n_out" : LstmDim, "direction": 1, "from": ["lstm1_pool"], "dropout": 0.3 },
"lstm2_bw" : { "class": "rec", "unit": "nativelstm2", "n_out" : LstmDim, "direction": -1, "from": ["lstm1_pool"], "dropout": 0.3 },
"lstm2_pool": {"class": "pool", "mode": "max", "padding": "same", "pool_size": (2,), "from": ["lstm2_fw", "lstm2_bw"], "trainable": False},
"lstm3_fw" : { "class": "rec", "unit": "nativelstm2", "n_out" : LstmDim, "direction": 1, "from": ["lstm2_pool"], "dropout": 0.3 },
"lstm3_bw" : { "class": "rec", "unit": "nativelstm2", "n_out" : LstmDim, "direction": -1, "from": ["lstm2_pool"], "dropout": 0.3 },
"lstm3_pool": {"class": "pool", "mode": "max", "padding": "same", "pool_size": (1,), "from": ["lstm3_fw", "lstm3_bw"], "trainable": False},
"lstm4_fw" : { "class": "rec", "unit": "nativelstm2", "n_out" : LstmDim, "direction": 1, "from": ["lstm3_pool"], "dropout": 0.3 },
"lstm4_bw" : { "class": "rec", "unit": "nativelstm2", "n_out" : LstmDim, "direction": -1, "from": ["lstm3_pool"], "dropout": 0.3 },
"lstm4_pool": {"class": "pool", "mode": "max", "padding": "same", "pool_size": (1,), "from": ["lstm4_fw", "lstm4_bw"], "trainable": False},
"lstm5_fw" : { "class": "rec", "unit": "nativelstm2", "n_out" : LstmDim, "direction": 1, "from": ["lstm4_pool"], "dropout": 0.3 },
"lstm5_bw" : { "class": "rec", "unit": "nativelstm2", "n_out" : LstmDim, "direction": -1, "from": ["lstm4_pool"], "dropout": 0.3 },
"encoder": {"class": "copy", "from": ["lstm5_fw", "lstm5_bw"]}, # dim: EncValueTotalDim
"enc_ctx": {"class": "linear", "activation": None, "with_bias": True, "from": ["encoder"], "n_out": EncKeyTotalDim}, # preprocessed_attended in Blocks
"inv_fertility": {"class": "linear", "activation": "sigmoid", "with_bias": False, "from": ["encoder"], "n_out": AttNumHeads},
"enc_value": {"class": "split_dims", "axis": "F", "dims": (AttNumHeads, EncValuePerHeadDim), "from": ["encoder"]}, # (B, enc-T, H, D'/H)
"output": {"class": "rec", "from": [], 'cheating': config.bool("cheating", False), "unit": {
'output': {'class': 'choice', 'target': target, 'beam_size': beam_size, 'cheating': config.bool("cheating", False), 'from': ["output_prob"], "initial_output": 0},
"end": {"class": "compare", "from": ["output"], "value": 0},
'target_embed': {'class': 'linear', 'activation': None, "with_bias": False, 'from': ['output'], "n_out": 621, "initial_output": 0}, # feedback_input
"weight_feedback": {"class": "linear", "activation": None, "with_bias": False, "from": ["prev:accum_att_weights"], "n_out": EncKeyTotalDim},
"s_transformed": {"class": "linear", "activation": None, "with_bias": False, "from": ["s"], "n_out": EncKeyTotalDim},
"energy_in": {"class": "combine", "kind": "add", "from": ["base:enc_ctx", "weight_feedback", "s_transformed"], "n_out": EncKeyTotalDim},
"energy_tanh": {"class": "activation", "activation": "tanh", "from": ["energy_in"]},
"energy": {"class": "linear", "activation": None, "with_bias": False, "from": ["energy_tanh"], "n_out": AttNumHeads}, # (B, enc-T, H)
"att_weights": {"class": "softmax_over_spatial", "from": ["energy"]}, # (B, enc-T, H)
"accum_att_weights": {"class": "eval", "from": ["prev:accum_att_weights", "att_weights", "base:inv_fertility"],
"eval": "source(0) + source(1) * source(2) * 0.5", "out_type": {"dim": AttNumHeads, "shape": (None, AttNumHeads)}},
"att0": {"class": "generic_attention", "weights": "att_weights", "base": "base:enc_value"}, # (B, H, V)
"att": {"class": "merge_dims", "axes": "except_batch", "from": ["att0"]}, # (B, H*V)
"s": {"class": "rnn_cell", "unit": "LSTMBlock", "from": ["prev:target_embed", "prev:att"], "n_out": 1000}, # transform
"readout_in": {"class": "linear", "from": ["s", "prev:target_embed", "att"], "activation": None, "n_out": 1000}, # merge + post_merge bias
"readout": {"class": "reduce_out", "mode": "max", "num_pieces": 2, "from": ["readout_in"]},
"output_prob": {
"class": "softmax", "from": ["readout"], "dropout": 0.3,
"target": target, "loss": "ce", "loss_opts": {"label_smoothing": 0.1}}
}, "target": target, "max_seq_len": "max_len_from('base:encoder')"},
"decision": {
"class": "decide", "from": ["output"], "loss": "edit_distance", "target": target,
"loss_opts": {
#"debug_print": True
}
},
"ctc": {"class": "softmax", "from": ["encoder"], "loss": "ctc", "target": target,
"loss_opts": {"beam_width": 1, "ctc_opts": {"ignore_longer_outputs_than_inputs": True}}}
}
search_output_layer = "decision"
debug_print_layer_output_template = True
# trainer
batching = "random"
log_batch_size = True
batch_size = 20000
max_seqs = 200
max_seq_length = {"classes": 75}
#chunking = "" # no chunking
truncation = -1
def custom_construction_algo(idx, net_dict):
# For debugging, use: python3 ./crnn/Pretrain.py config... Maybe set repetitions=1 below.
# We will first construct layer-by-layer, starting with 2 layers.
# Initially, we will use a higher reduction factor, and at the end, we will reduce it.
# Also, we will initially have not label smoothing.
orig_num_lstm_layers = 0
while "lstm%i_fw" % orig_num_lstm_layers in net_dict:
orig_num_lstm_layers += 1
assert orig_num_lstm_layers >= 2
orig_red_factor = 1
for i in range(orig_num_lstm_layers - 1):
orig_red_factor *= net_dict["lstm%i_pool" % i]["pool_size"][0]
num_lstm_layers = idx + 2 # idx starts at 0. start with 2 layers
if idx == 0:
net_dict["lstm%i_fw" % (orig_num_lstm_layers - 1)]["dropout"] = 0
net_dict["lstm%i_bw" % (orig_num_lstm_layers - 1)]["dropout"] = 0
if idx >= 1:
num_lstm_layers -= 1 # repeat like idx=0, but now with dropout
# We will start with a higher reduction factor initially, for better convergence.
red_factor = 2 ** 5
if num_lstm_layers == orig_num_lstm_layers + 1:
# Use original reduction factor now.
num_lstm_layers = orig_num_lstm_layers
red_factor = orig_red_factor
if num_lstm_layers > orig_num_lstm_layers:
# Finish. This will also use label-smoothing then.
return None
# Use label smoothing only at the very end.
net_dict["output"]["unit"]["output_prob"]["loss_opts"]["label_smoothing"] = 0
# Other options during pretraining.
if idx == 0:
net_dict["#config"] = {"max_seq_length": {"classes": 60}}
net_dict["#repetition"] = 10
# Leave the last lstm layer as-is, but only modify its source.
net_dict["lstm%i_fw" % (orig_num_lstm_layers - 1)]["from"] = ["lstm%i_pool" % (num_lstm_layers - 2)]
net_dict["lstm%i_bw" % (orig_num_lstm_layers - 1)]["from"] = ["lstm%i_pool" % (num_lstm_layers - 2)]
if red_factor > orig_red_factor:
for i in range(num_lstm_layers - 2):
net_dict["lstm%i_pool" % i]["pool_size"] = (2,)
# Increase last pool-size to get the initial reduction factor.
assert red_factor % (2 ** (num_lstm_layers - 2)) == 0
last_pool_size = red_factor // (2 ** (num_lstm_layers - 2))
# Increase last pool-size to get the same encoder-seq-length folding.
net_dict["lstm%i_pool" % (num_lstm_layers - 2)]["pool_size"] = (last_pool_size,)
# Delete non-used lstm layers. This is not explicitly necessary but maybe nicer.
for i in range(num_lstm_layers - 1, orig_num_lstm_layers - 1):
del net_dict["lstm%i_fw" % i]
del net_dict["lstm%i_bw" % i]
del net_dict["lstm%i_pool" % i]
return net_dict
pretrain = {"repetitions": 5, "construction_algo": custom_construction_algo}
num_epochs = 250
model = "net-model/network"
cleanup_old_models = True
gradient_clip = 0
#gradient_clip_global_norm = 1.0
adam = True
optimizer_epsilon = 1e-8
#debug_add_check_numerics_ops = True
#debug_add_check_numerics_on_output = True
stop_on_nonfinite_train_score = False
tf_log_memory_usage = True
gradient_noise = 0.0
learning_rate = 0.0008
learning_rates = list(numpy.linspace(0.0003, learning_rate, num=10)) # warmup
learning_rate_control = "newbob_multi_epoch"
#learning_rate_control_error_measure = "dev_score_output"
learning_rate_control_relative_error_relative_lr = True
learning_rate_control_min_num_epochs_per_new_lr = 3
use_learning_rate_control_always = True
newbob_multi_num_epochs = EpochSplit
newbob_multi_update_interval = 1
newbob_learning_rate_decay = 0.9
learning_rate_file = "newbob.data"
# log
#log = "| /u/zeyer/dotfiles/system-tools/bin/mt-cat.py >> log/crnn.seq-train.%s.log" % task
log = "log/crnn.%s.log" % task
log_verbosity = 5