<a href="https://colab.research.google.com/github/shaja-asm/cry-detection/blob/main/tf_lite_model.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
import os
import numpy as np
import librosa
import librosa.display
# import matplotlib.pyplot as plt
# import matplotlib.pyplot as plt
import tensorflow as tf
from tensorflow.keras.preprocessing.image import ImageDataGenerator
from tensorflow.keras.models import Sequential
from tensorflow.keras.layers import Conv2D, MaxPooling2D, Flatten, Dense, Dropout, BatchNormalization, LSTM
from sklearn.model_selection import train_test_split
from sklearn.metrics import accuracy_score, f1_score
from sklearn.utils.class_weight import compute_class_weight
from tensorflow.keras.utils import Sequence
import datetime
from tensorflow.keras.callbacks import ModelCheckpoint, ReduceLROnPlateau, TensorBoard, EarlyStopping
from tensorflow.keras.optimizers import Adam
from scipy.ndimage import zoom
import ctypes
from kerastuner.tuners import RandomSearch

# gpus = tf.config.experimental.list_physical_devices('GPU')
# if gpus:
#     try:
#         for gpu in gpus:
#             tf.config.experimental.set_memory_growth(gpu, True)
#     except RuntimeError as e:
#         print(e)
# print("Num GPUs Available: ", len(tf.config.experimental.list_physical_devices('GPU')))

In [2]:
AUDIO_PATH = 'CryCorpusFinal'
CRY_FOLDER = os.path.join(AUDIO_PATH, 'cry/augmented')
NOTCRY_FOLDER = os.path.join(AUDIO_PATH, 'notcry')
IMG_SIZE = (64, 64)
BATCH_SIZE = 32
EPOCHS = 25
MODEL = 'cnn' # Choice: 'cnn' or 'lstm'

In [3]:
def load_audio_files(folder):
    files = []
    for filename in os.listdir(folder):
        if filename.endswith('.wav'):
            files.append(os.path.join(folder, filename))
    return files

def compute_spectrogram(y, sr, n_fft=2048, hop_length=512):
    D = librosa.stft(y, n_fft=n_fft, hop_length=hop_length)
    D_dB = librosa.amplitude_to_db(np.abs(D), ref=np.max)
    return D_dB

def save_spectrogram_to_disk(D_dB, save_path):
    if not os.path.exists(os.path.dirname(save_path)):
        os.makedirs(os.path.dirname(save_path))
    np.save(save_path, D_dB)


In [18]:
cry_files = load_audio_files(CRY_FOLDER)
notcry_files = load_audio_files(NOTCRY_FOLDER)

data = []
labels = []

for idx, file in enumerate(cry_files):
    y, sr = librosa.load(file, sr=22050)
    y = librosa.util.normalize(y)
    D_dB = compute_spectrogram(y, sr)
    save_path = os.path.join('{0}/spectrograms'.format(AUDIO_PATH), f'cry_{idx}.npy'.format(AUDIO_PATH))
    save_spectrogram_to_disk(D_dB, save_path)
    data.append(save_path)
    labels.append(1)

for idx, file in enumerate(notcry_files):
    y, sr = librosa.load(file, sr=22050)
    y = librosa.util.normalize(y)
    D_dB = compute_spectrogram(y, sr)
    save_path = os.path.join('{0}/spectrograms'.format(AUDIO_PATH), f'notcry_{idx}.npy'.format(AUDIO_PATH))
    save_spectrogram_to_disk(D_dB, save_path)
    data.append(save_path)
    labels.append(0)

data = np.array(data)
labels = np.array(labels)



In [19]:
# Split data
X_train, X_val, y_train, y_val = train_test_split(data, labels, test_size=0.2, random_state=42)

# Improved Data Generator
class OnTheFlyDataGenerator(tf.keras.utils.Sequence):
    def __init__(self, file_paths, labels, batch_size, img_size, shuffle=True, augment=False, is_lstm=False):
        self.file_paths = file_paths
        self.labels = labels
        self.batch_size = batch_size
        self.img_size = img_size
        self.shuffle = shuffle
        self.augment = augment
        self.is_lstm = is_lstm
        self.indices = np.arange(len(self.file_paths))
        self.on_epoch_end()

    def __len__(self):
        # Number of batches per epoch
        return int(np.floor(len(self.file_paths) / self.batch_size))

    def __getitem__(self, index):
        # Get batch indices
        batch_indices = self.indices[index * self.batch_size:(index + 1) * self.batch_size]
        batch_file_paths = [self.file_paths[i] for i in batch_indices]
        batch_labels = [self.labels[i] for i in batch_indices]

        X, y = self.__data_generation(batch_file_paths, batch_labels)
        return X, y

    def on_epoch_end(self):
        # Shuffle indices at the end of each epoch
        if self.shuffle:
            np.random.shuffle(self.indices)

    def __data_generation(self, batch_file_paths, batch_labels):
        # Create empty arrays for data and labels
        X = np.empty((len(batch_file_paths), *self.img_size, 1), dtype=np.float32)
        y = np.empty((len(batch_file_paths),), dtype=int)

        for i, file_path in enumerate(batch_file_paths):
            # Load data from file
            D_dB = np.load(file_path)
            D_dB = D_dB[..., np.newaxis]  # Add channel dimension

            # Resizing
            # Resizing
            zoom_factors = [self.img_size[0] / D_dB.shape[0], self.img_size[1] / D_dB.shape[1], 1]
            D_dB = zoom(D_dB, zoom_factors, order=3)  # Cubic interpolation

            # Augmentation
            if self.augment:
                if np.random.rand() > 0.5:
                    D_dB = np.flip(D_dB, axis=1)  # Flip left-right
                if np.random.rand() > 0.5:
                    D_dB = np.flip(D_dB, axis=0)  # Flip up-down
                if np.random.rand() > 0.5:
                    D_dB = D_dB + np.random.uniform(-0.2, 0.2, size=D_dB.shape)  # Random brightness

            X[i,] = D_dB
            y[i] = batch_labels[i]

        if self.is_lstm:
            # Reshape to (batch_size, time_steps, features) for LSTM
            X = X.reshape(len(batch_file_paths), self.img_size[1], self.img_size[0])

        return X, y

train_generator = OnTheFlyDataGenerator(X_train, y_train, BATCH_SIZE, IMG_SIZE, shuffle=True, augment=True)
val_generator = OnTheFlyDataGenerator(X_val, y_val, BATCH_SIZE, IMG_SIZE, shuffle=False, augment=False)

# l2 regularization
l2_regularizer = tf.keras.regularizers.l2(0.001)

    # Third Conv Block
    model.add(Conv2D(hp.Int('filters_3', min_value=128, max_value=512, step=128), (3, 3), activation='relu', kernel_regularizer=l2_regularizer))
    model.add(BatchNormalization())
    model.add(MaxPooling2D((2, 2)))
    model.add(Dropout(hp.Float('dropout_3', min_value=0.2, max_value=0.5, step=0.1)))

optimizer = Adam(learning_rate=1e-4)
model.compile(optimizer=optimizer, loss='binary_crossentropy', metrics=['accuracy'])
log_dir = "logs/fit/" + datetime.datetime.now().strftime("%Y%m%d-%H%M%S")
tensorboard_callback = TensorBoard(log_dir=log_dir, histogram_freq=1, profile_batch='500,520')
checkpoint_callback = ModelCheckpoint('cry_detection_model.keras', monitor='val_loss', save_best_only=True, mode='min')
lr_callback = ReduceLROnPlateau(monitor='val_loss', factor=0.5, patience=5, min_lr=1e-6)
early_stopping_callback = EarlyStopping(monitor='val_loss', patience=10, restore_best_weights=True)
class_weights = {0: 1., 1: 1.}

history = model.fit(
    train_generator,
    epochs=EPOCHS,
    validation_data=val_generator,
    class_weight=class_weights,
    callbacks=[tensorboard_callback, checkpoint_callback, lr_callback, early_stopping_callback]
)


  super().__init__(activity_regularizer=activity_regularizer, **kwargs)
2024-08-14 12:58:35.535648: I external/local_tsl/tsl/profiler/lib/profiler_session.cc:103] Profiler session initializing.
2024-08-14 12:58:35.535723: I external/local_tsl/tsl/profiler/lib/profiler_session.cc:118] Profiler session started.
2024-08-14 12:58:35.540533: I external/local_tsl/tsl/profiler/lib/profiler_session.cc:130] Profiler session tear down.


Epoch 1/25


  self._warn_if_super_not_called()


[1m63/63[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m39s[0m 532ms/step - accuracy: 0.5910 - loss: 1.4495 - val_accuracy: 0.6064 - val_loss: 0.9590 - learning_rate: 1.0000e-04
Epoch 2/25
[1m63/63[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m41s[0m 616ms/step - accuracy: 0.8003 - loss: 0.8413 - val_accuracy: 0.6365 - val_loss: 0.9630 - learning_rate: 1.0000e-04
Epoch 3/25
[1m63/63[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m36s[0m 537ms/step - accuracy: 0.8633 - loss: 0.6978 - val_accuracy: 0.8614 - val_loss: 0.6663 - learning_rate: 1.0000e-04
Epoch 4/25
[1m63/63[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m34s[0m 510ms/step - accuracy: 0.9020 - loss: 0.6153 - val_accuracy: 0.9357 - val_loss: 0.5231 - learning_rate: 1.0000e-04
Epoch 5/25
[1m63/63[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m35s[0m 524ms/step - accuracy: 0.9165 - loss: 0.5758 - val_accuracy: 0.8695 - val_loss: 0.7338 - learning_rate: 1.0000e-04
Epoch 6/25
[1m63/63[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[

2024-08-14 13:03:14.225486: I external/local_tsl/tsl/profiler/lib/profiler_session.cc:103] Profiler session initializing.
2024-08-14 13:03:14.225560: I external/local_tsl/tsl/profiler/lib/profiler_session.cc:118] Profiler session started.


[1m63/63[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m33s[0m 491ms/step - accuracy: 0.9490 - loss: 0.5127 - val_accuracy: 0.9478 - val_loss: 0.4930 - learning_rate: 1.0000e-04
Epoch 9/25
[1m15/63[0m [32m━━━━[0m[37m━━━━━━━━━━━━━━━━[0m [1m26s[0m 550ms/step - accuracy: 0.9255 - loss: 0.5408

2024-08-14 13:03:32.903861: I external/local_tsl/tsl/profiler/lib/profiler_session.cc:68] Profiler session collecting data.


[1m16/63[0m [32m━━━━━[0m[37m━━━━━━━━━━━━━━━[0m [1m30s[0m 649ms/step - accuracy: 0.9256 - loss: 0.5410

2024-08-14 13:03:33.725551: I external/local_tsl/tsl/profiler/lib/profiler_session.cc:130] Profiler session tear down.
2024-08-14 13:03:33.739643: I external/local_tsl/tsl/profiler/rpc/client/save_profile.cc:147] Collecting XSpace to repository: logs/fit/20240814-125835/plugins/profile/2024_08_14_13_03_33/TEC-LAP-47.xplane.pb


[1m63/63[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m41s[0m 616ms/step - accuracy: 0.9396 - loss: 0.5197 - val_accuracy: 0.9277 - val_loss: 0.5982 - learning_rate: 1.0000e-04
Epoch 10/25
[1m63/63[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m40s[0m 605ms/step - accuracy: 0.9560 - loss: 0.4766 - val_accuracy: 0.9759 - val_loss: 0.4340 - learning_rate: 1.0000e-04
Epoch 11/25
[1m63/63[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m40s[0m 595ms/step - accuracy: 0.9593 - loss: 0.4867 - val_accuracy: 0.9317 - val_loss: 0.5883 - learning_rate: 1.0000e-04
Epoch 12/25
[1m63/63[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m40s[0m 588ms/step - accuracy: 0.9516 - loss: 0.4936 - val_accuracy: 0.9237 - val_loss: 0.6275 - learning_rate: 1.0000e-04
Epoch 13/25
[1m63/63[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m41s[0m 610ms/step - accuracy: 0.9563 - loss: 0.4689 - val_accuracy: 0.9518 - val_loss: 0.4935 - learning_rate: 1.0000e-04
Epoch 14/25
[1m63/63[0m [32m━━━━━━━━━━━━━━━━━━━━

In [20]:
y_pred = model.predict(val_generator)
y_pred = (y_pred > 0.5).astype(int)
acc = accuracy_score(y_val, y_pred)
f1 = f1_score(y_val, y_pred)

print(f'Accuracy: {acc}')
print(f'F1 Score: {f1}')

model.save('cry_detection_model.keras')

[1m16/16[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m5s[0m 309ms/step
Accuracy: 0.9819277108433735
F1 Score: 0.9833024118738405


In [21]:
import pathlib

# Create directory for TFLite models
tflite_models_dir = pathlib.Path("tflite_models")
tflite_models_dir.mkdir(exist_ok=True, parents=True)

# Convert to TFLite
converter = tf.lite.TFLiteConverter.from_keras_model(model)

# Allow Select TF Ops for both CNN and LSTM models
converter.target_spec.supported_ops = [tf.lite.OpsSet.TFLITE_BUILTINS, tf.lite.OpsSet.SELECT_TF_OPS]

# Disable experimental lowering of tensor list ops
converter._experimental_lower_tensor_list_ops = False

# Convert the model
tflite_model = converter.convert()

# Save the model
tflite_model_file = tflite_models_dir / "cry_detection_model.tflite"
tflite_model_file.write_bytes(tflite_model)

# Apply optimizations and convert again
converter.optimizations = [tf.lite.Optimize.DEFAULT]

tflite_fp16_model = converter.convert()
tflite_model_fp16_file = tflite_models_dir / "cry_detection_model_quant.tflite"
tflite_model_fp16_file.write_bytes(tflite_fp16_model)

print("TFLite conversion successful!")



INFO:tensorflow:Assets written to: /tmp/tmpyeyequh6/assets


INFO:tensorflow:Assets written to: /tmp/tmpyeyequh6/assets


Saved artifact at '/tmp/tmpyeyequh6'. The following endpoints are available:

* Endpoint 'serve'
  args_0 (POSITIONAL_ONLY): TensorSpec(shape=(None, 64, 64, 1), dtype=tf.float32, name='keras_tensor_17')
Output Type:
  TensorSpec(shape=(None, 1), dtype=tf.float32, name=None)
Captures:
  139628529365968: TensorSpec(shape=(), dtype=tf.resource, name=None)
  139628247370560: TensorSpec(shape=(), dtype=tf.resource, name=None)
  139628384618800: TensorSpec(shape=(), dtype=tf.resource, name=None)
  139628383380480: TensorSpec(shape=(), dtype=tf.resource, name=None)
  139628247370032: TensorSpec(shape=(), dtype=tf.resource, name=None)
  139628384617920: TensorSpec(shape=(), dtype=tf.resource, name=None)
  139628384981536: TensorSpec(shape=(), dtype=tf.resource, name=None)
  139628384978192: TensorSpec(shape=(), dtype=tf.resource, name=None)
  139628384977136: TensorSpec(shape=(), dtype=tf.resource, name=None)
  139628384986816: TensorSpec(shape=(), dtype=tf.resource, name=None)
  1396283849861

W0000 00:00:1723621405.668818  109388 tf_tfl_flatbuffer_helpers.cc:392] Ignored output_format.
W0000 00:00:1723621405.670454  109388 tf_tfl_flatbuffer_helpers.cc:395] Ignored drop_control_dependency.
2024-08-14 13:13:25.673627: I tensorflow/cc/saved_model/reader.cc:83] Reading SavedModel from: /tmp/tmpyeyequh6
2024-08-14 13:13:25.675883: I tensorflow/cc/saved_model/reader.cc:52] Reading meta graph with tags { serve }
2024-08-14 13:13:25.675906: I tensorflow/cc/saved_model/reader.cc:147] Reading SavedModel debug info (if present) from: /tmp/tmpyeyequh6
2024-08-14 13:13:25.701262: I tensorflow/cc/saved_model/loader.cc:236] Restoring SavedModel bundle.
2024-08-14 13:13:25.807451: I tensorflow/cc/saved_model/loader.cc:220] Running initialization op on SavedModel bundle at path: /tmp/tmpyeyequh6
2024-08-14 13:13:25.833167: I tensorflow/cc/saved_model/loader.cc:462] SavedModel load for tags { serve }; Status: success: OK. Took 159956 microseconds.


INFO:tensorflow:Assets written to: /tmp/tmp7vwy0olh/assets


INFO:tensorflow:Assets written to: /tmp/tmp7vwy0olh/assets


Saved artifact at '/tmp/tmp7vwy0olh'. The following endpoints are available:

* Endpoint 'serve'
  args_0 (POSITIONAL_ONLY): TensorSpec(shape=(None, 64, 64, 1), dtype=tf.float32, name='keras_tensor_17')
Output Type:
  TensorSpec(shape=(None, 1), dtype=tf.float32, name=None)
Captures:
  139628529365968: TensorSpec(shape=(), dtype=tf.resource, name=None)
  139628247370560: TensorSpec(shape=(), dtype=tf.resource, name=None)
  139628384618800: TensorSpec(shape=(), dtype=tf.resource, name=None)
  139628383380480: TensorSpec(shape=(), dtype=tf.resource, name=None)
  139628247370032: TensorSpec(shape=(), dtype=tf.resource, name=None)
  139628384617920: TensorSpec(shape=(), dtype=tf.resource, name=None)
  139628384981536: TensorSpec(shape=(), dtype=tf.resource, name=None)
  139628384978192: TensorSpec(shape=(), dtype=tf.resource, name=None)
  139628384977136: TensorSpec(shape=(), dtype=tf.resource, name=None)
  139628384986816: TensorSpec(shape=(), dtype=tf.resource, name=None)
  1396283849861

W0000 00:00:1723621407.306275  109388 tf_tfl_flatbuffer_helpers.cc:392] Ignored output_format.
W0000 00:00:1723621407.306327  109388 tf_tfl_flatbuffer_helpers.cc:395] Ignored drop_control_dependency.
2024-08-14 13:13:27.306551: I tensorflow/cc/saved_model/reader.cc:83] Reading SavedModel from: /tmp/tmp7vwy0olh
2024-08-14 13:13:27.308194: I tensorflow/cc/saved_model/reader.cc:52] Reading meta graph with tags { serve }
2024-08-14 13:13:27.308221: I tensorflow/cc/saved_model/reader.cc:147] Reading SavedModel debug info (if present) from: /tmp/tmp7vwy0olh
2024-08-14 13:13:27.321664: I tensorflow/cc/saved_model/loader.cc:236] Restoring SavedModel bundle.
2024-08-14 13:13:27.456026: I tensorflow/cc/saved_model/loader.cc:220] Running initialization op on SavedModel bundle at path: /tmp/tmp7vwy0olh
2024-08-14 13:13:27.477812: I tensorflow/cc/saved_model/loader.cc:462] SavedModel load for tags { serve }; Status: success: OK. Took 171267 microseconds.


696352

In [22]:
# Initialize the TFLite interpreter
interpreter = tf.lite.Interpreter(model_path="tflite_models/cry_detection_model_quant.tflite")
interpreter.allocate_tensors()

# Get input and output tensors
input_details = interpreter.get_input_details()
output_details = interpreter.get_output_details()

def preprocess_audio(file_path, img_size):
    y, sr = librosa.load(file_path, sr=None)
    y = librosa.util.normalize(y)
    D = librosa.stft(y, n_fft=2048, hop_length=512)
    D_dB = librosa.amplitude_to_db(np.abs(D), ref=np.max)

    # Calculate zoom factors for resizing
    zoom_factors = [img_size[0] / D_dB.shape[0], img_size[1] / D_dB.shape[1]]
    D_dB_resized = zoom(D_dB, zoom_factors, order=3)  # Cubic interpolation

    # Add channel dimension to match the original function's output
    if not is_lstm:
        D_dB_resized = D_dB_resized[..., np.newaxis]

    return D_dB_resized

def predict(file_path, img_size=IMG_SIZE):
    input_data = preprocess_audio(file_path, img_size)
    input_data = np.expand_dims(input_data, axis=0).astype(np.float32)

    # Set the tensor to point to the input data to be inferred
    interpreter.set_tensor(input_details[0]['index'], input_data)

    # Run inference
    interpreter.invoke()

    output_data = interpreter.get_tensor(output_details[0]['index'])

    return output_data

def process_folder(folder_path, img_size=IMG_SIZE):
    correct_predictions = 0
    total_files = 0
    results = []

    for file_name in os.listdir(folder_path):
        if file_name.endswith('.wav'):
            file_path = os.path.join(folder_path, file_name)
            prediction = predict(file_path, img_size, is_lstm)
            prediction_label = 'Cry' if prediction > 0.5 else 'Not Cry'
            results.append((file_name, prediction_label))
            ground_truth = 'Cry' if '_cry.wav' in file_name else 'Not Cry'

            if prediction_label == ground_truth:
                correct_predictions += 1

            total_files += 1

    accuracy = (correct_predictions / total_files) * 100 if total_files > 0 else 0

    return results, accuracy

folder_path = '{0}/Test'.format(AUDIO_PATH)
predictions, accuracy = process_folder(folder_path)

for file_name, prediction_label in predictions:
    print(f"File: {file_name}, Prediction: {prediction_label}")

print(f"Prediction Accuracy: {accuracy:.2f}%")


File: P19_612_notcry.wav, Prediction: Not Cry
File: P26_829_cry.wav, Prediction: Not Cry
File: P29_2405_cry.wav, Prediction: Cry
File: P29_62_cry.wav, Prediction: Cry
File: P26_7_cry.wav, Prediction: Cry
File: P36_14_notcry.wav, Prediction: Not Cry
File: P29_35_cry.wav, Prediction: Cry
File: P29_1714_cry.wav, Prediction: Cry
File: P29_724_cry.wav, Prediction: Cry
File: P26_9_cry.wav, Prediction: Cry
File: P29_348_cry.wav, Prediction: Not Cry
File: P20_388_cry.wav, Prediction: Not Cry
File: P26_824_cry.wav, Prediction: Cry
File: P29_773_cry.wav, Prediction: Cry
File: P29_1564_cry.wav, Prediction: Not Cry
File: P19_607_notcry.wav, Prediction: Not Cry
File: P29_1873_cry.wav, Prediction: Not Cry
File: P20_895_cry.wav, Prediction: Not Cry
File: P20_802_cry.wav, Prediction: Cry
File: P29_2090_cry.wav, Prediction: Cry
File: P20_919_cry.wav, Prediction: Not Cry
File: P17_41_cry.wav, Prediction: Not Cry
File: P29_1452_cry.wav, Prediction: Cry
File: P36_52_notcry.wav, Prediction: Not Cry
File: P

In [23]:
lib = ctypes.cdll.LoadLibrary('{0}/libtensorflowlite_c.so'.format(AUDIO_PATH))

# Define types for the C API functions
lib.TfLiteModelCreate.restype = ctypes.POINTER(ctypes.c_void_p)
lib.TfLiteInterpreterCreate.restype = ctypes.POINTER(ctypes.c_void_p)
lib.TfLiteInterpreterOptionsCreate.restype = ctypes.POINTER(ctypes.c_void_p)
lib.TfLiteInterpreterOptionsSetNumThreads.argtypes = [ctypes.POINTER(ctypes.c_void_p), ctypes.c_int]
lib.TfLiteInterpreterOptionsDelete.argtypes = [ctypes.POINTER(ctypes.c_void_p)]
lib.TfLiteInterpreterDelete.argtypes = [ctypes.POINTER(ctypes.c_void_p)]
lib.TfLiteModelDelete.argtypes = [ctypes.POINTER(ctypes.c_void_p)]
lib.TfLiteInterpreterGetInputTensor.restype = ctypes.POINTER(ctypes.c_void_p)
lib.TfLiteInterpreterGetOutputTensor.restype = ctypes.POINTER(ctypes.c_void_p)

model_path = b"tflite_models/cry_detection_model_quant.tflite"
with open(model_path, 'rb') as f:
    model_data = f.read()

model = lib.TfLiteModelCreate(ctypes.c_char_p(model_data), ctypes.c_size_t(len(model_data)))

# Create interpreter options and set number of threads
options = lib.TfLiteInterpreterOptionsCreate()
lib.TfLiteInterpreterOptionsSetNumThreads(options, 2)

# Create the interpreter with the custom options
interpreter = lib.TfLiteInterpreterCreate(model, options)

# Allocate tensors
status = lib.TfLiteInterpreterAllocateTensors(interpreter)

# Get input and output tensor details
input_tensor = lib.TfLiteInterpreterGetInputTensor(interpreter, 0)
output_tensor = lib.TfLiteInterpreterGetOutputTensor(interpreter, 0)

# def preprocess_audio(file_path, img_size):
#     y, sr = librosa.load(file_path, sr=None)
#     y = librosa.util.normalize(y)
#     D = librosa.stft(y, n_fft=2048, hop_length=512)
#     D_dB = librosa.amplitude_to_db(np.abs(D), ref=np.max)

#     # Rescale the spectrogram to the target img_size
#     # zoom_factors = [img_size[0] / D_dB.shape[0], img_size[1] / D_dB.shape[1]]
#     # D_dB_resized = zoom(D_dB, zoom_factors).astype(np.float32)

#     # Resize using TensorFlow
#     # D_dB_resized = tf.image.resize(D_dB[..., np.newaxis], img_size).numpy()
#     # D_dB_resized = np.squeeze(D_dB_resized, axis=-1).astype(np.float32)

#     # Convert the spectrogram to an image
#     D_dB_img = Image.fromarray(D_dB)

#     # Resize the image using PIL with LANCZOS resampling
#     D_dB_resized = D_dB_img.resize(img_size, Image.Resampling.LANCZOS)

#     # Convert back to NumPy array
#     D_dB_resized = np.array(D_dB_resized).astype(np.float32)

#     return D_dB_resized

def preprocess_audio(file_path, img_size):
    y, sr = librosa.load(file_path, sr=None)
    y = librosa.util.normalize(y)
    D = librosa.stft(y, n_fft=2048, hop_length=512)
    D_dB = librosa.amplitude_to_db(np.abs(D), ref=np.max)

    # Calculate zoom factors for resizing
    zoom_factors = [img_size[0] / D_dB.shape[0], img_size[1] / D_dB.shape[1]]
    D_dB_resized = zoom(D_dB, zoom_factors, order=3)  # Cubic interpolation

    # Add channel dimension for CNN, keep 3D shape for LSTM
    if not is_lstm:
        D_dB_resized = D_dB_resized[..., np.newaxis]

    return D_dB_resized

def predict(file_path, img_size=(64, 64)):
    input_data = preprocess_audio(file_path, img_size)
    input_data = np.expand_dims(input_data, axis=0).astype(np.float32)

    # Set the tensor to point to the input data to be inferred
    lib.TfLiteTensorCopyFromBuffer(input_tensor, input_data.ctypes.data_as(ctypes.POINTER(ctypes.c_float)), ctypes.c_size_t(input_data.nbytes))

    # Run inference
    lib.TfLiteInterpreterInvoke(interpreter)

    # Extract output data
    output_size = 1
    output_size = 1
    output_data = np.empty(output_size, dtype=np.float32)
    lib.TfLiteTensorCopyToBuffer(output_tensor, output_data.ctypes.data_as(ctypes.POINTER(ctypes.c_float)), ctypes.c_size_t(output_data.nbytes))

    return output_data

def process_folder(folder_path, img_size=IMG_SIZE):
    correct_predictions = 0
    total_files = 0
    results = []

    # Initialize counters for F1 score calculations
    true_positives = 0
    false_positives = 0
    false_negatives = 0

    for file_name in os.listdir(folder_path):
        if file_name.endswith('.wav'):
            file_path = os.path.join(folder_path, file_name)
            prediction = predict(file_path, img_size, is_lstm)
            prediction_label = 'Cry' if prediction > 0.5 else 'Not Cry'
            results.append((file_name, prediction_label))
            ground_truth = 'Cry' if '_cry.wav' in file_name else 'Not Cry'

            if prediction_label == ground_truth:
                correct_predictions += 1
                if prediction_label == 'Cry':
                    true_positives += 1
            else:
                if prediction_label == 'Cry':
                    false_positives += 1
                elif prediction_label == 'Not Cry' and ground_truth == 'Cry':
                    false_negatives += 1

            total_files += 1

    accuracy = (correct_predictions / total_files) * 100 if total_files > 0 else 0

    # Calculate precision, recall, F1 score
    precision = true_positives / (true_positives + false_positives) if (true_positives + false_positives) > 0 else 0
    recall = true_positives / (true_positives + false_negatives) if (true_positives + false_negatives) > 0 else 0
    f1_score = (2 * precision * recall) / (precision + recall) if (precision + recall) > 0 else 0

folder_path = '{0}/Test'.format(AUDIO_PATH)
predictions, accuracy = process_folder(folder_path)

for file_name, prediction_label in predictions:
    print(f"File: {file_name}, Prediction: {prediction_label}")

print(f"Prediction Accuracy: {accuracy:.2f}%")
print(f"F1 Score: {f1_score:.2f}")
print(f"False Negative Percentage: {false_negative_percentage:.2f}%")

# Clean up
lib.TfLiteInterpreterDelete(interpreter)
lib.TfLiteInterpreterOptionsDelete(options)
lib.TfLiteModelDelete(model)

print("All operations completed successfully.")


File: P19_612_notcry.wav, Prediction: Not Cry
File: P26_829_cry.wav, Prediction: Not Cry
File: P29_2405_cry.wav, Prediction: Cry
File: P29_62_cry.wav, Prediction: Cry
File: P26_7_cry.wav, Prediction: Cry
File: P36_14_notcry.wav, Prediction: Not Cry
File: P29_35_cry.wav, Prediction: Cry
File: P29_1714_cry.wav, Prediction: Cry
File: P29_724_cry.wav, Prediction: Cry
File: P26_9_cry.wav, Prediction: Cry
File: P29_348_cry.wav, Prediction: Not Cry
File: P20_388_cry.wav, Prediction: Not Cry
File: P26_824_cry.wav, Prediction: Cry
File: P29_773_cry.wav, Prediction: Cry
File: P29_1564_cry.wav, Prediction: Not Cry
File: P19_607_notcry.wav, Prediction: Not Cry
File: P29_1873_cry.wav, Prediction: Not Cry
File: P20_895_cry.wav, Prediction: Not Cry
File: P20_802_cry.wav, Prediction: Cry
File: P29_2090_cry.wav, Prediction: Cry
File: P20_919_cry.wav, Prediction: Not Cry
File: P17_41_cry.wav, Prediction: Not Cry
File: P29_1452_cry.wav, Prediction: Cry
File: P36_52_notcry.wav, Prediction: Not Cry
File: P