In [1]:
import time
import pickle
import numpy as np

from tqdm import tqdm, tqdm_notebook
from PIL import Image
from pathlib import Path

import torch
from torch.utils.data import Dataset, DataLoader
import torch.nn as nn
import torchaudio
import copy

from matplotlib import colors, pyplot as plt
%matplotlib inline

# в sklearn не все гладко, чтобы в colab удобно выводить картинки 
# мы будем игнорировать warnings
import warnings
warnings.filterwarnings(action='ignore', category=DeprecationWarning)


In [30]:
SAMPLE_RATE = 48000
N_FFT = SAMPLE_RATE * 64 // 1000
HOP_LENGTH = SAMPLE_RATE * 16 // 1000

In [31]:
# разные режимы датасета 
DATA_MODES = ['train', 'test']
# работаем на видеокарте
DEVICE = torch.device("cuda")

In [32]:
class SpeechDataset(Dataset):
    """
    Датасет с аудио, который их обрезает/паддит то заданной длины, применяет оконное преобразование Фурье,
    нормализует и приводит к тензору.
    """
    def __init__(self, noisy_files, clean_files, n_fft=64, hop_length=16):
        super().__init__()
        # список файлов для загрузки
        self.noisy_files = sorted(noisy_files)
        self.clean_files = sorted(clean_files)
        self.labels = [path.parent.name for path in self.noisy_files]
        
        # параметры stft
        self.n_fft = n_fft
        self.hop_length = hop_length
        
        # размер датасета
        self.len_ = len(self.noisy_files)
        
        # будем обрезать/паддить waveform аудиофайлов до этого размера
        self.max_len = 164981

    def __len__(self):
        return self.len_
      
    def load_sample(self, file):
        waveform, _ = torchaudio.load(file)
        return waveform
  
    def __getitem__(self, index):
        # для преобразования изображений в тензоры PyTorch и нормализации входа
        x_clean = self.load_sample(self.clean_files[index])
        x_noisy = self.load_sample(self.noisy_files[index])
        
        # padding/cutting
        x_clean = self._prepare_sample(x_clean)
        x_noisy = self._prepare_sample(x_noisy)
        
        # оконное преобразование Фурье
        x_noisy_stft = torch.stft(input=x_noisy, n_fft=self.n_fft, 
                                  hop_length=self.hop_length, normalized=True)
        x_clean_stft = torch.stft(input=x_clean, n_fft=self.n_fft, 
                                  hop_length=self.hop_length, normalized=True)
        
        return x_noisy_stft, x_clean_stft
        
    def _prepare_sample(self, waveform):
        waveform = waveform.numpy()
        current_len = waveform.shape[1]
        
        output = np.zeros((1, self.max_len), dtype='float32')
        output[0, -current_len:] = waveform[0, :self.max_len]
        output = torch.from_numpy(output)
        
        return output

In [33]:
TEST_NOISY_DIR = Path('/home/philipp/Projects/DCUnet/data/test/noisy_testset')
TEST_CLEAN_DIR = Path('/home/philipp/Projects/DCUnet/data/test/clean_testset')

In [34]:
test_noisy_files = sorted(list(TEST_NOISY_DIR.rglob('*.wav')))[:5]
test_clean_files = sorted(list(TEST_CLEAN_DIR.rglob('*.wav')))[:5]

In [35]:
test_dataset = SpeechDataset(test_noisy_files, test_clean_files, N_FFT, HOP_LENGTH)

In [36]:
test_loader = DataLoader(test_dataset, batch_size=2, num_workers=4, shuffle=True)

In [37]:
a = None
for x_noisy, x_clean in test_loader:
    a = x_noisy
    break

In [38]:
# a.shape
a_r = a[..., 0]
a_im = a[..., 1]
a_r.shape

torch.Size([2, 1, 1537, 215])

In [39]:
torch.stack([a_r, a_im], dim=-1).shape

torch.Size([2, 1, 1537, 215, 2])

In [40]:
class CConv2d(nn.Module):
    def __init__(self, in_channels, out_channels, kernel_size, stride, padding=0):
        super().__init__()
        
        self.in_channels = in_channels
        self.out_channels = out_channels
        self.kernel_size = kernel_size
        self.padding = padding
        self.stride = stride
        
        self.real_conv = nn.Conv2d(in_channels=self.in_channels, 
                                   out_channels=self.out_channels, 
                                   kernel_size=self.kernel_size, 
                                   padding=self.padding, 
                                   stride=self.stride)
        
        self.im_conv = nn.Conv2d(in_channels=self.in_channels, 
                                 out_channels=self.out_channels, 
                                 kernel_size=self.kernel_size, 
                                 padding=self.padding, 
                                 stride=self.stride)
        
        # Glorot initialization.
        nn.init.xavier_uniform_(self.real_conv.weight)
        nn.init.xavier_uniform_(self.im_conv.weight)
        
        
    def forward(self, x):
        x_real = x[..., 0]
        x_im = x[..., 1]
        
        c_real = self.real_conv(x_real) - self.im_conv(x_im)
        c_im = self.im_conv(x_real) + self.real_conv(x_im)
        
        output = torch.stack([c_real, c_im], dim=-1)
        return output

In [41]:
class CConvTranspose2d(nn.Module):
    def __init__(self, in_channels, out_channels, kernel_size, stride, padding=0):
        super().__init__()
        
        self.in_channels = in_channels
        self.out_channels = out_channels
        self.kernel_size = kernel_size
        self.padding = padding
        self.stride = stride
        
        self.real_convt = nn.ConvTranspose2d(in_channels=self.in_channels, 
                                            out_channels=self.out_channels, 
                                            kernel_size=self.kernel_size, 
                                            padding=self.padding, 
                                            stride=self.stride)
        
        self.im_convt = nn.ConvTranspose2d(in_channels=self.in_channels, 
                                            out_channels=self.out_channels, 
                                            kernel_size=self.kernel_size, 
                                            padding=self.padding, 
                                            stride=self.stride)
        
        
        # Glorot initialization.
        nn.init.xavier_uniform_(self.real_convt.weight)
        nn.init.xavier_uniform_(self.im_convt.weight)
        
        
    def forward(self, x):
        x_real = x[..., 0]
        x_im = x[..., 1]
        
        ct_real = self.real_convt(x_real) - self.im_convt(x_im)
        ct_im = self.im_convt(x_real) + self.real_convt(x_im)
        
        output = torch.stack([ct_real, ct_im], dim=-1)
        return output

In [42]:
class CBatchNorm2d(nn.Module):
    def __init__(self, num_features, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True):
        super().__init__()
        
        self.num_features = num_features
        self.eps = eps
        self.momentum = momentum
        self.affine = affine
        self.track_running_stats = track_running_stats
        
        self.real_b = nn.BatchNorm2d(num_features=self.num_features, eps=self.eps, momentum=self.momentum,
                                      affine=self.affine, track_running_stats=self.track_running_stats)
        self.im_b = nn.BatchNorm2d(num_features=self.num_features, eps=self.eps, momentum=self.momentum,
                                    affine=self.affine, track_running_stats=self.track_running_stats) 
        
    def forward(self, x):
        x_real = x[..., 0]
        x_im = x[..., 1]
        
        n_real = self.real_b(x_real)
        n_im = self.im_b(x_im)  
        
        output = torch.stack([n_real, n_im], dim=-1)
        return output

In [43]:
class Encoder(nn.Module):
    def __init__(self, filter_size=(7,5), stride_size=(2,2), in_channels=1, out_channels=45, padding=0):
        super().__init__()
        
        self.filter_size = filter_size
        self.stride_size = stride_size
        self.in_channels = in_channels
        self.out_channels = out_channels

        self.cconv = CConv2d(in_channels=self.in_channels, out_channels=self.out_channels, 
                             kernel_size=self.filter_size, stride=self.stride_size)
        
        self.cbn = CBatchNorm2d(num_features=self.out_channels) 
        
        self.leaky_relu = nn.LeakyReLU()
            
    def forward(self, x):
        
        conved = self.cconv(x)
        normed = self.cbn(conved)
        acted = self.leaky_relu(normed)
        
        return acted

In [44]:
class Decoder(nn.Module):
    def __init__(self, filter_size=(7,5), stride_size=(2,2), in_channels=1, 
                 out_channels=45, last_layer=False, padding=(1,1)):
        super().__init__()
        
        self.filter_size = filter_size
        self.stride_size = stride_size
        self.in_channels = in_channels
        self.out_channels = out_channels

        self.last_layer = last_layer
        
        self.cconvt = CConvTranspose2d(in_channels=self.in_channels, out_channels=self.out_channels, 
                             kernel_size=self.filter_size, stride=self.stride_size, padding=padding)
        
        self.cbn = CBatchNorm2d(num_features=self.out_channels) 
        
        self.leaky_relu = nn.LeakyReLU()
            
    def forward(self, x):
        
        conved = self.cconvt(x)
        normed = self.cbn(conved)
        if not self.last_layer:
            output = self.leaky_relu(normed)
        else:
            m_phase = normed / torch.abs(normed)
            m_mag = nn.Tanh(torch.abs(normed))
            output = m_phase * m_mag
            
        return output

In [45]:
class DCUnet10(nn.Module):
    def __init__(self, n_fft=64, hop_length=16):
        super().__init__()
        
        # for istft
        self.n_fft = n_fft
        self.hop_length = hop_length
        
        # downsampling/encoding
        self.downsample0 = Encoder(filter_size=(7,5), stride_size=(2,2), in_channels=1, out_channels=45)
        self.downsample1 = Encoder(filter_size=(7,5), stride_size=(2,2), in_channels=45, out_channels=90)
        self.downsample2 = Encoder(filter_size=(5,3), stride_size=(2,2), in_channels=90, out_channels=90)
        self.downsample3 = Encoder(filter_size=(5,3), stride_size=(2,2), in_channels=90, out_channels=90)
        self.downsample4 = Encoder(filter_size=(5,3), stride_size=(2,1), in_channels=90, out_channels=90)
        
        # upsampling/decoding
        self.upsample0 = Decoder(filter_size=(5,3), stride_size=(2,1), in_channels=90, out_channels=90,)
        self.upsample1 = Decoder(filter_size=(5,3), stride_size=(2,2), in_channels=180, out_channels=90)
        self.upsample2 = Decoder(filter_size=(5,3), stride_size=(2,2), in_channels=180, out_channels=90)
        self.upsample3 = Decoder(filter_size=(7,3), stride_size=(2,2), in_channels=180, out_channels=45)
        self.upsample4 = Decoder(filter_size=(7,3), stride_size=(2,2), in_channels=90, 
                                 out_channels=1, last_layer=True)
        
        
    def forward(self, x):
        # downsampling/encoding
        print(x.shape)
        d0 = self.downsample0(x)
        print(d0.shape)
        d1 = self.downsample1(d0) 
        print(d1.shape)
        d2 = self.downsample2(d1)        
        print(d2.shape)
        d3 = self.downsample3(d2)        
        print(d3.shape)
        d4 = self.downsample4(d3)
        print(d4.shape)
        # upsampling/decoding
        u0 = self.upsample0(d4)
        print(u0.shape)
        c0 = torch.cat((u0, d3), dim=1)
        u1 = self.upsample1(c0)
        
        c1 = torch.cat((u1, d2), dim=1)
        u2 = self.upsample2(c1)
        
        c2 = torch.cat((u2, d1), dim=1)
        u3 = self.upsample3(c2)
        
        c3 = torch.cat((u3, d0), dim=1)
        u4 = self.upsample4(c3)
        
        # u4 - the mask
        output = u4 * x
        istft = torch.istft(input=output, n_fft=self.n_fft, hop_length=self.hop_length, normalized=True)
        
        return istft

In [46]:
dcunet10 = DCUnet10(N_FFT, HOP_LENGTH)

In [47]:
a_n, a_c = None, None
for x_noisy, x_clean in test_loader:
    a_n = x_noisy
    a_c = x_clean
    break

In [48]:
dcunet10(a_n)

torch.Size([2, 1, 1537, 215, 2])
torch.Size([2, 45, 766, 106, 2])
torch.Size([2, 90, 380, 51, 2])
torch.Size([2, 90, 188, 25, 2])
torch.Size([2, 90, 92, 12, 2])
torch.Size([2, 90, 44, 10, 2])
torch.Size([2, 90, 89, 10, 2])


RuntimeError: Sizes of tensors must match except in dimension 1. Got 89 and 92 in dimension 2