-
Notifications
You must be signed in to change notification settings - Fork 475
/
expert.py
408 lines (363 loc) · 16.8 KB
/
expert.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
# -*- coding: utf-8 -*- #
"""*********************************************************************************************"""
# FileName [ expert.py ]
# Synopsis [ the speech separation downstream wrapper ]
# Source [ Reference some code from https://github.com/funcwj/uPIT-for-speech-separation and https://github.com/asteroid-team/asteroid ]
# Author [ Zili Huang ]
# Copyright [ Copyright(c), Johns Hopkins University ]
"""*********************************************************************************************"""
###############
# IMPORTATION #
###############
import os
import math
import random
import h5py
import numpy as np
from pathlib import Path
from collections import defaultdict
import librosa
# -------------#
import torch
import torch.nn as nn
from torch.utils.data import DataLoader
from torch.nn.utils.rnn import pack_sequence, pad_sequence
import torch.nn.functional as F
# -------------#
from .model import SepRNN
from .dataset import SeparationDataset
from asteroid.metrics import get_metrics
from .loss import MSELoss, SISDRLoss
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
def match_length(feat_list, length_list):
assert len(feat_list) == len(length_list)
bs = len(length_list)
new_feat_list = []
for i in range(bs):
assert abs(feat_list[i].size(0) - length_list[i]) < 5
if feat_list[i].size(0) == length_list[i]:
new_feat_list.append(feat_list[i])
elif feat_list[i].size(0) > length_list[i]:
new_feat_list.append(feat_list[i][:length_list[i], :])
else:
new_feat = torch.zeros(length_list[i], feat_list[i].size(1)).to(feat_list[i].device)
new_feat[:feat_list[i].size(0), :] = feat_list[i]
new_feat_list.append(new_feat)
return new_feat_list
# We cannot guarantee the predicted STFT feature is always valid.
# In our experiments, we often observe impulse at the end of signal.
# This function is used to suppress the impluse.
def postprocess(x, pad_zeros=True):
y = np.copy(x)
p = int(np.max(np.nonzero(y))) + 1 # y[p:] = 0
if p < x.shape[0] - 2048:
print("Warning: the predicted signal is 0 from sample {} to {}".format(p, x.shape[0]))
return x
window_size = 512
start_p = p - window_size
if start_p <= 0: # the wav length too short
print("Warning: the length of wav is too short")
return x
else:
max_value = np.max(np.abs(y[:start_p]))
invalid = np.nonzero(np.abs(y[start_p:p]) > max_value)[0]
if len(invalid) == 0:
return x
else:
invalid_pos = np.min(invalid) + start_p
z = np.copy(x)
if pad_zeros:
z[invalid_pos:] = 0
print("Set from {} to {} 0, {} samples".format(invalid_pos, x.shape[0], x.shape[0] - invalid_pos))
else:
z[invalid_pos:] = np.random.normal(loc=0.0, scale=0.01, size=(x.shape[0] - invalid_pos,))
print("Set from {} to {} Gaussian noise, {} samples".format(invalid_pos, x.shape[0], x.shape[0] - invalid_pos))
return z
class DownstreamExpert(nn.Module):
"""
Used to handle downstream-specific operations
eg. downstream forward, metric computation, contents to log
"""
def __init__(self, upstream_dim, upstream_rate, downstream_expert, expdir, **kwargs):
super(DownstreamExpert, self).__init__()
self.upstream_dim = upstream_dim
self.upstream_rate = upstream_rate
self.datarc = downstream_expert["datarc"]
self.loaderrc = downstream_expert["loaderrc"]
self.modelrc = downstream_expert["modelrc"]
self.expdir = expdir
self.train_dataset = SeparationDataset(
data_dir=self.loaderrc["train_dir"],
rate=self.datarc['rate'],
src=self.datarc['src'],
tgt=self.datarc['tgt'],
n_fft=self.datarc['n_fft'],
hop_length=self.upstream_rate,
win_length=self.datarc['win_length'],
window=self.datarc['window'],
center=self.datarc['center'],
)
self.dev_dataset = SeparationDataset(
data_dir=self.loaderrc["dev_dir"],
rate=self.datarc['rate'],
src=self.datarc['src'],
tgt=self.datarc['tgt'],
n_fft=self.datarc['n_fft'],
hop_length=self.upstream_rate,
win_length=self.datarc['win_length'],
window=self.datarc['window'],
center=self.datarc['center'],
)
self.test_dataset = SeparationDataset(
data_dir=self.loaderrc["test_dir"],
rate=self.datarc['rate'],
src=self.datarc['src'],
tgt=self.datarc['tgt'],
n_fft=self.datarc['n_fft'],
hop_length=self.upstream_rate,
win_length=self.datarc['win_length'],
window=self.datarc['window'],
center=self.datarc['center'],
)
if self.modelrc["model"] == "SepRNN":
self.model = SepRNN(
input_dim=self.upstream_dim,
num_bins=int(self.datarc['n_fft'] / 2 + 1),
rnn=self.modelrc["rnn"],
num_spks=self.datarc['num_speakers'],
num_layers=self.modelrc["rnn_layers"],
hidden_size=self.modelrc["hidden_size"],
dropout=self.modelrc["dropout"],
non_linear=self.modelrc["non_linear"],
bidirectional=self.modelrc["bidirectional"]
)
else:
raise ValueError("Model type not defined.")
self.loss_type = self.modelrc["loss_type"]
if self.modelrc["loss_type"] == "MSE":
self.objective = MSELoss(self.datarc['num_speakers'], self.modelrc["mask_type"])
elif self.modelrc["loss_type"] == "SISDR":
self.objective = SISDRLoss(self.datarc['num_speakers'],
n_fft=self.datarc['n_fft'],
hop_length=self.upstream_rate,
win_length=self.datarc['win_length'],
window=self.datarc['window'],
center=self.datarc['center'])
else:
raise ValueError("Loss type not defined.")
self.register_buffer("best_score", torch.ones(1) * -10000)
def _get_train_dataloader(self, dataset):
return DataLoader(
dataset,
batch_size=self.loaderrc["train_batchsize"],
shuffle=True,
num_workers=self.loaderrc["num_workers"],
drop_last=False,
pin_memory=True,
collate_fn=dataset.collate_fn,
)
def _get_eval_dataloader(self, dataset):
return DataLoader(
dataset,
batch_size=self.loaderrc["eval_batchsize"],
shuffle=False,
num_workers=self.loaderrc["num_workers"],
drop_last=False,
pin_memory=True,
collate_fn=dataset.collate_fn,
)
def get_dataloader(self, mode):
"""
Args:
mode: string
'train', 'dev' or 'test'
Return:
a torch.utils.data.DataLoader returning each batch in the format of:
[wav1, wav2, ...], your_other_contents1, your_other_contents2, ...
where wav1, wav2 ... are in variable length
each wav is torch.FloatTensor in cpu with:
1. dim() == 1
2. sample_rate == 16000
3. directly loaded by torchaudio
"""
if mode == "train":
return self._get_train_dataloader(self.train_dataset)
elif mode == "dev":
return self._get_eval_dataloader(self.dev_dataset)
elif mode == "test":
return self._get_eval_dataloader(self.test_dataset)
def forward(self, mode, features, uttname_list, source_attr, source_wav, target_attr, target_wav_list, feat_length, wav_length, records, **kwargs):
"""
Args:
mode: string
'train', 'dev' or 'test' for this forward step
features:
list of unpadded features [feat1, feat2, ...]
each feat is in torch.FloatTensor and already
put in the device assigned by command-line args
uttname_list:
list of utterance names
source_attr:
source_attr is a dict containing the STFT information
for the mixture. source_attr['magnitude'] stores the STFT
magnitude, source_attr['phase'] stores the STFT phase and
source_attr['stft'] stores the raw STFT feature. The shape
is [bs, max_length, feat_dim]
source_wav:
source_wav contains the raw waveform for the mixture,
and it has the shape of [bs, max_wav_length]
target_attr:
similar to source_attr, it contains the STFT information
for individual sources. It only has two keys ('magnitude' and 'phase')
target_attr['magnitude'] is a list of length n_srcs, and
target_attr['magnitude'][i] has the shape [bs, max_length, feat_dim]
target_wav_list:
target_wav_list contains the raw waveform for the individual
sources, and it is a list of length n_srcs. target_wav_list[0]
has the shape [bs, max_wav_length]
feat_length:
length of STFT features
wav_length:
length of raw waveform
records:
defaultdict(list), by appending contents into records,
these contents can be averaged and logged on Tensorboard
later by self.log_records every log_step
Return:
loss:
the loss to be optimized, should not be detached
"""
# match the feature length to STFT feature length
features = match_length(features, feat_length)
features = pack_sequence(features)
mask_list = self.model(features)
# evaluate the separation quality of predict sources
if mode == 'dev' or mode == 'test':
if mode == 'dev':
COMPUTE_METRICS = ["si_sdr"]
elif mode == 'test':
COMPUTE_METRICS = ["si_sdr", "stoi", "pesq"]
predict_stfts = [torch.squeeze(m * source_attr['stft'].to(device)) for m in mask_list]
predict_stfts_np = [np.transpose(s.data.cpu().numpy()) for s in predict_stfts]
assert len(wav_length) == 1
# reconstruct the signal using iSTFT
predict_srcs_np = [postprocess(librosa.istft(stft_mat,
hop_length=self.upstream_rate,
win_length=self.datarc['win_length'],
window=self.datarc['window'],
center=self.datarc['center'],
length=wav_length[0])) for stft_mat in predict_stfts_np]
predict_srcs_np = np.stack(predict_srcs_np, 0)
gt_srcs_np = torch.cat(target_wav_list, 0).data.cpu().numpy()
mix_np = source_wav.data.cpu().numpy()
utt_metrics = get_metrics(
mix_np,
gt_srcs_np,
predict_srcs_np,
sample_rate = self.datarc['rate'],
metrics_list = COMPUTE_METRICS,
compute_permutation=True,
)
for metric in COMPUTE_METRICS:
input_metric = "input_" + metric
assert metric in utt_metrics and input_metric in utt_metrics
imp = utt_metrics[metric] - utt_metrics[input_metric]
if metric not in records:
records[metric] = []
if metric == "si_sdr":
records[metric].append(imp)
elif metric == "stoi" or metric == "pesq":
records[metric].append(utt_metrics[metric])
else:
raise ValueError("Metric type not defined.")
assert 'batch_id' in kwargs
if kwargs['batch_id'] % 1000 == 0: # Save the prediction every 1000 examples
records['mix'].append(mix_np)
records['hypo'].append(predict_srcs_np)
records['ref'].append(gt_srcs_np)
records['uttname'].append(uttname_list[0])
if self.loss_type == "MSE": # mean square loss
loss = self.objective.compute_loss(mask_list, feat_length, source_attr, target_attr)
elif self.loss_type == "SISDR": # end-to-end SI-SNR loss
loss = self.objective.compute_loss(mask_list, feat_length, source_attr, wav_length, target_wav_list)
else:
raise ValueError("Loss type not defined.")
records["loss"].append(loss.item())
return loss
# interface
def log_records(
self, mode, records, logger, global_step, batch_ids, total_batch_num, **kwargs
):
"""
Args:
mode: string
'train':
records and batchids contain contents for `log_step` batches
`log_step` is defined in your downstream config
eg. downstream/example/config.yaml
'dev' or 'test' :
records and batchids contain contents for the entire evaluation dataset
records:
defaultdict(list), contents already appended
logger:
Tensorboard SummaryWriter
please use f'{prefix}your_content_name' as key name
to log your customized contents
global_step:
The global_step when training, which is helpful for Tensorboard logging
batch_ids:
The batches contained in records when enumerating over the dataloader
total_batch_num:
The total amount of batches in the dataloader
Return:
a list of string
Each string is a filename we wish to use to save the current model
according to the evaluation result, like the best.ckpt on the dev set
You can return nothing or an empty list when no need to save the checkpoint
"""
if mode == 'train':
avg_loss = np.mean(records["loss"])
logger.add_scalar(
f"separation_stft/{mode}-loss", avg_loss, global_step=global_step
)
return []
else:
if mode == 'dev':
COMPUTE_METRICS = ["si_sdr"]
elif mode == 'test':
COMPUTE_METRICS = ["si_sdr", "stoi", "pesq"]
avg_loss = np.mean(records["loss"])
logger.add_scalar(
f"separation_stft/{mode}-loss", avg_loss, global_step=global_step
)
with (Path(self.expdir) / f"{mode}_metrics.txt").open("w") as output:
for metric in COMPUTE_METRICS:
avg_metric = np.mean(records[metric])
if mode == "test" or mode == "dev":
print("Average {} of {} utts: {:.4f}".format(metric, len(records[metric]), avg_metric))
print(metric, avg_metric, file=output)
logger.add_scalar(
f'separation_stft/{mode}-'+metric,
avg_metric,
global_step=global_step
)
save_ckpt = []
assert 'si_sdr' in records
if mode == "dev" and np.mean(records['si_sdr']) > self.best_score:
self.best_score = torch.ones(1) * np.mean(records['si_sdr'])
save_ckpt.append(f"best-states-{mode}.ckpt")
for s in ['mix', 'ref', 'hypo', 'uttname']:
assert s in records
for i in range(len(records['uttname'])):
utt = records['uttname'][i]
mix_wav = records['mix'][i][0, :]
mix_wav = librosa.util.normalize(mix_wav, norm=np.inf, axis=None)
logger.add_audio('step{:06d}_{}_mix.wav'.format(global_step, utt), mix_wav, global_step=global_step, sample_rate=self.datarc['rate'])
for j in range(records['ref'][i].shape[0]):
ref_wav = records['ref'][i][j, :]
hypo_wav = records['hypo'][i][j, :]
ref_wav = librosa.util.normalize(ref_wav, norm=np.inf, axis=None)
hypo_wav = librosa.util.normalize(hypo_wav, norm=np.inf, axis=None)
logger.add_audio('step{:06d}_{}_ref_s{}.wav'.format(global_step, utt, j+1), ref_wav, global_step=global_step, sample_rate=self.datarc['rate'])
logger.add_audio('step{:06d}_{}_hypo_s{}.wav'.format(global_step, utt, j+1), hypo_wav, global_step=global_step, sample_rate=self.datarc['rate'])
return save_ckpt