In [8]:
import os
import sqlite3
import faiss
import numpy as np
from bitarray import bitarray
from lhotse import CutSet
from tqdm import tqdm
from util import *
import faiss, numpy as np, multiprocessing as mp
import librosa
from sklearn.decomposition import PCA

In [9]:
IN_DIR = "../datasets/LongSpeechSource/voxpopuli"
IN_DIR = "/mnt/d/voicedata/CommenVoice/delta"
# directory paths to save metadata and processed aduio files
OUT_DIR = '../datasets/LongSpeech'

In [10]:
def build_feature(cuts: CutSet, batch_size: int = 100, dim: int = 384):
    cut_list = cuts.to_eager()
    n = len(cut_list)

    vec_mm = np.memmap(f"{OUT_DIR}/vecs.f32", dtype="float32", mode="w+", shape=(n, dim))
    dur_mm = np.memmap(f"{OUT_DIR}/durs.f32", dtype="float32", mode="w+", shape=(n,))

    string_ids = []

    ptr = 0
    for i in tqdm(range(0, n, batch_size), desc="Get Embedding"):
        cut_batch = cut_list[i:i+batch_size]

        texts = [c.supervisions[0].text if c.supervisions else "" for c in cut_batch]
        durations = [c.duration for c in cut_batch]
        string_ids.extend([c.id for c in cut_batch])

        vec_np = get_sentence_embeddings(texts).astype("float32")
        B = len(cut_batch)

        vec_mm[ptr:ptr+B] = vec_np
        dur_mm[ptr:ptr+B] = durations
        ptr += B

    vec_mm.flush(); dur_mm.flush()

    return vec_mm, dur_mm, string_ids

In [4]:
cuts = CutSet.from_jsonl(os.path.join(OUT_DIR, "commonvoice_raw_cuts.jsonl"))
vec_mm, dur_mm, string_ids = build_feature(cuts)

Get Embedding: 100%|██████████| 2/2 [00:01<00:00,  1.84it/s]


In [5]:
def build_hnsw_index(vec_mm: np.memmap,
                     dim: int = 384,
                     m: int = 32,
                     ef_c: int = 200,
                     n_threads: int = mp.cpu_count(),
                     out_path: str = "cache_hnsw.faiss"):

    faiss.omp_set_num_threads(n_threads)
    faiss.normalize_L2(vec_mm)

    index = faiss.IndexHNSWFlat(dim, m)
    index.hnsw.efConstruction = ef_c
    index.metric_type = faiss.METRIC_INNER_PRODUCT

    index.add(vec_mm)
    faiss.write_index(index, os.path.join(OUT_DIR,out_path))
    return os.path.join(OUT_DIR,out_path)

In [6]:
index_path = build_hnsw_index(vec_mm)

In [16]:
def get_speaker_embedding_ids(ids, neighs, cuts):
    """
    获取邻居的说话人ID
    Returns:
        speaker_embeddings: (batch_num, feature_dim)
    """
    speaker_embeddings = []
    for idx in neighs:
        if idx == -1:
            break
        real_id = ids[idx]
        cut_pth = cuts[real_id].recording.sources[0].source
        audio, sr = librosa.load(cut_pth)
        speaker_embeddings.append(get_speaker_embedding(audio, sr).flatten())

    spk_emb_np = np.array(speaker_embeddings)
    pc1 = PCA(n_components=1, svd_solver="auto").fit_transform(spk_emb_np).ravel()
    return np.argsort(pc1)

print(get_speaker_embedding_ids(string_ids, [0, 2, 4, 6, 8], cuts))

[0 3 4 1 2]


In [21]:
def greedy_cluster(index_path: str,
                   vec_mm: np.memmap,
                   dur_mm: np.memmap,
                   ids,
                   cuts,
                   bucket_min: int = 300,
                   bucket_avg: int = 600,
                   k_neigh: int = 1024,
                   ef_s: int = 96):
    index = faiss.read_index(index_path)

    params = faiss.SearchParametersHNSW()
    params.efSearch = ef_s
    
    N = len(vec_mm)
    assigned = bitarray(N)
    assigned.setall(False)

    order = np.argsort(-dur_mm)
    buckets = []

    for seed in tqdm(order, desc="Clustering (Optimized)"):
        if assigned[seed]:
            continue

        cluster = []
        total_dur = 0

        unassigned_indices_list = assigned.search(bitarray('0'))
        unassigned_indices = np.fromiter(unassigned_indices_list, dtype=np.int64)


        if len(unassigned_indices) > 0:
            selector = faiss.IDSelectorArray(unassigned_indices)
            params.sel = selector

            _, neighs = index.search(vec_mm[seed : seed + 1], k_neigh, params=params)

            speaker_order = get_speaker_embedding_ids(ids, neighs[0].tolist(), cuts)
            #print(speaker_order)

            for idx2 in speaker_order:
                idx = neighs[0][idx2]
                if idx == -1:
                    break
                if assigned[idx]:
                    print("Warning: Already assigned index", idx)
                    continue

                cluster.append(int(idx))
                assigned[idx] = True
                total_dur += dur_mm[idx]
                if total_dur >= bucket_avg:
                    break

            if total_dur < bucket_min:
                for i in cluster:
                    assigned[i] = False
            else:
                total_dur = dur_mm[cluster].sum()
                buckets.append((cluster, total_dur))

    final_buckets = [b for b in buckets if b[1] >= bucket_min]
    final_clusters = [c for c, _ in final_buckets]
    final_duration = sum(sec for _, sec in final_buckets)

    loss = 1 - final_duration / dur_mm.sum()
    print(f"桶数 {len(final_clusters)}, 最终时长 {final_duration:.2f}s, 总时长 {dur_mm.sum():.2f}s, 丢弃比例 {loss:.2%}")

    strategy = []
    for cluster in final_clusters:
        strategy.append([ids[i] for i in cluster])

    return strategy
greedy_cluster(index_path, vec_mm, dur_mm, string_ids, cuts)

Clustering (Optimized):   1%|          | 1/168 [01:03<2:56:23, 63.37s/it]

[ 12  92  24 115   5   9 120 138 129 119 130  51  95  30  10  74 134 126
  88  39 135 128 162 159 110 112 148 103  65  47  71  69  45  60  70  83
  17  16 105  78  86  75  77   2  34  33   0 102 140 163  81 109  55  98
  76  73  46 141  20 157  85 104  26  23  52 133  84 142 164 137  40  63
  97 114 149   1 143 122 116   4  57 165 158  27 131  14 144 111  62 145
   6 150  91   3 160 154  18  28  31  96 147  68  19 152 107 118  25 127
 125 153 132  50  80 146  42  93  66 156  72 166  64  89  21 117  15 151
 108  43  49 101  35 113 106   8  99  53  61  56  11  13  44  32  37 124
  94  41  67  59  54  22  90 100  79 136 123   7  36  82  48 155 121 139
 167  29  87 161  38  58]


Clustering (Optimized): 100%|██████████| 168/168 [01:29<00:00,  1.87it/s]

[14 44 20 22 11 39  2 63 29 35 27 64 42 51  7 59 46 21 54 52 10  4 45 66
  9  0 19 48 26 58 12 33  8 61 36 37 31  5  6 24 62 65 49 57 40  1 32 15
 47 25 17 34 13 28  3 30 60 38 16 43 23 50 67 18 56 41 55 53]
桶数 2, 最终时长 1027.54s, 总时长 1027.54s, 丢弃比例 0.00%





[['common_voice_en_42788438-96',
  'common_voice_en_42807958-114',
  'common_voice_en_42788567-98',
  'common_voice_en_42791979-48',
  'common_voice_en_42814003-56',
  'common_voice_en_42814559-63',
  'common_voice_en_42814462-62',
  'common_voice_en_42814024-57',
  'common_voice_en_42807943-113',
  'common_voice_en_42797769-55',
  'common_voice_en_42808017-116',
  'common_voice_en_42807963-115',
  'common_voice_en_42788519-97',
  'common_voice_en_42815757-73',
  'common_voice_en_42798328-2',
  'common_voice_en_42787815-33',
  'common_voice_en_42806811-20',
  'common_voice_en_42789325-44',
  'common_voice_en_42814267-53',
  'common_voice_en_42791629-11',
  'common_voice_en_42814135-32',
  'common_voice_en_42713893-146',
  'common_voice_en_42815717-72',
  'common_voice_en_42812836-49',
  'common_voice_en_42812198-59',
  'common_voice_en_42752135-7',
  'common_voice_en_42799557-8',
  'common_voice_en_42840257-136',
  'common_voice_en_42805882-14',
  'common_voice_en_42767185-39',
  'comm