# AudioEmbeddingExtractor Class

The `AudioEmbeddingExtractor` class is designed to handle the extraction of audio embeddings using the VGGish model from TensorFlow Hub, save these embeddings to a CSV file, and compute cosine similarity between the embeddings.

## Functions

- **__init__(self, audio_dir, embedding_file, vggish_model_url='https://tfhub.dev/google/vggish/1')**
  - Initializes the class with the directory of audio files, the file to save embeddings, and the VGGish model URL.

- **load_and_preprocess_audio(self, audio_file)**
  - Loads and preprocesses an audio file, normalizing the waveform.

- **get_embeddings(self, audio_file)**
  - Extracts the VGGish embeddings from an audio file and returns the mean of the embedding batch.

- **extract_embeddings(self)**
  - Extracts embeddings for all audio files in the specified directory and saves them to a CSV file.

- **compute_cosine_similarity(self)**
  - Computes and prints the cosine similarity matrix between the extracted embeddings.


In [3]:
from tqdm import tqdm
import tensorflow as tf
import numpy as np
import pandas as pd
import tensorflow_hub as hub
import librosa
import gc
import os
import csv
import glob
from sklearn.metrics.pairwise import cosine_similarity

class AudioEmbeddingExtractor:
    def __init__(self, audio_dir, embedding_file, vggish_model_url='https://tfhub.dev/google/vggish/1'):
        self.audio_dir = audio_dir
        self.embedding_file = embedding_file
        self.vggish_model = hub.load(vggish_model_url)
    
    def load_and_preprocess_audio(self, audio_file):
        """Loads and preprocesses an audio file."""
        waveform, sr = librosa.load(audio_file, sr=16000, mono=True)
        waveform = waveform / np.max(np.abs(waveform))
        return waveform
    
    def get_embeddings(self, audio_file):
        """Extracts embeddings from an audio file."""
        try:
            waveform = self.load_and_preprocess_audio(audio_file)
            waveform_tensor = tf.convert_to_tensor(waveform, dtype=tf.float32)
            embeddings = self.vggish_model(waveform_tensor)

            if isinstance(embeddings, dict):
                embeddings = embeddings['embedding']

            embedding_batch = embeddings.numpy()
            return np.mean(embedding_batch, axis=0)
        
        except Exception as e:
            print(f"Error processing {audio_file}: {e}")
            return None
    
    def extract_embeddings(self):
        """Extracts embeddings for all audio files in the directory and saves them to a CSV file."""
        audio_files = glob.glob(os.path.join(self.audio_dir, "*.wav"))

        with open(self.embedding_file, 'w', newline='') as f:
            writer = csv.writer(f)

            for audio_file in tqdm(audio_files, desc="Extracting embeddings"):
                embedding = self.get_embeddings(audio_file)
                
                if embedding is not None:
                    writer.writerow(embedding)

                gc.collect()

    def compute_cosine_similarity(self):
        """Computes and prints the cosine similarity between the extracted embeddings."""
        embeddings = pd.read_csv(self.embedding_file)
        embeddings.dropna(inplace=True)
        similarity_matrix = cosine_similarity(embeddings.T)
        print(similarity_matrix)


if __name__ == "__main__":
    audio_dir = "../../Datasets/ml-100k/Audio/"
    embedding_file = f"{audio_dir}embeddings.csv"

    extractor = AudioEmbeddingExtractor(audio_dir, embedding_file)
    extractor.extract_embeddings()
    extractor.compute_cosine_similarity()


Extracting embeddings: 100%|██████████| 132/132 [08:41<00:00,  3.95s/it]

[[ 1.         -0.84605974 -0.57733514 ...  0.97031921 -0.81763315
  -0.83027508]
 [-0.84605974  1.          0.31645597 ... -0.79676824  0.68237485
   0.80003714]
 [-0.57733514  0.31645597  1.         ... -0.65401476  0.50326707
   0.27250846]
 ...
 [ 0.97031921 -0.79676824 -0.65401476 ...  1.         -0.76700101
  -0.77202247]
 [-0.81763315  0.68237485  0.50326707 ... -0.76700101  1.
   0.75856906]
 [-0.83027508  0.80003714  0.27250846 ... -0.77202247  0.75856906
   1.        ]]



