-
Notifications
You must be signed in to change notification settings - Fork 55
/
train.py
297 lines (252 loc) · 10.7 KB
/
train.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
from models.generator import TSCNet
from models import discriminator
import os
from data import dataloader
import torch.nn.functional as F
import torch
from utils import power_compress, power_uncompress
import logging
from torchinfo import summary
import argparse
import torch.multiprocessing as mp
from torch.nn.parallel import DistributedDataParallel as DDP
from torch.distributed import init_process_group, destroy_process_group
parser = argparse.ArgumentParser()
parser.add_argument("--epochs", type=int, default=120, help="number of epochs of training")
parser.add_argument("--batch_size", type=int, default=4)
parser.add_argument("--log_interval", type=int, default=500)
parser.add_argument("--decay_epoch", type=int, default=30, help="epoch from which to start lr decay")
parser.add_argument("--init_lr", type=float, default=5e-4, help="initial learning rate")
parser.add_argument("--cut_len", type=int, default=16000*2, help="cut length, default is 2 seconds in denoise "
"and dereverberation")
parser.add_argument("--data_dir", type=str, default='dir to VCTK-DEMAND dataset',
help="dir of VCTK+DEMAND dataset")
parser.add_argument("--save_model_dir", type=str, default='./saved_model',
help="dir of saved model")
parser.add_argument("--loss_weights", type=list, default=[0.1, 0.9, 0.2, 0.05],
help="weights of RI components, magnitude, time loss, and Metric Disc")
args = parser.parse_args()
logging.basicConfig(level=logging.INFO)
def ddp_setup(rank, world_size):
"""
Args:
rank: Unique identifier of each process
world_size: Total number of processes
"""
os.environ["MASTER_ADDR"] = "localhost"
os.environ["MASTER_PORT"] = "12355"
init_process_group(backend="nccl", rank=rank, world_size=world_size)
class Trainer:
def __init__(self, train_ds, test_ds, gpu_id: int):
self.n_fft = 400
self.hop = 100
self.train_ds = train_ds
self.test_ds = test_ds
self.model = TSCNet(num_channel=64, num_features=self.n_fft // 2 + 1).cuda()
summary(
self.model, [(1, 2, args.cut_len // self.hop + 1, int(self.n_fft / 2) + 1)]
)
self.discriminator = discriminator.Discriminator(ndf=16).cuda()
summary(
self.discriminator,
[
(1, 1, int(self.n_fft / 2) + 1, args.cut_len // self.hop + 1),
(1, 1, int(self.n_fft / 2) + 1, args.cut_len // self.hop + 1),
],
)
self.optimizer = torch.optim.AdamW(self.model.parameters(), lr=args.init_lr)
self.optimizer_disc = torch.optim.AdamW(
self.discriminator.parameters(), lr=2 * args.init_lr
)
self.model = DDP(self.model, device_ids=[gpu_id])
self.discriminator = DDP(self.discriminator, device_ids=[gpu_id])
self.gpu_id = gpu_id
def forward_generator_step(self, clean, noisy):
# Normalization
c = torch.sqrt(noisy.size(-1) / torch.sum((noisy**2.0), dim=-1))
noisy, clean = torch.transpose(noisy, 0, 1), torch.transpose(clean, 0, 1)
noisy, clean = torch.transpose(noisy * c, 0, 1), torch.transpose(
clean * c, 0, 1
)
noisy_spec = torch.stft(
noisy,
self.n_fft,
self.hop,
window=torch.hamming_window(self.n_fft).to(self.gpu_id),
onesided=True,
)
clean_spec = torch.stft(
clean,
self.n_fft,
self.hop,
window=torch.hamming_window(self.n_fft).to(self.gpu_id),
onesided=True,
)
noisy_spec = power_compress(noisy_spec).permute(0, 1, 3, 2)
clean_spec = power_compress(clean_spec)
clean_real = clean_spec[:, 0, :, :].unsqueeze(1)
clean_imag = clean_spec[:, 1, :, :].unsqueeze(1)
est_real, est_imag = self.model(noisy_spec)
est_real, est_imag = est_real.permute(0, 1, 3, 2), est_imag.permute(0, 1, 3, 2)
est_mag = torch.sqrt(est_real**2 + est_imag**2)
clean_mag = torch.sqrt(clean_real**2 + clean_imag**2)
est_spec_uncompress = power_uncompress(est_real, est_imag).squeeze(1)
est_audio = torch.istft(
est_spec_uncompress,
self.n_fft,
self.hop,
window=torch.hamming_window(self.n_fft).to(self.gpu_id),
onesided=True,
)
return {
"est_real": est_real,
"est_imag": est_imag,
"est_mag": est_mag,
"clean_real": clean_real,
"clean_imag": clean_imag,
"clean_mag": clean_mag,
"est_audio": est_audio,
}
def calculate_generator_loss(self, generator_outputs):
predict_fake_metric = self.discriminator(
generator_outputs["clean_mag"], generator_outputs["est_mag"]
)
gen_loss_GAN = F.mse_loss(
predict_fake_metric.flatten(), generator_outputs["one_labels"].float()
)
loss_mag = F.mse_loss(
generator_outputs["est_mag"], generator_outputs["clean_mag"]
)
loss_ri = F.mse_loss(
generator_outputs["est_real"], generator_outputs["clean_real"]
) + F.mse_loss(generator_outputs["est_imag"], generator_outputs["clean_imag"])
time_loss = torch.mean(
torch.abs(generator_outputs["est_audio"] - generator_outputs["clean"])
)
loss = (
args.loss_weights[0] * loss_ri
+ args.loss_weights[1] * loss_mag
+ args.loss_weights[2] * time_loss
+ args.loss_weights[3] * gen_loss_GAN
)
return loss
def calculate_discriminator_loss(self, generator_outputs):
length = generator_outputs["est_audio"].size(-1)
est_audio_list = list(generator_outputs["est_audio"].detach().cpu().numpy())
clean_audio_list = list(generator_outputs["clean"].cpu().numpy()[:, :length])
pesq_score = discriminator.batch_pesq(clean_audio_list, est_audio_list)
# The calculation of PESQ can be None due to silent part
if pesq_score is not None:
predict_enhance_metric = self.discriminator(
generator_outputs["clean_mag"], generator_outputs["est_mag"].detach()
)
predict_max_metric = self.discriminator(
generator_outputs["clean_mag"], generator_outputs["clean_mag"]
)
discrim_loss_metric = F.mse_loss(
predict_max_metric.flatten(), generator_outputs["one_labels"]
) + F.mse_loss(predict_enhance_metric.flatten(), pesq_score)
else:
discrim_loss_metric = None
return discrim_loss_metric
def train_step(self, batch):
# Trainer generator
clean = batch[0].to(self.gpu_id)
noisy = batch[1].to(self.gpu_id)
one_labels = torch.ones(args.batch_size).to(self.gpu_id)
generator_outputs = self.forward_generator_step(
clean,
noisy,
)
generator_outputs["one_labels"] = one_labels
generator_outputs["clean"] = clean
loss = self.calculate_generator_loss(generator_outputs)
self.optimizer.zero_grad()
loss.backward()
self.optimizer.step()
# Train Discriminator
discrim_loss_metric = self.calculate_discriminator_loss(generator_outputs)
if discrim_loss_metric is not None:
self.optimizer_disc.zero_grad()
discrim_loss_metric.backward()
self.optimizer_disc.step()
else:
discrim_loss_metric = torch.tensor([0.0])
return loss.item(), discrim_loss_metric.item()
@torch.no_grad()
def test_step(self, batch):
clean = batch[0].to(self.gpu_id)
noisy = batch[1].to(self.gpu_id)
one_labels = torch.ones(args.batch_size).to(self.gpu_id)
generator_outputs = self.forward_generator_step(
clean,
noisy,
)
generator_outputs["one_labels"] = one_labels
generator_outputs["clean"] = clean
loss = self.calculate_generator_loss(generator_outputs)
discrim_loss_metric = self.calculate_discriminator_loss(generator_outputs)
if discrim_loss_metric is None:
discrim_loss_metric = torch.tensor([0.0])
return loss.item(), discrim_loss_metric.item()
def test(self):
self.model.eval()
self.discriminator.eval()
gen_loss_total = 0.0
disc_loss_total = 0.0
for idx, batch in enumerate(self.test_ds):
step = idx + 1
loss, disc_loss = self.test_step(batch)
gen_loss_total += loss
disc_loss_total += disc_loss
gen_loss_avg = gen_loss_total / step
disc_loss_avg = disc_loss_total / step
template = "GPU: {}, Generator loss: {}, Discriminator loss: {}"
logging.info(template.format(self.gpu_id, gen_loss_avg, disc_loss_avg))
return gen_loss_avg
def train(self):
scheduler_G = torch.optim.lr_scheduler.StepLR(
self.optimizer, step_size=args.decay_epoch, gamma=0.5
)
scheduler_D = torch.optim.lr_scheduler.StepLR(
self.optimizer_disc, step_size=args.decay_epoch, gamma=0.5
)
for epoch in range(args.epochs):
self.model.train()
self.discriminator.train()
for idx, batch in enumerate(self.train_ds):
step = idx + 1
loss, disc_loss = self.train_step(batch)
template = "GPU: {}, Epoch {}, Step {}, loss: {}, disc_loss: {}"
if (step % args.log_interval) == 0:
logging.info(
template.format(self.gpu_id, epoch, step, loss, disc_loss)
)
gen_loss = self.test()
path = os.path.join(
args.save_model_dir,
"CMGAN_epoch_" + str(epoch) + "_" + str(gen_loss)[:5],
)
if not os.path.exists(args.save_model_dir):
os.makedirs(args.save_model_dir)
if self.gpu_id == 0:
torch.save(self.model.module.state_dict(), path)
scheduler_G.step()
scheduler_D.step()
def main(rank: int, world_size: int, args):
ddp_setup(rank, world_size)
if rank == 0:
print(args)
available_gpus = [
torch.cuda.get_device_name(i) for i in range(torch.cuda.device_count())
]
print(available_gpus)
train_ds, test_ds = dataloader.load_data(
args.data_dir, args.batch_size, 2, args.cut_len
)
trainer = Trainer(train_ds, test_ds, rank)
trainer.train()
destroy_process_group()
if __name__ == "__main__":
world_size = torch.cuda.device_count()
mp.spawn(main, args=(world_size, args), nprocs=world_size)