In [1]:
import math
import torch
import torch.nn as nn
from torch.autograd import Variable
from model_ae import Encoder

In [2]:
class MultiHeadedAttention(nn.Module):
    def __init__(self, num_attn_heads, attn_hidden_size, dropout_prob, with_focus_attn):
        super(MultiHeadedAttention, self).__init__()
        self.num_attn_heads = num_attn_heads
        self.hidden_size = attn_hidden_size
        self.dropout_prob = dropout_prob
        self.with_focus_attn = with_focus_attn
        
        self.attn_head_size = int(self.hidden_size / self.num_attn_heads)
        self.all_head_size = self.num_attn_heads * self.attn_head_size

        self.query = nn.Linear(self.hidden_size, self.all_head_size)
        self.key = nn.Linear(self.hidden_size, self.all_head_size)
        self.value = nn.Linear(self.hidden_size, self.all_head_size)

        self.o_proj = nn.Linear(self.hidden_size, self.hidden_size)
        self.dropout = nn.Dropout(self.dropout_prob)

        self.softmax = nn.Softmax(dim=-1)
        
        if(with_focus_attn == True):
            self.tanh = nn.Tanh()
            self.sigmoid = nn.Sigmoid()
            
            self.linear_focus_query = nn.Linear(num_attn_heads * self.attn_head_size, 
                                                num_attn_heads * self.attn_head_size)
            self.linear_focus_global = nn.Linear(num_attn_heads * self.attn_head_size, 
                                                 num_attn_heads * self.attn_head_size)
            
            up = torch.randn(num_attn_heads, 1, self.attn_head_size)
            self.up = Variable(up, requires_grad=True).cuda()
            torch.nn.init.xavier_uniform_(self.up)
            
            uz = torch.randn(num_attn_heads, 1, self.attn_head_size)
            self.uz = Variable(uz, requires_grad=True).cuda()
            torch.nn.init.xavier_uniform_(self.uz)

    def transpose_for_scores(self, x):
        new_x_shape = x.size()[:-1] + (self.num_attn_heads, self.attn_head_size)
        x = x.view(*new_x_shape)
        return x.permute(0, 2, 1, 3)

    def forward(self, hidden_states):
        key_len = hidden_states.size(1)
        
        mixed_query_layer = self.query(hidden_states)
        mixed_key_layer = self.key(hidden_states)
        mixed_value_layer = self.value(hidden_states)
        
        if(self.with_focus_attn == True):
            glo = torch.mean(mixed_query_layer, dim=1, keepdim=True)
            
            c = self.tanh(self.linear_focus_query(mixed_query_layer) + self.linear_focus_global(glo))
            c = self.transpose_for_scores(c)
            
            p = c * self.up
            p = p.sum(3).squeeze()
            z = c * self.uz
            z = z.sum(3).squeeze()
            
            P = self.sigmoid(p) * key_len
            Z = self.sigmoid(z) * key_len
            
            j = torch.arange(start=0, end=key_len, dtype=P.dtype).unsqueeze(0).unsqueeze(0).unsqueeze(0).cuda()
            P = P.unsqueeze(-1)
            Z = Z.unsqueeze(-1)
            
            G = -(j - P)**2 * 2 / (Z**2)

        query_layer = self.transpose_for_scores(mixed_query_layer)
        key_layer = self.transpose_for_scores(mixed_key_layer)
        value_layer = self.transpose_for_scores(mixed_value_layer)

        attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2))
        attention_scores = attention_scores / math.sqrt(self.attn_head_size)
        
        if(self.with_focus_attn == True):
            attention_scores = attention_scores + G
            
        attention_probs = self.softmax(attention_scores)
        attention_probs = self.dropout(attention_probs)

        context_layer = torch.matmul(attention_probs, value_layer)
        context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
        new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,)
        context_layer = context_layer.view(*new_context_layer_shape)
        attention_output = self.o_proj(context_layer)

        return attention_output

In [3]:
class CLDNN(nn.Module):
    def __init__(self, conv_dim, checkpoint=None, hidden_size=64, num_layers=2,
                 bidirectional=True, with_focus_attn=False):
        super(CLDNN, self).__init__()
        self.conv_dim = conv_dim
        if(conv_dim == '1d'):
            self.encoder = Encoder(conv_dim)
            if checkpoint:
                self.encoder.load_state_dict(torch.load(checkpoint))
            self.attn = MultiHeadedAttention(num_attn_heads=4, attn_hidden_size=8, dropout_prob=0.1,
                                             with_focus_attn=with_focus_attn)
            self.lstm = nn.LSTM(8, hidden_size=hidden_size, num_layers=num_layers, bidirectional=bidirectional) 
            self.fc = nn.Sequential(
                nn.Linear(hidden_size*2 if bidirectional else hidden_size, 1),
                nn.Sigmoid()
            )
        elif(conv_dim == '2d'):
            self.encoder = Encoder(conv_dim)
            if checkpoint:
                self.encoder.load_state_dict(torch.load(checkpoint))
            self.attn = MultiHeadedAttention(num_attn_heads=4, attn_hidden_size=176, dropout_prob=0.1, 
                                             with_focus_attn=with_focus_attn)
            self.gap = nn.AdaptiveAvgPool2d((1, 11))
            self.lstm = nn.LSTM(11, hidden_size=hidden_size, num_layers=num_layers, bidirectional=bidirectional) 
            self.fc = nn.Sequential(
                nn.Linear(hidden_size*2 if bidirectional else hidden_size, 1),
                nn.Sigmoid()
            )
        else:
            raise ValueError("Convolution dimension not found: %s" % (conv_dim))
            
    def forward(self, x):
        if(self.conv_dim == '1d'):
            out = self.encoder(x)  # (batch, 1, 40, 100) -> (batch, 8, 1, 100)
            out = torch.squeeze(out, 2)  # (batch, 8, 1, 100) -> (batch, 8, 100)
            out = out.permute(0, 2, 1)  # (batch, 8, 100) -> (batch, 100, 8)
            h = out
            out = self.attn(out) # (batch, 100, 8) -> (batch, 100, 8)
            out = h + out
            out = out.permute(1, 0, 2)  # (batch, 100, 8) -> (100, batch, 8)
            out, _ = self.lstm(out)  # (100, batch, 8) -> (100, batch, num_directions*hidden_size)
            out = out[-1]  # (100, batch, num_directions*hidden_size) -> (batch, num_directions*hidden_size)
            out = self.fc(out)  # (batch, num_directions*hidden_size) -> (batch, 1)
        elif(self.conv_dim == '2d'):
            out = self.encoder(x)  # (batch, 1, 128, 100) -> (batch, 16, 11, 8)
            out = out.permute(0, 3, 1, 2)  # (batch, 16, 11, 8) -> (batch, 8, 16, 11)
            h = out
            new_out_shape = out.size()[:2] + (out.size()[2] * out.size()[3],)
            out = out.view(*new_out_shape)  # (batch, 8, 16, 11) -> (batch, 8, 176)
            out = self.attn(out)  # (batch, 8, 176) -> (batch, 8, 176)
            out = out.view(h.size())  # (batch, 8, 176) -> (batch, 8, 16, 11)
            out = h + out
            out = self.gap(out)  # (batch, 8, 16, 11) -> (batch, 8, 1, 11)
            out = torch.squeeze(out, 2)  # (batch, 8, 1, 11) -> (batch, 8, 11)
            out = out.permute(1, 0, 2)  # (batch, 8, 11) -> (8, batch, 11)
            out, _ = self.lstm(out)  # (8, batch, 11) -> (8, batch, num_directions*hidden_size)
            out = out[-1]  # (8, batch, num_directions*hidden_size) -> (batch, num_directions*hidden_size)
            out = self.fc(out)  # (batch, num_directions*hidden_size) -> (batch, 1)
        return out

In [4]:
from preprocessing import convert_spectrograms, convert_tensor
import os, glob
import numpy as np

In [5]:
sample_data_repo = os.path.join('.', 'wav_data', 'pretrain')

samples_data = glob.glob(os.path.join(sample_data_repo, '**', '*wav'), recursive=True)
samples_data = sorted(samples_data)


np.random.seed(42)
idx = np.random.permutation(len(samples_data))
train_idx = idx[:int(len(samples_data)*0.8)]
eval_idx = idx[int(len(samples_data)*0.8):]

In [6]:
train_samples = list(np.array(samples_data)[train_idx])
eval_samples = list(np.array(samples_data)[eval_idx])

In [7]:
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

In [8]:
#X_train = convert_spectrograms(train_samples, conv_dim=args.conv_dim)
X_eval = convert_spectrograms(eval_samples, conv_dim='2d')

#X_train = convert_tensor(X_train, device=device)
X_eval = convert_tensor(X_eval, device=device)

100%|██████████| 288/288 [00:50<00:00,  5.73it/s]


In [9]:
model = CLDNN(conv_dim='2d', checkpoint='./output/aae_2d_step_100.pt', with_focus_attn=True)

In [10]:
model = model.to(device)

In [11]:
out = model(X_eval[:1])

In [12]:
out

tensor([[0.5003]], device='cuda:0', grad_fn=<SigmoidBackward>)

In [13]:
from torch.utils.data import TensorDataset, DataLoader

In [14]:
y = np.zeros(len(X_eval))
y[int(len(X_eval)/2):] = 1

In [15]:
y = torch.tensor(y, device=device).float()
y = y.view(-1, 1)

In [16]:
train_ds = TensorDataset(X_eval, y)

In [17]:
train_dataloader = DataLoader(train_ds, batch_size=64, shuffle=True, num_workers=0, drop_last=True)

In [18]:
import torch.optim as optim

loss_func = nn.BCELoss()
optimizer = optim.Adam(model.parameters())

In [19]:
def train(dataloader, epochs):
    for epoch in range(epochs):
        for x_batch, y_batch in dataloader:
            optimizer.zero_grad()
            
            outputs = model(x_batch)
            
            loss = loss_func(outputs, y_batch)
            loss.backward()
            
            optimizer.step()
            
        for param_group in optimizer.param_groups:
            lr = param_group['lr']
        print('epoch: {:3d},    lr={:6f},    loss={:5f}'.format(epoch+1, lr, loss.item()))

In [20]:
train(train_dataloader, 10)

epoch:   1,    lr=0.001000,    loss=0.681675
epoch:   2,    lr=0.001000,    loss=0.691006
epoch:   3,    lr=0.001000,    loss=0.669674
epoch:   4,    lr=0.001000,    loss=0.696231
epoch:   5,    lr=0.001000,    loss=0.685858
epoch:   6,    lr=0.001000,    loss=0.610714
epoch:   7,    lr=0.001000,    loss=0.703751
epoch:   8,    lr=0.001000,    loss=0.589321
epoch:   9,    lr=0.001000,    loss=0.600516
epoch:  10,    lr=0.001000,    loss=0.586961
