In [1]:
import os
import glob

import math
import librosa
import numpy as np
from scipy.io.wavfile import write
from tqdm import tqdm
from preprocessing import preprocessing, get_mrcg, get_mfcc

import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.init as init
from torch.nn.utils import clip_grad_norm_
from torch.utils.data import TensorDataset, DataLoader
from src.utils.optimization import WarmupLinearSchedule

device = torch.device("cuda:0" if torch.cuda.is_available else "cpu")
print(device)

cuda:0


In [2]:
sr = 16000
frame_length = 0.025  # 25ms
frame_stride = 0.01  # 10ms
n_fft = int(round(sr * frame_length))
hop_length = int(round(sr * frame_stride))
n_mfcc = 40

In [3]:
sample_data_repo = os.path.join('.', 'Data', 'sample_data')

samples_0 = glob.glob(os.path.join(sample_data_repo, '0', '**', '*wav'), recursive=True)
samples_0 = sorted(samples_0)

samples_20 = glob.glob(os.path.join(sample_data_repo, '20', '**', '*wav'), recursive=True)
samples_20 = sorted(samples_20)

samples_40 = glob.glob(os.path.join(sample_data_repo, '40', '**', '*wav'), recursive=True)
samples_40 = sorted(samples_40)

samples_60 = glob.glob(os.path.join(sample_data_repo, '60', '**', '*wav'), recursive=True)
samples_60 = sorted(samples_60)

samples_80 = glob.glob(os.path.join(sample_data_repo, '80', '**', '*wav'), recursive=True)
samples_80 = sorted(samples_80)

samples_100 = glob.glob(os.path.join(sample_data_repo, '100', '**', '*wav'), recursive=True)
samples_100 = sorted(samples_100)

samples_120 = glob.glob(os.path.join(sample_data_repo, '120', '**', '*wav'), recursive=True)
samples_120 = sorted(samples_120)

samples_140 = glob.glob(os.path.join(sample_data_repo, '140', '**', '*wav'), recursive=True)
samples_140 = sorted(samples_140)

samples_160 = glob.glob(os.path.join(sample_data_repo, '160', '**', '*wav'), recursive=True)
samples_160 = sorted(samples_160)

samples_180 = glob.glob(os.path.join(sample_data_repo, '180', '**', '*wav'), recursive=True)
samples_180 = sorted(samples_180)

In [4]:
sample_vad_seg_repo = os.path.join('.', 'Data', 'binary_segment')   # 적절하게 변경 필요

samples_vad_seg_0 = glob.glob(os.path.join(sample_vad_seg_repo, '0', '*[npy|npz]'), recursive=True)
samples_vad_seg_0 = sorted(samples_vad_seg_0)   

samples_vad_seg_20 = glob.glob(os.path.join(sample_vad_seg_repo, '20', '*[npy|npz]'), recursive=True)
samples_vad_seg_20 = sorted(samples_vad_seg_20)   

samples_vad_seg_40 = glob.glob(os.path.join(sample_vad_seg_repo, '40', '*[npy|npz]'), recursive=True)
samples_vad_seg_40 = sorted(samples_vad_seg_40)   

samples_vad_seg_60 = glob.glob(os.path.join(sample_vad_seg_repo, '60', '*[npy|npz]'), recursive=True)
samples_vad_seg_60 = sorted(samples_vad_seg_60)   

samples_vad_seg_80 = glob.glob(os.path.join(sample_vad_seg_repo, '80', '*[npy|npz]'), recursive=True)
samples_vad_seg_80 = sorted(samples_vad_seg_80)   

samples_vad_seg_100 = glob.glob(os.path.join(sample_vad_seg_repo, '100', '*[npy|npz]'), recursive=True)
samples_vad_seg_100 = sorted(samples_vad_seg_100)   

samples_vad_seg_120 = glob.glob(os.path.join(sample_vad_seg_repo, '120', '*[npy|npz]'), recursive=True)
samples_vad_seg_120 = sorted(samples_vad_seg_120)   

samples_vad_seg_140 = glob.glob(os.path.join(sample_vad_seg_repo, '140', '*[npy|npz]'), recursive=True)
samples_vad_seg_140 = sorted(samples_vad_seg_140)   

samples_vad_seg_160 = glob.glob(os.path.join(sample_vad_seg_repo, '160', '*[npy|npz]'), recursive=True)
samples_vad_seg_160 = sorted(samples_vad_seg_160)   

samples_vad_seg_180 = glob.glob(os.path.join(sample_vad_seg_repo, '180', '*[npy|npz]'), recursive=True)
samples_vad_seg_180 = sorted(samples_vad_seg_180)   

In [6]:
def generatio_tensor_instances(array_2d, seq_len, hop, label):
    """
    array_2d : ndarray.  STFT magnitude or phase.
    dest_path : file path
    seq_len : number of frames in an instance.
    label : segmented labels.  0 and 1's. The same length as original wav file of the audio sample. 
    """
    row_size, col_size = array_2d.shape[0], array_2d.shape[1]
    ratio = len(label)/col_size  # ratio : how many data points per frame 
    stack_array = []    # 4D tensor that will hold the instances
    label_array = []

    j=0
    while j <= (col_size - (seq_len+1)): 
        context_frame = array_2d[:, j:(j+seq_len)]
        # seg_label = round( label[int(j*ratio):int((j+seq_len)*ratio)].mean() ) 
        threshold = 0.5  # if greater than the threshold, then speech 
        seg_label = 1 if label[int(j*ratio):int((j+seq_len)*ratio)].mean() > threshold else 0
        
        stack_array.append(context_frame[:,:,np.newaxis])   # make context_frame to 3d tensor & append 
        label_array.append(seg_label)
            
        j = j+hop
        
    return np.stack(stack_array, axis=0), label_array

In [7]:
no_samples = len(samples_0) 

mfcc_instances_0 = []    # elements are ndarrays
label_instances_0 = []         # elements are lists

for i in range(0, no_samples):
    voice_noise_label = np.load(samples_vad_seg_0[i])
    if('npy' in samples_vad_seg_0[i].split('/')[-1]):
        label = voice_noise_label[0]        # use the left channel label.  this take care of 0 degree problem
    else:                                   # npz file
        label = voice_noise_label["label"]    
    mfcc = get_mfcc(samples_0[i], sr=sr, n_mfcc=n_mfcc, n_fft=n_fft, hop_length=hop_length)
    
    # generate instances with 1.16 sec duration (100 frames), at every 0.116 sec apart (10 hops)
    mfcc_instances_sub, label_sub = generatio_tensor_instances(mfcc, 50, 10, label)

    mfcc_instances_0.append(mfcc_instances_sub)
    label_instances_0.append(np.array(label_sub))
    
np.concatenate(mfcc_instances_0).shape, np.concatenate(label_instances_0).shape

((282, 40, 50, 1), (282,))

In [8]:
no_samples = len(samples_20) 

mfcc_instances_20 = []    # elements are ndarrays
label_instances_20 = []         # elements are lists

for i in range(0, no_samples):
    voice_noise_label = np.load(samples_vad_seg_20[i])
    if('npy' in samples_vad_seg_20[i].split('/')[-1]):
        label = voice_noise_label[0]        # use the left channel label.  this take care of 0 degree problem
    else:                                   # npz file
        label = voice_noise_label["label"]    
    mfcc = get_mfcc(samples_20[i], sr=sr, n_mfcc=n_mfcc, n_fft=n_fft, hop_length=hop_length)
    
    # generate instances with 1.16 sec duration (100 frames), at every 0.116 sec apart (10 hops)
    mfcc_instances_sub, label_sub = generatio_tensor_instances(mfcc, 50, 10, label)
    
    mfcc_instances_20.append(mfcc_instances_sub)
    label_instances_20.append(np.array(label_sub))
    
np.concatenate(mfcc_instances_20).shape, np.concatenate(label_instances_20).shape

((1296, 40, 50, 1), (1296,))

In [9]:
no_samples = len(samples_40) 

mfcc_instances_40 = []    # elements are ndarrays
label_instances_40 = []         # elements are lists

for i in range(0, no_samples):
    voice_noise_label = np.load(samples_vad_seg_40[i])
    if('npy' in samples_vad_seg_40[i].split('/')[-1]):
        label = voice_noise_label[0]        # use the left channel label.  this take care of 0 degree problem
    else:                                   # npz file
        label = voice_noise_label["label"]    
    mfcc = get_mfcc(samples_40[i], sr=sr, n_mfcc=n_mfcc, n_fft=n_fft, hop_length=hop_length)
    
    # generate instances with 1.16 sec duration (100 frames), at every 0.116 sec apart (10 hops)
    mfcc_instances_sub, label_sub = generatio_tensor_instances(mfcc, 50, 10, label)
    
    mfcc_instances_40.append(mfcc_instances_sub)
    label_instances_40.append(np.array(label_sub))
    
np.concatenate(mfcc_instances_40).shape, np.concatenate(label_instances_40).shape

((1404, 40, 50, 1), (1404,))

In [10]:
no_samples = len(samples_60) 

mfcc_instances_60 = []    # elements are ndarrays
label_instances_60 = []         # elements are lists

for i in range(0, no_samples):
    voice_noise_label = np.load(samples_vad_seg_60[i])
    if('npy' in samples_vad_seg_60[i].split('/')[-1]):
        label = voice_noise_label[0]        # use the left channel label.  this take care of 0 degree problem
    else:                                   # npz file
        label = voice_noise_label["label"]    
    mfcc = get_mfcc(samples_60[i], sr=sr, n_mfcc=n_mfcc, n_fft=n_fft, hop_length=hop_length)
    
    # generate instances with 1.16 sec duration (100 frames), at every 0.116 sec apart (10 hops)
    mfcc_instances_sub, label_sub = generatio_tensor_instances(mfcc, 50, 10, label)
    
    mfcc_instances_60.append(mfcc_instances_sub)
    label_instances_60.append(np.array(label_sub))
    
np.concatenate(mfcc_instances_60).shape, np.concatenate(label_instances_60).shape

((333, 40, 50, 1), (333,))

In [11]:
no_samples = len(samples_80) 

mfcc_instances_80 = []    # elements are ndarrays
label_instances_80 = []         # elements are lists

for i in range(0, no_samples):
    voice_noise_label = np.load(samples_vad_seg_80[i])
    if('npy' in samples_vad_seg_80[i].split('/')[-1]):
        label = voice_noise_label[0]        # use the left channel label.  this take care of 0 degree problem
    else:                                   # npz file
        label = voice_noise_label["label"]    
    mfcc = get_mfcc(samples_80[i], sr=sr, n_mfcc=n_mfcc, n_fft=n_fft, hop_length=hop_length)
    
    # generate instances with 1.16 sec duration (100 frames), at every 0.116 sec apart (10 hops)
    mfcc_instances_sub, label_sub = generatio_tensor_instances(mfcc, 50, 10, label)
    
    mfcc_instances_80.append(mfcc_instances_sub)
    label_instances_80.append(np.array(label_sub))
    
np.concatenate(mfcc_instances_80).shape, np.concatenate(label_instances_80).shape

((1692, 40, 50, 1), (1692,))

In [12]:
no_samples = len(samples_100) 

mfcc_instances_100 = []    # elements are ndarrays
label_instances_100 = []         # elements are lists

for i in range(0, no_samples):
    voice_noise_label = np.load(samples_vad_seg_100[i])
    if('npy' in samples_vad_seg_100[i].split('/')[-1]):
        label = voice_noise_label[0]        # use the left channel label.  this take care of 0 degree problem
    else:                                   # npz file
        label = voice_noise_label["label"]    
    mfcc = get_mfcc(samples_100[i], sr=sr, n_mfcc=n_mfcc, n_fft=n_fft, hop_length=hop_length)
    
    # generate instances with 1.16 sec duration (100 frames), at every 0.116 sec apart (10 hops)
    mfcc_instances_sub, label_sub = generatio_tensor_instances(mfcc, 50, 10, label)
    
    mfcc_instances_100.append(mfcc_instances_sub)
    label_instances_100.append(np.array(label_sub))
    
np.concatenate(mfcc_instances_100).shape, np.concatenate(label_instances_100).shape

((1836, 40, 50, 1), (1836,))

In [13]:
no_samples = len(samples_140) 

mfcc_instances_140 = []    # elements are ndarrays
label_instances_140 = []         # elements are lists

for i in range(0, no_samples):
    voice_noise_label = np.load(samples_vad_seg_140[i])
    if('npy' in samples_vad_seg_140[i].split('/')[-1]):
        label = voice_noise_label[0]        # use the left channel label.  this take care of 0 degree problem
    else:                                   # npz file
        label = voice_noise_label["label"]    
    mfcc = get_mfcc(samples_140[i], sr=sr, n_mfcc=n_mfcc, n_fft=n_fft, hop_length=hop_length)
    
    # generate instances with 1.16 sec duration (100 frames), at every 0.116 sec apart (10 hops)
    mfcc_instances_sub, label_sub = generatio_tensor_instances(mfcc, 50, 10, label)
    
    mfcc_instances_140.append(mfcc_instances_sub)
    label_instances_140.append(np.array(label_sub))
    
np.concatenate(mfcc_instances_140).shape, np.concatenate(label_instances_140).shape

((1476, 40, 50, 1), (1476,))

In [14]:
no_samples = len(samples_160) 

mfcc_instances_160 = []    # elements are ndarrays
label_instances_160 = []         # elements are lists

for i in range(0, no_samples):
    voice_noise_label = np.load(samples_vad_seg_160[i])
    if('npy' in samples_vad_seg_160[i].split('/')[-1]):
        label = voice_noise_label[0]        # use the left channel label.  this take care of 0 degree problem
    else:                                   # npz file
        label = voice_noise_label["label"]    
    mfcc = get_mfcc(samples_160[i], sr=sr, n_mfcc=n_mfcc, n_fft=n_fft, hop_length=hop_length)
    
    # generate instances with 1.16 sec duration (100 frames), at every 0.116 sec apart (10 hops)
    mfcc_instances_sub, label_sub = generatio_tensor_instances(mfcc, 50, 10, label)
    
    mfcc_instances_160.append(mfcc_instances_sub)
    label_instances_160.append(np.array(label_sub))
    
np.concatenate(mfcc_instances_160).shape, np.concatenate(label_instances_160).shape

((1836, 40, 50, 1), (1836,))

In [15]:
no_samples = len(samples_180) 

mfcc_instances_180 = []    # elements are ndarrays
label_instances_180 = []         # elements are lists

for i in range(0, no_samples):
    voice_noise_label = np.load(samples_vad_seg_180[i])
    if('npy' in samples_vad_seg_180[i].split('/')[-1]):
        label = voice_noise_label[0]        # use the left channel label.  this take care of 0 degree problem
    else:                                   # npz file
        label = voice_noise_label["label"]    
    mfcc = get_mfcc(samples_180[i], sr=sr, n_mfcc=n_mfcc, n_fft=n_fft, hop_length=hop_length)
    
    # generate instances with 1.16 sec duration (100 frames), at every 0.116 sec apart (10 hops)
    mfcc_instances_sub, label_sub = generatio_tensor_instances(mfcc, 50, 10, label)
    
    mfcc_instances_180.append(mfcc_instances_sub)
    label_instances_180.append(np.array(label_sub))
    
np.concatenate(mfcc_instances_180).shape, np.concatenate(label_instances_180).shape

((349, 40, 50, 1), (349,))

In [16]:
mfcc_instances_0 = np.concatenate(mfcc_instances_0)
mfcc_instances_20 = np.concatenate(mfcc_instances_20)
mfcc_instances_40 = np.concatenate(mfcc_instances_40)
mfcc_instances_60 = np.concatenate(mfcc_instances_60)
mfcc_instances_80 = np.concatenate(mfcc_instances_80)
mfcc_instances_100= np.concatenate(mfcc_instances_100)
mfcc_instances_140 = np.concatenate(mfcc_instances_140)
mfcc_instances_160 = np.concatenate(mfcc_instances_160)
mfcc_instances_180 = np.concatenate(mfcc_instances_180)

In [17]:
label_instances_0 = np.concatenate(label_instances_0)
label_instances_20 = np.concatenate(label_instances_20)
label_instances_40 = np.concatenate(label_instances_40)
label_instances_60 = np.concatenate(label_instances_60)
label_instances_80 = np.concatenate(label_instances_80)
label_instances_100= np.concatenate(label_instances_100)
label_instances_140 = np.concatenate(label_instances_140)
label_instances_160 = np.concatenate(label_instances_160)
label_instances_180 = np.concatenate(label_instances_180)

In [18]:
mfcc_instances = np.concatenate([mfcc_instances_0, mfcc_instances_20, mfcc_instances_40, mfcc_instances_60,mfcc_instances_80, 
                                  mfcc_instances_100, mfcc_instances_140, mfcc_instances_160, mfcc_instances_180], axis=0)

In [19]:
label_instances = np.concatenate([label_instances_0, label_instances_20, label_instances_40, label_instances_60,label_instances_80, 
                                  label_instances_100, label_instances_140, label_instances_160, label_instances_180], axis=0)

In [20]:
X = mfcc_instances
y = label_instances

In [21]:
num_eval = len(label_instances_160) + len(label_instances_180)
num_eval

2185

In [22]:
X_train = X[:-num_eval]
X_eval = X[-num_eval:]
y_train = y[:-num_eval]
y_eval = y[-num_eval:]

In [23]:
X_train.shape, y_train.shape, X_eval.shape, y_eval.shape

((8319, 40, 50, 1), (8319,), (2185, 40, 50, 1), (2185,))

In [24]:
X_train = X_train.transpose(0, 3, 1, 2)
X_eval = X_eval.transpose(0, 3, 1, 2)

In [25]:
X_train = torch.tensor(X_train).to(device)
X_eval = torch.tensor(X_eval).to(device)
y_train = torch.tensor(y_train).to(device)
y_eval = torch.tensor(y_eval).to(device)

In [26]:
y_train = y_train.float().unsqueeze(-1)
y_eval = y_eval.float().unsqueeze(-1)

In [27]:
train_ds = TensorDataset(X_train, y_train)
eval_ds = TensorDataset(X_eval, y_eval)

train_dataloader = DataLoader(train_ds, batch_size=128, shuffle=True, num_workers=0, drop_last=True)
eval_dataloader = DataLoader(eval_ds, batch_size=128, num_workers=0, drop_last=True)

In [28]:
class Encoder(nn.Module):
    def __init__(self, conv_dim):
        super(Encoder, self).__init__()
        self.conv_dim = conv_dim
        if(conv_dim == '1d'):
            self.conv1 = nn.Sequential(
                nn.Conv1d(1, 4, (11, 1)), # (1, 40, 100) -> (4, 30, 100)
                nn.BatchNorm2d(4),
                nn.ReLU()
            )
            self.conv2 = nn.Sequential(
                nn.Conv1d(4, 4, (11, 1)), # (4, 30, 100) -> (4, 20, 100)
                nn.BatchNorm2d(4),
                nn.ReLU()
            )
            self.conv3 = nn.Sequential(
                nn.Conv1d(4, 8, (11, 1)), # (4, 20, 100) -> (8, 10, 100)
                nn.BatchNorm2d(8),
                nn.ReLU()
            )
            self.conv4 = nn.Sequential(
                nn.Conv1d(8, 8, (10, 1)), # (8, 10, 100) -> (8, 1, 100)
                nn.BatchNorm2d(8),
                nn.ReLU()
            )
        elif(conv_dim == '2d'):
            self.conv1 = nn.Sequential(
                nn.Conv2d(1, 4, (5, 3), padding=(0, 1)), # (1, 128, 50) -> (4, 124, 50)
                nn.BatchNorm2d(4),
                nn.ReLU(),
                nn.Conv2d(4, 4, (5, 3), padding=(0, 1)),  # (4, 124, 50) -> (4, 120, 50)
                nn.BatchNorm2d(4),
                nn.ReLU(),
                nn.MaxPool2d(2, 2)  # (4, 120, 50) -> (4, 60, 25)
            )
            self.conv2 = nn.Sequential(
                nn.Conv2d(4, 8, (5, 3), padding=(0, 1)), # (4, 60, 25) -> (8, 56, 25)
                nn.BatchNorm2d(8),
                nn.ReLU(),
                nn.Conv2d(8, 8, (5, 3), padding=(0, 1)),  # (8, 56, 25) -> (8, 52, 25)
                nn.BatchNorm2d(8),
                nn.ReLU(),
                nn.MaxPool2d(2, 2)  # (8, 52, 25) -> (8, 26, 12)
            )
            self.conv3 = nn.Sequential(
                nn.Conv2d(8, 16, (5, 3), padding=0), # (8, 26, 12) -> (16, 22, 10)
                nn.BatchNorm2d(16),
                nn.ReLU(),
                nn.MaxPool2d((2, 1), (2, 1)),  # (16, 22, 10) -> (16, 11, 10)
                nn.Conv2d(16, 16, 3, padding=(1, 0)),  # (16, 11, 10) -> (16, 11, 8)
                nn.BatchNorm2d(16),
                nn.ReLU()
            )
        else:
            raise ValueError("Convolution dimension not found: %s" % (conv_dim))

    def forward(self, x):
        if(self.conv_dim == '1d'):
            out = self.conv1(x)
            out = self.conv2(out)
            out = self.conv3(out)
            out = self.conv4(out)
            # out = out.contiguous().view(x.size()[0], -1)  # (800,)
        elif(self.conv_dim == '2d'):
            out = self.conv1(x)
            out = self.conv2(out)
            out = self.conv3(out)
            # out = out.contiguous().view(x.size()[0], -1)  # (1408,)
        return out

In [29]:
class MultiHeadedAttention(nn.Module):
    def __init__(self, num_attn_heads, attn_hidden_size, dropout_prob, with_focus_attn):
        super(MultiHeadedAttention, self).__init__()
        self.num_attn_heads = num_attn_heads
        self.hidden_size = attn_hidden_size
        self.dropout_prob = dropout_prob
        self.with_focus_attn = with_focus_attn
        
        self.attn_head_size = int(self.hidden_size / self.num_attn_heads)
        self.all_head_size = self.num_attn_heads * self.attn_head_size

        self.query = nn.Linear(self.hidden_size, self.all_head_size)
        self.key = nn.Linear(self.hidden_size, self.all_head_size)
        self.value = nn.Linear(self.hidden_size, self.all_head_size)

        self.o_proj = nn.Linear(self.hidden_size, self.hidden_size)
        self.dropout = nn.Dropout(self.dropout_prob)

        self.softmax = nn.Softmax(dim=-1)
        
        if(with_focus_attn == True):
            self.tanh = nn.Tanh()
            self.sigmoid = nn.Sigmoid()
            
            self.linear_focus_query = nn.Linear(num_attn_heads * self.attn_head_size, 
                                                num_attn_heads * self.attn_head_size)
            self.linear_focus_global = nn.Linear(num_attn_heads * self.attn_head_size, 
                                                 num_attn_heads * self.attn_head_size)
            
            up = torch.randn(num_attn_heads, 1, self.attn_head_size)
            self.up = Variable(up, requires_grad=True).cuda()
            torch.nn.init.xavier_uniform_(self.up)
            
            uz = torch.randn(num_attn_heads, 1, self.attn_head_size)
            self.uz = Variable(uz, requires_grad=True).cuda()
            torch.nn.init.xavier_uniform_(self.uz)

    def transpose_for_scores(self, x):
        new_x_shape = x.size()[:-1] + (self.num_attn_heads, self.attn_head_size)
        x = x.view(*new_x_shape)
        return x.permute(0, 2, 1, 3)

    def forward(self, hidden_states):
        key_len = hidden_states.size(1)
        
        mixed_query_layer = self.query(hidden_states)
        mixed_key_layer = self.key(hidden_states)
        mixed_value_layer = self.value(hidden_states)
        
        if(self.with_focus_attn == True):
            glo = torch.mean(mixed_query_layer, dim=1, keepdim=True)
            
            c = self.tanh(self.linear_focus_query(mixed_query_layer) + self.linear_focus_global(glo))
            c = self.transpose_for_scores(c)
            
            p = c * self.up
            p = p.sum(3).squeeze()
            z = c * self.uz
            z = z.sum(3).squeeze()
            
            P = self.sigmoid(p) * key_len
            Z = self.sigmoid(z) * key_len
            
            j = torch.arange(start=0, end=key_len, dtype=P.dtype).unsqueeze(0).unsqueeze(0).unsqueeze(0).to('cuda')
            P = P.unsqueeze(-1)
            Z = Z.unsqueeze(-1)
            
            G = -(j - P)**2 * 2 / (Z**2)

        query_layer = self.transpose_for_scores(mixed_query_layer)
        key_layer = self.transpose_for_scores(mixed_key_layer)
        value_layer = self.transpose_for_scores(mixed_value_layer)

        attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2))
        attention_scores = attention_scores / math.sqrt(self.attn_head_size)
        
        if(self.with_focus_attn == True):
            attention_scores = attention_scores + G
            
        attention_probs = self.softmax(attention_scores)
        attention_probs = self.dropout(attention_probs)

        context_layer = torch.matmul(attention_probs, value_layer)
        context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
        new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,)
        context_layer = context_layer.view(*new_context_layer_shape)
        attention_output = self.o_proj(context_layer)

        return attention_output

In [30]:
class CLDNN(nn.Module):
    def __init__(self, conv_dim, checkpoint=None, hidden_size=64, num_layers=2,
                 bidirectional=True, with_focus_attn=False):
        super(CLDNN, self).__init__()
        self.conv_dim = conv_dim
        if(conv_dim == '1d'):
            self.encoder = Encoder(conv_dim)
            if checkpoint:
                self.encoder.load_state_dict(torch.load(checkpoint))
            self.attn = MultiHeadedAttention(num_attn_heads=4, attn_hidden_size=8, dropout_prob=0.1,
                                             with_focus_attn=with_focus_attn)
            self.lstm = nn.LSTM(8, hidden_size=hidden_size, num_layers=num_layers, bidirectional=bidirectional)
            self.fc = nn.Sequential(
                nn.Linear(hidden_size*2 if bidirectional else hidden_size, 1),
                nn.Sigmoid()
            )
        elif(conv_dim == '2d'):
            self.encoder = Encoder(conv_dim)
            if checkpoint:
                self.encoder.load_state_dict(torch.load(checkpoint))
            self.attn = MultiHeadedAttention(num_attn_heads=4, attn_hidden_size=176, dropout_prob=0.1, 
                                             with_focus_attn=with_focus_attn)
            self.gap = nn.AdaptiveAvgPool2d((1, 11))
            self.lstm = nn.LSTM(11, hidden_size=hidden_size, num_layers=num_layers, bidirectional=bidirectional)
            self.fc = nn.Sequential(
                nn.Linear(hidden_size*2 if bidirectional else hidden_size, 1),
                nn.Sigmoid()
            )
        else:
            raise ValueError("Convolution dimension not found: %s" % (conv_dim))
            
    def forward(self, x):
        if(self.conv_dim == '1d'):
            out = self.encoder(x)
            out = torch.squeeze(out, 2)
            out = out.permute(0, 2, 1)  
            h = out
            out = self.attn(out)  
            out = h + out
            out = out.permute(1, 0, 2)  
            self.lstm.flatten_parameters()
            out, _ = self.lstm(out)  
            out = out[-1]  
            out = self.fc(out)  
        elif(self.conv_dim == '2d'):
            out = self.encoder(x)  
            out = out.permute(0, 3, 1, 2)  
            h = out
            new_out_shape = out.size()[:2] + (out.size()[2] * out.size()[3],)
            out = out.view(*new_out_shape)  
            out = self.attn(out)  
            out = out.view(h.size())  
            out = h + out
            out = self.gap(out)  
            out = torch.squeeze(out, 2)  
            out = out.permute(1, 0, 2)  
            self.lstm.flatten_parameters()
            out, _ = self.lstm(out)
            out = out[-1]
            out = self.fc(out)
        return out

In [31]:
model = CLDNN(conv_dim='1d').to(device)

In [32]:
loss_func = nn.BCELoss()
optimizer = optim.Adam(model.parameters(), lr=0.0003)

In [33]:
def train(train_dataloader, eval_dataloader, epochs):
        print('Start training')
        softmax = nn.Softmax(dim=1)
        for epoch in range(epochs):
            model.train()
            train_loss = 0
            nb_train_steps = 0
            correct = 0
            num_samples = 0

            for X_batch, y_batch in train_dataloader:
                X_batch = X_batch.to(device)
                y_batch = y_batch.to(device)

                optimizer.zero_grad()

                outputs = model(X_batch)

                loss = loss_func(outputs, y_batch)
                loss.backward()

                optimizer.step()

                train_loss += loss.mean().item()
                nb_train_steps += 1

                outputs = (outputs >= 0.5).float()
                correct += (outputs == y_batch).float().sum()
                num_samples += len(X_batch)

            train_loss = train_loss / nb_train_steps
            train_accuracy = correct / num_samples

            model.eval()
            eval_loss = 0
            nb_eval_steps = 0
            correct = 0
            num_samples = 0

            for X_batch, y_batch in eval_dataloader:
                X_batch = X_batch.to(device)
                y_batch = y_batch.to(device)
                with torch.no_grad():
                    outputs = model(X_batch)

                tmp_eval_loss = loss_func(outputs, y_batch)
                eval_loss += tmp_eval_loss.mean().item()
                nb_eval_steps += 1

                outputs = (outputs >= 0.5).float()
                correct += (outputs == y_batch).float().sum()
                num_samples += len(X_batch)

            eval_loss = eval_loss / nb_eval_steps
            eval_accuracy = correct / num_samples

            for param_group in optimizer.param_groups:
                lr = param_group['lr']
            print('epoch: {:3d},    lr={:6f},    loss={:5f},    train_acc={:5f},    eval_loss={:5f},    eval_acc={:5f}'
                  .format(epoch+1, lr, train_loss, train_accuracy, eval_loss, eval_accuracy))

In [34]:
train(train_dataloader, eval_dataloader, 100)

Start training
epoch:   1,    lr=0.000300,    loss=0.678607,    train_acc=0.545654,    eval_loss=0.632309,    eval_acc=0.581801
epoch:   2,    lr=0.000300,    loss=0.555032,    train_acc=0.731445,    eval_loss=0.570106,    eval_acc=0.724265
epoch:   3,    lr=0.000300,    loss=0.487330,    train_acc=0.786255,    eval_loss=0.574816,    eval_acc=0.739430
epoch:   4,    lr=0.000300,    loss=0.406753,    train_acc=0.832886,    eval_loss=0.500673,    eval_acc=0.757812
epoch:   5,    lr=0.000300,    loss=0.353022,    train_acc=0.854980,    eval_loss=0.481345,    eval_acc=0.763787
epoch:   6,    lr=0.000300,    loss=0.313275,    train_acc=0.869629,    eval_loss=0.554657,    eval_acc=0.744945
epoch:   7,    lr=0.000300,    loss=0.283488,    train_acc=0.877563,    eval_loss=0.442086,    eval_acc=0.791360
epoch:   8,    lr=0.000300,    loss=0.268458,    train_acc=0.884766,    eval_loss=0.411509,    eval_acc=0.808824
epoch:   9,    lr=0.000300,    loss=0.242499,    train_acc=0.896729,    eval_loss

RuntimeError: cuDNN error: CUDNN_STATUS_INTERNAL_ERROR