-
Notifications
You must be signed in to change notification settings - Fork 13
/
seld.py
142 lines (124 loc) · 5.9 KB
/
seld.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
import torch
import torch.nn as nn
import torch.nn.functional as F
from methods.utils.model_utilities import (DoubleConv, PositionalEncoding,
init_layer)
class EINV2(nn.Module):
def __init__(self, cfg, dataset):
super().__init__()
self.pe_enable = False # Ture | False
if cfg['data']['audio_feature'] == 'logmel&intensity':
self.f_bins = cfg['data']['n_mels']
self.in_channels = 7
self.downsample_ratio = 2 ** 2
self.sed_conv_block1 = nn.Sequential(
DoubleConv(in_channels=4, out_channels=64),
nn.AvgPool2d(kernel_size=(2, 2)),
)
self.sed_conv_block2 = nn.Sequential(
DoubleConv(in_channels=64, out_channels=128),
nn.AvgPool2d(kernel_size=(2, 2)),
)
self.sed_conv_block3 = nn.Sequential(
DoubleConv(in_channels=128, out_channels=256),
nn.AvgPool2d(kernel_size=(1, 2)),
)
self.sed_conv_block4 = nn.Sequential(
DoubleConv(in_channels=256, out_channels=512),
nn.AvgPool2d(kernel_size=(1, 2)),
)
self.doa_conv_block1 = nn.Sequential(
DoubleConv(in_channels=self.in_channels, out_channels=64),
nn.AvgPool2d(kernel_size=(2, 2)),
)
self.doa_conv_block2 = nn.Sequential(
DoubleConv(in_channels=64, out_channels=128),
nn.AvgPool2d(kernel_size=(2, 2)),
)
self.doa_conv_block3 = nn.Sequential(
DoubleConv(in_channels=128, out_channels=256),
nn.AvgPool2d(kernel_size=(1, 2)),
)
self.doa_conv_block4 = nn.Sequential(
DoubleConv(in_channels=256, out_channels=512),
nn.AvgPool2d(kernel_size=(1, 2)),
)
self.stitch = nn.ParameterList([
nn.Parameter(torch.FloatTensor(64, 2, 2).uniform_(0.1, 0.9)),
nn.Parameter(torch.FloatTensor(128, 2, 2).uniform_(0.1, 0.9)),
nn.Parameter(torch.FloatTensor(256, 2, 2).uniform_(0.1, 0.9)),
])
if self.pe_enable:
self.pe = PositionalEncoding(pos_len=100, d_model=512, pe_type='t', dropout=0.0)
self.sed_trans_track1 = nn.TransformerEncoder(
nn.TransformerEncoderLayer(d_model=512, nhead=8, dim_feedforward=1024, dropout=0.2), num_layers=2)
self.sed_trans_track2 = nn.TransformerEncoder(
nn.TransformerEncoderLayer(d_model=512, nhead=8, dim_feedforward=1024, dropout=0.2), num_layers=2)
self.doa_trans_track1 = nn.TransformerEncoder(
nn.TransformerEncoderLayer(d_model=512, nhead=8, dim_feedforward=1024, dropout=0.2), num_layers=2)
self.doa_trans_track2 = nn.TransformerEncoder(
nn.TransformerEncoderLayer(d_model=512, nhead=8, dim_feedforward=1024, dropout=0.2), num_layers=2)
self.fc_sed_track1 = nn.Linear(512, 14, bias=True)
self.fc_sed_track2 = nn.Linear(512, 14, bias=True)
self.fc_doa_track1 = nn.Linear(512, 3, bias=True)
self.fc_doa_track2 = nn.Linear(512, 3, bias=True)
self.final_act_sed = nn.Sequential() # nn.Sigmoid()
self.final_act_doa = nn.Tanh()
self.init_weight()
def init_weight(self):
init_layer(self.fc_sed_track1)
init_layer(self.fc_sed_track2)
init_layer(self.fc_doa_track1)
init_layer(self.fc_doa_track2)
def forward(self, x):
"""
x: waveform, (batch_size, num_channels, data_length)
"""
x_sed = x[:, :4]
x_doa = x
# cnn
x_sed = self.sed_conv_block1(x_sed)
x_doa = self.doa_conv_block1(x_doa)
x_sed = torch.einsum('c, nctf -> nctf', self.stitch[0][:, 0, 0], x_sed) + \
torch.einsum('c, nctf -> nctf', self.stitch[0][:, 0, 1], x_doa)
x_doa = torch.einsum('c, nctf -> nctf', self.stitch[0][:, 1, 0], x_sed) + \
torch.einsum('c, nctf -> nctf', self.stitch[0][:, 1, 1], x_doa)
x_sed = self.sed_conv_block2(x_sed)
x_doa = self.doa_conv_block2(x_doa)
x_sed = torch.einsum('c, nctf -> nctf', self.stitch[1][:, 0, 0], x_sed) + \
torch.einsum('c, nctf -> nctf', self.stitch[1][:, 0, 1], x_doa)
x_doa = torch.einsum('c, nctf -> nctf', self.stitch[1][:, 1, 0], x_sed) + \
torch.einsum('c, nctf -> nctf', self.stitch[1][:, 1, 1], x_doa)
x_sed = self.sed_conv_block3(x_sed)
x_doa = self.doa_conv_block3(x_doa)
x_sed = torch.einsum('c, nctf -> nctf', self.stitch[2][:, 0, 0], x_sed) + \
torch.einsum('c, nctf -> nctf', self.stitch[2][:, 0, 1], x_doa)
x_doa = torch.einsum('c, nctf -> nctf', self.stitch[2][:, 1, 0], x_sed) + \
torch.einsum('c, nctf -> nctf', self.stitch[2][:, 1, 1], x_doa)
x_sed = self.sed_conv_block4(x_sed)
x_doa = self.doa_conv_block4(x_doa)
x_sed = x_sed.mean(dim=3) # (N, C, T)
x_doa = x_doa.mean(dim=3) # (N, C, T)
# transformer
if self.pe_enable:
x_sed = self.pe(x_sed)
if self.pe_enable:
x_doa = self.pe(x_doa)
x_sed = x_sed.permute(2, 0, 1) # (T, N, C)
x_doa = x_doa.permute(2, 0, 1) # (T, N, C)
x_sed_1 = self.sed_trans_track1(x_sed).transpose(0, 1) # (N, T, C)
x_sed_2 = self.sed_trans_track2(x_sed).transpose(0, 1) # (N, T, C)
x_doa_1 = self.doa_trans_track1(x_doa).transpose(0, 1) # (N, T, C)
x_doa_2 = self.doa_trans_track2(x_doa).transpose(0, 1) # (N, T, C)
# fc
x_sed_1 = self.final_act_sed(self.fc_sed_track1(x_sed_1))
x_sed_2 = self.final_act_sed(self.fc_sed_track2(x_sed_2))
x_sed = torch.stack((x_sed_1, x_sed_2), 2)
x_doa_1 = self.final_act_doa(self.fc_doa_track1(x_doa_1))
x_doa_2 = self.final_act_doa(self.fc_doa_track2(x_doa_2))
x_doa = torch.stack((x_doa_1, x_doa_2), 2)
output = {
'sed': x_sed,
'doa': x_doa,
}
return output