This repository has been archived by the owner on May 2, 2023. It is now read-only.
-
Notifications
You must be signed in to change notification settings - Fork 0
/
baseline_model.py
43 lines (34 loc) · 1.53 KB
/
baseline_model.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
from torch import nn
from torch.nn import Sequential
from torch.nn.utils.rnn import pack_padded_sequence, pad_packed_sequence
from hw_asr.base import BaseModel
class BaselineModel(BaseModel):
def __init__(self, n_feats, n_class, rnn_dim, rnn_layers, hidden_dim, hidden_layers, *args, **kwargs):
super().__init__(n_feats, n_class, *args, **kwargs)
is_bidirectional = kwargs.get("bidirectional", False)
dropout = kwargs.get("dropout", 0.1)
self.rnn = nn.GRU(
n_feats, hidden_size=rnn_dim, num_layers=rnn_layers,
bidirectional=is_bidirectional, batch_first=True
)
cur_dim = 2 * rnn_dim if is_bidirectional else rnn_dim
layers = []
for _ in range(hidden_layers - 1):
layers.extend([
nn.Linear(cur_dim, hidden_dim),
nn.GELU(),
nn.Dropout(dropout)
])
cur_dim = hidden_dim
layers.append(nn.Linear(cur_dim, n_class))
self.classifier = Sequential(*layers)
def forward(self, spectrogram, *args, **kwargs):
# (batch, time, feature)
spectrogram = pack_padded_sequence(
spectrogram, kwargs["spectrogram_length"], batch_first=True, enforce_sorted=False
)
out, _ = self.rnn(spectrogram)
out_padded, _ = pad_packed_sequence(out, batch_first=True)
return {"logits": self.classifier(out_padded)}
def transform_input_lengths(self, input_lengths):
return input_lengths # we don't reduce time dimension here