In [20]:
import chromadb
import numpy as np
import torch
import wespeaker
import pandas as pd
import os
from pathlib import Path
import soundfile as sf

In [8]:
df_train = pd.read_csv("dataset/train.csv")
df_test = pd.read_csv("dataset/test.csv")

train_audio_files = df_train["filename"].to_list()
test_audio_files = df_test["filename"].to_list()

In [21]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
pretrain_dir = "voxblink2_samresnet34"

In [24]:
def extract_embeddings(df, audio_files, device, pretrain_dir):
    """
    Extracts embeddings from audio files using the WeSpeaker model
    """
    model = wespeaker.load_model_local(pretrain_dir)
    model.set_device(device)

    embeddings = []

    for audio_file in audio_files:
        file_path = "dataset/" + audio_file
        
        data, sample_rate = sf.read(file_path)
        pcm = torch.from_numpy(data).float()
        
        if len(pcm.shape) == 1:
            pcm = pcm.unsqueeze(0)  
        elif len(pcm.shape) == 2:
            pcm = pcm.transpose(0, 1) 

        embedding = model.extract_embedding_from_pcm(pcm, sample_rate)

        embedding = embedding.cpu().numpy()
        embeddings.append({
            "file_path": file_path,
            "embedding": embedding,
            "label": df[df["filename"] == audio_file]["age"].item()
        })

    return embeddings

In [27]:
train_embeddings = extract_embeddings(df_train, train_audio_files, device, pretrain_dir)

In [25]:
test_embeddings = extract_embeddings(df_test, test_audio_files, device, pretrain_dir)

In [29]:
def save_to_chromadb(embeddings, db_path, split):
    """
    Stores embeddings in ChromaDB
    """
    client = chromadb.PersistentClient(path=db_path)
    collection = client.get_or_create_collection(name="gender_embeddings")

    collection.add(
        ids=[f"{split}_{i}" for i in range(len(embeddings))],
        embeddings=[item['embedding'] for item in embeddings],
        metadatas=[{
            "file_path": item['file_path'], "label": item['label'],
            "split": split
        }
            for item in embeddings]
    )

In [30]:
save_to_chromadb(train_embeddings, "chromaDB", "train")
save_to_chromadb(test_embeddings, "chromaDB", "test")