-
Notifications
You must be signed in to change notification settings - Fork 1
/
maskestimation
164 lines (125 loc) · 5.03 KB
/
maskestimation
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
from paderbox.database import keys
from padertorch.contrib.jensheit.data import SequenceProvider, MaskTransformer, STFT
import sacred
import os
from torch.autograd import Variable
from sacred import Experiment
ex = sacred.Experiment()
provider_config = dict(
transform=dict(factory=MaskTransformer,
stft=dict(factory=STFT)),
database= dict(factory='paderbox.database.chime.Chime3'),
audio_keys=[keys.OBSERVATION, keys.SPEECH_IMAGE, keys.NOISE_IMAGE],
batch_size=None
)
provider_config = SequenceProvider.get_config(provider_config)
provider = SequenceProvider.from_config(provider_config)
train_iterator = provider.get_train_iterator()
test_iterator = provider.get_eval_iterator()
#build net
import torch.nn as nn
import torch.nn.functional as F
import torch.distributions.uniform as uni
import padertorch as pt
import padertorch.train.optimizer as pt_opt
@ex.config
def config(): # Start with an empty dict to allow tracking by Sacred
# batch_size = 6
# layer_size = 1 # Variable hidden layer size
# resume = False # PT: Continue from checkpoints
foo = {
'use_pt':True, # Use trainer from pt
'epochs' : 10, # Use resume=True to train for more epochs
'storage_dir' : 'net/vol/zhenyuw/mask_estimator',
'input_size' : 513,
'hidden_size': 256,
'output_size' : 1026,
'num_layers' :1,
'lr' : 0.001,
'p_out' : 0.5,
'batch_first' : False,
'bidirectional' : True
}
class MaskEstimator(pt.Model):
# weight=uni.Uniform(torch.tensor([-0.04]), torch.tensor([0.04])
loss_function = nn.BCELoss()
def __init__(self):
super(MaskEstimator, self).__init__()
self.blstm = nn.LSTM(
config['input_size'],
config['hidden_size'],
foo['num_layers'],
foo['batch_first'], # seq-batch-input#
foo['bidirectional'],
)
self.fc1 = nn.Linear(2 * hidden_size, 513) # 3 FF Layers
self.fc2 = nn.Linear(513, 513)
self.fc3 = nn.Linear(513, 1026)
def forward(self, example): # input = (seq,batch,input_size)
x = example['observation_abs'] # noisy signal with shape (seq, batch, input_size)
r_out, _ = self.blstm(x) # x = (seq, batch, input_size) r_out = (seq, batch, 2*hidden_size)
self.dropout = nn.Dropout(0.5)
# dropout
blstm_do = self.dropout(r_out)
# tanh
out_blstm = torch.tanh(blstm_do)
# dropout
FF1_do = self.dropout(out_blstm)
# relu
out_fc1 = torch.relu(self.fc1(FF1_do)) # out_fc1 = (seq, batch, 513)
# dropout
FF2_do = self.dropout(out_fc1)
# relu
out_fc2 = torch.relu(self.fc2(FF2_do)) # out_fc2 = (seq, batch, 513)
# dropout = none
# out = sigmoid
out_fc3 = torch.sigmoid(self.fc3(out_fc2)) # out_fc3 = (seq, batch, 1026)
speech_mask_estimate, noise_mask_estimate = out_fc3.split(513, dim=-1)
return speech_mask_estimate, noise_mask_estimate
def review(self, example, output):
speech_mask_estimate, noise_mask_estimate = output
loss_speechmask = self.loss_function(speech_mask_estimate, example['speech_mask_target'])
loss_noisemask = self.loss_function(noise_mask_estimate, example['noise_mask_target'])
loss = loss_speechmask + loss_noisemask
return {'scalars': {'speechloss': torch.mean(loss_speechmask),
'noiseloss': torch.mean(loss_noisemask)},
'loss': torch.mean(loss)}
model = MaskEstimator()
def train(net, train_iterator,gpu=False):
device = None
if gpu and torch.cuda.is_available():
device = torch.device("cuda")
net.to(device)
def validate(net, testloader, gpu=False):
device = None
if gpu and torch.cuda.is_available():
device = torch.device("cuda")
import os
import torch
from padertorch.train.trainer import Trainer
import padertorch.train.optimizer as pt_opt
@ex.automain
def main(_config, use_pt,storage_dir):
if use_pt:
model = MaskEstimator()
optimizer = pt_opt.Adam()
trainer = Trainer(model,
storage_dir=storage_dir,
optimizer=optimizer,
lr_scheduler=None,
loss_weights=None,
summary_trigger=(1000, 'iteration'),
checkpoint_trigger=(1000, 'iteration'),
keep_all_checkpoints=False,
max_trigger=(2, 'epoch'),
virtual_minibatch_size=1,
)
try:
trainer.train(train_iterator,
test_iterator,
resume=False,
device='cpu'
)
except Exception:
print('#' * 1000)
raise