In [55]:
import math
import torch
import torch.nn as nn
import torch.nn.functional as F
import pandas as pd
import numpy as np
from tqdm.auto import tqdm
from sklearn.datasets import make_classification

In [5]:
NUM_CLASSES = 10
X, y = make_classification(100, 16000, n_informative=1000, n_classes=NUM_CLASSES, random_state=914)
X = X.astype(np.float32)
y = y.astype(np.int64)

In [6]:
y

array([1, 3, 1, 3, 5, 8, 1, 2, 0, 0, 9, 5, 8, 2, 9, 6, 3, 4, 3, 9, 9, 3,
       4, 0, 2, 2, 0, 6, 6, 1, 9, 3, 1, 6, 4, 7, 0, 8, 7, 1, 2, 6, 8, 0,
       1, 0, 4, 7, 6, 7, 4, 5, 3, 7, 8, 7, 3, 7, 8, 5, 4, 7, 2, 5, 9, 4,
       4, 8, 1, 2, 7, 8, 5, 4, 5, 9, 8, 2, 6, 3, 3, 5, 6, 0, 1, 9, 6, 8,
       0, 7, 0, 5, 9, 4, 2, 5, 9, 1, 6, 2], dtype=int64)

In [7]:
X

array([[ -0.24934824,  -1.2389227 ,  -0.12173659, ...,  14.727759  ,
         -0.7366306 ,  -1.351909  ],
       [  1.0690868 ,  -1.1818848 ,  -0.6803679 , ..., -16.5411    ,
          0.24864252,   0.07018664],
       [ -0.39488146,  -0.43264174,   0.5022811 , ...,   9.970029  ,
         -0.19629164,  -0.8747166 ],
       ...,
       [ -1.7917298 ,   1.3505819 ,   0.046682  , ...,  16.731117  ,
         -1.1025734 ,   1.4075642 ],
       [ -0.8673802 ,   0.49094677,   1.4046149 , ...,   7.9407005 ,
          0.27935445,  -0.5779842 ],
       [ -0.44752482,  -0.56108904,  -1.0531406 , ..., -14.808728  ,
          0.23746619,  -1.4545943 ]], dtype=float32)

In [28]:
def joint_different_speakers(waveforms, speakers, num_mix):
    """
    Parameters
    ----------
    audio_files: list
    speakers: list
    num_mix: int
    """
    mixed = []
    for i, wav in enumerate(waveforms):
        current_speaker = speakers[i]
        is_different_speakers = list(map(lambda x: x!=current_speaker, speakers))
        different_speakers_idx = [k for k, boolean in enumerate(is_different_speakers) if boolean]
        select_idx = list(np.random.choice(different_speakers_idx, num_mix, replace=False))
        for j in select_idx:
            trg_wav = wav
            itf_wav = waveforms[j]
            trg_spk = current_speaker
            itf_spk = speakers[j]
            yield trg_wav, itf_wav, trg_spk, itf_spk

def mix_speakers_by_snr(waveforms, speakers, num_mix, snr):
    generator = joint_different_speakers(waveforms, speakers, num_mix)
    mixed_data = []
    for trg_wav, itf_wav, trg_spk, itf_spk in tqdm(generator, total=len(speakers)*num_mix):
        # Calculate the scale to mix two speakers based on fixed SNR
        itf_spk_power = np.mean(np.square(trg_wav)) / (10**(snr/10))
        scale = np.sqrt(itf_spk_power / np.mean(np.square(itf_wav)))

        # Mix two speakers based on given snr
        mix_wav = trg_wav + scale * itf_wav

        mixed_data.append([mix_wav, itf_spk, trg_spk])
    return mixed_data

In [48]:
mixed_data = mix_speakers_by_snr(X, y, num_mix=2, snr=5)
mix_wavs = np.stack([data[0] for data in mixed_data])
itf_spks = np.stack([data[1] for data in mixed_data])
trg_spks = np.stack([data[2] for data in mixed_data])
X_mix_wavs = torch.tensor(mix_wavs, dtype=torch.float)

y_onehot = np.zeros((len(mix_wavs), NUM_CLASSES), dtype=int)
for idx, (itf_spk, trg_spk) in enumerate(zip(itf_spks, trg_spks)):
    y_onehot[idx, itf_spk] = 1
    y_onehot[idx, trg_spk] = 1
y_onehot = torch.from_numpy(y_onehot)

  0%|          | 0/200 [00:00<?, ?it/s]

In [49]:
class SVD(nn.Module):
    """
    Singular value decomposition layer
    
    Examples
    --------
    >>> A = torch.rand(126, 100, 20).to('cuda')
    >>> U, S, V = SVD(compute_uv=True)(A)
    >>> A_ = torch.matmul(U, torch.matmul(S, V.transpose(-1, -2)))
    >>> print(torch.dist(A_, A))
    """
    def __init__(self):
        super(SVD, self).__init__()
    
    def forward(self, A):
        """
        Inputs
        ------
        A: [b, m, n]
        
        Outputs
        -------
        U: [b, m, n]
        S: [b, n, n]
        V: [b, n, n]
        """
        return self.svd_(A)
        
    @staticmethod
    def svd_(A):
        """
        Parameters
        ----------
        A: torch.FloatTensor
            A tensor of shape [b, m, n].

        Returns
        -------
        U: [b, m, n]
        S: [b, n, n]
        V: [b, n, n]

        References
        ----------
        1. https://www.youtube.com/watch?v=pSbafxDHdgE&t=205s
        2. https://www2.math.ethz.ch/education/bachelor/lectures/hs2014/other/linalg_INFK/svdneu.pdf
        """
        ATA = torch.matmul(A.transpose(-1, -2), A)
        lv, vv = torch.linalg.eig(ATA)
        lv = lv.real
        vv = vv.real
        V = F.normalize(vv, dim=1)
        S = torch.diag_embed(torch.sqrt(lv))
        U = torch.matmul(torch.matmul(A, V), torch.inverse(S))
        return U, S, V

In [57]:
class SincConv(nn.Module):
    """This function implements SincConv (SincNet).
    M. Ravanelli, Y. Bengio, "Speaker Recognition from raw waveform with
    SincNet", in Proc. of  SLT 2018 (https://arxiv.org/abs/1808.00158)
    Arguments
    ---------
    input_shape : tuple
        The shape of the input. Alternatively use ``in_channels``.
    in_channels : int
        The number of input channels. Alternatively use ``input_shape``.
    out_channels : int
        It is the number of output channels.
    kernel_size: int
        Kernel size of the convolutional filters.
    stride : int
        Stride factor of the convolutional filters. When the stride factor > 1,
        a decimation in time is performed.
    dilation : int
        Dilation factor of the convolutional filters.
    padding : str
        (same, valid, causal). If "valid", no padding is performed.
        If "same" and stride is 1, output shape is the same as the input shape.
        "causal" results in causal (dilated) convolutions.
    padding_mode : str
        This flag specifies the type of padding. See torch.nn documentation
        for more information.
    groups : int
        This option specifies the convolutional groups. See torch.nn
        documentation for more information.
    bias : bool
        If True, the additive bias b is adopted.
    sample_rate : int,
        Sampling rate of the input signals. It is only used for sinc_conv.
    min_low_hz : float
        Lowest possible frequency (in Hz) for a filter. It is only used for
        sinc_conv.
    min_low_hz : float
        Lowest possible value (in Hz) for a filter bandwidth.
    Example
    -------
    >>> inp_tensor = torch.rand([10, 16000])
    >>> conv = SincConv(input_shape=inp_tensor.shape, out_channels=25, kernel_size=11)
    >>> out_tensor = conv(inp_tensor)
    >>> out_tensor.shape
    torch.Size([10, 16000, 25])
    """

    def __init__(
        self,
        out_channels,
        kernel_size,
        input_shape=None,
        in_channels=None,
        stride=1,
        dilation=1,
        padding="same",
        padding_mode="reflect",
        sample_rate=16000,
        min_low_hz=50,
        min_band_hz=50,
    ):
        super().__init__()
        self.out_channels = out_channels
        self.kernel_size = kernel_size
        self.stride = stride
        self.dilation = dilation
        self.padding = padding
        self.padding_mode = padding_mode
        self.sample_rate = sample_rate
        self.min_low_hz = min_low_hz
        self.min_band_hz = min_band_hz

        # input shape inference
        if input_shape is None and in_channels is None:
            raise ValueError("Must provide one of input_shape or in_channels")

        if in_channels is None:
            in_channels = self._check_input_shape(input_shape)

        # Initialize Sinc filters
        self._init_sinc_conv()

    def forward(self, x):
        """Returns the output of the convolution.
        Arguments
        ---------
        x : torch.Tensor (batch, time, channel)
            input to convolve. 2d or 4d tensors are expected.
        """
        x = x.transpose(1, -1)
        self.device = x.device

        unsqueeze = x.ndim == 2
        if unsqueeze:
            x = x.unsqueeze(1)

        if self.padding == "same":
            x = self._manage_padding(
                x, self.kernel_size, self.dilation, self.stride
            )

        elif self.padding == "causal":
            num_pad = (self.kernel_size - 1) * self.dilation
            x = F.pad(x, (num_pad, 0))

        elif self.padding == "valid":
            pass

        else:
            raise ValueError(
                "Padding must be 'same', 'valid' or 'causal'. Got %s."
                % (self.padding)
            )

        sinc_filters = self._get_sinc_filters()

        wx = F.conv1d(
            x,
            sinc_filters,
            stride=self.stride,
            padding=0,
            dilation=self.dilation,
        )

        if unsqueeze:
            wx = wx.squeeze(1)

        wx = wx.transpose(1, -1)

        return wx

    def _check_input_shape(self, shape):
        """Checks the input shape and returns the number of input channels.
        """

        if len(shape) == 2:
            in_channels = 1
        elif len(shape) == 3:
            in_channels = 1
        else:
            raise ValueError(
                "sincconv expects 2d or 3d inputs. Got " + str(len(shape))
            )

        # Kernel size must be odd
        if self.kernel_size % 2 == 0:
            raise ValueError(
                "The field kernel size must be an odd number. Got %s."
                % (self.kernel_size)
            )
        return in_channels

    def _get_sinc_filters(self,):
        """This functions creates the sinc-filters to used for sinc-conv.
        """
        # Computing the low frequencies of the filters
        low = self.min_low_hz + torch.abs(self.low_hz_)

        # Setting minimum band and minimum freq
        high = torch.clamp(
            low + self.min_band_hz + torch.abs(self.band_hz_),
            self.min_low_hz,
            self.sample_rate / 2,
        )
        band = (high - low)[:, 0]

        # Passing from n_ to the corresponding f_times_t domain
        self.n_ = self.n_.to(self.device)
        self.window_ = self.window_.to(self.device)
        f_times_t_low = torch.matmul(low, self.n_)
        f_times_t_high = torch.matmul(high, self.n_)

        # Left part of the filters.
        band_pass_left = (
            (torch.sin(f_times_t_high) - torch.sin(f_times_t_low))
            / (self.n_ / 2)
        ) * self.window_

        # Central element of the filter
        band_pass_center = 2 * band.view(-1, 1)

        # Right part of the filter (sinc filters are symmetric)
        band_pass_right = torch.flip(band_pass_left, dims=[1])

        # Combining left, central, and right part of the filter
        band_pass = torch.cat(
            [band_pass_left, band_pass_center, band_pass_right], dim=1
        )

        # Amplitude normalization
        band_pass = band_pass / (2 * band[:, None])

        # Setting up the filter coefficients
        filters = band_pass.view(self.out_channels, 1, self.kernel_size)

        return filters

    def _init_sinc_conv(self):
        """Initializes the parameters of the sinc_conv layer."""

        # Initialize filterbanks such that they are equally spaced in Mel scale
        high_hz = self.sample_rate / 2 - (self.min_low_hz + self.min_band_hz)

        mel = torch.linspace(
            self._to_mel(self.min_low_hz),
            self._to_mel(high_hz),
            self.out_channels + 1,
        )

        hz = self._to_hz(mel)

        # Filter lower frequency and bands
        self.low_hz_ = hz[:-1].unsqueeze(1)
        self.band_hz_ = (hz[1:] - hz[:-1]).unsqueeze(1)

        # Maiking freq and bands learnable
        self.low_hz_ = nn.Parameter(self.low_hz_)
        self.band_hz_ = nn.Parameter(self.band_hz_)

        # Hamming window
        n_lin = torch.linspace(
            0, (self.kernel_size / 2) - 1, steps=int((self.kernel_size / 2))
        )
        self.window_ = 0.54 - 0.46 * torch.cos(
            2 * math.pi * n_lin / self.kernel_size
        )

        # Time axis  (only half is needed due to symmetry)
        n = (self.kernel_size - 1) / 2.0
        self.n_ = (
            2 * math.pi * torch.arange(-n, 0).view(1, -1) / self.sample_rate
        )

    def _to_mel(self, hz):
        """Converts frequency in Hz to the mel scale.
        """
        return 2595 * np.log10(1 + hz / 700)

    def _to_hz(self, mel):
        """Converts frequency in the mel scale to Hz.
        """
        return 700 * (10 ** (mel / 2595) - 1)

    def _manage_padding(
        self, x, kernel_size: int, dilation: int, stride: int,
    ):
        """This function performs zero-padding on the time axis
        such that their lengths is unchanged after the convolution.
        Arguments
        ---------
        x : torch.Tensor
            Input tensor.
        kernel_size : int
            Size of kernel.
        dilation : int
            Dilation used.
        stride : int
            Stride.
        """

        # Detecting input shape
        L_in = x.shape[-1]

        # Time padding
        padding = get_padding_elem(L_in, stride, kernel_size, dilation)

        # Applying padding
        x = F.pad(x, padding, mode=self.padding_mode)

        return x

def get_padding_elem(L_in: int, stride: int, kernel_size: int, dilation: int):
    """This function computes the number of elements to add for zero-padding.
    Arguments
    ---------
    L_in : int
    stride: int
    kernel_size : int
    dilation : int
    """
    if stride > 1:
        n_steps = math.ceil(((L_in - kernel_size * dilation) / stride) + 1)
        L_out = stride * (n_steps - 1) + kernel_size * dilation
        padding = [kernel_size // 2, kernel_size // 2]

    else:
        L_out = (L_in - dilation * (kernel_size - 1) - 1) // stride + 1

        padding = [(L_in - L_out) // 2, (L_in - L_out) // 2]
    return padding

In [61]:
class DemixingNet(nn.Module):
    
    def __init__(self):
        super(DemixingNet, self).__init__()
        self.conv = SincConv(in_channels=1, out_channels=20, kernel_size=11)
        
    def forward(self, wavs):
        return self.conv(wavs)

In [62]:
output = DemixingNet()(X_mix_wavs)

In [63]:
output.shape

torch.Size([200, 16000, 20])