In [7]:
import torch
import torchaudio
from speechbrain.pretrained import EncoderClassifier

# Use a GPU if available, otherwise use CPU
device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Using device: {device}")

# Load the pre-trained ECAPA-TDNN model from SpeechBrain (via Hugging Face)
# This model is trained for speaker recognition.
try:
    classifier = EncoderClassifier.from_hparams(
        source="speechbrain/spkrec-ecapa-voxceleb",
        savedir="pretrained_models/spkrec-ecapa-voxceleb",
        run_opts={"device": device}
    )
except Exception as e:
    print(f"Error loading model: {e}")
    print("Please ensure you have a stable internet connection for the initial download.")
    exit()

# Define the cosine similarity metric
cosine_similarity = torch.nn.CosineSimilarity(dim=-1)

def create_voiceprint(audio_file_path):
    """
    Loads an audio file and computes its speaker embedding using the ECAPA-TDNN model.
    """
    try:
        signal, fs = torchaudio.load(audio_file_path)
        # Ensure the audio is mono and at the correct sample rate (16kHz)
        if fs != 16000:
            resampler = torchaudio.transforms.Resample(orig_freq=fs, new_freq=16000)
            signal = resampler(signal)
        if signal.shape[0] > 1:
            signal = torch.mean(signal, dim=0, keepdim=True)
            
        with torch.no_grad():
            embedding = classifier.encode_batch(signal)
            # Squeeze to remove batch dimension and normalize
            embedding = torch.nn.functional.normalize(embedding, p=2, dim=2)
            return embedding.squeeze()
    except Exception as e:
        print(f"Could not process file {audio_file_path}: {e}")
        return None

def identify_speaker(unknown_segment_path, voiceprint_database, similarity_threshold=0.50):
    """
    Compares an unknown audio segment to a database of known voiceprints.
    """
    unknown_embedding = create_voiceprint(unknown_segment_path)
    if unknown_embedding is None:
        return "Error", 0.0

    max_score = 0
    identified_speaker = "Unknown"

    for speaker_name, known_embedding in voiceprint_database.items():
        # Calculate the similarity score
        score = cosine_similarity(unknown_embedding, known_embedding).item()
        
        print(f"  Comparing with {speaker_name}, Score: {score:.2f}")

        if score > max_score:
            max_score = score
            if score > similarity_threshold:
                identified_speaker = speaker_name

    return identified_speaker, max_score

# --- Main execution block ---
if __name__ == "__main__":
    # 1. ENROLLMENT: Define the known speakers and their sample audio files.
    # Replace these paths with the actual paths to your audio files.
    # The audio should be a clear recording of the person's voice.
    known_speakers = {
            # Example: "Priya": "/path/to/priya_sample.wav"
            "spk1": "/Users/sujanh/Downloads/data2/spk1.mp3",
            "spk2": "/Users/sujanh/Downloads/data2/spk2.mp3",
            "spk3": "/Users/sujanh/Downloads/data2/spk3.mp3",
            "spk4": "/Users/sujanh/Downloads/data2/spk4.mp3"
        }

    print("--- Creating Voiceprint Database ---")
    voiceprint_db = {}
    for name, path in known_speakers.items():
        print(f"Processing enrollment for {name}...")
        embedding = create_voiceprint(path)
        if embedding is not None:
            voiceprint_db[name] = embedding
    print("--- Voiceprint Database Created ---\n")

    # 2. IDENTIFICATION: List the diarized audio segments you want to identify.
    # These would be the outputs from your diarization tool.
    diarized_segments = [
        "/Users/sujanh/Documents/github/NewIdea/Assembly titanet/Gender and Politics Panel Discussion_merged_speakers/SPEAKER_A_merged.mp3",
        "/Users/sujanh/Documents/github/NewIdea/Assembly titanet/Gender and Politics Panel Discussion_merged_speakers/SPEAKER_B_merged.mp3",
        "/Users/sujanh/Documents/github/NewIdea/Assembly titanet/Gender and Politics Panel Discussion_merged_speakers/SPEAKER_C_merged.mp3",
        "/Users/sujanh/Documents/github/NewIdea/Assembly titanet/Gender and Politics Panel Discussion_merged_speakers/SPEAKER_D_merged.mp3",
        "/Users/sujanh/Documents/github/NewIdea/Assembly titanet/Gender and Politics Panel Discussion_merged_speakers/SPEAKER_E_merged.mp3"

    ]

    print("--- Identifying Speakers in Diarized Segments ---")
    for segment_path in diarized_segments:
        print(f"\nIdentifying speaker for [{segment_path}]...")
        speaker, score = identify_speaker(segment_path, voiceprint_db)
        print(f"==> Result: '{segment_path}' is identified as {speaker} with a confidence score of {score:.2f}\n")

Using device: cpu
--- Creating Voiceprint Database ---
Processing enrollment for spk1...
Processing enrollment for spk2...
Processing enrollment for spk3...
Processing enrollment for spk4...
--- Voiceprint Database Created ---

--- Identifying Speakers in Diarized Segments ---

Identifying speaker for [/Users/sujanh/Documents/github/NewIdea/Assembly titanet/Gender and Politics Panel Discussion_merged_speakers/SPEAKER_A_merged.mp3]...
  Comparing with spk1, Score: 0.92
  Comparing with spk2, Score: 0.35
  Comparing with spk3, Score: 0.28
  Comparing with spk4, Score: 0.20
==> Result: '/Users/sujanh/Documents/github/NewIdea/Assembly titanet/Gender and Politics Panel Discussion_merged_speakers/SPEAKER_A_merged.mp3' is identified as spk1 with a confidence score of 0.92


Identifying speaker for [/Users/sujanh/Documents/github/NewIdea/Assembly titanet/Gender and Politics Panel Discussion_merged_speakers/SPEAKER_B_merged.mp3]...
  Comparing with spk1, Score: 0.35
  Comparing with spk2, Score