# keras-sincnet

Speaker Recognition from Raw Waveform with SincNet  
*Mirco Ravanelli, Yoshua Bengio*  
https://arxiv.org/pdf/1808.00158.pdf

## To-Do

- [ ] Reproduce results and compare with similar CNN 1D
- [ ] Add skip connections and Q-RNN (PASE+)
- [ ] Add theory behind SincNet, MFCC and Q-RNN


In [None]:
import tensorflow as tf
from tensorflow.keras.layers import Layer
from tensorflow.keras import backend as K

import numpy as np

tf.__version__

'2.4.1'

## Slow implementation

In [None]:
class SincConv(Layer):
    '''
    Sinc-based convolution Keras layer

    Reference
    ---------
    Mirco Ravanelli, Yoshua Bengio,
    "Speaker Recognition from raw waveform with SincNet".
    https://arxiv.org/abs/1808.00158
    '''

    @staticmethod
    def sinc(band, t_right):
        y_right = K.sin(2 * np.pi * band * t_right) / (2 * np.pi * band * t_right)
        y_left = K.reverse(y_right, 0)
        y = K.concatenate([y_left, K.variable(K.ones(1)), y_right])
        return y

    @staticmethod
    def hz_to_mel(hz):
        return 2595.0 * np.log10(1.0 + hz / 700.0)

    @staticmethod
    def mel_to_hz(mels):
        return 700.0 * (10.0 ** (mels / 2595.0) - 1.0)

    def __init__(self, nb_filters, kernel_size, sample_freq):
        super(SincConv, self).__init__()

        self.nb_filters = nb_filters
        self.kernel_size = kernel_size
        self.sample_freq = sample_freq

        # Set trainable parameters
        self.b1 = self.add_weight(
            name='b1',
            shape=(self.nb_filters,),
            initializer="zeros",
            trainable=True)
        self.band = self.add_weight(
            name='band',
            shape=(self.nb_filters,),
            initializer="zeros",
            trainable=True)
        
        # Initialize weights with cutoff frequencies of the mel-scale filter-bank
        low_freq_mel = self.hz_to_mel(50)
        high_freq_mel = self.hz_to_mel(self.sample_freq / 2)
        mel_points = np.linspace(low_freq_mel, high_freq_mel, num=self.nb_filters)
        hz_points = self.mel_to_hz(mel_points)

        b1 = np.roll(hz_points, 1)
        b1[0] = 30
        b2 = np.roll(hz_points, -1)
        b2[-1] = (self.sample_freq / 2) - 100

        self.set_weights([b1 / self.sample_freq, (b2 - b1) / self.sample_freq])

        # Initialize weights by 0 and the Nyquist frequency
        # low = np.zeros(self.nb_filters)
        # high = np.repeat(self.sample_freq / 2, self.nb_filters)
        # self.set_weights([low / self.sample_freq,
        #                   (high - low) / self.sample_freq])
        
        # Get beginning and end frequencies of the filters
        min_freq = 50.0
        min_band = 50.0
        self.beg_freq = K.abs(self.b1) + min_freq / self.sample_freq
        self.end_freq = self.beg_freq + (K.abs(self.band) + min_band / self.sample_freq)
        
        t_right_linspace = np.linspace(1, (self.kernel_size - 1) / 2, int((self.kernel_size - 1) / 2))
        self.t_right = K.variable(t_right_linspace / self.sample_freq)

        # Hamming window
        n = np.linspace(0, self.kernel_size, num=self.kernel_size)
        window = 0.54 - 0.46 * K.cos(2 * np.pi * n / self.kernel_size)
        window = K.cast(window, "float32")
        self.window = K.variable(window)

    def call(self, X):
        filters = []
        for i in range(self.nb_filters):
            low_pass1 = 2 * self.beg_freq[i] * self.sinc(self.beg_freq[i] * self.sample_freq, self.t_right)
            low_pass2 = 2 * self.end_freq[i] * self.sinc(self.end_freq[i] * self.sample_freq, self.t_right)
            band_pass = low_pass2 - low_pass1
            band_pass = band_pass / K.max(band_pass)

            filters.append(band_pass * self.window)

        filters = K.stack(filters)

        # TF convolution assumes data is stored as NWC
        filters = K.transpose(filters)
        filters = K.reshape(filters, (self.kernel_size, 1, self.nb_filters))

        return K.conv1d(X, filters)

    def compute_output_shape(self, input_shape):
        out_width_size = conv_utils.conv_output_length(
            input_shape[1],
            self.kernel_size,
            padding="valid",
            stride=1,
            dilation=1)
        return (input_shape[0], out_width_size, self.nb_filters)


X = np.arange(63, dtype=np.single).reshape((1, 63, 1))
sinc_layer = SincConv(1, 9, 400)
y = sinc_layer(X)
print(y.numpy().transpose(0, 2, 1))

[[[0.0931545  0.11644317 0.1397316  0.16302049 0.18630917 0.20959736
   0.23288602 0.25617468 0.27946335 0.30275205 0.3260407  0.34932888
   0.37261754 0.39590624 0.41919488 0.44248354 0.46577224 0.4890609
   0.51234955 0.5356382  0.5589269  0.58221555 0.60550326 0.6287919
   0.65208066 0.67536926 0.69865793 0.7219466  0.74523526 0.768524
   0.7918126  0.81510127 0.83839    0.8616786  0.88496727 0.90825593
   0.9315446  0.95483327 0.97812194 1.0014106  1.0246992  1.0479879
   1.0712767  1.0945653  1.117854   1.1411407  1.1644312  1.187718
   1.2110087  1.2342954  1.2575841  1.2808727  1.3041613  1.32745
   1.3507388 ]]]


## Fast implementation

In [55]:
class SincConvFast(Layer):
    '''
    Sinc-based convolution Keras layer

    Parameters
    ----------
    nb_filters : `int`
        Number of filters (= number of output channels).
    kernel_size : `int`
        Convolution filter width/length (will be increased by one if even).
    sample_freq : `int`
        Sample rate of input audio.
    stride : `int`
        Convolution stride param. Defaults to 1.
    padding : `string`
        Convolution padding param. Defaults to "VALID".
    min_low_hz : `int`
        Minimum lowest frequency for pass band filter. Defaults to 50.
    min_band_hz : `int`
        Minimum frequency for pass band filter. Defaults to 50.

    Reference
    ---------
    Mirco Ravanelli, Yoshua Bengio,
    "Speaker Recognition from raw waveform with SincNet".
    https://arxiv.org/abs/1808.00158
    '''

    @staticmethod
    def hz_to_mel(hz):
        return 2595.0 * np.log10(1.0 + hz / 700.0)

    @staticmethod
    def mel_to_hz(mels):
        return 700.0 * (10.0 ** (mels / 2595.0) - 1.0)

    def __init__(self, nb_filters, kernel_size, sample_freq,
                 stride=1, padding="VALID", min_low_hz=50, min_band_hz=50):
        super(SincConvFast, self).__init__()

        self.nb_filters = nb_filters
        self.kernel_size = kernel_size
        self.sample_freq = sample_freq
        self.stride = stride
        self.padding = padding
        self.min_low_hz = min_low_hz
        self.min_band_hz = min_band_hz

        # Force filter size to be odd for later optimizations with symmetry
        if kernel_size % 2 == 0:
            self.kernel_size = self.kernel_size + 1

        # Set trainable parameters
        self.low_hz = self.add_weight(
            name='low_hz',
            shape=(self.nb_filters,),
            initializer="zeros",
            trainable=True)
        self.band_hz = self.add_weight(
            name='band_hz',
            shape=(self.nb_filters,),
            initializer="zeros",
            trainable=True)
        
        # Initialize weights with frequencies of the mel-scale filter-bank
        low_freq_mel = self.hz_to_mel(30)
        high_freq_mel = self.hz_to_mel(self.sample_freq / 2 - (self.min_low_hz + self.min_band_hz))
        mel_points = np.linspace(low_freq_mel, high_freq_mel, num=self.nb_filters + 1)
        hz_points = self.mel_to_hz(mel_points)
        self.set_weights([hz_points[:-1], np.diff(hz_points)])
      
        # Determine half of t
        t_linspace = np.arange(-(self.kernel_size - 1) / 2, 0)
        t = tf.Variable(2 * np.pi * t_linspace / self.sample_freq)
        t = tf.cast(t, "float32")
        self.t = tf.reshape(t, (1, -1))

        # Determine half of the hamming window
        n = np.linspace(0, (self.kernel_size / 2) - 1, num=int((self.kernel_size / 2)))
        window = 0.54 - 0.46 * tf.cos(2 * np.pi * n / self.kernel_size)
        window = tf.cast(window, "float32")
        self.window = tf.Variable(window)

    def call(self, X):
        low = self.min_low_hz + tf.abs(self.low_hz)
        high = tf.clip_by_value(low + self.min_band_hz + tf.abs(self.band_hz), self.min_low_hz, self.sample_freq / 2)
        band = high - low

        low_times_t = tf.linalg.matmul(tf.reshape(low, (-1, 1)), self.t)
        high_times_t = tf.linalg.matmul(tf.reshape(high, (-1, 1)), self.t)

        band_pass_left = ((tf.sin(high_times_t) - tf.sin(low_times_t)) / (self.t / 2)) * self.window
        band_pass_center = tf.reshape(2 * band, (-1, 1))
        band_pass_right = tf.reverse(band_pass_left, [1])

        filters = tf.concat([band_pass_left,
                             band_pass_center,
                             band_pass_right], axis=1)
        filters = filters / (2 * band[:, None])

        # TF convolution assumes data is stored as NWC
        filters = tf.transpose(filters)
        filters = tf.reshape(filters, (self.kernel_size, 1, self.nb_filters))

        return tf.nn.conv1d(X, filters, self.stride, self.padding)

    def compute_output_shape(self, input_shape):
        out_width_size = conv_utils.conv_output_length(
            input_shape[1],
            self.kernel_size,
            padding="valid",
            stride=1,
            dilation=1)
        return (input_shape[0], out_width_size, self.nb_filters)


X = np.arange(63, dtype=np.single).reshape((1, 63, 1))
sinc_layer = SincConvFast(2, 9, 400)
y = sinc_layer(X)

print(y.numpy().transpose(0, 2, 1))

[[[-0.07339406 -0.09174238 -0.11009077 -0.12843938 -0.146788
   -0.1651365  -0.18348466 -0.20183326 -0.22018176 -0.23853038
   -0.256879   -0.27522764 -0.2935753  -0.31192377 -0.33027253
   -0.348621   -0.36696953 -0.38531825 -0.40366676 -0.42201528
   -0.44036403 -0.45871246 -0.47706097 -0.49540973 -0.5137573
   -0.53210604 -0.5504545  -0.568803   -0.5871527  -0.6054993
   -0.623848   -0.64219654 -0.66054505 -0.67889357 -0.697242
   -0.715591   -0.73394    -0.752288   -0.770636   -0.788985
   -0.8073345  -0.82568055 -0.8440305  -0.8623776  -0.8807265
   -0.8990751  -0.917425   -0.9357721  -0.95412105 -0.972469
   -0.99081707 -1.0091665  -1.0275155  -1.0458635  -1.0642116 ]
  [-0.18435074 -0.23043844 -0.27652615 -0.3226139  -0.36870134
   -0.41478932 -0.46087682 -0.50696474 -0.55305207 -0.59913963
   -0.64522773 -0.6913156  -0.737403   -0.7834905  -0.8295781
   -0.8756664  -0.92175376 -0.96784145 -1.013929   -1.0600163
   -1.1061046  -1.1521919  -1.1982797  -1.2443671  -1.2904555
   -1