In [1]:
import torch
import torchaudio
from torch import nn
from torch.nn import functional as F

In [2]:
import sys
sys.path.append('../wave_wizard/')

In [3]:
from src.dataset import get_loader

In [4]:
SR = 22050
SECS = 11
length=SR*SECS
sample_rate=22050

In [53]:
config = dict({
    'dataset': {
        'json_dir': '../dataset/',
        'length': SR*SECS,
        'sample_rate': SR,
        'num_samples': 10_000,

    },
    'dataloader': {
        'batch_size': 64,
    }
})
num_samples = config['dataset'].pop('num_samples')
num_samples = config['dataset'].pop('a', None)
loader = get_loader(config)

In [46]:
len(loader)

449182

# Basic Conv

In [6]:
class BasicConv(nn.Module):
    def __init__(self,
                 in_channels,
                 out_channels,
                 kernel_size=3,
                 stride=1,
                 padding=1,
                 bias=True,
                 activation=nn.PReLU()):
        super(BasicConv, self).__init__()
        self.conv = nn.Conv1d(
            in_channels, out_channels,
            kernel_size=kernel_size,
            stride=stride, padding=padding,
            bias=bias
        )
        self.bn = nn.BatchNorm1d(out_channels)
        self.activation = activation
        

    def forward(self, x):
        x = self.conv(x)
        x = self.bn(x)
        x = self.activation(x)
        return x

In [7]:
class BasicDeConv(nn.Module):
    def __init__(self,
                 in_channels,
                 out_channels,
                 kernel_size=3,
                 stride=1,
                 padding=1,
                 bias=True,
                 activation=nn.PReLU()):
        super(BasicDeConv, self).__init__()
        self.deconv = nn.ConvTranspose1d(
            in_channels, out_channels,
            kernel_size=kernel_size,
            stride=stride, padding=padding,
            bias=bias
        )
        self.bn = nn.BatchNorm1d(out_channels)
        self.activation = activation
        

    def forward(self, x):
        x = self.deconv(x)
        x = self.bn(x)
        x = self.activation(x)
        return x

# Gated Conv

In [16]:
class GatedConv(nn.Module):

    def __init__(self,
                in_channels, out_channels,
                kernel_size=3, stride=1, padding=1,
                dilation=1, groups=1, bias=True,
                batch_norm=True,
                activation=nn.LeakyReLU(0.2, inplace=True)):
        super(GatedConv, self).__init__()
        self.batch_norm = batch_norm
        self.activation = activation
        self.conv = nn.Conv1d(
                in_channels, out_channels,
                kernel_size=kernel_size,
                stride=stride, padding=padding,
            
                bias=bias
            )
        self.mask_conv = nn.Conv1d(
                in_channels, out_channels,
                kernel_size=kernel_size,
                stride=stride, padding=padding,
                bias=bias
            )
        self.batch_norm = nn.BatchNorm1d(out_channels)
        self.sigmoid = nn.Sigmoid()

        for m in self.modules():
            if isinstance(m, nn.Conv1d):
                nn.init.kaiming_normal_(m.weight)
    
    def gated(self, mask):
        return self.sigmoid(mask)

    def forward(self, input):
        x = self.conv(input)
        mask = self.mask_conv(input)
        x = x * self.gated(mask)
        x = self.batch_norm(x)
        return x

In [32]:
class GatedDeConv(nn.Module):

    def __init__(self,
                in_channels, out_channels,
                kernel_size=3, stride=1, padding=1,
                dilation=1, groups=1, bias=True,
                batch_norm=True,
                activation=nn.LeakyReLU(0.2, inplace=True)):
        super(GatedDeConv, self).__init__()
        self.batch_norm = batch_norm
        self.activation = activation
        self.deconv = nn.ConvTranspose1d(
                in_channels, out_channels,
                kernel_size=kernel_size,
                stride=stride, padding=padding,
                bias=bias
            )
        self.mask_deconv = nn.ConvTranspose1d(
                in_channels, out_channels,
                kernel_size=kernel_size,
                stride=stride, padding=padding,
                bias=bias
            )
        self.batch_norm = nn.BatchNorm1d(out_channels)
        self.sigmoid = nn.Sigmoid()

        for m in self.modules():
            if isinstance(m, nn.ConvTranspose1d):
                nn.init.kaiming_normal_(m.weight)
    
    def gated(self, mask):
        return self.sigmoid(mask)

    def forward(self, input):
        x = self.deconv(input)
        mask = self.mask_deconv(input)
        x = x * self.gated(mask)
        x = self.batch_norm(x)
        return x

In [33]:
for noisy, clean in loader:
    break

In [34]:
import math

class GateWave(nn.Module):
    def __init__(self,
                 depth=3, scale=2, init_hidden=32,
                 kernel_size=7, stride=1, padding=2,
                 encoder_class=BasicConv,
                 decoder_class=BasicDeConv):
        super(GateWave, self).__init__()
        self.depth = depth
        self.kernel_size = kernel_size
        self.stride = stride
        in_channels = 1
        out_channels = 1
        encoders = []
        decoders = []
        
        hidden = init_hidden
        in_ch = in_channels
        for i in range(depth):
            
            encoder = encoder_class(in_channels, hidden, kernel_size, stride, padding)
            encoders.append(encoder)
            
            decoder = decoder_class(hidden, out_channels, kernel_size, stride, padding)
            decoders.append(decoder)
            out_channels = hidden
            in_channels = hidden
            hidden = int(hidden * scale)
            
            
        self.encoder = nn.Sequential(*encoders)
        self.decoder = nn.Sequential(*decoders[::-1])
    
    def valid_length(self, length):
        """
        Return the nearest valid length to use with the model so that
        there is no time steps left over in a convolutions, e.g. for all
        layers, size of the input - kernel_size % stride = 0.
        If the mixture has a valid length, the estimated sources
        will have exactly the same length.
        """
#         length = math.ceil(length * self.resample)
        for idx in range(self.depth):
            length = math.ceil((length - self.kernel_size) / self.stride) + 1
            length = max(length, 1)
        for idx in range(self.depth):
            length = (length - 1) * self.stride + self.kernel_size
#         length = int(math.ceil(length / self.resample))
        return int(length)
    
    
    def forward(self, x):
        length = x.shape[-1]
        x = F.pad(x, (0, self.valid_length(length) - length))
        latent = self.encoder(x)
        output = self.decoder(latent)
        return output

In [35]:
model = GateWave(encoder_class=GatedConv, decoder_class=GatedDeConv)

In [37]:
out = model(noisy)

In [38]:
loss = nn.L1Loss()

In [40]:
for noisy, clean in loader:
    pred = model(noisy)
    l = loss(clean, pred)
    break

In [41]:
l

tensor(0.2287, grad_fn=<MeanBackward0>)

In [None]:
import py

In [18]:
loss = STFTLoss()

In [22]:
loss(clean, out)

RuntimeError: stft(torch.FloatTensor[32, 1, 243574], n_fft=1024, hop_length=120, win_length=600, window=torch.FloatTensor{[600]}, normalized=0, onesided=None, return_complex=1) : expected a 1D or 2D tensor

In [20]:
depth=3
scale=2
hidden=32
kernel_size=7
stride=1
padding=2
encoder_class=BasicConv
decoder_class=BasicDeConv

In [21]:
# in_channels = 1
# out_channels = 1
# encoders = []
# decoders = []

# in_ch = in_channels
# for i in range(depth):

#     encoder = encoder_class(in_channels, hidden, kernel_size, stride, padding)
#     encoders.append(encoder)

#     decoder = decoder_class(hidden, out_channels, kernel_size, stride, padding)
#     decoders.append(decoder)
#     out_channels = hidden
#     in_channels = hidden
#     hidden *= scale