In [9]:
import tensorflow as tf
import tensorflow_io as tfio
import json

In [4]:
testing_wav_file_name = tf.keras.utils.get_file(
    'miaow_16k.wav',
    'https://storage.googleapis.com/audioset/miaow_16k.wav',                                                cache_dir='./',                                                cache_subdir='datasets/test_data'
)


In [5]:
@tf.function
def load_wav_16k_mono(filename):
    """ Load a WAV file, convert it to a float tensor, resample to 16 kHz single-channel audio. """
    file_contents = tf.io.read_file(filename)
    wav, sample_rate = tf.audio.decode_wav(
        file_contents,
        desired_channels=1
    )
    wav = tf.squeeze(wav, axis=-1)
    sample_rate = tf.cast(sample_rate, dtype=tf.int64)
    wav = tfio.audio.resample(wav, rate_in=sample_rate, rate_out=16000)
    return wav

@tf.function
def frame_16k_mono(filename):
    wav = load_wav_16k_mono(filename)
    frames = tf.signal.frame(wav, 15600, 15600)
    return frames

In [6]:
# testing_wav_data = load_wav_16k_mono(testing_wav_file_name)
frames = frame_16k_mono(testing_wav_file_name)



In [17]:
'''
Create interpreter, allocate tensors
'''
tflite_interpreter = tf.lite.Interpreter(model_path='models/yamnet/tfhub/cpu.tflite')
tflite_interpreter.allocate_tensors()

'''
Check input/output details
'''
input_details = tflite_interpreter.get_input_details()
output_details = tflite_interpreter.get_output_details()

print("== Input details ==")
print("name:", input_details[0]['name'])
print("shape:", input_details[0]['shape'])
print("type:", input_details[0]['dtype'])
print("\n== Output details ==")
print("name:", output_details[0]['name'])
print("shape:", output_details[0]['shape'])
print("type:", output_details[0]['dtype'])

'''
Run prediction (optional), input_array has input's shape and dtype
'''
tflite_interpreter.set_tensor(input_details[0]['index'], frames[1])
tflite_interpreter.invoke()
output_array = tflite_interpreter.get_tensor(output_details[0]['index'])
print("\n")

'''
This gives a list of dictionaries. 
'''
tensor_details = tflite_interpreter.get_tensor_details()

for dict in tensor_details:
    # index = dict['index']
    # name = dict['name']
    # shape = dict['shape']
    # dtype = dict['dtype']
    # qp = dict['quantization_parameters']
    # scales = qp['scales']    
    # zero_points = dict['quantization_parameters']['zero_points']
    # tensor = tflite_interpreter.tensor(i)()

    # print(i, type, name, scales.shape, zero_points.shape, tensor.shape)
    # json.dumps(d, indent=4)

    for key, value in dict.items():
        print(f'{key}: {value},')
    print("\n")

    '''
    See note below
    '''


== Input details ==
name: waveform_binary
shape: [15600]
type: <class 'numpy.float32'>

== Output details ==
name: tower0/network/layer32/final_output
shape: [  1 521]
type: <class 'numpy.float32'>


name: waveform_binary,
index: 0,
shape: [15600],
shape_signature: [15600],
dtype: <class 'numpy.float32'>,
quantization: (0.0, 0),
quantization_parameters: {'scales': array([], dtype=float32), 'zero_points': array([], dtype=int32), 'quantized_dimension': 0},
sparsity_parameters: {},


name: stft/frame/zeros_like,
index: 1,
shape: [1],
shape_signature: [1],
dtype: <class 'numpy.int32'>,
quantization: (0.0, 0),
quantization_parameters: {'scales': array([], dtype=float32), 'zero_points': array([], dtype=int32), 'quantized_dimension': 0},
sparsity_parameters: {},


name: stft/frame/concat,
index: 2,
shape: [1],
shape_signature: [1],
dtype: <class 'numpy.int32'>,
quantization: (0.0, 0),
quantization_parameters: {'scales': array([], dtype=float32), 'zero_points': array([], dtype=int32), 'quantiz