In [40]:
import soundfile as sf
import numpy as np
import tensorflow.lite as tflite
import time

In [41]:
block_len = 512
block_shift = 128

In [42]:
interpreter_1 = tflite.Interpreter(model_path='./weights/model_1.tflite')
interpreter_1.allocate_tensors()
interpreter_2 = tflite.Interpreter(model_path='./weights/model_2.tflite')
interpreter_2.allocate_tensors()

In [43]:
input_details_1 = interpreter_1.get_input_details()
output_details_1 = interpreter_1.get_output_details()

input_details_2 = interpreter_2.get_input_details()
output_details_2 = interpreter_2.get_output_details()

In [44]:
states_1 = np.zeros(input_details_1[2]['shape']).astype('float32')
states_2 = np.zeros(input_details_2[2]['shape']).astype('float32')

In [45]:
farend,fs = sf.read('farend_speech_fileid_0.wav')
nearend,fs = sf.read('nearend_mic_fileid_0.wav')

In [46]:
if fs != 16000:
    raise ValueError('This model only supports 16k sampling rate.')

In [47]:
out_file = np.zeros((len(nearend)))

In [48]:
farend_in_buffer = np.zeros((block_len)).astype('float32')
nearend_in_buffer = np.zeros((block_len)).astype('float32')
out_buffer = np.zeros((block_len)).astype('float32')

In [49]:
num_blocks = (nearend.shape[0] - (block_len-block_shift)) // block_shift
time_array = []

In [50]:
for idx in range(num_blocks):
    start_time = time.time()
    # shift values and write to buffer
    farend_in_buffer[:-block_shift] = farend_in_buffer[block_shift:]
    farend_in_buffer[-block_shift:] = farend[idx*block_shift:(idx*block_shift)+block_shift]
    nearend_in_buffer[:-block_shift] = nearend_in_buffer[block_shift:]
    nearend_in_buffer[-block_shift:] = nearend[idx*block_shift:(idx*block_shift)+block_shift]
    # calculate fft of input block
    farend_in_block_fft = np.fft.rfft(np.squeeze(farend_in_buffer)).astype("complex64")
    nearend_in_block_fft = np.fft.rfft(np.squeeze(nearend_in_buffer)).astype("complex64")
    farend_in_mag = np.abs(farend_in_block_fft)
    nearend_in_mag = np.abs(nearend_in_block_fft)
    nearend_in_phase = np.angle(nearend_in_block_fft)
    # reshape magnitude to input dimensions
    farend_in_mag = np.reshape(farend_in_mag, (1,1,-1)).astype('float32')
    nearend_in_mag = np.reshape(nearend_in_mag, (1,1,-1)).astype('float32')
    # set tensors to the first model
    interpreter_1.set_tensor(input_details_1[0]['index'], farend_in_mag)
    interpreter_1.set_tensor(input_details_1[1]['index'], nearend_in_mag)
    interpreter_1.set_tensor(input_details_1[2]['index'], states_1)
    # run calculation 
    interpreter_1.invoke()
    # get the output of the first block
    out_mask = interpreter_1.get_tensor(output_details_1[0]['index']) 
    states_1 = interpreter_1.get_tensor(output_details_1[1]['index']) 
    # calculate the ifft
    estimated_complex = nearend_in_block_fft * out_mask
    estimated_block = np.fft.irfft(estimated_complex)
    # reshape the time domain block
    estimated_block = np.reshape(estimated_block, (1,1,-1)).astype('float32')
    # set tensors to the second block
    interpreter_2.set_tensor(input_details_2[0]['index'], estimated_block)
    interpreter_2.set_tensor(input_details_2[1]['index'], np.reshape(farend_in_buffer, (1,1,-1)).astype('float32'))
    interpreter_2.set_tensor(input_details_2[2]['index'], states_2)
    # run calculation
    interpreter_2.invoke()
    # get output tensors
    out_block = interpreter_2.get_tensor(output_details_2[0]['index']) 
    states_2 = interpreter_2.get_tensor(output_details_2[1]['index']) 
    # shift values and write to buffer
    out_buffer[:-block_shift] = out_buffer[block_shift:]
    out_buffer[-block_shift:] = np.zeros((block_shift))
    out_buffer  += np.squeeze(out_block)
    # write block to output file
    out_file[idx*block_shift:(idx*block_shift)+block_shift] = out_buffer[:block_shift]
    time_array.append(time.time()-start_time)

In [51]:
predicted_speech = out_file[
        (block_len - block_shift) : (block_len - block_shift) + len(nearend)
    ]

In [52]:
# write to .wav file 
sf.write('samples/tfwav.wav', out_file, fs) 
print('Processing Time [ms]:')
print(np.mean(np.stack(time_array))*1000)
print('Processing finished.')

Processing Time [ms]:
0.8687379846220797
Processing finished.
