In [29]:
import numpy as np
import tensorflow as tf
from tensorflow import keras

In [30]:
from typing import Optional

## Preprocessing

In [2]:
def waveform_to_log_mel_spectrogram_patches(waveform, params):
    """
    Compute log mel spectrogram patches of a 1-D waveform.
    """
    with tf.name_scope('log_mel_features'):
        # Waveform has shape [<# samples>]

        # Convert waveform into spectrogram using STFT
        window_length_samples = int(round(params.sample_rate*params.stft_window_seconds))
        hop_length_samples = int(round(params.sample_rate * params.stft_hop_seconds))
        fft_length = 2**int(np.ceil(np.log(window_length_samples)/np.log(2.0)))
        num_spectrogram_bins = fft_length//2+1
        if params.tflite_compatible:
            magnitude_spectrogram = _tflite_stft_magnitude(
                signal = waveform,
                frame_length = window_length_samples,
                frame_step = hop_length_samples,
                fft_length = fft_length
            )
        else:
            magnitude_spectrogram = tf.abs(tf.signal.stft(
                signals = waveform,
                frame_length = window_length_samples,
                frame_step = hop_length_samples,
                fft_length = fft_length
            ))
        # magnitude_spectrogram has shape [<# STFT frames>, num_spectrogram_bins]

        # Convert spectrogram into log mel spectrogram.
        linear_to_mel_weight_matrix = tf.signal.linear_to_mel_weight_matrix(
            num_mel_bins = params.mel_bands,
            num_spectrogram_bins = num_spectrogram_bins,
            sample_rate = params.sample_rate,
            lower_edge_hertz = params.mel_min_hz,
            upper_edge_hertz = params.mel_max_hz
        )
        mel_spectrogram = tf.matmul(magnitude_spectrogram, linear_to_mel_weight_matrix)
        log_mel_spectrogram = tf.math.log(mel_spectrogram + params.log_offset)
        # Log_mel_spectrogram has shape [<# STFT frames>, params.mel_bands]

        # Frame spectrogram (shape [<# STFT frames>, params.mel_bands]) into patches
        # (the input examples). Only compute frames are emitted, so if there is
        # less then params.patch_window_seconds of waveform then nothing is emitted
        # (to avoid this, zero-pad before processing)
        spectrogram_hop_length_samples = int(
            round(params.sample_rate*params.stft_hop_seconds)
        )
        spectrogram_sample_rate = params.sample_rate/spectrogram_hop_length_samples
        patch_window_length_samples = int(
            round(spectrogram_sample_rate*params.patch_window_seconds)
        )
        patch_hop_length_samples = int(
            round(spectrogram_sample_rate * params.patch_hop_seconds)
        )
        features = tf.signal.frame(
            signal = log_mel_spectrogram,
            frame_length = path_window_length_samples,
            frame_step = patch_hop_length_samples,
            axis = 0
        )

In [32]:
class Spectrogram(keras.Model):
    def __init__(self,
                 sample_rate: int = 16000,
                 n_fft: int = 400,
                 win_length: Optional[int] = None,
                 hop_length: Optional[int] = None,
                 pad_end: bool = False,
                 power: float = 2.0) -> None:
        super().__init__()
        self.sample_rate = sample_rate
        self.n_fft = n_fft
        self.win_length = win_length if win_length is not None else n_fft
        self.hop_length = hop_length if hop_length is not None else win_length//2
        self.pad_end = pad_end
        self.power = power

    def call(self, waveform: tf.Tensor) -> tf.Tensor:
        spectrogram = tf.abs(tf.signal.stft(
                signals = waveform,
                frame_length = self.win_length,
                frame_step = self.hop_length,
                fft_length = self.n_fft,
                pad_end = self.pad_end
            ))
        if self.power == 2:
            spectrogram = spectrogram*spectrogram
        return spectrogram

In [33]:
f_spec = Spectrogram(16000,
                     n_fft = 512,
                     win_length = 480,
                     hop_length = 160,
                     pad_end = True)
f_spec(wav).shape

TensorShape([100, 257])

In [37]:
class MelSpectrogram(keras.Model):
    def __init__(self,
                 sample_rate: int = 16000,
                 n_fft: int = 400,
                 win_length: Optional[int] = None,
                 hop_length: Optional[int] = 160,
                 f_min: float = 0.0,
                 f_max: float = 3800,
                 pad_end: bool = False,
                 n_mels: int = 128,
                 power: float = 2.0,
                 power_to_db: bool = True) -> None:
          super().__init__()
          num_spectrogram_bins = n_fft//2+1
          self.spec = Spectrogram(sample_rate,
                                  n_fft,
                                  win_length,
                                  hop_length,
                                  pad_end,
                                  power)
          self.linear_to_mel_weight_matrix = tf.signal.linear_to_mel_weight_matrix(
              num_mel_bins = n_mels,
              num_spectrogram_bins = num_spectrogram_bins,
              sample_rate = sample_rate,
              lower_edge_hertz = f_min,
              upper_edge_hertz = f_max
            )
          self.power_to_db = power_to_db

    def call(self, waveform: tf.Tensor) -> tf.Tensor:
        spectrogram = self.spec(waveform)
        mel_spectrogram = tf.matmul(spectrogram, self.linear_to_mel_weight_matrix)
        if self.power_to_db:
            # Log mel spectrogram
            log_offset = 1e-6
            mel_spectrogram = tf.math.log(mel_spectrogram + log_offset)
        return mel_spectrogram

In [38]:
f_mel = MelSpectrogram(16000,
                       512,
                       480,
                       160,
                       pad_end = True,
                       n_mels = 40,
                       power = 1,
                       power_to_db = True)
f_mel(wav).shape

TensorShape([100, 40])

In [39]:
f_mel.summary()

Model: "mel_spectrogram_1"
_________________________________________________________________
 Layer (type)                Output Shape              Param #   
 spectrogram_2 (Spectrogram  multiple                  0         
 )                                                               
                                                                 
Total params: 0 (0.00 Byte)
Trainable params: 0 (0.00 Byte)
Non-trainable params: 0 (0.00 Byte)
_________________________________________________________________


In [43]:
test_specs = np.random.rand(200, 16000)
test_specs.shape

(200, 16000)

In [44]:
def representative_data_gen():
    for input_value in tf.data.Dataset.from_tensor_slices(test_specs).batch(1).take(100):
        yield [input_value]

representative_data_gen()

<generator object representative_data_gen at 0x78944c943140>

In [42]:
def representative_data_gen():
    for input_value in tf.data.Dataset.from_tensor_slices(test_specs).batch(1).take(100):
        yield [input_value]

converter = tf.lite.TFLiteConverter.from_keras_model(f_spec)
converter.optimizations = [tf.lite.Optimize.DEFAULT]
converter.representative_dataset = representative_data_gen
# Ensure that if any ops can't be quantized, the converter throws an error
converter.target_spec.supported_ops = [tf.lite.OpsSet.TFLITE_BUILTINS_INT8]
# Set the input and output tensors to uint 8 (APIs added in r2.3)
converter.inference_input_type = tf.uint8
converter.inference_output_type = tf.uint8

tflite_model_quant = converter.convert()



RuntimeError: The size_splits must sum to the dimension of value along axis.Node number 1 (SPLIT_V) failed to prepare.

In [None]:
def Wav2LogMelSpectrogram(waveform):
    window_length_samples = int(round(params.sample_rate*params.stft_window_seconds))
    hop_length_samples = int(round(params.sample_rate * params.stft_hop_seconds))
    fft_length = 2**int(np.ceil(np.log(window_length_samples)/np.log(2.0)))
    num_spectrogram_bins = fft_length//2+1

In [3]:
wav = np.random.rand(16000).astype(np.float32)

WIN_LENGTH = 480
HOP_LENGTH = 160
NFFT = 512
NUM_SPECTROGRAM_BINS = NFFT//2+1

In [4]:
magnitude_spectrogram = tf.abs(tf.signal.stft(
    signals = wav,
    frame_length = WIN_LENGTH,
    frame_step = HOP_LENGTH,
    fft_length = NFFT,
    pad_end = True
))
magnitude_spectrogram.shape

TensorShape([100, 257])

In [9]:
import torch
import torchaudio

In [10]:
spec = torchaudio.transforms.Spectrogram(
    n_fft = NFFT,
    win_length = WIN_LENGTH,
    hop_length = HOP_LENGTH,
    power = 1
)(torch.from_numpy(wav))
spec.shape

torch.Size([257, 101])

In [5]:
linear_to_mel_weight_matrix = tf.signal.linear_to_mel_weight_matrix(
    num_mel_bins = 40,
    num_spectrogram_bins = NUM_SPECTROGRAM_BINS,
    sample_rate = 16000
)
linear_to_mel_weight_matrix.shape

TensorShape([257, 40])

In [6]:
mel_spectrogram = tf.matmul(magnitude_spectrogram, linear_to_mel_weight_matrix)
mel_spectrogram.shape

TensorShape([100, 40])

In [27]:
def _tflite_stft_magnitude(signal, frame_length, frame_step, fft_length):
    """
    TF-Lite-compatible verison of tf
    """
    def _hann_window():
        return tf.reshape(
            tf.constant(
                (0.5 - 0.5*np.cos(2*np.pi*np.arange(0, 1.0, 1.0/frame_length))
            ).astype(np.float32),
            name = 'hann_window'), [1, frame_length]
        )
    def _dft_matrix(dft_length):
        """Calculate the full DFT matrix in numpy"""
        omega = (0 + 1j)*2.0*np.pi/float(dft_length)
        return np.exp(omega*np.outer(np.arange(dft_length), np.arange(dft_length)))

    def _rdft(framed_signal, fft_length):
        """
        Implement real-input DFT by matmul.
        """
        # We are right-multiplying by the DFT matrix, and we are keeping only the
        # first half ("positive frequncies"). So discard the second half of rows,
        # bu transpose the array for right-multiplication. The DFT matrix is
        # symmetric, so we could have done it more directly, but this reflects our
        # intention better
        complex_dft_matrix_kept_values = _dft_matrix(fft_length)[:(
            fft_length//2+1
        ), :].transpose()
        real_dft_matrix = tf.constant(
            np.real(complex_dft_matrix_kept_values).astype(np.float32),
            name = 'real_dft_matrix'
        )
        imag_dft_matrix = tf.constant(
            np.imag(complex_dft_matrix_kept_values).astype(np.float32),
            name = 'imaginary_dft_matrix'
        )
        signal_frame_length = tf.shape(framed_signal)[-1]
        half_pad = (fft_length - signal_frame_length)//2
        padded_frames = tf.pad(
            framed_signal,
            [
                # Don't add any padding in the frame dimension
                [0, 0],
                # Pad before and after the signal within each frame
                [half_pad, fft_length - signal_frame_length - half_pad]
            ],
            mode = 'CONSTANT',
            constant_values = 0.0
        )
        real_stft = tf.matmul(padded_frames, real_dft_matrix)
        imag_stft = tf.matmul(padded_frames, imag_dft_matrix)
        return real_stft, imag_stft

    def _complex_abs(real, imag):
        return tf.sqrt(tf.add(real*real, imag*imag))

    framed_signal = tf.signal.frame(signal, frame_length, frame_step, pad_end = True)
    windowed_signal = framed_signal*_hann_window()
    real_stft, imag_stft = _rdft(windowed_signal, fft_length)
    stft_magnitude = _complex_abs(real_stft, imag_stft)
    return stft_magnitude

In [15]:
framed_signal = tf.signal.frame(wav, WIN_LENGTH, HOP_LENGTH, pad_end = True)
framed_signal.shape

TensorShape([100, 480])

In [17]:
def _hann_window(frame_length):
    return tf.reshape(
            tf.constant(
                (0.5 - 0.5*np.cos(2*np.pi*np.arange(0, 1.0, 1.0/frame_length))
            ).astype(np.float32),
            name = 'hann_window'), [1, frame_length]
        )
han_window = _hann_window(WIN_LENGTH)
han_window.shape

TensorShape([1, 480])

In [18]:
windowed_signal = framed_signal*han_window
windowed_signal.shape

TensorShape([100, 480])

In [19]:
def _dft_matrix(dft_length):
    """Calculate the full DFT matrix in numpy"""
    omega = (0 + 1j)*2.0*np.pi/float(dft_length)
    return np.exp(omega*np.outer(np.arange(dft_length), np.arange(dft_length)))

dft_mat = _dft_matrix(NFFT)
dft_mat.shape

(512, 512)

In [25]:
np.abs(dft_mat[:5, :5])

array([[1., 1., 1., 1., 1.],
       [1., 1., 1., 1., 1.],
       [1., 1., 1., 1., 1.],
       [1., 1., 1., 1., 1.],
       [1., 1., 1., 1., 1.]])

In [None]:
def _rdft(framed_signal, fft_length):
    """
    Implement real-input DFT by matmul.
    """
    # We are right-multiplying by the DFT matrix, and we are keeping only the
    # first half ("positive frequncies"). So discard the second half of rows,
    # bu transpose the array for right-multiplication. The DFT matrix is
    # symmetric, so we could have done it more directly, but this reflects our
    # intention better
    complex_dft_matrix_kept_values = _dft_matrix(fft_length)[:(
            fft_length//2+1
    ), :].transpose()
    real_dft_matrix = tf.constant(
        np.real(complex_dft_matrix_kept_values).astype(np.float32),
        name = 'real_dft_matrix'
    )
    imag_dft_matrix = tf.constant(
        np.imag(complex_dft_matrix_kept_values).astype(np.float32),
        name = 'imaginary_dft_matrix'
    )
    signal_frame_length = tf.shape(framed_signal)[-1]
    half_pad = (fft_length - signal_frame_length)//2
    padded_frames = tf.pad(
        framed_signal,
        [
        # Don't add any padding in the frame dimension
        [0, 0],
        # Pad before and after the signal within each frame
        [half_pad, fft_length - signal_frame_length - half_pad]
        ],
        mode = 'CONSTANT',
        constant_values = 0.0
    )
    real_stft = tf.matmul(padded_frames, real_dft_matrix)
    imag_stft = tf.matmul(padded_frames, imag_dft_matrix)
    return real_stft, imag_stft

In [28]:
magnitude_spectrogram = _tflite_stft_magnitude(wav, WIN_LENGTH, HOP_LENGTH, NFFT)
magnitude_spectrogram.shape

TensorShape([100, 257])

In [None]:
class Prep_TF(keras.Model):
    def __init__(self):
        self.prep =