# keras-sincnet

## To-Do

- [ ] Fast implementation / clean / docs

---

- [ ] Reproduce results and comparison with similar CNN 1D

---

- [ ] Add skip connections and Q-RNN (PASE+)
- [ ] Add theoretical explanations for MFCC and Q-RNN


In [None]:
import tensorflow as tf
from tensorflow.keras.layers import Layer
from tensorflow.keras import backend as K

import numpy as np

In [None]:
class SincConv(Layer):
    '''
    Sinc-based convolution Keras layer

    Reference
    ---------
    Mirco Ravanelli, Yoshua Bengio,
    "Speaker Recognition from raw waveform with SincNet".
    https://arxiv.org/abs/1808.00158
    '''

    @staticmethod
    def sinc(band, t_right):
        y_right = K.sin(2 * np.pi * band * t_right) / (2 * np.pi * band * t_right)
        y_left = K.reverse(y_right, 0)
        y = K.concatenate([y_left, K.variable(K.ones(1)), y_right])
        return y

    @staticmethod
    def hz_to_mel(hz):
        return 2595.0 * np.log10(1.0 + hz / 700.0)

    @staticmethod
    def mel_to_hz(mels):
        return 700.0 * (10.0 ** (mels / 2595.0) - 1.0)

    def __init__(self, nb_filters, kernel_size, sample_freq):
        super(SincConv, self).__init__()

        self.nb_filters = nb_filters
        self.kernel_size = kernel_size
        self.sample_freq = sample_freq

        # Set trainable parameters
        self.b1 = self.add_weight(
            name='b1',
            shape=(self.nb_filters,),
            initializer="zeros",
            trainable=True)
        self.band = self.add_weight(
            name='band',
            shape=(self.nb_filters,),
            initializer="zeros",
            trainable=True)
        
        # Initialize weights with cutoff frequencies of the mel-scale filter-bank
        low_freq_mel = self.hz_to_mel(50)
        high_freq_mel = self.hz_to_mel(self.sample_freq / 2)
        mel_points = np.linspace(low_freq_mel, high_freq_mel, num=self.nb_filters)
        hz_points = self.mel_to_hz(mel_points)

        b1 = np.roll(hz_points, 1)
        b1[0] = 30
        b2 = np.roll(hz_points, -1)
        b2[-1] = (self.sample_freq / 2) - 100

        self.set_weights([b1 / self.sample_freq, (b2 - b1) / self.sample_freq])

        # Initialize weights by 0 and the Nyquist frequency
        # low = np.zeros(self.nb_filters)
        # high = np.repeat(self.sample_freq / 2, self.nb_filters)
        # self.set_weights([low / self.sample_freq,
        #                   (high - low) / self.sample_freq])
        
        # Get beginning and end frequencies of the filters
        min_freq = 50.0
        min_band = 50.0
        self.beg_freq = K.abs(self.b1) + min_freq / self.sample_freq
        self.end_freq = self.beg_freq + (K.abs(self.band) + min_band / self.sample_freq)
        
        t_right_linspace = np.linspace(1, (self.kernel_size - 1) / 2, int((self.kernel_size - 1) / 2))
        self.t_right = K.variable(t_right_linspace / self.sample_freq)

        # Hamming window
        n = np.linspace(0, self.kernel_size, num=self.kernel_size)
        window = 0.54 - 0.46 * K.cos(2 * np.pi * n / self.kernel_size)
        window = K.cast(window, "float32")
        self.window = K.variable(window)

    def call(self, X):
        filters = []
        for i in range(self.nb_filters):
            low_pass1 = 2 * self.beg_freq[i] * self.sinc(self.beg_freq[i] * self.sample_freq, self.t_right)
            low_pass2 = 2 * self.end_freq[i] * self.sinc(self.end_freq[i] * self.sample_freq, self.t_right)
            band_pass = low_pass2 - low_pass1
            band_pass = band_pass / K.max(band_pass)

            filters.append(band_pass * self.window)

        filters = K.stack(filters)

        # TF convolution assumes data is stored as NWC
        filters = K.transpose(filters)
        filters = K.reshape(filters, (self.kernel_size, 1, self.nb_filters))

        return K.conv1d(X, filters)

    def compute_output_shape(self, input_shape):
        out_width_size = conv_utils.conv_output_length(
            input_shape[1],
            self.kernel_size,
            padding="valid",
            stride=1,
            dilation=1)
        return (input_shape[0], out_width_size, self.nb_filters)


X = np.arange(63, dtype=np.single).reshape((1, 63, 1))
sinc_layer = SincConv(1, 9, 400)
y = sinc_layer(X)
print(y.numpy().transpose(0, 2, 1))

[[[0.0931545  0.11644317 0.1397316  0.16302049 0.18630917 0.20959736
   0.23288602 0.25617468 0.27946335 0.30275205 0.3260407  0.34932888
   0.37261754 0.39590624 0.41919488 0.44248354 0.46577224 0.4890609
   0.51234955 0.5356382  0.5589269  0.58221555 0.60550326 0.6287919
   0.65208066 0.67536926 0.69865793 0.7219466  0.74523526 0.768524
   0.7918126  0.81510127 0.83839    0.8616786  0.88496727 0.90825593
   0.9315446  0.95483327 0.97812194 1.0014106  1.0246992  1.0479879
   1.0712767  1.0945653  1.117854   1.1411407  1.1644312  1.187718
   1.2110087  1.2342954  1.2575841  1.2808727  1.3041613  1.32745
   1.3507388 ]]]
