In [1]:
import os
from pathlib import Path
import soundfile as sf

import chromadb
import numpy as np
import torch
import wespeaker

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
def get_audio_path(audio_dir):
    """
    Recursively finds all audio files in the specified directory.
    """
    audio_dir = Path(audio_dir)
    audio_files = list(audio_dir.glob('**/*.wav')) + list(
        audio_dir.glob('**/*.mp3'))

    return audio_files

In [3]:
def extract_embeddings(audio_files, device, pretrain_dir):
    """
    Extracts embeddings from audio files using the WeSpeaker model
    """
    model = wespeaker.load_model_local(pretrain_dir)
    model.set_device(device)

    embeddings = []

    for file_path in audio_files:
        data, sample_rate = sf.read(file_path)

        pcm = torch.from_numpy(data).float()

        if len(pcm.shape) == 1:
            pcm = pcm.unsqueeze(0)
        elif len(pcm.shape) == 2:
            pcm = pcm.transpose(0, 1)
        
        if pcm.shape[1] <= 0:
            continue

        
        embedding = model.extract_embedding_from_pcm(pcm, sample_rate)

        embedding = embedding.cpu().numpy()
        embeddings.append({
            'file_path': str(file_path),
            'embedding': embedding
        })

    return embeddings

In [4]:
def assign_labels(embeddings):
    """
    Assigns labels to classes. In this case, by the name of the parent folder.
    """
    for emb in embeddings:
        class_name = Path(emb['file_path']).parent.name
        emb['label'] = class_name

In [5]:
def save_to_chromadb(embeddings, db_path, split):
    """
    Stores embeddings in ChromaDB
    """
    client = chromadb.PersistentClient(path=db_path)
    collection = client.get_or_create_collection(name="gender_embeddings")

    collection.add(
        ids=[f"{split}_{i}" for i in range(len(embeddings))],
        embeddings=[item['embedding'] for item in embeddings],
        metadatas=[{
            "file_path": item['file_path'], "label": item['label'],
            "split": split
        }
            for item in embeddings]
    )

In [6]:
train_audio_files = get_audio_path("D:/PPSpeech/interp_dev/SpeechDisorder/SpeechDisorderDataset")
test_audio_files = get_audio_path("D:/PPSpeech/interp_dev/SpeechDisorder/TESTSpeechDisorderDataset")

In [7]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [11]:
test_embeddings = extract_embeddings(test_audio_files, device,
                                         "D:/PPSpeech/interp_dev/SpeechDisorder/voxblink/voxblink2_samresnet34")

In [9]:
train_embeddings = extract_embeddings(train_audio_files, device,
                                          "D:/PPSpeech/interp_dev/SpeechDisorder/voxblink/voxblink2_samresnet34")

In [12]:
assign_labels(train_embeddings)
assign_labels(test_embeddings)

In [13]:
save_to_chromadb(train_embeddings, "D:\PPSpeech\interp_dev\SpeechDisorder\chromadb", split="train")
save_to_chromadb(test_embeddings, "D:\PPSpeech\interp_dev\SpeechDisorder\chromadb", split="test")