In [1]:
%reload_ext autoreload
%autoreload 2

In [None]:
#%%writefile ../speakerlib.py
#!/usr/bin/env python

'''
    pip install speechbrain faiss torchaudio numpy
'''

# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
import os, numpy as np, torchaudio, faiss, pickle, mangorest 
from   speechbrain.pretrained import EncoderClassifier

# Initialize the speaker embedding model
classifier = EncoderClassifier.from_hparams(
    source="speechbrain/spkrec-ecapa-voxceleb",
    savedir="tmp_model"
)

# Directory to store speaker embeddings
EMBEDDINGS_DIR = "~/data/speaker_embeddings"
os.makedirs(EMBEDDINGS_DIR, exist_ok=True)

# Embedding storage (faiss index)
INDEX_FILE = os.path.join(EMBEDDINGS_DIR, "speaker_index.pkl")
EMBEDDING_DIM = 192  # Dimension of embeddings from ECAPA-TDNN


def extract_embedding(audio_path):
    """Extracts the speaker embedding from an audio file."""
    signal, fs = torchaudio.load(audio_path)
    embedding = classifier.encode_batch(signal).squeeze(0).detach().numpy()
    return embedding


def add_speaker(audio_path, speaker_name):
    """Adds a speaker embedding to the library."""
    embedding = extract_embedding(audio_path)
    embedding_path = os.path.join(EMBEDDINGS_DIR, f"{speaker_name}.npy")
    np.save(embedding_path, embedding)
    print(f"Speaker '{speaker_name}' added to the library.")


def load_library():
    """Loads all speaker embeddings into a FAISS index."""
    embeddings = []
    speaker_names = []
    
    for file in os.listdir(EMBEDDINGS_DIR):
        if file.endswith(".npy"):
            speaker_name = file.replace(".npy", "")
            speaker_names.append(speaker_name)
            embedding = np.load(os.path.join(EMBEDDINGS_DIR, file))
            embeddings.append(embedding)

    if embeddings:
        embeddings = np.array(embeddings).astype("float32")
        index = faiss.IndexFlatL2(EMBEDDING_DIM)
        index.add(embeddings)
        with open(INDEX_FILE, "wb") as f:
            pickle.dump({"index": index, "speaker_names": speaker_names}, f)
        print("Library loaded successfully.")
    else:
        print("No speakers in the library.")


def recognize_speaker(audio_path, top_k=5):
    """Recognizes the speaker from the input audio."""
    if not os.path.exists(INDEX_FILE):
        print("Speaker library not found. Please load the library first.")
        return

    with open(INDEX_FILE, "rb") as f:
        data = pickle.load(f)
        index = data["index"]
        speaker_names = data["speaker_names"]

    # Extract embedding for input audio
    query_embedding = extract_embedding(audio_path).astype("float32").reshape(1, -1)

    # Perform a similarity search
    distances, indices = index.search(query_embedding, top_k)

    # Display results
    for rank, idx in enumerate(indices[0]):
        if idx < len(speaker_names):
            print(f"Rank {rank + 1}: {speaker_names[idx]} (distance: {distances[0][rank]:.2f})")


# Example usage
if __name__ == "__main__" and not mangorest.mango.inJupyter():
    # Add speakers
    add_speaker("path_to_speaker1_audio.wav", "Speaker1")
    add_speaker("path_to_speaker2_audio.wav", "Speaker2")
 
    # Load library
    load_library()

    # Recognize speaker
    recognize_speaker("path_to_unknown_audio.wav", top_k=3)
