In [1]:
# changed model input from dynamic to fixed 15600 samples using this: https://medium.com/@antonyharfield/converting-the-yamnet-audio-detection-model-for-tensorflow-lite-inference-43d049bd357c
# cloned aheartman's repo instead of the tensorflow model repo like he recommends in the Medium article

# transfer learned using this: https://www.tensorflow.org/tutorials/audio/transfer_learning_audio

# quantized using this: https://www.tensorflow.org/lite/performance/post_training_quantization#integer_only

# compiled to edge tflite using this: https://coral.ai/docs/edgetpu/compiler/#usage

In [1]:
import numpy as np
import pandas as pd

import tensorflow as tf
import tensorflow_io as tfio
import tensorflow_hub as hub

In [2]:
# load fixed input model created in aheartman's repo
yamnet_model = tf.keras.models.load_model("models/yamnet/tf")
yamnet_model2 = hub.load("https://tfhub.dev/google/yamnet/1")



In [9]:
testing_wav_file_name = tf.keras.utils.get_file(
	'miaow_16k.wav',
	'https://storage.googleapis.com/audioset/miaow_16k.wav',
	cache_dir='./',
	cache_subdir='test_data'
)

print(testing_wav_file_name)

./test_data\miaow_16k.wav


In [10]:
# Utility functions for loading audio files and making sure the sample rate is correct.

@tf.function
def load_wav_16k_mono(filename):
    """ Load a WAV file, convert it to a float tensor, resample to 16 kHz single-channel audio. """
    file_contents = tf.io.read_file(filename)
    wav, sample_rate = tf.audio.decode_wav(
		file_contents,
		desired_channels=1
	)
    wav = tf.squeeze(wav, axis=-1)
    sample_rate = tf.cast(sample_rate, dtype=tf.int64)
    wav = tfio.audio.resample(wav, rate_in=sample_rate, rate_out=16000)
    return wav

In [11]:
class_map_path = "models/yamnet/tf/assets/yamnet_class_map.csv"
class_names =list(pd.read_csv(class_map_path)['display_name'])

for name in class_names[:20]:
  print(name)
print('...')

Speech
Child speech, kid speaking
Conversation
Narration, monologue
Babbling
Speech synthesizer
Shout
Bellow
Whoop
Yell
Children shouting
Screaming
Whispering
Laughter
Baby laughter
Giggle
Snicker
Belly laugh
Chuckle, chortle
Crying, sobbing
...


In [12]:
testing_wav_data = load_wav_16k_mono(testing_wav_file_name)
print(testing_wav_data[0:15600])
print(tf.reshape(testing_wav_data[0:15600],(1,15600)))

tf.Tensor(
[-1.8715975e-08  6.4694795e-08 -1.3643448e-07 ...  2.4044040e-01
  1.7158213e-01  4.3755710e-02], shape=(15600,), dtype=float32)
tf.Tensor(
[[-1.8715975e-08  6.4694795e-08 -1.3643448e-07 ...  2.4044040e-01
   1.7158213e-01  4.3755710e-02]], shape=(1, 15600), dtype=float32)


In [15]:
# scores, other = yamnet_model(tf.reshape(testing_wav_data[15600:31200],(1,15600)))
scores, embeddings, spectrogram = yamnet_model2(testing_wav_data)
class_scores = tf.reduce_mean(scores, axis=0)
top_class = tf.argmax(class_scores)
inferred_class = class_names[top_class]

print(f'The main sound is: {inferred_class}')
print(f'The embeddings shape: {embeddings.shape}')

The main sound is: Animal
The embeddings shape: (13, 1024)


In [33]:
top5 = np.argsort(class_scores)[::-1][:5]
print(top5)
for i in top5:
	inferred_class = class_names[i]
	print(f'The main sound is: {inferred_class}')

[67 93 81 94 96]
The main sound is: Animal
The main sound is: Fowl
The main sound is: Livestock, farm animals, working animals
The main sound is: Chicken, rooster
The main sound is: Crowing, cock-a-doodle-doo


In [35]:
print(other)

tf.Tensor(
[[-2.688576   -3.395819   -3.3842807  ... -1.9334265  -1.6974654
  -1.9440112 ]
 [-1.7409291  -1.591264   -1.7010896  ... -2.055423   -1.6811252
  -1.8205918 ]
 [-1.6530831  -2.106609   -3.3813334  ... -2.2581878  -2.081408
  -1.6031291 ]
 ...
 [-2.439605   -1.66035    -1.6649653  ... -1.3464409  -1.5554621
  -1.2106203 ]
 [-2.2034495  -3.5700097  -2.8465383  ... -1.5448197  -1.2740288
  -1.0537565 ]
 [-2.8868978  -2.0260997  -1.8647212  ... -1.101402   -0.80602986
  -1.2401828 ]], shape=(96, 64), dtype=float32)


In [8]:
print(dir(yamnet_model2.signatures['serving_default']))

['__call__', '__class__', '__delattr__', '__dict__', '__dir__', '__doc__', '__eq__', '__format__', '__ge__', '__getattribute__', '__gt__', '__hash__', '__init__', '__init_subclass__', '__le__', '__lt__', '__module__', '__ne__', '__new__', '__reduce__', '__reduce_ex__', '__repr__', '__setattr__', '__sizeof__', '__str__', '__subclasshook__', '__weakref__', '_arg_keywords', '_as_name_attr_list', '_attrs', '_build_call_outputs', '_call_flat', '_call_impl', '_call_with_flat_signature', '_call_with_structured_signature', '_captured_closures', '_captured_inputs', '_delayed_rewrite_functions', '_experimental_with_cancellation_manager', '_filtered_call', '_first_order_tape_functions', '_flat_signature_summary', '_func_graph', '_function_spec', '_garbage_collector', '_get_gradient_function', '_higher_order_tape_functions', '_inference_function', '_initialize_function_spec', '_ndarray_singleton', '_ndarrays_list', '_num_positional_args', '_output_shapes', '_pre_initialized_function_spec', '_selec