In [1]:
import time
import pickle
import numpy as np

from tqdm import tqdm, tqdm_notebook
from pathlib import Path

import torch
from torch.utils.data import Dataset, DataLoader
import torch.nn as nn
import torchaudio
import copy

from matplotlib import colors, pyplot as plt
%matplotlib inline

# в sklearn не все гладко, чтобы в colab удобно выводить картинки 
# мы будем игнорировать warnings
import warnings
warnings.filterwarnings(action='ignore', category=DeprecationWarning)


In [2]:
# First checking if GPU is available
train_on_gpu=torch.cuda.is_available()

if(train_on_gpu):
    print('Training on GPU.')
else:
    print('No GPU available, training on CPU.')
       
DEVICE = torch.device('cuda' if train_on_gpu else 'cpu')

No GPU available, training on CPU.


In [3]:
SAMPLE_RATE = 48000
N_FFT = SAMPLE_RATE * 64 // 1000 + 4
HOP_LENGTH = SAMPLE_RATE * 16 // 1000 + 4

In [4]:
# разные режимы датасета 
DATA_MODES = ['train', 'test']
# работаем на видеокарте
DEVICE = torch.device("cuda")

In [5]:
class SpeechDataset(Dataset):
    """
    Датасет с аудио, который их обрезает/паддит то заданной длины, применяет оконное преобразование Фурье,
    нормализует и приводит к тензору.
    """
    def __init__(self, noisy_files, clean_files, n_fft=64, hop_length=16):
        super().__init__()
        # список файлов для загрузки
        self.noisy_files = sorted(noisy_files)
        self.clean_files = sorted(clean_files)
        self.labels = [path.parent.name for path in self.noisy_files]
        
        # параметры stft
        self.n_fft = n_fft
        self.hop_length = hop_length
        
        # размер датасета
        self.len_ = len(self.noisy_files)
        
        # будем обрезать/паддить waveform аудиофайлов до этого размера
        self.max_len = 165000

    
    def __len__(self):
        return self.len_
      
    def load_sample(self, file):
        waveform, _ = torchaudio.load(file)
        return waveform
  
    def __getitem__(self, index):
        # для преобразования изображений в тензоры PyTorch и нормализации входа
        x_clean = self.load_sample(self.clean_files[index])
        x_noisy = self.load_sample(self.noisy_files[index])
        
        # padding/cutting
        x_clean = self._prepare_sample(x_clean)
        x_noisy = self._prepare_sample(x_noisy)
        
        # оконное преобразование Фурье
        x_noisy_stft = torch.stft(input=x_noisy, n_fft=self.n_fft, 
                                  hop_length=self.hop_length, normalized=True)
        x_clean_stft = torch.stft(input=x_clean, n_fft=self.n_fft, 
                                  hop_length=self.hop_length, normalized=True)
        
        return x_noisy_stft, x_clean_stft
        
    def _prepare_sample(self, waveform):
        waveform = waveform.numpy()
        current_len = waveform.shape[1]
        
        output = np.zeros((1, self.max_len), dtype='float32')
        output[0, -current_len:] = waveform[0, :self.max_len]
        output = torch.from_numpy(output)
        
        return output

In [6]:
TEST_NOISY_DIR = Path('/home/philipp/Projects/DCUnet/data/test/noisy_testset')
TEST_CLEAN_DIR = Path('/home/philipp/Projects/DCUnet/data/test/clean_testset')

# TRAIN_NOISY_DIR = Path('/home/philipp/Projects/DCUnet/data/train/noisy_trainset')
# TRAIN_CLEAN_DIR = Path('/home/philipp/Projects/DCUnet/data/train/clean_trainset')

TRAIN_NOISY_DIR = TEST_NOISY_DIR
TRAIN_CLEAN_DIR = TEST_CLEAN_DIR

In [7]:
test_noisy_files = sorted(list(TEST_NOISY_DIR.rglob('*.wav')))[:100]
test_clean_files = sorted(list(TEST_CLEAN_DIR.rglob('*.wav')))[:100]

train_noisy_files = sorted(list(TRAIN_NOISY_DIR.rglob('*.wav')))[:100]
train_clean_files = sorted(list(TRAIN_CLEAN_DIR.rglob('*.wav')))[:100]

In [8]:
test_dataset = SpeechDataset(test_noisy_files, test_clean_files, N_FFT, HOP_LENGTH)
train_dataset = SpeechDataset(train_noisy_files, train_clean_files, N_FFT, HOP_LENGTH)

In [9]:
test_loader = DataLoader(test_dataset, batch_size=1, num_workers=1, shuffle=True)
train_loader = DataLoader(train_dataset, batch_size=1, num_workers=1, shuffle=True)

In [10]:
class CConv2d(nn.Module):
    def __init__(self, in_channels, out_channels, kernel_size, stride, padding=0):
        super().__init__()
        
        self.in_channels = in_channels
        self.out_channels = out_channels
        self.kernel_size = kernel_size
        self.padding = padding
        self.stride = stride
        
        self.real_conv = nn.Conv2d(in_channels=self.in_channels, 
                                   out_channels=self.out_channels, 
                                   kernel_size=self.kernel_size, 
                                   padding=self.padding, 
                                   stride=self.stride)
        
        self.im_conv = nn.Conv2d(in_channels=self.in_channels, 
                                 out_channels=self.out_channels, 
                                 kernel_size=self.kernel_size, 
                                 padding=self.padding, 
                                 stride=self.stride)
        
        # Glorot initialization.
        nn.init.xavier_uniform_(self.real_conv.weight)
        nn.init.xavier_uniform_(self.im_conv.weight)
        
        
    def forward(self, x):
        x_real = x[..., 0]
        x_im = x[..., 1]
        
        c_real = self.real_conv(x_real) - self.im_conv(x_im)
        c_im = self.im_conv(x_real) + self.real_conv(x_im)
        
        output = torch.stack([c_real, c_im], dim=-1)
        return output

In [11]:
class CConvTranspose2d(nn.Module):
    def __init__(self, in_channels, out_channels, kernel_size, stride, output_padding=0, padding=0):
        super().__init__()
        
        self.in_channels = in_channels
        self.out_channels = out_channels
        self.kernel_size = kernel_size
        self.output_padding = output_padding
        self.padding = padding
        self.stride = stride
        
        self.real_convt = nn.ConvTranspose2d(in_channels=self.in_channels, 
                                            out_channels=self.out_channels, 
                                            kernel_size=self.kernel_size, 
                                            output_padding=self.output_padding,
                                            padding=self.padding,
                                            stride=self.stride)
        
        self.im_convt = nn.ConvTranspose2d(in_channels=self.in_channels, 
                                            out_channels=self.out_channels, 
                                            kernel_size=self.kernel_size, 
                                            output_padding=self.output_padding, 
                                            padding=self.padding,
                                            stride=self.stride)
        
        
        # Glorot initialization.
        nn.init.xavier_uniform_(self.real_convt.weight)
        nn.init.xavier_uniform_(self.im_convt.weight)
        
        
    def forward(self, x):
        x_real = x[..., 0]
        x_im = x[..., 1]
        
        ct_real = self.real_convt(x_real) - self.im_convt(x_im)
        ct_im = self.im_convt(x_real) + self.real_convt(x_im)
        
        output = torch.stack([ct_real, ct_im], dim=-1)
        return output

In [12]:
class CBatchNorm2d(nn.Module):
    def __init__(self, num_features, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True):
        super().__init__()
        
        self.num_features = num_features
        self.eps = eps
        self.momentum = momentum
        self.affine = affine
        self.track_running_stats = track_running_stats
        
        self.real_b = nn.BatchNorm2d(num_features=self.num_features, eps=self.eps, momentum=self.momentum,
                                      affine=self.affine, track_running_stats=self.track_running_stats)
        self.im_b = nn.BatchNorm2d(num_features=self.num_features, eps=self.eps, momentum=self.momentum,
                                    affine=self.affine, track_running_stats=self.track_running_stats) 
        
    def forward(self, x):
        x_real = x[..., 0]
        x_im = x[..., 1]
        
        n_real = self.real_b(x_real)
        n_im = self.im_b(x_im)  
        
        output = torch.stack([n_real, n_im], dim=-1)
        return output

In [13]:
class Encoder(nn.Module):
    def __init__(self, filter_size=(7,5), stride_size=(2,2), in_channels=1, out_channels=45, padding=(0,0)):
        super().__init__()
        
        self.filter_size = filter_size
        self.stride_size = stride_size
        self.in_channels = in_channels
        self.out_channels = out_channels
        self.padding = padding

        self.cconv = CConv2d(in_channels=self.in_channels, out_channels=self.out_channels, 
                             kernel_size=self.filter_size, stride=self.stride_size, padding=self.padding)
        
        self.cbn = CBatchNorm2d(num_features=self.out_channels) 
        
        self.leaky_relu = nn.LeakyReLU()
            
    def forward(self, x):
        
        conved = self.cconv(x)
        normed = self.cbn(conved)
        acted = self.leaky_relu(normed)
        
        return acted

In [14]:
class Decoder(nn.Module):
    def __init__(self, filter_size=(7,5), stride_size=(2,2), in_channels=1, out_channels=45,
                 output_padding=(0,0), padding=(0,0), last_layer=False):
        super().__init__()
        
        self.filter_size = filter_size
        self.stride_size = stride_size
        self.in_channels = in_channels
        self.out_channels = out_channels
        self.output_padding = output_padding
        self.padding = padding
        
        self.last_layer = last_layer
        
        self.cconvt = CConvTranspose2d(in_channels=self.in_channels, out_channels=self.out_channels, 
                             kernel_size=self.filter_size, stride=self.stride_size, output_padding=self.output_padding, padding=self.padding)
        
        self.cbn = CBatchNorm2d(num_features=self.out_channels) 
        
        self.leaky_relu = nn.LeakyReLU()
            
    def forward(self, x):
        
        conved = self.cconvt(x)
        normed = self.cbn(conved)
        if not self.last_layer:
            output = self.leaky_relu(normed)
        else:
            m_phase = normed / torch.abs(normed)
            m_mag = torch.tanh(torch.abs(normed))
            output = m_phase * m_mag
            
        return output

In [15]:
class DCUnet10(nn.Module):
    def __init__(self, n_fft=64, hop_length=16):
        super().__init__()
        
        # for istft
        self.n_fft = n_fft
        self.hop_length = hop_length
        
        # downsampling/encoding
        self.downsample0 = Encoder(filter_size=(7,5), stride_size=(2,2), in_channels=1, out_channels=45)
        self.downsample1 = Encoder(filter_size=(7,5), stride_size=(2,2), in_channels=45, out_channels=90)
        self.downsample2 = Encoder(filter_size=(5,3), stride_size=(2,2), in_channels=90, out_channels=90)
        self.downsample3 = Encoder(filter_size=(5,3), stride_size=(2,2), in_channels=90, out_channels=90)
        self.downsample4 = Encoder(filter_size=(5,3), stride_size=(2,1), in_channels=90, out_channels=90)
        
        # upsampling/decoding
        self.upsample0 = Decoder(filter_size=(5,3), stride_size=(2,1), in_channels=90, out_channels=90)
        self.upsample1 = Decoder(filter_size=(5,3), stride_size=(2,2), in_channels=180, out_channels=90)
        self.upsample2 = Decoder(filter_size=(5,3), stride_size=(2,2), in_channels=180, out_channels=90)
        self.upsample3 = Decoder(filter_size=(7,5), stride_size=(2,2), in_channels=180, out_channels=45)
        self.upsample4 = Decoder(filter_size=(7,5), stride_size=(2,2), in_channels=90, output_padding=(0,1),
                                 out_channels=1, last_layer=True)
        
        
    def forward(self, x, is_istft=False):
        # downsampling/encoding
        d0 = self.downsample0(x)
        d1 = self.downsample1(d0) 
        d2 = self.downsample2(d1)        
        d3 = self.downsample3(d2)        
        d4 = self.downsample4(d3)
        
        # upsampling/decoding
        u0 = self.upsample0(d4)
        c0 = torch.cat((u0, d3), dim=1)
        
        u1 = self.upsample1(c0)
        c1 = torch.cat((u1, d2), dim=1)
        
        u2 = self.upsample2(c1)
        c2 = torch.cat((u2, d1), dim=1)
        
        u3 = self.upsample3(c2)
        c3 = torch.cat((u3, d0), dim=1)
        
        u4 = self.upsample4(c3)
        
        # u4 - the mask
        output = u4 * x
        if is_istft:
            output = torchaudio.functional.istft(output, n_fft=self.n_fft, hop_length=self.hop_length, normalized=True)
        
        return output

In [16]:
dcunet10 = DCUnet10(N_FFT, HOP_LENGTH)

In [17]:
class sdr_loss(nn.Module):
    def __init__(self, eps=1e-8):
        super().__init__()
        self.eps = eps
    def forward(self, y_pred, y_true):
        num = (y_pred * y_true).mean()
        den = torch.norm(y_pred) * torch.norm(y_true)
        return (-1) * num/den

In [18]:
class wsdr_loss(nn.Module):
    def __init__(self, eps=1e-8):
        super().__init__()
        self.eps = eps
        self.sdr = sdr_loss()
    def forward(self, x, y_pred, y_true):
        z_pred = x - y_pred
        z_true = x - y_true
        sdr_y = self.sdr(y_pred, y_true)
        sdr_z = self.sdr(z_pred, z_true)
        alpha = (torch.norm(y_true) ** 2) / ((torch.norm(y_true) ** 2) * (torch.norm(z_true) ** 2))
        
        return alpha * sdr_y + (1 - alpha) * sdr_z

In [19]:
train_on_gpu = False

In [20]:
def train(net, train_loader, test_loader, loss_fn, optimizer, scheduler, epochs, train_on_gpu=False):
    net.train()
    losses = []
    if(train_on_gpu):
        net = net.cuda()
        
    for e in tqdm(range(epochs)):
        # batch loop
        counter = 0
        for noisy_x, clean_x in tqdm(train_loader):
            counter += 1

            if(train_on_gpu):
                noisy_x, clean_x = noisy_x.cuda(), clean_x.cuda()

            # zero  gradients
            net.zero_grad()

            # get the output from the model
            pred_x = net(noisy_x)

            # calculate the loss and perform backprop
            loss = loss_fn(noisy_x, pred_x, clean_x)
            print(loss)
            losses.append(loss)
            loss.backward()
            
            optimizer.step()

            # loss stats
            if counter % 20 == 0:
                # Get validation loss
                val_losses = []
                net.eval()
                for noisy_val, clean_val in test_loader:

                    if(train_on_gpu):
                        noisy_val, clean_val = noisy_val.cuda(), clean_val.cuda()

                    pred_val = net(noisy_val)

                    # calculate the loss
                    val_loss = loss(n_val, pred_val, clean_val)
                    val_losses.append(val_loss.item())

                net.train()
                print("Epoch: {}/{}...".format(e+1, epochs),
                      "Step: {}...".format(counter),
                      "Loss: {:.6f}...".format(loss.item()),
                      "Val Loss: {:.6f}".format(np.mean(val_losses)))
      
    scheduler.step()
    return losses

In [21]:
loss_fn = wsdr_loss()
optimizer = torch.optim.Adam(dcunet10.parameters())
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=1, gamma=0.1)

In [None]:
ls = train(dcunet10, train_loader, test_loader, loss_fn, optimizer, scheduler, epochs=1)

  0%|          | 0/1 [00:00<?, ?it/s]
  0%|          | 0/100 [00:00<?, ?it/s][A

tensor(-6.3210e-07, grad_fn=<AddBackward0>)



  1%|          | 1/100 [00:03<04:59,  3.02s/it][A

tensor(-3.6107e-07, grad_fn=<AddBackward0>)



  2%|▏         | 2/100 [00:05<04:49,  2.96s/it][A

tensor(-5.9302e-07, grad_fn=<AddBackward0>)



  3%|▎         | 3/100 [00:08<04:40,  2.89s/it][A

tensor(-1.0415e-06, grad_fn=<AddBackward0>)



  4%|▍         | 4/100 [00:11<04:41,  2.94s/it][A

tensor(-2.9863e-07, grad_fn=<AddBackward0>)



  5%|▌         | 5/100 [00:14<04:33,  2.88s/it][A

tensor(-5.9556e-07, grad_fn=<AddBackward0>)



  6%|▌         | 6/100 [00:17<04:26,  2.83s/it][A

tensor(-9.5648e-07, grad_fn=<AddBackward0>)



  7%|▋         | 7/100 [00:20<04:31,  2.92s/it][A

tensor(-6.3801e-07, grad_fn=<AddBackward0>)



  8%|▊         | 8/100 [00:23<04:40,  3.05s/it][A

tensor(-1.1442e-06, grad_fn=<AddBackward0>)



  9%|▉         | 9/100 [00:26<04:41,  3.09s/it][A

tensor(-7.2000e-07, grad_fn=<AddBackward0>)



 10%|█         | 10/100 [00:29<04:39,  3.10s/it][A

tensor(-1.3022e-06, grad_fn=<AddBackward0>)



 11%|█         | 11/100 [00:32<04:35,  3.10s/it][A

tensor(-6.7683e-07, grad_fn=<AddBackward0>)



 12%|█▏        | 12/100 [00:36<04:31,  3.08s/it][A

tensor(-5.4090e-07, grad_fn=<AddBackward0>)



 13%|█▎        | 13/100 [00:39<04:27,  3.08s/it][A

tensor(-1.1459e-06, grad_fn=<AddBackward0>)



 14%|█▍        | 14/100 [00:42<04:26,  3.09s/it][A

tensor(-1.4401e-06, grad_fn=<AddBackward0>)



 15%|█▌        | 15/100 [00:45<04:21,  3.08s/it][A

tensor(-1.0991e-06, grad_fn=<AddBackward0>)



 16%|█▌        | 16/100 [00:48<04:17,  3.06s/it][A

tensor(-1.3447e-06, grad_fn=<AddBackward0>)



 17%|█▋        | 17/100 [00:51<04:14,  3.06s/it][A

tensor(-9.8020e-07, grad_fn=<AddBackward0>)



 18%|█▊        | 18/100 [00:54<04:13,  3.09s/it][A

tensor(-1.0819e-06, grad_fn=<AddBackward0>)



 19%|█▉        | 19/100 [00:57<04:10,  3.09s/it][A

tensor(-1.2726e-06, grad_fn=<AddBackward0>)


In [None]:
plt.plot(ls)