In [1]:
import os, fnmatch
import tensorflow.keras as keras
from tensorflow.keras.models import Model
from tensorflow.keras.layers import Activation, Dense, LSTM, Dropout, \
    Lambda, Input, Multiply, Layer, Conv1D, Concatenate
from tensorflow.keras.callbacks import ReduceLROnPlateau, CSVLogger, \
    EarlyStopping, ModelCheckpoint
import tensorflow as tf
import soundfile as sf
from wavinfo import WavInfoReader
from random import shuffle, seed
import numpy as np
from DTLN_model import DTLN_model

E:\conda\envs\tf\lib\site-packages\numpy\.libs\libopenblas.PYQHXLVVQ7VESDPUVUADXEVJOBGHJPAY.gfortran-win_amd64.dll
E:\conda\envs\tf\lib\site-packages\numpy\.libs\libopenblas.WCDJNK7YVMPZQ2ME2ZZHJJRJ3JIKNDB7.gfortran-win_amd64.dll
  stacklevel=1)


In [2]:
weights_file='weights/DTLN_model.h5'
numLayer=2
numUnits = 128
blockLen = 512
block_shift = 128
num_elements_first_core = 2 + numLayer * 3 + 2

In [3]:
class InstantLayerNormalization(Layer):
    '''
    Class implementing instant layer normalization. It can also be called 
    channel-wise layer normalization and was proposed by 
    Luo & Mesgarani (https://arxiv.org/abs/1809.07454v2) 
    '''

    def __init__(self, **kwargs):
        '''
            Constructor
        '''
        super(InstantLayerNormalization, self).__init__(**kwargs)
        self.epsilon = 1e-7 
        self.gamma = None
        self.beta = None

    def build(self, input_shape):
        '''
        Method to build the weights.
        '''
        shape = input_shape[-1:]
        # initialize gamma
        self.gamma = self.add_weight(shape=shape,
                             initializer='ones',
                             trainable=True,
                             name='gamma')
        # initialize beta
        self.beta = self.add_weight(shape=shape,
                             initializer='zeros',
                             trainable=True,
                             name='beta')
 

    def call(self, inputs):
        '''
        Method to call the Layer. All processing is done here.
        '''

        # calculate mean of each frame
        mean = tf.math.reduce_mean(inputs, axis=[-1], keepdims=True)
        # calculate variance of each frame
        variance = tf.math.reduce_mean(tf.math.square(inputs - mean), 
                                       axis=[-1], keepdims=True)
        # calculate standard deviation
        std = tf.math.sqrt(variance + self.epsilon)
        # normalize each frame independently 
        outputs = (inputs - mean) / std
        # scale with gamma
        outputs = outputs * self.gamma
        # add the bias beta
        outputs = outputs + self.beta
        # return output
        return outputs

In [9]:
class DTLN_model():
    '''
    Class to create and train the DTLN model
    '''
    
    def __init__(self):
        '''
        Constructor
        '''

        # defining default cost function
        self.cost_function = self.snr_cost
        # empty property for the model
        self.model = []
        # defining default parameters
        self.fs = 16000
        self.batchsize = 32
        self.len_samples = 4
        self.activation = 'sigmoid'
        self.numUnits = 128
        self.numLayer = 2
        self.blockLen = 512
        self.block_shift = 128
        self.dropout = 0.25
        self.lr = 1e-3
        self.max_epochs = 200
        self.encoder_size = 256
        self.eps = 1e-7
        # reset all seeds to 42 to reduce invariance between training runs
        os.environ['PYTHONHASHSEED']=str(42)
        seed(42)
        np.random.seed(42)
        tf.random.set_seed(42)
        # some line to correctly find some libraries in TF 2.x
        physical_devices = tf.config.experimental.list_physical_devices('GPU')
        if len(physical_devices) > 0:
            for device in physical_devices:
                tf.config.experimental.set_memory_growth(device, enable=True)
        

    @staticmethod
    def snr_cost(s_estimate, s_true):
        '''
        Static Method defining the cost function. 
        The negative signal to noise ratio is calculated here. The loss is 
        always calculated over the last dimension. 
        '''

        # calculating the SNR
        snr = tf.reduce_mean(tf.math.square(s_true), axis=-1, keepdims=True) / \
            (tf.reduce_mean(tf.math.square(s_true-s_estimate), axis=-1, keepdims=True)+1e-7)
        # using some more lines, because TF has no log10
        num = tf.math.log(snr) 
        denom = tf.math.log(tf.constant(10, dtype=num.dtype))
        loss = -10*(num / (denom))
        # returning the loss
        return loss
        

    def lossWrapper(self):
        '''
        A wrapper function which returns the loss function. This is done to
        to enable additional arguments to the loss function if necessary.
        '''
        def lossFunction(y_true,y_pred):
            # calculating loss and squeezing single dimensions away
            loss = tf.squeeze(self.cost_function(y_pred,y_true))
            # calculate mean over batches
            loss = tf.reduce_mean(loss)
            # return the loss
            return loss
        # returning the loss function as handle
        return lossFunction
    
    

    '''
    In the following some helper layers are defined.
    '''  
    
    def segment(self, x):
        '''
        Method for an STFT helper layer used with a Lambda layer. The layer
        calculates the STFT on the last dimension and returns the magnitude and
        phase of the STFT.
        '''
        
        # creating frames from the continuous waveform
        frames = tf.signal.frame(x, self.blockLen, self.block_shift)
        
        return frames

    def tftLayer(self, x):
        '''
        Method for an STFT helper layer used with a Lambda layer. The layer
        calculates the STFT on the last dimension and returns the magnitude and
        phase of the STFT.
        '''
        
        # calculating the fft over the time frames. rfft returns NFFT/2+1 bins.
        stft_dat = tf.signal.rfft(x)
        # calculating magnitude and phase from the complex signal
        mag = tf.abs(stft_dat)
        phase = tf.math.angle(stft_dat)
        # returning magnitude and phase as list
        return [mag, phase]
    
    def stftLayer(self, x):
        '''
        Method for an STFT helper layer used with a Lambda layer. The layer
        calculates the STFT on the last dimension and returns the magnitude and
        phase of the STFT.
        '''
        
        # creating frames from the continuous waveform
        frames = tf.signal.frame(x, self.blockLen, self.block_shift)
        # calculating the fft over the time frames. rfft returns NFFT/2+1 bins.
        stft_dat = tf.signal.rfft(frames)
        # calculating magnitude and phase from the complex signal
        mag = tf.abs(stft_dat)
        phase = tf.math.angle(stft_dat)
        # returning magnitude and phase as list
        return [mag, phase]
    
    def fftLayer(self, x):
        '''
        Method for an fft helper layer used with a Lambda layer. The layer
        calculates the rFFT on the last dimension and returns the magnitude and
        phase of the STFT.
        '''
        
        # expanding dimensions
        frame = tf.expand_dims(x, axis=1)
        # calculating the fft over the time frames. rfft returns NFFT/2+1 bins.
        stft_dat = tf.signal.rfft(frame)
        # calculating magnitude and phase from the complex signal
        mag = tf.abs(stft_dat)
        phase = tf.math.angle(stft_dat)
        # returning magnitude and phase as list
        return [mag, phase]

 
        
    def ifftLayer(self, x):
        '''
        Method for an inverse FFT layer used with an Lambda layer. This layer
        calculates time domain frames from magnitude and phase information. 
        As input x a list with [mag,phase] is required.
        '''
        
        # calculating the complex representation
        s1_stft = (tf.cast(x[0], tf.complex64) * 
                    tf.exp( (1j * tf.cast(x[1], tf.complex64))))
        # returning the time domain frames
        return tf.signal.irfft(s1_stft)  
    
    
    def overlapAddLayer(self, x):
        '''
        Method for an overlap and add helper layer used with a Lambda layer.
        This layer reconstructs the waveform from a framed signal.
        '''

        # calculating and returning the reconstructed waveform
        return tf.signal.overlap_and_add(x, self.block_shift)
    
        

    def seperation_kernel(self, num_layer, mask_size, x, stateful=False):
        '''
        Method to create a separation kernel. 
        !! Important !!: Do not use this layer with a Lambda layer. If used with
        a Lambda layer the gradients are updated correctly.

        Inputs:
            num_layer       Number of LSTM layers
            mask_size       Output size of the mask and size of the Dense layer
        '''

        # creating num_layer number of LSTM layers
        for idx in range(num_layer):
            x = LSTM(self.numUnits, return_sequences=True, stateful=stateful)(x)
            # using dropout between the LSTM layer for regularization 
            if idx<(num_layer-1):
                x = Dropout(self.dropout)(x)
        # creating the mask with a Dense and an Activation layer
        mask = Dense(mask_size)(x)
        mask = Activation(self.activation)(mask)
        # returning the mask
        return mask
    
    def seperation_kernel_with_states(self, num_layer, mask_size, x, 
                                      in_states):
        '''
        Method to create a separation kernel, which returns the LSTM states. 
        !! Important !!: Do not use this layer with a Lambda layer. If used with
        a Lambda layer the gradients are updated correctly.

        Inputs:
            num_layer       Number of LSTM layers
            mask_size       Output size of the mask and size of the Dense layer
        '''
        
        states_h = []
        states_c = []
        # creating num_layer number of LSTM layers
        for idx in range(num_layer):
            in_state = [in_states[:,idx,:, 0], in_states[:,idx,:, 1]]
            x, h_state, c_state = LSTM(self.numUnits, return_sequences=True, 
                     unroll=True, return_state=True)(x, initial_state=in_state)
            # using dropout between the LSTM layer for regularization 
            if idx<(num_layer-1):
                x = Dropout(self.dropout)(x)
            states_h.append(h_state)
            states_c.append(c_state)
        # creating the mask with a Dense and an Activation layer
        mask = Dense(mask_size)(x)
        mask = Activation(self.activation)(mask)
        out_states_h = tf.reshape(tf.stack(states_h, axis=0), 
                                  [1,num_layer,self.numUnits])
        out_states_c = tf.reshape(tf.stack(states_c, axis=0), 
                                  [1,num_layer,self.numUnits])
        out_states = tf.stack([out_states_h, out_states_c], axis=-1)
        # returning the mask and states
        return mask, out_states

    def build_DTLN_model(self, norm_stft=False):
        '''
        Method to build and compile the DTLN model. The model takes time domain 
        batches of size (batchsize, len_in_samples) and returns enhanced clips 
        in the same dimensions. As optimizer for the Training process the Adam
        optimizer with a gradient norm clipping of 3 is used. 
        The model contains two separation cores. The first has an STFT signal 
        transformation and the second a learned transformation based on 1D-Conv 
        layer. 
        '''
        # input layer for time signal
        
        farend_dat = Input(batch_shape=(2, 64000))
        nearend_dat = Input(batch_shape=(2, 64000))
        
        # calculate Segment
        farend_frames = Lambda(self.segment)(farend_dat)
        nearend_frames = Lambda(self.segment)(nearend_dat)
        print('farend_frames:',farend_frames.shape)
        print('nearend_frames:',nearend_frames.shape)
        
        # calculate STFT
        farend_mag,farend_angle = Lambda(self.tftLayer)(farend_frames)
        nearend_mag,nearend_angle = Lambda(self.tftLayer)(nearend_frames)
        print('farend_mag:',farend_mag.shape)
        print('nearend_mag:',nearend_mag.shape)
        
        # normalizing log magnitude stfts to get more robust against level variations

        if norm_stft:
            farend_mag_norm = InstantLayerNormalization()(tf.math.log(farend_mag + 1e-7))
        else:
            # behaviour like in the paper
            farend_mag_norm = farend_mag
        
        if norm_stft:
            nearend_mag_norm = InstantLayerNormalization()(tf.math.log(nearend_mag + 1e-7))
        else:
            # behaviour like in the paper
            nearend = nearend_mag
        
        print('farend_mag_norm:',farend_mag_norm.shape)
        print('nearend_mag_norm:',nearend_mag_norm.shape)
        
        spectra = Concatenate(axis=-1)([nearend_mag_norm,farend_mag_norm])
        print('spectra:',spectra.shape)
        
        # predicting mask with separation kernel  
        a
        print('mask_1:',mask_1.shape)
        
        # multiply mask with magnitude
        estimated_mag = Multiply()([nearend_mag, mask_1])
        print('estimated_mag:',estimated_mag.shape)
        
        # transform frames back to time domain
        estimated_frames_1 = Lambda(self.ifftLayer)([estimated_mag,nearend_angle])
        print('estimated_frames_1:',estimated_frames_1.shape)
        
        # encode time domain frames to feature domain
        encoded_frames = Conv1D(self.encoder_size,1,strides=1,use_bias=False)(estimated_frames_1)
        encoded_frames_ = Conv1D(self.encoder_size,1,strides=1,use_bias=False)(farend_frames)
        print('encoded_frames:',encoded_frames.shape)
        print('encoded_frames_:',encoded_frames_.shape)
        
        # normalize the input to the separation kernel
        encoded_frames_norm = InstantLayerNormalization()(encoded_frames)
        encoded_frames_norm_ = InstantLayerNormalization()(encoded_frames_)
        print('encoded_frames_norm:',encoded_frames_norm.shape)
        print('encoded_frames_norm_:',encoded_frames_norm_.shape)
        
        feature = Concatenate(axis=-1)([encoded_frames_norm,encoded_frames_norm_])
        print('feature:',feature.shape)
        
        # predict mask based on the normalized feature frames
        mask_2 = self.seperation_kernel(self.numLayer, self.encoder_size, feature)
        print('mask_2:',mask_2.shape)
        
        # multiply encoded frames with the mask
        estimated = Multiply()([encoded_frames, mask_2]) 
        print('estimated:',estimated.shape)
        
        # decode the frames back to time domain
        decoded_frames = Conv1D(self.blockLen, 1, padding='causal',use_bias=False)(estimated)
        print('decoded_frames:',decoded_frames.shape)
        
        # create waveform with overlap and add procedure
        estimated_sig = Lambda(self.overlapAddLayer)(decoded_frames)
        print('estimated_sig:',estimated_sig.shape)

        
        # create the model
        self.model = Model(inputs=[farend_dat,nearend_dat], outputs=estimated_sig)
        # show the model summary
        print(self.model.summary())
    
    def build_DTLN_model_stateful(self, norm_stft=False):
        '''
        Method to build stateful DTLN model for real time processing. The model 
        takes one time domain frame of size (1, blockLen) and one enhanced frame. 
         
        '''
        
        # input layer for time signal
        farend_dat = Input(batch_shape=(1, self.blockLen))
        nearend_dat = Input(batch_shape=(1, self.blockLen))
        
        # calculate Segment
        farend_frames = Lambda(self.segment)(farend_dat)
        nearend_frames = Lambda(self.segment)(nearend_dat)
        
        # calculate STFT
        farend_mag,farend_angle = Lambda(self.tftLayer)(farend_frames)
        nearend_mag,nearend_angle = Lambda(self.tftLayer)(nearend_frames)
        
        # normalizing log magnitude stfts to get more robust against level variations

        if norm_stft:
            farend_mag_norm = InstantLayerNormalization()(tf.math.log(farend_mag + 1e-7))
        else:
            # behaviour like in the paper
            farend_mag_norm = farend_mag
        
        if norm_stft:
            nearend_mag_norm = InstantLayerNormalization()(tf.math.log(nearend_mag + 1e-7))
        else:
            # behaviour like in the paper
            nearend_mag_norm = nearend_mag
    
        
        spectra = Concatenate(axis=-1)([nearend_mag_norm,farend_mag_norm])
        
        # predicting mask with separation kernel  
        mask_1 = self.seperation_kernel(self.numLayer, (self.blockLen//2+1), spectra)
        
        # multiply mask with magnitude
        estimated_mag = Multiply()([nearend_mag, mask_1])
        
        # transform frames back to time domain
        estimated_frames_1 = Lambda(self.ifftLayer)([estimated_mag,nearend_angle])
        
        # encode time domain frames to feature domain
        encoded_frames = Conv1D(self.encoder_size,1,strides=1,use_bias=False)(estimated_frames_1)
        encoded_frames_ = Conv1D(self.encoder_size,1,strides=1,use_bias=False)(farend_frames)
        
        # normalize the input to the separation kernel
        encoded_frames_norm = InstantLayerNormalization()(encoded_frames)
        encoded_frames_norm_ = InstantLayerNormalization()(encoded_frames_)
        
        feature = Concatenate(axis=-1)([encoded_frames_norm,encoded_frames_norm_])
        
        # predict mask based on the normalized feature frames
        mask_2 = self.seperation_kernel(self.numLayer, self.encoder_size, feature)
        
        # multiply encoded frames with the mask
        estimated = Multiply()([encoded_frames, mask_2]) 
        
        # decode the frames back to time domain
        decoded_frames = Conv1D(self.blockLen, 1, padding='causal',use_bias=False)(estimated)
        
        # create waveform with overlap and add procedure
        estimated_sig = Lambda(self.overlapAddLayer)(decoded_frames)

        
        # create the model
        self.model = Model(inputs=[farend_dat,nearend_dat], outputs=estimated_sig)
        # show the model summary
        print(self.model.summary())
    
    def create_tf_lite_model(self, weights_file, target_name, norm_stft,use_dynamic_range_quant=False):
        '''
        Method to create a tf lite model folder from a weights file. 
        The conversion creates two models, one for each separation core. 
        Tf lite does not support complex numbers yet. Some processing must be 
        done outside the model.
        For further information and how real time processing can be 
        implemented see "real_time_processing_tf_lite.py".
        
        The conversion only works with TF 2.3.

        '''
        # check for type
        if  norm_stft:
            num_elements_first_core = 4 + self.numLayer * 3 + 2
        else:
            num_elements_first_core = self.numLayer * 3 + 2
        # build model    
        self.build_DTLN_model_stateful(norm_stft=norm_stft)
        # load weights
        self.model.load_weights(weights_file)
        
        #### Model 1 ##########################
        farend_mag = Input(batch_shape=(1, 1, (self.blockLen//2+1)))
        nearend_mag = Input(batch_shape=(1, 1, (self.blockLen//2+1)))
        states_in_1 = Input(batch_shape=(1, self.numLayer, self.numUnits, 2))
        # normalizing log magnitude stfts to get more robust against level variations
        if norm_stft:
            farend_mag_norm = InstantLayerNormalization()(tf.math.log(farend_mag + 1e-7))
        else:
            # behaviour like in the paper
            farend_mag_norm = farend_mag
        
        if norm_stft:
            nearend_mag_norm = InstantLayerNormalization()(tf.math.log(nearend_mag + 1e-7))
        else:
            # behaviour like in the paper
            nearend_mag_norm = nearend_mag
        
        spectra = Concatenate(axis=-1)([nearend_mag_norm,farend_mag_norm])
        # predicting mask with separation kernel
        mask_1, states_out_1 = self.seperation_kernel_with_states(self.numLayer, (self.blockLen//2+1), spectra, states_in_1)
        
        model_1 = Model(inputs=[farend_mag, nearend_mag, states_in_1], outputs=[mask_1, states_out_1])
        
        #### Model 2 ###########################
        
        estimated_frames_1 = Input(batch_shape=(1, 1, (self.blockLen)))
        farend_mag_seg = Input(batch_shape=(1, 1, (self.blockLen)))
        states_in_2 = Input(batch_shape=(1, self.numLayer, self.numUnits, 2))
        
        # encode time domain frames to feature domain
        encoded_frames = Conv1D(self.encoder_size,1,strides=1,use_bias=False)(estimated_frames_1)
        encoded_frames_ = Conv1D(self.encoder_size,1,strides=1,use_bias=False)(farend_mag_seg)
        
        # normalize the input to the separation kernel
        encoded_frames_norm = InstantLayerNormalization()(encoded_frames)
        encoded_frames_norm_ = InstantLayerNormalization()(encoded_frames_)
        
        feature = Concatenate(axis=-1)([encoded_frames_norm,encoded_frames_norm_])
        # predict mask based on the normalized feature frames
        mask_2, states_out_2 = self.seperation_kernel_with_states(self.numLayer, 
                                                    self.encoder_size, 
                                                    feature, 
                                                    states_in_2)
        # multiply encoded frames with the mask
        estimated = Multiply()([encoded_frames, mask_2]) 
        # decode the frames back to time domain
        decoded_frame = Conv1D(self.blockLen, 1, padding='causal',
                               use_bias=False)(estimated)
        
        model_2 = Model(inputs=[estimated_frames_1, farend_mag_seg, states_in_2], 
                        outputs=[decoded_frame, states_out_2])
        
        # set weights to submodels
        weights = self.model.get_weights()
        print(type(weights))
        print(len(weights))
        model_1.set_weights(weights[:num_elements_first_core])
        model_2.set_weights(weights[num_elements_first_core:])
        # convert first model
        converter = tf.lite.TFLiteConverter.from_keras_model(model_1)
        if use_dynamic_range_quant:
            converter.optimizations = [tf.lite.Optimize.DEFAULT]
        tflite_model = converter.convert()
        with tf.io.gfile.GFile(target_name + '_1.tflite', 'wb') as f:
              f.write(tflite_model)
        # convert second model    
        converter = tf.lite.TFLiteConverter.from_keras_model(model_2)
        if use_dynamic_range_quant:
            converter.optimizations = [tf.lite.Optimize.DEFAULT]
        tflite_model = converter.convert()
        with tf.io.gfile.GFile(target_name + '_2.tflite', 'wb') as f:
              f.write(tflite_model)
              
        print('TF lite conversion complete!')

In [10]:
modelTrainer = DTLN_model()

In [11]:
target_folder='weights/tflite_weight'

In [12]:
modelTrainer.create_tf_lite_model(weights_file, 
                                  target_folder,
                                  norm_stft=True,
                                  use_dynamic_range_quant=bool(False))

Model: "functional_6"
__________________________________________________________________________________________________
Layer (type)                    Output Shape         Param #     Connected to                     
input_10 (InputLayer)           [(1, 512)]           0                                            
__________________________________________________________________________________________________
input_9 (InputLayer)            [(1, 512)]           0                                            
__________________________________________________________________________________________________
lambda_7 (Lambda)               (1, 1, 512)          0           input_10[0][0]                   
__________________________________________________________________________________________________
lambda_6 (Lambda)               (1, 1, 512)          0           input_9[0][0]                    
_______________________________________________________________________________________

<class 'list'>
27
TF lite conversion complete!


In [9]:
weights = modelTrainer.model.get_weights()

In [10]:
model= modelTrainer.model

In [11]:
for i in weights:
    print(i.shape)

(257,)
(257,)
(257,)
(257,)
(514, 512)
(128, 512)
(512,)
(128, 512)
(128, 512)
(512,)
(128, 257)
(257,)
(1, 512, 256)
(1, 512, 256)
(256,)
(256,)
(256,)
(256,)
(512, 512)
(128, 512)
(512,)
(128, 512)
(128, 512)
(512,)
(128, 256)
(256,)
(1, 256, 512)
