In [3]:
import os, fnmatch
import tensorflow.keras as keras
from tensorflow.keras.models import Model
from tensorflow.keras.layers import Activation, Dense, LSTM, Dropout, \
    Lambda, Input, Multiply, Layer, Conv1D, Concatenate
from tensorflow.keras.callbacks import ReduceLROnPlateau, CSVLogger, \
    EarlyStopping, ModelCheckpoint
import tensorflow as tf
import soundfile as sf
from wavinfo import WavInfoReader
from random import shuffle, seed
import numpy as np
from pathlib import Path
from scipy.io import wavfile
from utils import preprocess

In [5]:
class audio_generator():
    '''
    Class to create a Tensorflow dataset based on an iterator from a large scale 
    audio dataset. This audio generator only supports single channel audio files.
    '''
    
    def __init__(self, path_to_nearend_signal, path_to_farend_signal, path_to_rirs, len_of_samples, fs, train_flag=False):
        '''
        Constructor of the audio generator class.
        Inputs:
            path_to_input       path to the mixtures
            path_to_s1          path to the target source data
            len_of_samples      length of audio snippets in samples
            fs                  sampling rate
            train_flag          flag for activate shuffling of files
        '''
        # set inputs to properties
        self.path_to_nearend_signal = path_to_nearend_signal
        self.path_to_farend_signal = path_to_farend_signal
        self.path_to_rirs = path_to_rirs
        
        self.len_of_samples = len_of_samples
        self.fs = fs
        self.train_flag=train_flag
        # count the number of samples in your data set (depending on your disk,
        #                                               this can take some time)
        #self.count_samples()
        # create iterable tf.data.Dataset object
        self.create_tf_data_obj()
        
    def count_samples(self):
        '''
        Method to list the data of the dataset and count the number of samples. 
        '''

        # list .wav files in directory
        self.file_names = fnmatch.filter(os.listdir(self.path_to_nearend_signal), '*.wav')
        # count the number of samples contained in the dataset
        self.total_samples = 0
        for file in self.file_names:
            info = WavInfoReader(os.path.join(self.path_to_nearend_signal, file))
            self.total_samples = self.total_samples + \
                int(np.fix(info.data.frame_count/self.len_of_samples))
    
         
    def create_generator(self):
        '''
        Method to create the iterator. 
        '''
        
        near_speechs=np.load(self.path_to_nearend_signal)
        far_speechs=np.load(self.path_to_farend_signal)
        total=min(len(near_speechs),len(far_speechs))
        near_speechs=near_speechs[:total]
        far_speechs=far_speechs[:total]
        rirs=np.load(self.path_to_rirs)
        # iterate over the files  
        shuffle(near_speechs)
        shuffle(far_speechs)
        for nearend,farend in zip(near_speechs,far_speechs):
            # read the audio files
            nearend_signal, fs_1 = sf.read(nearend[:-2])
            farend_signal, fs_2 = sf.read(farend[:-2])
            nearend_time=int(nearend[-2:])
            farend_time=int(farend[-2:])
            nearend_signal=nearend_signal[self.len_of_samples*(nearend_time-1):self.len_of_samples*(nearend_time)]
            farend_signal=farend_signal[self.len_of_samples*(nearend_time-1):self.len_of_samples*(nearend_time)]
            # check if the sampling rates are matching the specifications
            if fs_1 != self.fs or fs_2 != self.fs:
                raise ValueError('Sampling rates do not match.')
            if nearend_signal.ndim != 1 or farend_signal.ndim != 1:
                raise ValueError('Too many audio channels. The DTLN audio_generator \
                                 only supports single channel audio data.')
            # count the number of samples in one file
            num_samples = int(np.fix(nearend_signal.shape[0]/self.len_of_samples))
            # iterate over the number of samples
            
            selected_rirs=np.random.choice(rirs,2)
            nearend_rir=sf.read(selected_rirs[0])
            farend_rir=sf.read(selected_rirs[1])
            input_nearend_signal,input_farend_signal,output_discarded_nearend_speech=preprocess(nearend_signal,farend_signal,nearend_rir,farend_rir,fs)
            yield {"input_1": input_farend_signal.astype('float32'), "input_2": input_nearend_signal.astype('float32')},output_discarded_nearend_speech.astype('float32')
              

    def create_tf_data_obj(self):
        '''
        Method to to create the tf.data.Dataset. 
        '''

        # creating the tf.data.Dataset from the iterator
        self.tf_data_set = tf.data.Dataset.from_generator(
                        self.create_generator,
                        output_types=({"input_1": tf.float32, "input_2": tf.float32}, tf.float32),
                        args=None
                        )