# Importing

In [None]:
# from google.colab import drive
# drive.mount('/content/drive', force_remount=True)

ROOT = 'drive/MyDrive/diarization' # путь до папки diarization

Mounted at /content/drive


In [None]:
import os
import sys
import glob2
from tqdm import tqdm

In [None]:
import numpy as np
import matplotlib.pyplot as plt

In [None]:
import torch
import torch.nn as nn
import torchaudio

In [None]:
!pip install https://github.com/pyannote/pyannote-audio/archive/develop.zip

In [None]:
from pyannote.core import Annotation, Timeline, Segment, SlidingWindow
from pyannote.database.util import load_rttm
from pyannote.audio import Model
from pyannote.audio.pipelines import VoiceActivityDetection, OverlappedSpeechDetection
from pyannote.metrics.diarization import DiarizationErrorRate, JaccardErrorRate

In [None]:
!pip install onnxruntime
!pip install speechbrain

In [None]:
from drive.MyDrive.diarization.data_io import load_audio
from drive.MyDrive.diarization.backend import transform_embeddings, prepare_plda
from drive.MyDrive.diarization.embedder.brno.wrapper import prepare_model_brno
from drive.MyDrive.diarization.embedder.clova.wrapper import prepare_model_clova
from drive.MyDrive.diarization.embedder.speechbrain.wrapper import prepare_model_speechbrain
from drive.MyDrive.diarization.segmentation import split_segments, split_overlap_part
from drive.MyDrive.diarization.features import extract_embeddings
from sklearn.cluster import AgglomerativeClustering
from drive.MyDrive.diarization.clustering import VB_diarization, VB_diarization_UP

# Voice activity and overlapped speech detection

## Functions

In [None]:
device = 'cuda:0'
SAMPLE_RATE = 16000
np.random.seed(0)

# unused
# Parameters: https://huggingface.co/pyannote/segmentation#reproducible-research
HYPER_PARAMETERS = {
  # onset/offset activation thresholds
  "onset": 0.5, "offset": 0.5,
  # remove speech regions shorter than that many seconds.
  "min_duration_on": 0.1,
  # fill non-speech regions shorter than that many seconds.
  "min_duration_off": 0.1
}

In [None]:
def get_annotations(data_root, dataset_name, data_type, HYPER_PARAMETERS):
    if dataset_name == 'ami':
        uris = []

        for wav in glob2.glob(os.path.join(data_root, '*.rttm')):
            uris.append(os.path.splitext(os.path.basename(wav))[0])
        uri2path = {uri: f"amicorpus/{uri}/audio/{uri}.Mix-Headset.wav" for uri in uris}
    else:
        if dataset_name == 'aishell':
            extention = 'wav/*.flac'
        elif dataset_name == 'voxconverse':
            extention = f'voxconverse_{data_type}_wav/*.wav' # data_type \in {test, dev}
            if data_type == 'dev':
                extention = 'audio/*.wav'
        else:
            print("unknown dataset_name")
            return
            
        wav_list = glob2.glob(os.path.join(data_root, extention))
        uri2path = {os.path.splitext(os.path.basename(wav))[0]: wav for wav in wav_list}
    uri2ann_ref = {}


    if dataset_name == 'aishell':
        extention = 'TextGrid/'
    elif dataset_name == 'voxconverse':
        extention = data_type+'/' # data_type \in {test, dev}
        data_root = 'voxconverse'
    elif dataset_name == 'ami':
        extention = ""

    for uri in uri2path:
        uri2ann_ref.update(load_rttm(os.path.join(data_root, f'{extention}{uri}.rttm')))
    
    # unused
    vad_osd_joint = Model.from_pretrained(f"{ROOT}/pretrained/frontend/pytorch_model.bin")
    vad_model = VoiceActivityDetection(segmentation=vad_osd_joint)
    vad_model.instantiate(HYPER_PARAMETERS)
    osd_model = OverlappedSpeechDetection(segmentation=vad_osd_joint)
    osd_model.instantiate(HYPER_PARAMETERS)

    VAD_ORACLE = True
    uri2vad = {}
    uri2osd = {}
    for uri, wav_path in uri2path.items():
        
        if VAD_ORACLE:
            vad = uri2ann_ref[uri].get_timeline().support()
            osd = vad.get_overlap()
        else:
            ann_vad = vad_model(uri2path[uri])
            vad = ann_vad.get_timeline().support()
            ann_osd = osd_model(uri2path[uri]) 
            osd = ann_osd.get_timeline().support()
            
        uri2vad[uri] = vad
        uri2osd[uri] = osd
      
    np.save(f'{ROOT}/annotations/{dataset_name}/{dataset_name}_{data_type}_uri2vad.npy', uri2vad)
    np.save(f'{ROOT}/annotations/{dataset_name}/{dataset_name}_{data_type}_uri2osd.npy', uri2osd)
    np.save(f'{ROOT}/annotations/{dataset_name}/{dataset_name}_{data_type}_uri2path.npy', uri2path)
    np.save(f'{ROOT}/annotations/{dataset_name}/{dataset_name}_{data_type}_uri2ann_ref.npy', uri2ann_ref)

    return vad_model, osd_model

## Getting Aishell4 Annotations



In [None]:
!wget https://www.openslr.org/resources/111/test.tar.gz # AISHELL
!tar xfvz test.tar.gz # for AISHELL
!rm -r test.tar.gz
data_root_aishell4_test = 'test'

get_annotations(data_root_aishell4_test, 'aishell', 'test', HYPER_PARAMETERS)

!rm -r test

In [None]:
!wget https://www.openslr.org/resources/111/train_M.tar.gz # AISHELL
!tar xfvz train_M.tar.gz # for AISHELL
!rm -r train_M.tar.gz
data_root_aishell4_train_M = 'train_M'

get_annotations(data_root_aishell4_train_M, 'aishell', 'train_M', HYPER_PARAMETERS)

!rm -r train_M

In [None]:
!wget https://www.openslr.org/resources/111/train_S.tar.gz # AISHELL
!tar xfvz train_S.tar.gz # for AISHELL
!rm -r train_S.tar.gz
data_root_aishell4_train_S = 'train_S'

get_annotations(data_root_aishell4_train_S, 'aishell', 'train_S', HYPER_PARAMETERS)

!rm -r train_S

In [None]:
!wget https://www.openslr.org/resources/111/train_L.tar.gz # AISHELL
!tar xfvz train_L.tar.gz # for AISHELL
!rm -r train_L.tar.gz
data_root_aishell4_train_L = 'train_L'

get_annotations(data_root_aishell4_train_L, 'aishell', 'train_L', HYPER_PARAMETERS)

!rm -r train_L

## Getting Voxconverse Annotations

In [None]:
!wget https://www.robots.ox.ac.uk/~vgg/data/voxconverse/data/voxconverse_test_wav.zip # VOXCONVERSE
!unzip voxconverse_test_wav.zip -d voxconverse_test_wav # for VOXCONVERSE
!rm -r voxconverse_test_wav.zip
!git clone https://github.com/joonson/voxconverse 
data_root_voxconverse_test = 'voxconverse_test_wav'

get_annotations(data_root_voxconverse_test, 'voxconverse', 'test', HYPER_PARAMETERS)

!rm -r voxconverse_test_wav

In [None]:
!wget https://www.robots.ox.ac.uk/~vgg/data/voxconverse/data/voxconverse_dev_wav.zip # VOXCONVERSE
!unzip voxconverse_dev_wav.zip -d voxconverse_dev_wav # for VOXCONVERSE
!rm -r voxconverse_dev_wav.zip
!git clone https://github.com/joonson/voxconverse 
data_root_voxconverse_test = 'voxconverse_dev_wav'

get_annotations(data_root_voxconverse_test, 'voxconverse', 'dev', HYPER_PARAMETERS)

!rm -r voxconverse_dev_wav
!rm -r voxconverse 

## Getting Ami Annotations

In [None]:
!wget https://github.com/pyannote/AMI-diarization-setup/blob/main/pyannote/download_ami.sh # AMI
!git clone https://github.com/pyannote/AMI-diarization-setup ami # for AMI
!sh ami/pyannote/download_ami.sh

data_root_ami_test = 'ami/only_words/rttms/test'
get_annotations(data_root_ami_test, 'ami', 'test', HYPER_PARAMETERS)

data_root_ami_dev = 'ami/only_words/rttms/dev'
get_annotations(data_root_ami_dev, 'ami', 'dev', HYPER_PARAMETERS)

data_root_ami_train = 'ami/only_words/rttms/train'
get_annotations(data_root_ami_train, 'ami', 'train', HYPER_PARAMETERS)

!rm -r ami
!rm -r amicorpus

# Embeddings extraction

## Extraction functions

In [None]:
def get_embeddings(embedder_name, dataset_name, data_type, uri2path, uri2vad, uri2osd, device='cuda:0', skip_overlap=True):
    model_path_brno = f'{ROOT}/pretrained/brno/VBx/models/ResNet101_16kHz/nnet/raw_81.pth'
    model_path_clova = f'{ROOT}/pretrained/clova/baseline_v2_ap.model'
    model_path_speechbrain = f'{ROOT}/pretrained/speechbrain/embedding_model.ckpt'

    if embedder_name == 'brno':
        emb_model = prepare_model_brno(model_path_brno, device, 'onnx' if model_path_brno.endswith('onnx') else 'pytorch')

    elif embedder_name == 'clova':
        emb_model = prepare_model_clova(model_path_clova, device)

    elif embedder_name == 'speechbrain':
        emb_model = prepare_model_speechbrain(model_path_speechbrain, device)

    win_size = 2.0
    step_size = 1.0

    uri2data = {}
    it = 1
    errors = []
    for uri, wav_path in tqdm(uri2path.items()):
        try:
            vad = uri2vad[uri]
            if skip_overlap:
                osd = uri2osd[uri]
                vad = vad.extrude(osd).support() # exclude segments with overlapped speech
            
            waveform = load_audio(wav_path)
            
            segments = split_segments(vad, win_size, step_size)
            embeddings = extract_embeddings(emb_model, waveform, segments, device, batch_size=1)
            
            uri2data[uri] = (embeddings, segments)
        except:
            errors.append(it)
        it += 1
    print(f'{embedder_name} {dataset_name} {data_type} errors: {errors}')
    np.save(f'{ROOT}/embeddings/skip_overlap={str(skip_overlap)}/{dataset_name}/{dataset_name}_{data_type}_uri2data_{embedder_name}.npy', uri2data)

In [None]:
def get_all_embeddings(dataset_name, data_type, skip_overlap=True):
    for embedder_name in ['brno', 'clova', 'speechbrain']:
        uri2vad = np.load(f'{ROOT}/annotations/{dataset_name}/{dataset_name}_{data_type}_uri2vad.npy', allow_pickle=True).item()
        uri2osd = np.load(f'{ROOT}/annotations/{dataset_name}/{dataset_name}_{data_type}_uri2osd.npy', allow_pickle=True).item()
        uri2path = np.load(f'{ROOT}/annotations/{dataset_name}/{dataset_name}_{data_type}_uri2path.npy', allow_pickle=True).item()

        get_embeddings(embedder_name, dataset_name, data_type, uri2path, uri2vad, uri2osd, skip_overlap=skip_overlap)

In [None]:
SKIP_OVERLAP = False # True

## Getting Aishell4 Embeddings

In [None]:
!wget https://www.openslr.org/resources/111/test.tar.gz # AISHELL
!tar xfvz test.tar.gz # for AISHELL
!rm -r test.tar.gz

dataset_name, data_type = 'aishell', 'test'

get_all_embeddings(dataset_name, data_type, SKIP_OVERLAP)

!rm -r test

In [None]:
!wget https://www.openslr.org/resources/111/train_M.tar.gz # AISHELL
!tar xfvz train_M.tar.gz # for AISHELL
!rm -r train_M.tar.gz

dataset_name, data_type = 'aishell', 'train_M'

get_all_embeddings(dataset_name, data_type, SKIP_OVERLAP)

!rm -r train_M

In [None]:
!wget https://www.openslr.org/resources/111/train_L.tar.gz # AISHELL
!tar xfvz train_L.tar.gz # for AISHELL
!rm -r train_L.tar.gz

dataset_name, data_type = 'aishell', 'train_L'

get_all_embeddings(dataset_name, data_type, SKIP_OVERLAP)

!rm -r train_L

In [None]:
!wget https://www.openslr.org/resources/111/train_S.tar.gz # AISHELL
!tar xfvz train_S.tar.gz # for AISHELL
!rm -r train_S.tar.gz

dataset_name, data_type = 'aishell', 'train_S'

get_all_embeddings(dataset_name, data_type, SKIP_OVERLAP)

!rm -r train_S

In [None]:
# создаём dev set из смеси train_L, train_M, train_S

for skip_overlap in [True, False]:
    for embedder_name in ['brno', 'clova', 'speechbrain']:
        dev_uri2data, dev_uri2ann_ref = {}, {}
        for train_type in ['train_L', 'train_M', 'train_S']:
            train_uri2ann_ref = np.load(
                f'{ROOT}/annotations/aishell/aishell_{train_type}_uri2ann_ref.npy', 
                allow_pickle=True
            ).item()
            train_uri2data = np.load(
                f'{ROOT}/embeddings/skip_overlap={skip_overlap}/aishell/aishell_{train_type}_uri2data_{embedder_name}.npy',
                allow_pickle=True
            ).item()
            for i, key in enumerate(train_uri2data.keys(), start=1):
                try:
                    dev_uri2data[key] = train_uri2data[key]
                    dev_uri2ann_ref[key] = train_uri2ann_ref[key]
                except:
                    continue
                if i > 6:
                    break
        np.save(f'{ROOT}/embeddings/skip_overlap={skip_overlap}/aishell/aishell_dev_uri2data_{embedder_name}.npy', dev_uri2data)
        np.save(f'{ROOT}/annotations/aishell/aishell_dev_uri2ann_ref.npy', dev_uri2ann_ref)

## Getting Voxconverse Embeddings

In [None]:
!wget https://www.robots.ox.ac.uk/~vgg/data/voxconverse/data/voxconverse_test_wav.zip # VOXCONVERSE
!unzip voxconverse_test_wav.zip -d voxconverse_test_wav # for VOXCONVERSE
!rm -r voxconverse_test_wav.zip
# !git clone https://github.com/joonson/voxconverse 

dataset_name, data_type = 'voxconverse', 'test'

get_all_embeddings(dataset_name, data_type, SKIP_OVERLAP)

!rm -r voxconverse_test_wav

In [None]:
!wget https://www.robots.ox.ac.uk/~vgg/data/voxconverse/data/voxconverse_dev_wav.zip # VOXCONVERSE
!unzip voxconverse_dev_wav.zip -d voxconverse_dev_wav # for VOXCONVERSE
!rm -r voxconverse_dev_wav.zip

dataset_name, data_type = 'voxconverse', 'dev'

get_all_embeddings(dataset_name, data_type, SKIP_OVERLAP) # при skip_overlap=True не считается (speechbrain) для 70

!rm -r voxconverse_dev_wav

## Getting Ami Embeddings

In [None]:
!git clone https://github.com/pyannote/AMI-diarization-setup ami # for AMI
!sh ami/pyannote/download_ami.sh

In [None]:
dataset_name, data_type = 'ami', 'test'

get_all_embeddings(dataset_name, data_type, SKIP_OVERLAP)

In [None]:
dataset_name, data_type = 'ami', 'dev'

get_all_embeddings(dataset_name, data_type, SKIP_OVERLAP)

100%|██████████| 18/18 [12:17<00:00, 40.99s/it]


brno ami dev errors: []
Embedding size is 512, encoder ASP.


100%|██████████| 18/18 [04:22<00:00, 14.58s/it]


clova ami dev errors: []


100%|██████████| 18/18 [04:56<00:00, 16.45s/it]

speechbrain ami dev errors: []





In [None]:
dataset_name, data_type = 'ami', 'train'

get_all_embeddings(dataset_name, data_type, SKIP_OVERLAP)

100%|██████████| 136/136 [1:50:14<00:00, 48.64s/it]


brno ami train errors: [133]
Embedding size is 512, encoder ASP.


100%|██████████| 136/136 [39:42<00:00, 17.52s/it]


clova ami train errors: [133]


100%|██████████| 136/136 [43:45<00:00, 19.30s/it]


speechbrain ami train errors: [5, 31, 37, 85, 127, 133]


In [None]:
!rm -r ami
!rm -r amicorpus

# Clustering

## CLustering Functions

In [None]:
def get_clst_metric_res(metric, uri2ann_ref, uri2ann_hyp, skip_overlap=True):
    if metric == 'der':
        der_metric = DiarizationErrorRate(collar=0, skip_overlap=skip_overlap)
        for uri, ann_hyp in uri2ann_hyp.items():
            ann_ref = uri2ann_ref[uri]
            der_metric(ann_ref, ann_hyp)
            
        return der_metric.report(display=False)

    elif metric == 'jer':
        jer_metric = JaccardErrorRate(collar=0, skip_overlap=skip_overlap)
        for uri, ann_hyp in uri2ann_hyp.items():
            ann_ref = uri2ann_ref[uri]
            jer_metric(ann_ref, ann_hyp)

        return jer_metric.report(display=False)

    else:
        print("unknown metric")
        return

### Mean-Shift Clustering

In [None]:
# !pip install optuna
import optuna
import traceback
from sklearn.cluster import MeanShift

def get_mean_shift_clustering(embedder_name, uri2data, uri2ann_ref, params, skip_overlap=True, metric='der', return_ann_hyp=False):
    uri2ann_hyp = {}
    for uri in uri2data:
        embeddings_raw, segments = uri2data[uri]
        embeddings = transform_embeddings(embeddings_raw, embedder_name)
        embeddings[np.isnan(embeddings)] = 1
        embeddings[np.isinf(embeddings)] = 100
        try:
            msh = MeanShift(
                    **params
                )  
            msh.fit(embeddings)
            labels = msh.labels_
        except:
            print('Ошибка:', traceback.format_exc())
            continue
        ann_hyp = Annotation(uri=uri)
        for segment, label in zip(segments, labels):
            ann_hyp[segment] = str(label)
        uri2ann_hyp[uri] = split_overlap_part(ann_hyp.support())
    
    if not return_ann_hyp:
        return get_clst_metric_res(metric, uri2ann_ref, uri2ann_hyp, skip_overlap=skip_overlap)
    return get_clst_metric_res(metric, uri2ann_ref, uri2ann_hyp, skip_overlap=skip_overlap), uri2ann_hyp


def get_mean_shift_params(embedder_name, uri2data, uri2ann_ref, skip_overlap=True, metric='der', return_ann_hyp=False):

  
    def mean_shift_obj(trial):
        trial.suggest_float('distance_threshold', 0.1, 2.5)


        report = get_mean_shift_clustering(embedder_name, uri2data, uri2ann_ref,  trial.params, skip_overlap=skip_overlap, metric=metric)
        if metric == 'der':
            return report.loc['TOTAL', 'diarization error rate'].values[0]
        elif metric == 'jer':
            return report.loc['TOTAL', 'jaccard error rate'].values[0]
        else:
            print("unknown metric")
            return


    study = optuna.create_study(direction='minimize')
    study.optimize(mean_shift_obj, n_trials=15, timeout=1200)
    return study.best_params, study.best_value

### Agglomarative Hierarhical CLustering

In [None]:
# !pip install optuna
import optuna
import traceback

def get_ahc_clustering(embedder_name, uri2data, uri2ann_ref, params, skip_overlap=True, metric='der', return_hyp=False):
    uri2ann_hyp = {}
    for uri in uri2data:
        embeddings_raw, segments = uri2data[uri]
        embeddings = transform_embeddings(embeddings_raw, embedder_name)
        embeddings[np.isnan(embeddings)] = 1
        embeddings[np.isinf(embeddings)] = 100
        try:
            ahc = AgglomerativeClustering(
                      n_clusters=None,
                      affinity='cosine',
                      **params
                )  
            ahc.fit(embeddings)
            labels = ahc.labels_
        except:
            print('Ошибка:', traceback.format_exc())
            continue
        ann_hyp = Annotation(uri=uri)
        for segment, label in zip(segments, labels):
            ann_hyp[segment] = str(label)
        uri2ann_hyp[uri] = split_overlap_part(ann_hyp.support())

    if return_hyp:
        return get_clst_metric_res(metric, uri2ann_ref, uri2ann_hyp, skip_overlap=skip_overlap), uri2ann_hyp
    return get_clst_metric_res(metric, uri2ann_ref, uri2ann_hyp, skip_overlap=skip_overlap)


def get_ahc_params(embedder_name, uri2data, uri2ann_ref, skip_overlap=True, metric='der'):

  
    def ahc_obj(trial):
        trial.suggest_float('distance_threshold', 0.5, 1)
        trial.suggest_categorical('linkage', ['complete', 'average', 'single'])

        report = get_ahc_clustering(embedder_name, uri2data, uri2ann_ref,  trial.params, skip_overlap=skip_overlap, metric=metric)
        if metric == 'der':
            return report.loc['TOTAL', 'diarization error rate'].values[0]
        elif metric == 'jer':
            return report.loc['TOTAL', 'jaccard error rate'].values[0]
        else:
            print("unknown metric")
            return


    study = optuna.create_study(direction='minimize')
    study.optimize(ahc_obj, n_trials=15, timeout=1200)
    return study.best_params, study.best_value

### Spectrul Clustering

In [None]:
!pip install spectralcluster

In [None]:
from spectralcluster.autotune import AutoTune
from spectralcluster.constraint import ConstraintName, ConstraintOptions
from spectralcluster.laplacian import LaplacianType
from spectralcluster.refinement import RefinementName, RefinementOptions, ThresholdType, SymmetrizeType
from spectralcluster.spectral_clusterer import SpectralClusterer

def get_spectrul_clustering(embedder_name, uri2data, uri2ann_ref, skip_overlap=True, metric='der', return_hyp=False):
    uri2ann_hyp = {}
    for uri in uri2data:
        embeddings_raw, segments = uri2data[uri]
        embeddings = transform_embeddings(embeddings_raw, embedder_name)
        embeddings[np.isnan(embeddings)] = 1
        embeddings[np.isinf(embeddings)] = 100
        try:
            TURNTODIARIZE_REFINEMENT_SEQUENCE = [
                RefinementName.RowWiseThreshold, RefinementName.Symmetrize
            ]

            turntodiarize_refinement_options = RefinementOptions(
                thresholding_soft_multiplier=0.01,
                thresholding_type=ThresholdType.Percentile,
                thresholding_with_binarization=True,
                thresholding_preserve_diagonal=True,
                symmetrize_type=SymmetrizeType.Average,
                refinement_sequence=TURNTODIARIZE_REFINEMENT_SEQUENCE)

            turntodiarize_auto_tune = AutoTune(
                p_percentile_min=0.40,
                p_percentile_max=0.95,
                init_search_step=0.05,
                search_level=1)

            turntodiarize_clusterer = SpectralClusterer(
                min_clusters=2,
                max_clusters=30,
                refinement_options=turntodiarize_refinement_options,
                autotune=turntodiarize_auto_tune,
                laplacian_type=LaplacianType.GraphCut,
                row_wise_renorm=True,
                custom_dist="cosine")

            labels = turntodiarize_clusterer.predict(embeddings)
        except:
            print('Ошибка:', traceback.format_exc())
            continue
        ann_hyp = Annotation(uri=uri)
        for segment, label in zip(segments, labels):
            ann_hyp[segment] = str(label)
        uri2ann_hyp[uri] = split_overlap_part(ann_hyp.support())

    if return_hyp:
        return get_clst_metric_res(metric, uri2ann_ref, uri2ann_hyp, skip_overlap=skip_overlap), uri2ann_hyp
    return get_clst_metric_res(metric, uri2ann_ref, uri2ann_hyp, skip_overlap=skip_overlap)

In [None]:
# import optuna
# import traceback
# from sklearn.cluster import SpectralClustering
# from sklearn.metrics.pairwise import cosine_similarity

# def get_spectrul_clustering(embedder_name, uri2data, uri2ann_ref, ahc_labels, skip_overlap=True, metric='der'):
#     uri2ann_hyp = {}
#     for uri in uri2data:
#         embeddings_raw, segments = uri2data[uri]
#         embeddings = transform_embeddings(embeddings_raw, embedder_name)
#         embeddings[np.isnan(embeddings)] = 1
#         embeddings[np.isinf(embeddings)] = 100
#         try:
#             X = 1 + cosine_similarity(embeddings)
#             # X = cosine_similarity(embeddings)
#             # print("nans:", X[np.isnan(X)].sum())
#             # print("infs:", X[np.isinf(X)].sum())
#             # X = embeddings
#             # print(len(np.unique(ahc_labels[uri])))
#             spect = SpectralClustering(
#                         n_clusters=len(np.unique(ahc_labels[uri])),
#                         affinity='precomputed',
#                         # affinity='cosine',
#                         random_state=17
#                     ).fit(X)
                    
#             labels = spect.labels_
#         except:
#             print('Ошибка:', traceback.format_exc())
#             continue
#         ann_hyp = Annotation(uri=uri)
#         for segment, label in zip(segments, labels):
#             ann_hyp[segment] = str(label)
#         uri2ann_hyp[uri] = split_overlap_part(ann_hyp.support())

#     if metric == 'der':
#         der_metric = DiarizationErrorRate(collar=0, skip_overlap=skip_overlap)
#         for uri, ann_hyp in uri2ann_hyp.items():
#             ann_ref = uri2ann_ref[uri]
#             der_metric(ann_ref, ann_hyp)
            
#         return der_metric.report(display=False)

#     elif metric == 'jer':
#         jer_metric = JaccardErrorRate(collar=0, skip_overlap=skip_overlap)
#         for uri, ann_hyp in uri2ann_hyp.items():
#             ann_ref = uri2ann_ref[uri]
#             jer_metric(ann_ref, ann_hyp)

#         return jer_metric.report(display=False)

#     else:
#         print("unknown metric")
#         return

### Vbx Clustering

In [None]:
!pip install optuna
import optuna
import traceback
def get_vbx_clustering(embedder_name, uri2data, uri2ann_ref, params, skip_overlap=True, metric='der', return_hyp=False):
    uri2ann_hyp = {}
    for uri in tqdm(uri2data):
        embeddings_raw, segments = uri2data[uri]
        embeddings = transform_embeddings(embeddings_raw, embedder_name)
        embeddings[np.isnan(embeddings)] = 1
        embeddings[np.isinf(embeddings)] = 100

        plda_mu, plda_tr, plda_psi = prepare_plda(embedder_name)
        lda_dim = 128
        mean = np.zeros(lda_dim)
        invW = np.eye(lda_dim)
        V = np.diag(np.sqrt(plda_psi[:lda_dim]))
        
        features = (embeddings - plda_mu).dot(plda_tr.T)[:, :lda_dim]
        try:
            np.random.seed(0)
            q, sp, L = VB_diarization(
                              features, mean, invW, V, 
                              pi=None, 
                              gamma=None, 
                              maxIters=50, 
                              epsilon=1e-6,
                              maxSpeakers=30,
                              Fa=0.3,
                              **params
                          )
            labels = np.argmax(q, axis=1)
            assert labels.shape == (features.shape[0],)
        except:
            print('Ошибка:', traceback.format_exc())
            continue
        ann_hyp = Annotation(uri=uri)
        for segment, label in zip(segments, labels):
            ann_hyp[segment] = str(label)
        uri2ann_hyp[uri] = split_overlap_part(ann_hyp.support())

    if return_hyp:
        return get_clst_metric_res(metric, uri2ann_ref, uri2ann_hyp, skip_overlap=skip_overlap), uri2ann_hyp
    return get_clst_metric_res(metric, uri2ann_ref, uri2ann_hyp, skip_overlap=skip_overlap)


def get_vbx_params(embedder_name, uri2data, uri2ann_ref, skip_overlap=True, metric='der'):

    def vbx_obj(trial):
        trial.suggest_float('loopProb', 0.9, 0.99)
        trial.suggest_float('Fb', 4, 10)

        report = get_vbx_clustering(embedder_name, uri2data, uri2ann_ref,  trial.params, skip_overlap=skip_overlap, metric=metric)
        if metric == 'der':
            return report.loc['TOTAL', 'diarization error rate'].values[0]
        elif metric == 'jer':
            return report.loc['TOTAL', 'jaccard error rate'].values[0]
        else:
            print("unknown metric")
            return


    study = optuna.create_study(direction='minimize')
    study.optimize(vbx_obj, n_trials=10)
    return study.best_params, study.best_value

### Vbx with AHC Initialization

In [None]:
import optuna
import traceback

def get_ahc_vbx_clustering(embedder_name, uri2data, uri2ann_ref, ahc_params, vbx_params, skip_overlap=True, metric='der', return_hyp=False):
    uri2ann_hyp = {}
    it=1
    for uri in tqdm(uri2data):
        embeddings_raw, segments = uri2data[uri]
        embeddings = transform_embeddings(embeddings_raw, embedder_name)

        plda_mu, plda_tr, plda_psi = prepare_plda(embedder_name)
        
        lda_dim = 128
        mean = np.zeros(lda_dim)
        invW = np.eye(lda_dim)
        V = np.diag(np.sqrt(plda_psi[:lda_dim]))

        embeddings[np.isnan(embeddings)] = 1
        embeddings[np.isinf(embeddings)] = 100
        
        features = (embeddings - plda_mu).dot(plda_tr.T)[:, :lda_dim]

        try:
            ahc = AgglomerativeClustering(
                          n_clusters=None,
                          distance_threshold=ahc_params['distance_threshold'],
                          affinity='cosine',
                          linkage=ahc_params['linkage']
                  )
            
            ahc.fit(embeddings)
            labels = ahc.labels_
            
            maxSpeakers = len(np.unique(labels))
            I = np.eye(maxSpeakers)
            q = I[labels]
            
            np.random.seed(0)
            q, sp, L = VB_diarization(
                            features, mean, invW, V, 
                            pi=np.sum(q, axis=0),
                            gamma=q, 
                            maxSpeakers=maxSpeakers, 
                            maxIters=50,
                            epsilon=1e-6,
                            loopProb=vbx_params['loopProb'], 
                            Fa=0.3, 
                            Fb=vbx_params['Fb']
                      )
            labels = np.argmax(q, axis=1)
            assert labels.shape == (features.shape[0],)
        except:
            print('Ошибка:', traceback.format_exc())
            continue

        ann_hyp = Annotation(uri=uri)
        for segment, label in zip(segments, labels):
            ann_hyp[segment] = str(label)
        uri2ann_hyp[uri] = split_overlap_part(ann_hyp.support())

    if return_hyp:
        return get_clst_metric_res(metric, uri2ann_ref, uri2ann_hyp, skip_overlap=skip_overlap), uri2ann_hyp
    return get_clst_metric_res(metric, uri2ann_ref, uri2ann_hyp, skip_overlap=skip_overlap)


def get_ahc_vbx_params(embedder_name, uri2data, uri2ann_ref, ahc_params, skip_overlap=True, metric='der'):
    best_score, best_params = 100, None
    for lp in [0.9, 0.92, 0.95, 0.97, 0.99]:
        for fb in [1, 4, 6, 7.5, 8, 10]:
            vbx_params = {
                'loopProb': lp,
                'Fb': fb
            }
            report = get_ahc_vbx_clustering(embedder_name, uri2data, uri2ann_ref, ahc_params, vbx_params, skip_overlap=skip_overlap, metric=metric)
            if metric == 'der':
                cur_score = report.loc['TOTAL', 'diarization error rate'].values[0]
            elif metric == 'jer':
                cur_score = report.loc['TOTAL', 'jaccard error rate'].values[0]
            else:
                print("unknown metric")
                return 
            if cur_score < best_score:
                best_score = cur_score
                best_params = vbx_params
    return best_params, best_score

## Hyperparameters

### Agglomerative Clustering

In [None]:
for metric in ['der', 'jer']:
    for skip_overlap in [True, False]:
        for dataset_name in ['ami', 'voxconverse', 'aishell']:
            uri2ann_ref = np.load(
                    f'{ROOT}/annotations/{dataset_name}/{dataset_name}_dev_uri2ann_ref.npy', 
                    allow_pickle=True
                ).item()
            for embedder_name in ['brno', 'clova', 'speechbrain']:
        
                uri2data = np.load(
                        f'{ROOT}/embeddings/skip_overlap={str(skip_overlap)}/{dataset_name}/{dataset_name}_dev_uri2data_{embedder_name}.npy',
                        allow_pickle=True
                    ).item()

                best_params, best_value = get_ahc_params(embedder_name, uri2data, uri2ann_ref, skip_overlap=skip_overlap, metric=metric)
                best_params['value'] = best_value
                print(f'{metric}, {skip_overlap}, {dataset_name}, {embedder_name}, best_value={best_value}')
                np.save(f'{ROOT}/best_params/{metric}/skip_overlap={str(skip_overlap)}/ahc/{dataset_name}_{embedder_name}.npy', best_params)

der, True, aishell, brno, best_value=37.41933110767699
der, True, aishell, clova, best_value=14.409575968141485
der, True, aishell, speechbrain, best_value=31.939281741294078
der, False, ami, brno, best_value=18.773439489103694
der, False, ami, clova, best_value=21.945025857407405
der, False, ami, speechbrain, best_value=20.378038924662654
der, False, voxconverse, brno, best_value=7.287597980697067
der, False, voxconverse, clova, best_value=7.39385624766375
der, False, voxconverse, speechbrain, best_value=6.6731247861032115
der, False, aishell, brno, best_value=24.83460864544305
der, False, aishell, clova, best_value=26.55927343340595
der, False, aishell, speechbrain, best_value=26.74626813157529
jer, True, ami, brno, best_value=10.240671502608457
jer, True, ami, clova, best_value=11.68315288626
jer, True, ami, speechbrain, best_value=9.65622659287975
jer, True, voxconverse, brno, best_value=10.759226823803877
jer, True, voxconverse, clova, best_value=15.91106572313371
jer, True, voxco

### Spectrul Clustering

### VBx Clustering

In [None]:
for metric in ['der', 'jer']:
    for skip_overlap in [True, False]:
        for dataset_name in ['ami', 'voxconverse', 'aishell']:
            uri2ann_ref = np.load(
                f'{ROOT}/annotations/{dataset_name}/{dataset_name}_dev_uri2ann_ref.npy', 
                allow_pickle=True
                ).item()
            for embedder_name in ['brno', 'clova', 'speechbrain']:
                uri2data = np.load(
                    f'{ROOT}/embeddings/skip_overlap={str(skip_overlap)}/{dataset_name}/{dataset_name}_dev_uri2data_{embedder_name}.npy',
                    allow_pickle=True
                    ).item()

                best_params, best_value = get_vbx_params(embedder_name, uri2data, uri2ann_ref, skip_overlap=skip_overlap, metric=metric)
                best_params['value'] = best_value
                print(f'{skip_overlap}, {dataset_name}, {embedder_name}, best_value={best_value}')
                np.save(f'{ROOT}/best_params/{metric}/skip_overlap={str(skip_overlap)}/vbx/{dataset_name}_{embedder_name}.npy', best_params)

100%|██████████| 18/18 [01:52<00:00,  6.26s/it]
100%|██████████| 18/18 [01:45<00:00,  5.85s/it]
100%|██████████| 18/18 [01:41<00:00,  5.61s/it]
100%|██████████| 18/18 [01:24<00:00,  4.70s/it]
100%|██████████| 18/18 [01:32<00:00,  5.13s/it]
100%|██████████| 18/18 [01:27<00:00,  4.85s/it]
100%|██████████| 18/18 [01:45<00:00,  5.88s/it]
100%|██████████| 18/18 [01:27<00:00,  4.88s/it]
100%|██████████| 18/18 [01:39<00:00,  5.50s/it]
100%|██████████| 18/18 [01:41<00:00,  5.62s/it]


True, ami, brno, best_value=6.534745201006527


100%|██████████| 18/18 [01:19<00:00,  4.43s/it]
100%|██████████| 18/18 [01:44<00:00,  5.78s/it]
100%|██████████| 18/18 [01:22<00:00,  4.61s/it]
100%|██████████| 18/18 [01:22<00:00,  4.58s/it]
100%|██████████| 18/18 [01:49<00:00,  6.06s/it]
100%|██████████| 18/18 [01:17<00:00,  4.31s/it]
100%|██████████| 18/18 [01:23<00:00,  4.66s/it]
100%|██████████| 18/18 [01:25<00:00,  4.76s/it]
100%|██████████| 18/18 [01:18<00:00,  4.33s/it]
100%|██████████| 18/18 [01:13<00:00,  4.08s/it]


True, ami, clova, best_value=7.16059010548501


100%|██████████| 18/18 [01:45<00:00,  5.85s/it]
100%|██████████| 18/18 [01:33<00:00,  5.19s/it]
100%|██████████| 18/18 [01:43<00:00,  5.76s/it]
100%|██████████| 18/18 [01:39<00:00,  5.51s/it]
100%|██████████| 18/18 [01:40<00:00,  5.58s/it]
100%|██████████| 18/18 [02:00<00:00,  6.71s/it]
100%|██████████| 18/18 [01:35<00:00,  5.32s/it]
100%|██████████| 18/18 [01:37<00:00,  5.42s/it]
100%|██████████| 18/18 [02:17<00:00,  7.64s/it]
100%|██████████| 18/18 [01:34<00:00,  5.23s/it]


True, ami, speechbrain, best_value=6.84425658958871


100%|██████████| 216/216 [04:29<00:00,  1.25s/it]
100%|██████████| 216/216 [04:57<00:00,  1.38s/it]
100%|██████████| 216/216 [04:22<00:00,  1.21s/it]
100%|██████████| 216/216 [04:48<00:00,  1.33s/it]
100%|██████████| 216/216 [05:32<00:00,  1.54s/it]
100%|██████████| 216/216 [04:22<00:00,  1.22s/it]
100%|██████████| 216/216 [04:54<00:00,  1.36s/it]
100%|██████████| 216/216 [04:51<00:00,  1.35s/it]
100%|██████████| 216/216 [04:26<00:00,  1.23s/it]
100%|██████████| 216/216 [04:33<00:00,  1.27s/it]


True, voxconverse, brno, best_value=10.814278642162868


100%|██████████| 216/216 [03:26<00:00,  1.05it/s]
100%|██████████| 216/216 [03:59<00:00,  1.11s/it]
100%|██████████| 216/216 [04:57<00:00,  1.38s/it]
100%|██████████| 216/216 [04:40<00:00,  1.30s/it]
100%|██████████| 216/216 [03:54<00:00,  1.09s/it]
100%|██████████| 216/216 [04:12<00:00,  1.17s/it]
100%|██████████| 216/216 [04:23<00:00,  1.22s/it]
100%|██████████| 216/216 [04:12<00:00,  1.17s/it]
100%|██████████| 216/216 [03:50<00:00,  1.07s/it]
100%|██████████| 216/216 [04:12<00:00,  1.17s/it]


True, voxconverse, clova, best_value=14.258528906015833


100%|██████████| 215/215 [05:12<00:00,  1.45s/it]
100%|██████████| 215/215 [03:30<00:00,  1.02it/s]
100%|██████████| 215/215 [04:36<00:00,  1.29s/it]
100%|██████████| 215/215 [05:22<00:00,  1.50s/it]
100%|██████████| 215/215 [05:37<00:00,  1.57s/it]
100%|██████████| 215/215 [04:57<00:00,  1.38s/it]
100%|██████████| 215/215 [04:12<00:00,  1.17s/it]
100%|██████████| 215/215 [04:53<00:00,  1.36s/it]
100%|██████████| 215/215 [04:56<00:00,  1.38s/it]
100%|██████████| 215/215 [04:20<00:00,  1.21s/it]


True, voxconverse, speechbrain, best_value=9.244423711960316


100%|██████████| 21/21 [01:28<00:00,  4.23s/it]
100%|██████████| 21/21 [01:41<00:00,  4.81s/it]
100%|██████████| 21/21 [01:38<00:00,  4.71s/it]
100%|██████████| 21/21 [01:37<00:00,  4.63s/it]
100%|██████████| 21/21 [01:38<00:00,  4.67s/it]
100%|██████████| 21/21 [01:35<00:00,  4.53s/it]
100%|██████████| 21/21 [01:40<00:00,  4.78s/it]
100%|██████████| 21/21 [01:35<00:00,  4.55s/it]
100%|██████████| 21/21 [01:40<00:00,  4.77s/it]
100%|██████████| 21/21 [01:44<00:00,  4.96s/it]


True, aishell, brno, best_value=14.565468776088467


100%|██████████| 21/21 [01:29<00:00,  4.26s/it]
100%|██████████| 21/21 [01:38<00:00,  4.71s/it]
100%|██████████| 21/21 [01:35<00:00,  4.55s/it]
100%|██████████| 21/21 [01:44<00:00,  4.97s/it]
100%|██████████| 21/21 [01:34<00:00,  4.49s/it]
100%|██████████| 21/21 [01:36<00:00,  4.59s/it]
100%|██████████| 21/21 [01:28<00:00,  4.19s/it]
100%|██████████| 21/21 [01:32<00:00,  4.38s/it]
100%|██████████| 21/21 [01:27<00:00,  4.17s/it]
100%|██████████| 21/21 [01:41<00:00,  4.85s/it]


True, aishell, clova, best_value=20.21210057225468


100%|██████████| 21/21 [02:08<00:00,  6.12s/it]
100%|██████████| 21/21 [02:06<00:00,  6.01s/it]
100%|██████████| 21/21 [01:57<00:00,  5.61s/it]
100%|██████████| 21/21 [01:39<00:00,  4.71s/it]
100%|██████████| 21/21 [02:25<00:00,  6.95s/it]
100%|██████████| 21/21 [01:45<00:00,  5.04s/it]
100%|██████████| 21/21 [02:34<00:00,  7.35s/it]
100%|██████████| 21/21 [02:41<00:00,  7.68s/it]
100%|██████████| 21/21 [01:40<00:00,  4.79s/it]
100%|██████████| 21/21 [01:35<00:00,  4.56s/it]


True, aishell, speechbrain, best_value=16.04158464983188


100%|██████████| 18/18 [01:37<00:00,  5.39s/it]
100%|██████████| 18/18 [01:22<00:00,  4.61s/it]
100%|██████████| 18/18 [01:35<00:00,  5.31s/it]
100%|██████████| 18/18 [01:51<00:00,  6.19s/it]
100%|██████████| 18/18 [01:36<00:00,  5.38s/it]
100%|██████████| 18/18 [01:46<00:00,  5.91s/it]
100%|██████████| 18/18 [01:34<00:00,  5.23s/it]
100%|██████████| 18/18 [01:34<00:00,  5.23s/it]
100%|██████████| 18/18 [01:36<00:00,  5.34s/it]
100%|██████████| 18/18 [01:32<00:00,  5.15s/it]


False, ami, brno, best_value=19.273809355943712


100%|██████████| 18/18 [01:18<00:00,  4.34s/it]
100%|██████████| 18/18 [01:17<00:00,  4.30s/it]
100%|██████████| 18/18 [01:38<00:00,  5.45s/it]
100%|██████████| 18/18 [01:16<00:00,  4.23s/it]
100%|██████████| 18/18 [01:18<00:00,  4.38s/it]
100%|██████████| 18/18 [01:17<00:00,  4.30s/it]
100%|██████████| 18/18 [01:14<00:00,  4.12s/it]
100%|██████████| 18/18 [01:14<00:00,  4.15s/it]
100%|██████████| 18/18 [01:14<00:00,  4.16s/it]
100%|██████████| 18/18 [01:18<00:00,  4.35s/it]


False, ami, clova, best_value=19.573061652976012


100%|██████████| 18/18 [01:31<00:00,  5.08s/it]
100%|██████████| 18/18 [01:34<00:00,  5.23s/it]
100%|██████████| 18/18 [01:43<00:00,  5.77s/it]
100%|██████████| 18/18 [02:04<00:00,  6.92s/it]
100%|██████████| 18/18 [01:36<00:00,  5.39s/it]
100%|██████████| 18/18 [02:03<00:00,  6.87s/it]
100%|██████████| 18/18 [01:50<00:00,  6.14s/it]
100%|██████████| 18/18 [02:09<00:00,  7.17s/it]
100%|██████████| 18/18 [01:51<00:00,  6.17s/it]
100%|██████████| 18/18 [01:29<00:00,  4.95s/it]


False, ami, speechbrain, best_value=19.455946395687647


100%|██████████| 216/216 [04:28<00:00,  1.24s/it]
100%|██████████| 216/216 [03:44<00:00,  1.04s/it]
100%|██████████| 216/216 [04:36<00:00,  1.28s/it]
100%|██████████| 216/216 [04:32<00:00,  1.26s/it]
100%|██████████| 216/216 [04:54<00:00,  1.36s/it]
100%|██████████| 216/216 [03:30<00:00,  1.03it/s]
100%|██████████| 216/216 [04:13<00:00,  1.18s/it]
100%|██████████| 216/216 [03:45<00:00,  1.04s/it]
100%|██████████| 216/216 [04:53<00:00,  1.36s/it]
100%|██████████| 216/216 [03:46<00:00,  1.05s/it]


False, voxconverse, brno, best_value=15.913122698043855


100%|██████████| 216/216 [03:25<00:00,  1.05it/s]
100%|██████████| 216/216 [04:09<00:00,  1.15s/it]
100%|██████████| 216/216 [04:21<00:00,  1.21s/it]
100%|██████████| 216/216 [03:22<00:00,  1.06it/s]
100%|██████████| 216/216 [04:35<00:00,  1.28s/it]
100%|██████████| 216/216 [04:45<00:00,  1.32s/it]
100%|██████████| 216/216 [03:24<00:00,  1.05it/s]
100%|██████████| 216/216 [04:41<00:00,  1.30s/it]
100%|██████████| 216/216 [04:52<00:00,  1.35s/it]
100%|██████████| 216/216 [03:58<00:00,  1.11s/it]


False, voxconverse, clova, best_value=21.92072986253154


100%|██████████| 215/215 [04:27<00:00,  1.24s/it]
100%|██████████| 215/215 [04:21<00:00,  1.22s/it]
100%|██████████| 215/215 [05:03<00:00,  1.41s/it]
100%|██████████| 215/215 [03:53<00:00,  1.09s/it]
100%|██████████| 215/215 [04:15<00:00,  1.19s/it]
100%|██████████| 215/215 [04:13<00:00,  1.18s/it]
100%|██████████| 215/215 [05:13<00:00,  1.46s/it]
100%|██████████| 215/215 [04:30<00:00,  1.26s/it]
100%|██████████| 215/215 [04:57<00:00,  1.38s/it]
100%|██████████| 215/215 [04:31<00:00,  1.26s/it]


False, voxconverse, speechbrain, best_value=14.253370899577375


100%|██████████| 21/21 [01:47<00:00,  5.13s/it]
100%|██████████| 21/21 [01:36<00:00,  4.59s/it]
100%|██████████| 21/21 [01:39<00:00,  4.72s/it]
100%|██████████| 21/21 [01:38<00:00,  4.70s/it]
100%|██████████| 21/21 [01:35<00:00,  4.55s/it]
100%|██████████| 21/21 [01:36<00:00,  4.58s/it]
100%|██████████| 21/21 [01:36<00:00,  4.58s/it]
100%|██████████| 21/21 [01:36<00:00,  4.61s/it]
100%|██████████| 21/21 [01:36<00:00,  4.60s/it]
100%|██████████| 21/21 [01:28<00:00,  4.23s/it]


False, aishell, brno, best_value=22.693405535316806


100%|██████████| 21/21 [01:37<00:00,  4.63s/it]
100%|██████████| 21/21 [01:25<00:00,  4.08s/it]
100%|██████████| 21/21 [01:29<00:00,  4.26s/it]
100%|██████████| 21/21 [01:31<00:00,  4.37s/it]
100%|██████████| 21/21 [01:33<00:00,  4.45s/it]
100%|██████████| 21/21 [01:42<00:00,  4.86s/it]
100%|██████████| 21/21 [01:13<00:00,  3.52s/it]
100%|██████████| 21/21 [01:37<00:00,  4.62s/it]
100%|██████████| 21/21 [01:26<00:00,  4.14s/it]
100%|██████████| 21/21 [01:35<00:00,  4.54s/it]


False, aishell, clova, best_value=27.987202158050202


100%|██████████| 21/21 [01:32<00:00,  4.42s/it]
100%|██████████| 21/21 [01:52<00:00,  5.36s/it]
100%|██████████| 21/21 [01:55<00:00,  5.50s/it]
100%|██████████| 21/21 [01:52<00:00,  5.38s/it]
100%|██████████| 21/21 [01:35<00:00,  4.57s/it]
100%|██████████| 21/21 [01:50<00:00,  5.25s/it]
100%|██████████| 21/21 [01:50<00:00,  5.25s/it]
100%|██████████| 21/21 [02:05<00:00,  5.97s/it]
100%|██████████| 21/21 [01:34<00:00,  4.49s/it]
100%|██████████| 21/21 [02:20<00:00,  6.71s/it]


False, aishell, speechbrain, best_value=24.29056633218791


100%|██████████| 18/18 [01:39<00:00,  5.54s/it]
100%|██████████| 18/18 [01:38<00:00,  5.46s/it]
100%|██████████| 18/18 [01:37<00:00,  5.41s/it]
100%|██████████| 18/18 [01:33<00:00,  5.19s/it]
100%|██████████| 18/18 [01:31<00:00,  5.10s/it]
100%|██████████| 18/18 [01:42<00:00,  5.68s/it]
100%|██████████| 18/18 [01:19<00:00,  4.43s/it]
100%|██████████| 18/18 [01:23<00:00,  4.66s/it]
100%|██████████| 18/18 [01:17<00:00,  4.28s/it]
100%|██████████| 18/18 [01:47<00:00,  5.95s/it]


True, ami, brno, best_value=13.818480513488199


100%|██████████| 18/18 [01:16<00:00,  4.22s/it]
100%|██████████| 18/18 [01:12<00:00,  4.05s/it]
100%|██████████| 18/18 [01:15<00:00,  4.18s/it]
100%|██████████| 18/18 [01:15<00:00,  4.21s/it]
100%|██████████| 18/18 [01:14<00:00,  4.13s/it]
100%|██████████| 18/18 [01:21<00:00,  4.53s/it]
100%|██████████| 18/18 [01:09<00:00,  3.88s/it]
100%|██████████| 18/18 [01:27<00:00,  4.86s/it]
100%|██████████| 18/18 [01:19<00:00,  4.41s/it]
100%|██████████| 18/18 [01:48<00:00,  6.02s/it]


True, ami, clova, best_value=15.961791339390896


100%|██████████| 18/18 [01:40<00:00,  5.57s/it]
100%|██████████| 18/18 [01:35<00:00,  5.29s/it]
100%|██████████| 18/18 [01:35<00:00,  5.33s/it]
100%|██████████| 18/18 [01:37<00:00,  5.42s/it]
100%|██████████| 18/18 [01:40<00:00,  5.59s/it]
100%|██████████| 18/18 [02:01<00:00,  6.77s/it]
100%|██████████| 18/18 [02:14<00:00,  7.47s/it]
100%|██████████| 18/18 [01:44<00:00,  5.79s/it]
100%|██████████| 18/18 [01:38<00:00,  5.46s/it]
100%|██████████| 18/18 [01:38<00:00,  5.50s/it]


True, ami, speechbrain, best_value=14.22638398369516


100%|██████████| 216/216 [03:43<00:00,  1.04s/it]
100%|██████████| 216/216 [05:00<00:00,  1.39s/it]
100%|██████████| 216/216 [04:50<00:00,  1.34s/it]
100%|██████████| 216/216 [03:55<00:00,  1.09s/it]
100%|██████████| 216/216 [04:34<00:00,  1.27s/it]
100%|██████████| 216/216 [04:36<00:00,  1.28s/it]
100%|██████████| 216/216 [04:13<00:00,  1.17s/it]
100%|██████████| 216/216 [05:00<00:00,  1.39s/it]
100%|██████████| 216/216 [03:47<00:00,  1.05s/it]
100%|██████████| 216/216 [04:57<00:00,  1.38s/it]


True, voxconverse, brno, best_value=52.752594856154744


100%|██████████| 216/216 [04:48<00:00,  1.34s/it]
100%|██████████| 216/216 [04:05<00:00,  1.14s/it]
100%|██████████| 216/216 [04:10<00:00,  1.16s/it]
100%|██████████| 216/216 [04:26<00:00,  1.23s/it]
100%|██████████| 216/216 [04:17<00:00,  1.19s/it]
100%|██████████| 216/216 [05:02<00:00,  1.40s/it]
100%|██████████| 216/216 [05:12<00:00,  1.45s/it]
100%|██████████| 216/216 [03:53<00:00,  1.08s/it]
100%|██████████| 216/216 [04:49<00:00,  1.34s/it]
100%|██████████| 216/216 [04:51<00:00,  1.35s/it]


True, voxconverse, clova, best_value=59.65005218092349


100%|██████████| 215/215 [04:47<00:00,  1.34s/it]
100%|██████████| 215/215 [04:40<00:00,  1.31s/it]
100%|██████████| 215/215 [05:15<00:00,  1.47s/it]
100%|██████████| 215/215 [04:18<00:00,  1.20s/it]
100%|██████████| 215/215 [04:53<00:00,  1.36s/it]
100%|██████████| 215/215 [04:23<00:00,  1.23s/it]
100%|██████████| 215/215 [05:15<00:00,  1.47s/it]
100%|██████████| 215/215 [04:59<00:00,  1.39s/it]
100%|██████████| 215/215 [04:58<00:00,  1.39s/it]
100%|██████████| 215/215 [04:40<00:00,  1.31s/it]


True, voxconverse, speechbrain, best_value=49.17715319879631


100%|██████████| 21/21 [01:46<00:00,  5.07s/it]
100%|██████████| 21/21 [01:36<00:00,  4.61s/it]
100%|██████████| 21/21 [01:39<00:00,  4.76s/it]
100%|██████████| 21/21 [01:37<00:00,  4.64s/it]
100%|██████████| 21/21 [01:37<00:00,  4.64s/it]
100%|██████████| 21/21 [01:26<00:00,  4.11s/it]
100%|██████████| 21/21 [01:36<00:00,  4.60s/it]
100%|██████████| 21/21 [01:41<00:00,  4.82s/it]
100%|██████████| 21/21 [01:37<00:00,  4.65s/it]
100%|██████████| 21/21 [01:20<00:00,  3.83s/it]


True, aishell, brno, best_value=32.748309693237914


100%|██████████| 21/21 [01:38<00:00,  4.68s/it]
100%|██████████| 21/21 [01:36<00:00,  4.59s/it]
100%|██████████| 21/21 [01:31<00:00,  4.38s/it]
100%|██████████| 21/21 [01:35<00:00,  4.56s/it]
100%|██████████| 21/21 [01:42<00:00,  4.87s/it]
100%|██████████| 21/21 [01:28<00:00,  4.23s/it]
100%|██████████| 21/21 [01:32<00:00,  4.41s/it]
100%|██████████| 21/21 [01:37<00:00,  4.65s/it]
100%|██████████| 21/21 [01:34<00:00,  4.50s/it]
100%|██████████| 21/21 [01:28<00:00,  4.22s/it]


True, aishell, clova, best_value=46.4922354369577


100%|██████████| 21/21 [01:33<00:00,  4.44s/it]
100%|██████████| 21/21 [02:02<00:00,  5.82s/it]
100%|██████████| 21/21 [01:36<00:00,  4.62s/it]
100%|██████████| 21/21 [01:41<00:00,  4.84s/it]
100%|██████████| 21/21 [02:03<00:00,  5.89s/it]
100%|██████████| 21/21 [01:45<00:00,  5.00s/it]
100%|██████████| 21/21 [01:39<00:00,  4.76s/it]
100%|██████████| 21/21 [02:03<00:00,  5.89s/it]
100%|██████████| 21/21 [01:47<00:00,  5.10s/it]
100%|██████████| 21/21 [01:51<00:00,  5.29s/it]


True, aishell, speechbrain, best_value=39.05840589014548


100%|██████████| 18/18 [01:31<00:00,  5.09s/it]
100%|██████████| 18/18 [01:47<00:00,  5.97s/it]
100%|██████████| 18/18 [01:42<00:00,  5.68s/it]
100%|██████████| 18/18 [01:44<00:00,  5.82s/it]
100%|██████████| 18/18 [01:30<00:00,  5.02s/it]
100%|██████████| 18/18 [01:43<00:00,  5.74s/it]
100%|██████████| 18/18 [01:25<00:00,  4.76s/it]
100%|██████████| 18/18 [01:40<00:00,  5.59s/it]
100%|██████████| 18/18 [01:35<00:00,  5.31s/it]
100%|██████████| 18/18 [01:36<00:00,  5.34s/it]


False, ami, brno, best_value=24.61082290196697


100%|██████████| 18/18 [01:22<00:00,  4.61s/it]
100%|██████████| 18/18 [01:26<00:00,  4.82s/it]
100%|██████████| 18/18 [01:15<00:00,  4.17s/it]
100%|██████████| 18/18 [01:08<00:00,  3.80s/it]
100%|██████████| 18/18 [01:10<00:00,  3.93s/it]
100%|██████████| 18/18 [01:18<00:00,  4.36s/it]
100%|██████████| 18/18 [01:19<00:00,  4.42s/it]
100%|██████████| 18/18 [01:19<00:00,  4.43s/it]
100%|██████████| 18/18 [01:15<00:00,  4.18s/it]
100%|██████████| 18/18 [01:13<00:00,  4.08s/it]


False, ami, clova, best_value=33.11171001261658


100%|██████████| 18/18 [01:46<00:00,  5.92s/it]
100%|██████████| 18/18 [01:44<00:00,  5.83s/it]
100%|██████████| 18/18 [01:31<00:00,  5.11s/it]
100%|██████████| 18/18 [01:22<00:00,  4.57s/it]
100%|██████████| 18/18 [01:24<00:00,  4.67s/it]
100%|██████████| 18/18 [02:11<00:00,  7.33s/it]
100%|██████████| 18/18 [01:30<00:00,  5.04s/it]
100%|██████████| 18/18 [01:38<00:00,  5.47s/it]
100%|██████████| 18/18 [02:08<00:00,  7.16s/it]
100%|██████████| 18/18 [02:05<00:00,  6.94s/it]


False, ami, speechbrain, best_value=23.401463724039672


100%|██████████| 216/216 [05:10<00:00,  1.44s/it]
100%|██████████| 216/216 [04:31<00:00,  1.26s/it]
100%|██████████| 216/216 [04:44<00:00,  1.32s/it]
100%|██████████| 216/216 [04:00<00:00,  1.11s/it]
100%|██████████| 216/216 [04:40<00:00,  1.30s/it]
100%|██████████| 216/216 [05:00<00:00,  1.39s/it]
100%|██████████| 216/216 [04:55<00:00,  1.37s/it]
100%|██████████| 216/216 [04:44<00:00,  1.32s/it]
100%|██████████| 216/216 [03:59<00:00,  1.11s/it]
100%|██████████| 216/216 [05:15<00:00,  1.46s/it]


False, voxconverse, brno, best_value=51.825062868715065


100%|██████████| 216/216 [04:50<00:00,  1.35s/it]
100%|██████████| 216/216 [04:34<00:00,  1.27s/it]
100%|██████████| 216/216 [03:57<00:00,  1.10s/it]
100%|██████████| 216/216 [04:42<00:00,  1.31s/it]
100%|██████████| 216/216 [03:54<00:00,  1.08s/it]
100%|██████████| 216/216 [04:03<00:00,  1.13s/it]
100%|██████████| 216/216 [04:37<00:00,  1.28s/it]
100%|██████████| 216/216 [05:07<00:00,  1.43s/it]
100%|██████████| 216/216 [04:49<00:00,  1.34s/it]
100%|██████████| 216/216 [04:49<00:00,  1.34s/it]


False, voxconverse, clova, best_value=61.64083502139015


100%|██████████| 215/215 [04:40<00:00,  1.31s/it]
100%|██████████| 215/215 [04:17<00:00,  1.20s/it]
100%|██████████| 215/215 [04:00<00:00,  1.12s/it]
100%|██████████| 215/215 [05:07<00:00,  1.43s/it]
100%|██████████| 215/215 [04:55<00:00,  1.37s/it]
100%|██████████| 215/215 [04:52<00:00,  1.36s/it]
100%|██████████| 215/215 [05:06<00:00,  1.43s/it]
100%|██████████| 215/215 [03:48<00:00,  1.06s/it]
100%|██████████| 215/215 [05:29<00:00,  1.53s/it]
100%|██████████| 215/215 [04:57<00:00,  1.38s/it]


False, voxconverse, speechbrain, best_value=49.754395141455504


100%|██████████| 21/21 [01:39<00:00,  4.74s/it]
100%|██████████| 21/21 [01:37<00:00,  4.62s/it]
100%|██████████| 21/21 [01:43<00:00,  4.95s/it]
100%|██████████| 21/21 [01:38<00:00,  4.69s/it]
100%|██████████| 21/21 [01:29<00:00,  4.24s/it]
100%|██████████| 21/21 [01:42<00:00,  4.86s/it]
100%|██████████| 21/21 [01:41<00:00,  4.84s/it]
100%|██████████| 21/21 [01:37<00:00,  4.66s/it]
100%|██████████| 21/21 [01:38<00:00,  4.68s/it]
100%|██████████| 21/21 [01:37<00:00,  4.64s/it]


False, aishell, brno, best_value=42.954312343929104


100%|██████████| 21/21 [01:35<00:00,  4.56s/it]
100%|██████████| 21/21 [01:24<00:00,  4.03s/it]
100%|██████████| 21/21 [01:22<00:00,  3.92s/it]
100%|██████████| 21/21 [01:24<00:00,  4.01s/it]
100%|██████████| 21/21 [01:29<00:00,  4.25s/it]
100%|██████████| 21/21 [01:35<00:00,  4.53s/it]
100%|██████████| 21/21 [01:40<00:00,  4.76s/it]
100%|██████████| 21/21 [01:21<00:00,  3.88s/it]
100%|██████████| 21/21 [01:23<00:00,  4.00s/it]
100%|██████████| 21/21 [01:32<00:00,  4.39s/it]


False, aishell, clova, best_value=61.08097195007674


100%|██████████| 21/21 [01:44<00:00,  4.97s/it]
100%|██████████| 21/21 [01:30<00:00,  4.33s/it]
100%|██████████| 21/21 [01:40<00:00,  4.78s/it]
100%|██████████| 21/21 [01:56<00:00,  5.56s/it]
100%|██████████| 21/21 [02:31<00:00,  7.24s/it]
100%|██████████| 21/21 [01:31<00:00,  4.35s/it]
100%|██████████| 21/21 [01:39<00:00,  4.74s/it]
100%|██████████| 21/21 [02:18<00:00,  6.59s/it]
100%|██████████| 21/21 [02:21<00:00,  6.74s/it]
100%|██████████| 21/21 [01:36<00:00,  4.58s/it]


False, aishell, speechbrain, best_value=41.09684183040109


### VBx clustering + Agglomerative clustering initialization

In [None]:
for metric in ['der', 'jer']:
    for skip_overlap in [True, False]:
        if metric == 'der' and skip_overlap == True:
            continue
        for dataset_name in ['ami', 'voxconverse', 'aishell']:
            if metric == 'der' and skip_overlap == False and dataset_name == 'ami':
                continue
            uri2ann_ref = np.load(
                    f'{ROOT}/annotations/{dataset_name}/{dataset_name}_dev_uri2ann_ref.npy', 
                    allow_pickle=True
                ).item()
            for embedder_name in ['brno', 'clova', 'speechbrain']:
                uri2data = np.load(
                        f'{ROOT}/embeddings/skip_overlap={str(skip_overlap)}/{dataset_name}/{dataset_name}_dev_uri2data_{embedder_name}.npy',
                        allow_pickle=True
                    ).item()
                ahc_params = np.load(
                        f'{ROOT}/best_params/{metric}/skip_overlap={str(skip_overlap)}/ahc/{dataset_name}_{embedder_name}.npy',
                        allow_pickle=True
                    ).item()
                best_params, best_value = get_ahc_vbx_params(embedder_name, uri2data, uri2ann_ref, ahc_params, skip_overlap=skip_overlap, metric=metric)
                best_params['value'] = best_value
                print(f'{metric}, {skip_overlap}, {dataset_name}, {embedder_name}, best_value={best_value}')
                np.save(f'{ROOT}/best_params/{metric}/skip_overlap={str(skip_overlap)}/ahc_vbx/{dataset_name}_{embedder_name}.npy', best_params)


100%|██████████| 21/21 [00:21<00:00,  1.04s/it]
100%|██████████| 21/21 [00:21<00:00,  1.04s/it]
100%|██████████| 21/21 [00:21<00:00,  1.04s/it]
100%|██████████| 21/21 [00:22<00:00,  1.05s/it]
100%|██████████| 21/21 [00:24<00:00,  1.18s/it]
100%|██████████| 21/21 [00:22<00:00,  1.06s/it]
100%|██████████| 21/21 [00:21<00:00,  1.04s/it]
100%|██████████| 21/21 [00:21<00:00,  1.03s/it]
100%|██████████| 21/21 [00:21<00:00,  1.02s/it]
100%|██████████| 21/21 [00:22<00:00,  1.05s/it]
100%|██████████| 21/21 [00:23<00:00,  1.13s/it]
100%|██████████| 21/21 [00:22<00:00,  1.05s/it]
100%|██████████| 21/21 [00:21<00:00,  1.03s/it]
100%|██████████| 21/21 [00:21<00:00,  1.04s/it]
100%|██████████| 21/21 [00:21<00:00,  1.03s/it]
100%|██████████| 21/21 [00:26<00:00,  1.25s/it]
100%|██████████| 21/21 [00:22<00:00,  1.09s/it]
100%|██████████| 21/21 [00:22<00:00,  1.05s/it]
100%|██████████| 21/21 [00:38<00:00,  1.81s/it]
100%|██████████| 21/21 [00:45<00:00,  2.18s/it]
100%|██████████| 21/21 [00:32<00:00,  1.

jer, True, aishell, brno, best_value=22.062033282386952


100%|██████████| 21/21 [00:27<00:00,  1.30s/it]
100%|██████████| 21/21 [00:26<00:00,  1.27s/it]
100%|██████████| 21/21 [00:25<00:00,  1.24s/it]
100%|██████████| 21/21 [00:25<00:00,  1.21s/it]
100%|██████████| 21/21 [00:25<00:00,  1.19s/it]
100%|██████████| 21/21 [00:25<00:00,  1.19s/it]
100%|██████████| 21/21 [00:26<00:00,  1.28s/it]
100%|██████████| 21/21 [00:26<00:00,  1.25s/it]
100%|██████████| 21/21 [00:25<00:00,  1.20s/it]
100%|██████████| 21/21 [00:25<00:00,  1.20s/it]
100%|██████████| 21/21 [00:25<00:00,  1.20s/it]
100%|██████████| 21/21 [00:24<00:00,  1.18s/it]
100%|██████████| 21/21 [00:26<00:00,  1.27s/it]
100%|██████████| 21/21 [00:25<00:00,  1.21s/it]
100%|██████████| 21/21 [00:25<00:00,  1.22s/it]
100%|██████████| 21/21 [00:24<00:00,  1.19s/it]
100%|██████████| 21/21 [00:24<00:00,  1.18s/it]
100%|██████████| 21/21 [00:24<00:00,  1.19s/it]
100%|██████████| 21/21 [00:25<00:00,  1.24s/it]
100%|██████████| 21/21 [00:25<00:00,  1.20s/it]
100%|██████████| 21/21 [00:25<00:00,  1.

jer, True, aishell, clova, best_value=20.83794965121084


100%|██████████| 21/21 [00:33<00:00,  1.59s/it]
100%|██████████| 21/21 [00:28<00:00,  1.35s/it]
100%|██████████| 21/21 [00:28<00:00,  1.37s/it]
100%|██████████| 21/21 [00:28<00:00,  1.34s/it]
100%|██████████| 21/21 [00:28<00:00,  1.34s/it]
100%|██████████| 21/21 [00:28<00:00,  1.33s/it]
100%|██████████| 21/21 [00:33<00:00,  1.59s/it]
100%|██████████| 21/21 [00:28<00:00,  1.35s/it]
100%|██████████| 21/21 [00:28<00:00,  1.36s/it]
100%|██████████| 21/21 [00:28<00:00,  1.33s/it]
100%|██████████| 21/21 [00:27<00:00,  1.32s/it]
100%|██████████| 21/21 [00:47<00:00,  2.26s/it]
100%|██████████| 21/21 [00:38<00:00,  1.83s/it]
100%|██████████| 21/21 [00:31<00:00,  1.51s/it]
100%|██████████| 21/21 [00:30<00:00,  1.44s/it]
100%|██████████| 21/21 [00:38<00:00,  1.82s/it]
100%|██████████| 21/21 [00:38<00:00,  1.83s/it]
100%|██████████| 21/21 [00:47<00:00,  2.26s/it]
100%|██████████| 21/21 [00:34<00:00,  1.63s/it]
100%|██████████| 21/21 [00:39<00:00,  1.86s/it]
100%|██████████| 21/21 [00:55<00:00,  2.

jer, True, aishell, speechbrain, best_value=17.016627904819952


100%|██████████| 18/18 [00:22<00:00,  1.22s/it]
100%|██████████| 18/18 [00:20<00:00,  1.16s/it]
100%|██████████| 18/18 [00:21<00:00,  1.17s/it]
100%|██████████| 18/18 [00:20<00:00,  1.15s/it]
100%|██████████| 18/18 [00:20<00:00,  1.12s/it]
100%|██████████| 18/18 [00:19<00:00,  1.06s/it]
100%|██████████| 18/18 [00:20<00:00,  1.13s/it]
100%|██████████| 18/18 [00:20<00:00,  1.15s/it]
100%|██████████| 18/18 [00:19<00:00,  1.07s/it]
100%|██████████| 18/18 [00:19<00:00,  1.09s/it]
100%|██████████| 18/18 [00:19<00:00,  1.06s/it]
100%|██████████| 18/18 [00:19<00:00,  1.06s/it]
100%|██████████| 18/18 [00:19<00:00,  1.07s/it]
100%|██████████| 18/18 [00:19<00:00,  1.08s/it]
100%|██████████| 18/18 [00:20<00:00,  1.11s/it]
100%|██████████| 18/18 [00:19<00:00,  1.10s/it]
100%|██████████| 18/18 [00:19<00:00,  1.07s/it]
100%|██████████| 18/18 [00:19<00:00,  1.06s/it]
100%|██████████| 18/18 [00:19<00:00,  1.07s/it]
100%|██████████| 18/18 [00:18<00:00,  1.04s/it]
100%|██████████| 18/18 [00:18<00:00,  1.

jer, False, ami, brno, best_value=20.598105630872755


100%|██████████| 18/18 [00:24<00:00,  1.34s/it]
100%|██████████| 18/18 [00:22<00:00,  1.27s/it]
100%|██████████| 18/18 [00:22<00:00,  1.25s/it]
100%|██████████| 18/18 [00:22<00:00,  1.26s/it]
100%|██████████| 18/18 [00:22<00:00,  1.26s/it]
100%|██████████| 18/18 [00:22<00:00,  1.25s/it]
100%|██████████| 18/18 [00:22<00:00,  1.27s/it]
100%|██████████| 18/18 [00:22<00:00,  1.22s/it]
100%|██████████| 18/18 [00:21<00:00,  1.22s/it]
100%|██████████| 18/18 [00:21<00:00,  1.21s/it]
100%|██████████| 18/18 [00:21<00:00,  1.20s/it]
100%|██████████| 18/18 [00:21<00:00,  1.19s/it]
100%|██████████| 18/18 [00:21<00:00,  1.21s/it]
100%|██████████| 18/18 [00:22<00:00,  1.23s/it]
100%|██████████| 18/18 [00:21<00:00,  1.19s/it]
100%|██████████| 18/18 [00:21<00:00,  1.20s/it]
100%|██████████| 18/18 [00:21<00:00,  1.20s/it]
100%|██████████| 18/18 [00:21<00:00,  1.19s/it]
100%|██████████| 18/18 [00:21<00:00,  1.21s/it]
100%|██████████| 18/18 [00:21<00:00,  1.19s/it]
100%|██████████| 18/18 [00:21<00:00,  1.

jer, False, ami, clova, best_value=23.116644040399052


100%|██████████| 18/18 [00:32<00:00,  1.80s/it]
100%|██████████| 18/18 [00:30<00:00,  1.71s/it]
100%|██████████| 18/18 [00:30<00:00,  1.70s/it]
100%|██████████| 18/18 [00:30<00:00,  1.69s/it]
100%|██████████| 18/18 [00:30<00:00,  1.72s/it]
100%|██████████| 18/18 [00:30<00:00,  1.67s/it]
100%|██████████| 18/18 [00:31<00:00,  1.74s/it]
100%|██████████| 18/18 [00:29<00:00,  1.61s/it]
100%|██████████| 18/18 [00:29<00:00,  1.64s/it]
100%|██████████| 18/18 [00:30<00:00,  1.69s/it]
100%|██████████| 18/18 [00:29<00:00,  1.65s/it]
100%|██████████| 18/18 [00:29<00:00,  1.66s/it]
100%|██████████| 18/18 [00:31<00:00,  1.76s/it]
100%|██████████| 18/18 [00:29<00:00,  1.64s/it]
100%|██████████| 18/18 [00:29<00:00,  1.62s/it]
100%|██████████| 18/18 [00:29<00:00,  1.62s/it]
100%|██████████| 18/18 [00:29<00:00,  1.62s/it]
100%|██████████| 18/18 [00:29<00:00,  1.63s/it]
100%|██████████| 18/18 [00:31<00:00,  1.74s/it]
100%|██████████| 18/18 [01:15<00:00,  4.17s/it]
100%|██████████| 18/18 [00:29<00:00,  1.

jer, False, ami, speechbrain, best_value=20.586756721427545


100%|██████████| 216/216 [00:44<00:00,  4.88it/s]
100%|██████████| 216/216 [00:44<00:00,  4.85it/s]
100%|██████████| 216/216 [00:44<00:00,  4.89it/s]
100%|██████████| 216/216 [00:45<00:00,  4.79it/s]
100%|██████████| 216/216 [00:45<00:00,  4.77it/s]
100%|██████████| 216/216 [00:47<00:00,  4.59it/s]
100%|██████████| 216/216 [00:45<00:00,  4.74it/s]
100%|██████████| 216/216 [00:43<00:00,  4.97it/s]
100%|██████████| 216/216 [00:44<00:00,  4.90it/s]
100%|██████████| 216/216 [00:44<00:00,  4.85it/s]
100%|██████████| 216/216 [00:45<00:00,  4.71it/s]
100%|██████████| 216/216 [00:45<00:00,  4.70it/s]
100%|██████████| 216/216 [00:43<00:00,  4.95it/s]
100%|██████████| 216/216 [00:48<00:00,  4.43it/s]
100%|██████████| 216/216 [00:45<00:00,  4.77it/s]
100%|██████████| 216/216 [00:45<00:00,  4.80it/s]
100%|██████████| 216/216 [00:47<00:00,  4.59it/s]
100%|██████████| 216/216 [00:45<00:00,  4.73it/s]
100%|██████████| 216/216 [00:45<00:00,  4.78it/s]
100%|██████████| 216/216 [00:46<00:00,  4.64it/s]


jer, False, voxconverse, brno, best_value=12.81348744348178


100%|██████████| 216/216 [00:46<00:00,  4.67it/s]
100%|██████████| 216/216 [00:44<00:00,  4.87it/s]
100%|██████████| 216/216 [00:45<00:00,  4.70it/s]
100%|██████████| 216/216 [00:46<00:00,  4.67it/s]
100%|██████████| 216/216 [00:45<00:00,  4.78it/s]
100%|██████████| 216/216 [00:46<00:00,  4.69it/s]
100%|██████████| 216/216 [00:44<00:00,  4.91it/s]
100%|██████████| 216/216 [00:43<00:00,  4.99it/s]
100%|██████████| 216/216 [00:44<00:00,  4.87it/s]
100%|██████████| 216/216 [00:43<00:00,  4.94it/s]
100%|██████████| 216/216 [00:45<00:00,  4.79it/s]
100%|██████████| 216/216 [00:45<00:00,  4.76it/s]
100%|██████████| 216/216 [00:45<00:00,  4.78it/s]
100%|██████████| 216/216 [00:46<00:00,  4.65it/s]
100%|██████████| 216/216 [00:45<00:00,  4.75it/s]
100%|██████████| 216/216 [00:44<00:00,  4.81it/s]
100%|██████████| 216/216 [00:44<00:00,  4.89it/s]
100%|██████████| 216/216 [00:43<00:00,  4.95it/s]
100%|██████████| 216/216 [00:43<00:00,  4.94it/s]
100%|██████████| 216/216 [00:44<00:00,  4.86it/s]


jer, False, voxconverse, clova, best_value=18.397437415138263


100%|██████████| 215/215 [00:43<00:00,  4.90it/s]
100%|██████████| 215/215 [00:44<00:00,  4.81it/s]
100%|██████████| 215/215 [00:45<00:00,  4.76it/s]
100%|██████████| 215/215 [00:44<00:00,  4.86it/s]
100%|██████████| 215/215 [00:44<00:00,  4.81it/s]
100%|██████████| 215/215 [00:42<00:00,  5.02it/s]
100%|██████████| 215/215 [00:44<00:00,  4.89it/s]
100%|██████████| 215/215 [00:44<00:00,  4.83it/s]
100%|██████████| 215/215 [00:48<00:00,  4.40it/s]
100%|██████████| 215/215 [00:44<00:00,  4.80it/s]
100%|██████████| 215/215 [00:44<00:00,  4.88it/s]
100%|██████████| 215/215 [00:44<00:00,  4.87it/s]
100%|██████████| 215/215 [00:44<00:00,  4.85it/s]
100%|██████████| 215/215 [00:52<00:00,  4.08it/s]
100%|██████████| 215/215 [00:50<00:00,  4.29it/s]
100%|██████████| 215/215 [00:48<00:00,  4.47it/s]
100%|██████████| 215/215 [00:47<00:00,  4.55it/s]
100%|██████████| 215/215 [00:49<00:00,  4.37it/s]
100%|██████████| 215/215 [00:44<00:00,  4.85it/s]
100%|██████████| 215/215 [00:51<00:00,  4.21it/s]


jer, False, voxconverse, speechbrain, best_value=12.050475208550104


100%|██████████| 21/21 [00:27<00:00,  1.30s/it]
100%|██████████| 21/21 [00:24<00:00,  1.18s/it]
100%|██████████| 21/21 [00:24<00:00,  1.18s/it]
100%|██████████| 21/21 [00:24<00:00,  1.15s/it]
100%|██████████| 21/21 [00:24<00:00,  1.15s/it]
100%|██████████| 21/21 [00:24<00:00,  1.16s/it]
100%|██████████| 21/21 [00:27<00:00,  1.31s/it]
100%|██████████| 21/21 [00:24<00:00,  1.17s/it]
100%|██████████| 21/21 [00:23<00:00,  1.14s/it]
100%|██████████| 21/21 [00:38<00:00,  1.85s/it]
100%|██████████| 21/21 [00:40<00:00,  1.92s/it]
100%|██████████| 21/21 [00:38<00:00,  1.81s/it]
100%|██████████| 21/21 [00:26<00:00,  1.27s/it]
100%|██████████| 21/21 [00:24<00:00,  1.16s/it]
100%|██████████| 21/21 [00:38<00:00,  1.83s/it]
100%|██████████| 21/21 [00:37<00:00,  1.78s/it]
100%|██████████| 21/21 [00:35<00:00,  1.70s/it]
100%|██████████| 21/21 [01:01<00:00,  2.91s/it]
100%|██████████| 21/21 [00:26<00:00,  1.25s/it]
100%|██████████| 21/21 [00:38<00:00,  1.83s/it]
100%|██████████| 21/21 [00:38<00:00,  1.

jer, False, aishell, brno, best_value=26.160833512779426


100%|██████████| 21/21 [00:39<00:00,  1.86s/it]
100%|██████████| 21/21 [00:37<00:00,  1.78s/it]
100%|██████████| 21/21 [00:37<00:00,  1.78s/it]
100%|██████████| 21/21 [00:37<00:00,  1.80s/it]
100%|██████████| 21/21 [00:37<00:00,  1.79s/it]
100%|██████████| 21/21 [00:36<00:00,  1.75s/it]
100%|██████████| 21/21 [00:38<00:00,  1.82s/it]
100%|██████████| 21/21 [00:37<00:00,  1.78s/it]
100%|██████████| 21/21 [00:36<00:00,  1.76s/it]
100%|██████████| 21/21 [00:35<00:00,  1.71s/it]
100%|██████████| 21/21 [00:35<00:00,  1.69s/it]
100%|██████████| 21/21 [00:52<00:00,  2.50s/it]
100%|██████████| 21/21 [00:38<00:00,  1.82s/it]
100%|██████████| 21/21 [00:37<00:00,  1.76s/it]
100%|██████████| 21/21 [00:36<00:00,  1.75s/it]
100%|██████████| 21/21 [00:44<00:00,  2.11s/it]
100%|██████████| 21/21 [00:44<00:00,  2.11s/it]
100%|██████████| 21/21 [00:52<00:00,  2.51s/it]
100%|██████████| 21/21 [00:38<00:00,  1.82s/it]
100%|██████████| 21/21 [00:36<00:00,  1.74s/it]
100%|██████████| 21/21 [00:45<00:00,  2.

jer, False, aishell, clova, best_value=24.904783202152558


100%|██████████| 21/21 [00:31<00:00,  1.50s/it]
100%|██████████| 21/21 [00:27<00:00,  1.29s/it]
100%|██████████| 21/21 [00:26<00:00,  1.27s/it]
100%|██████████| 21/21 [00:26<00:00,  1.27s/it]
100%|██████████| 21/21 [00:26<00:00,  1.25s/it]
100%|██████████| 21/21 [00:26<00:00,  1.25s/it]
100%|██████████| 21/21 [00:30<00:00,  1.46s/it]
100%|██████████| 21/21 [00:27<00:00,  1.30s/it]
100%|██████████| 21/21 [00:26<00:00,  1.27s/it]
100%|██████████| 21/21 [00:26<00:00,  1.26s/it]
100%|██████████| 21/21 [00:26<00:00,  1.24s/it]
100%|██████████| 21/21 [00:26<00:00,  1.24s/it]
100%|██████████| 21/21 [00:30<00:00,  1.45s/it]
100%|██████████| 21/21 [00:26<00:00,  1.27s/it]
100%|██████████| 21/21 [00:26<00:00,  1.28s/it]
100%|██████████| 21/21 [00:25<00:00,  1.23s/it]
100%|██████████| 21/21 [00:32<00:00,  1.56s/it]
100%|██████████| 21/21 [00:39<00:00,  1.87s/it]
100%|██████████| 21/21 [00:30<00:00,  1.46s/it]
100%|██████████| 21/21 [00:26<00:00,  1.27s/it]
100%|██████████| 21/21 [00:26<00:00,  1.

jer, False, aishell, speechbrain, best_value=28.248225771634676


## Test Clustering

In [None]:
def get_value(metric, report):
    if metric == 'der':
        return report.loc['TOTAL', 'diarization error rate'].values[0]
    elif metric == 'jer':
        return report.loc['TOTAL', 'jaccard error rate'].values[0]

In [None]:
import pandas as pd
it = 1
for metric in ['der']: # ['der', 'jer']:
    for skip_overlap in [True, False]:
        for dataset_name in ['ami', 'voxconverse', 'aishell']:
            uri2ann_ref = np.load(
                    f'{ROOT}/annotations/{dataset_name}/{dataset_name}_test_uri2ann_ref.npy', 
                    allow_pickle=True
                ).item()
            for embedder_name in ['brno', 'clova', 'speechbrain']:
                print(f'{it}: {metric}, {skip_overlap}, {dataset_name}, {embedder_name}')
                it += 1
                uri2data = np.load(
                        f'{ROOT}/embeddings/skip_overlap={str(skip_overlap)}/{dataset_name}/{dataset_name}_test_uri2data_{embedder_name}.npy',
                        allow_pickle=True
                    ).item()


                # ahc clustering
                ahc_params = np.load(
                        f'{ROOT}/best_params/{metric}/skip_overlap={str(skip_overlap)}/ahc/{dataset_name}_{embedder_name}.npy',
                        allow_pickle=True
                    ).item()
                ahc_params.pop('value')
                report, uri2ann_hyp = get_ahc_clustering(embedder_name, uri2data, uri2ann_ref, ahc_params, skip_overlap=skip_overlap, metric=metric, return_hyp=True)
                print("AHC:", get_value(metric, report))
                np.save(f'{ROOT}/ann_hyp/{metric}/skip_overlap={str(skip_overlap)}/ahc/{dataset_name}_{embedder_name}.npy', uri2ann_hyp)


                # spectrul clustering
                report, uri2ann_hyp = get_spectrul_clustering(embedder_name, uri2data, uri2ann_ref, skip_overlap=skip_overlap, metric=metric, return_hyp=True)
                print("SpectrulClustering:", get_value(metric, report), end=", ")
                np.save(f'{ROOT}/ann_hyp/{metric}/skip_overlap={str(skip_overlap)}/spectrul/{dataset_name}_{embedder_name}.npy', uri2ann_hyp)


                # vbx clustering
                vbx_params = np.load(
                        f'{ROOT}/best_params/{metric}/skip_overlap={str(skip_overlap)}/vbx/{dataset_name}_{embedder_name}.npy',
                        allow_pickle=True
                    ).item()
                vbx_params.pop('value')
                report, uri2ann_hyp = get_vbx_clustering(embedder_name, uri2data, uri2ann_ref, vbx_params, skip_overlap=skip_overlap, metric=metric, return_hyp=True)
                print("Vbx:", get_value(metric, report), end=", ")
                np.save(f'{ROOT}/ann_hyp/{metric}/skip_overlap={str(skip_overlap)}/vbx/{dataset_name}_{embedder_name}.npy', uri2ann_hyp)


                # ahc_vbx clustering
                ahc_vbx_params = np.load(
                        f'{ROOT}/best_params/{metric}/skip_overlap={str(skip_overlap)}/ahc_vbx/{dataset_name}_{embedder_name}.npy',
                        allow_pickle=True
                    ).item()
                ahc_vbx_params.pop('value')
                report, uri2ann_hyp = get_ahc_vbx_clustering(embedder_name, uri2data, uri2ann_ref, ahc_params, ahc_vbx_params, skip_overlap=skip_overlap, metric=metric, return_hyp=True)
                print("AHC_Vbx:", get_value(metric, report))
                np.save(f'{ROOT}/ann_hyp/{metric}/skip_overlap={str(skip_overlap)}/ahc_vbx/{dataset_name}_{embedder_name}.npy', uri2ann_hyp)

# Performance metrics

In [None]:
indexes = ['ami_brno', 'ami_clova', 'ami_speechbrain', 'aishell_brno', 'aishell_clova', 'aishell_speechbrain', 'voxconverse_brno', 'voxconverse_clova', 'voxconverse_speechbrain']

def get_results(metric, skip_overlap):
    d = {
        'ahc': [],
        'vbx': [],
        'ahc_vbx': [],
        'spectrul': []
    }
    for dataset_name in ['ami', 'aishell', 'voxconverse']:
        uri2ann_ref = np.load(
                              f'{ROOT}/annotations/{dataset_name}/{dataset_name}_test_uri2ann_ref.npy', 
                              allow_pickle=True
                          ).item()
        for embedder_name in ['brno', 'clova', 'speechbrain']:
            for clst in ['ahc', 'spectrul', 'vbx', 'ahc_vbx']:
                uri2ann_hyp = np.load(
                                  f'{ROOT}/ann_hyp/{metric}/skip_overlap={str(skip_overlap)}/{clst}/{dataset_name}_{embedder_name}.npy',
                                  allow_pickle=True
                              ).item()

                report = get_clst_metric_res(metric, uri2ann_ref, uri2ann_hyp, skip_overlap=skip_overlap)
                d[clst].append(get_value(metric, report))
    return d

### Diarization Error Rate (DER)

In [None]:
import pandas as pd
d = get_results('der', True)
np.save(f'{ROOT}/der_True.npy', d)
df = pd.DataFrame(d, index=indexes)
df

Unnamed: 0,ahc,vbx,ahc_vbx,spectrul
ami_brno,5.737887,6.265993,3.235103,3.569658
ami_clova,5.051559,7.725323,3.703079,3.526255
ami_speechbrain,6.550472,5.514422,2.933022,3.376954
aishell_brno,22.434028,8.029049,21.730432,9.375953
aishell_clova,10.374267,9.244718,6.931529,13.335255
aishell_speechbrain,22.966473,6.588633,22.35063,10.622688
voxconverse_brno,4.655823,11.948572,3.411799,14.594538
voxconverse_clova,7.222782,15.360753,6.510049,14.39004
voxconverse_speechbrain,4.874131,11.671924,3.529485,13.84818


In [None]:
import pandas as pd
d = get_results('der', False)
np.save(f'{ROOT}/der_False.npy', d)
df = pd.DataFrame(d, index=indexes)
df

Unnamed: 0,ahc,vbx,ahc_vbx,spectrul
ami_brno,19.981081,21.599858,17.912833,17.737017
ami_clova,22.782156,22.099241,19.286966,17.807702
ami_speechbrain,21.469331,19.646412,17.248639,17.62736
aishell_brno,13.194266,12.88094,9.86914,13.880085
aishell_clova,16.963574,14.230856,14.741873,17.62204
aishell_speechbrain,16.853835,11.345442,11.536344,15.141068
voxconverse_brno,8.719711,15.83669,6.2862,17.022805
voxconverse_clova,9.100599,19.466095,7.222013,16.866941
voxconverse_speechbrain,8.377691,15.478372,7.144937,16.350528


### Jaccard Error Rate (JER)

In [None]:
import pandas as pd
d = get_results('jer', True)
np.save(f'{ROOT}/jer_True.npy', d)
df = pd.DataFrame(d, index=indexes)
df

Unnamed: 0,ahc,vbx,ahc_vbx,spectrul
ami_brno,11.157421,17.253626,9.866726,13.352156
ami_clova,13.061978,19.8951,11.032688,13.053777
ami_speechbrain,10.327392,18.692219,8.496096,12.974931
aishell_brno,16.251517,21.767584,14.218811,34.581305
aishell_clova,18.37138,31.400866,15.090445,39.720742
aishell_speechbrain,16.028252,24.228364,12.092401,35.440133
voxconverse_brno,17.767039,60.581274,15.860347,57.402045
voxconverse_clova,24.282323,66.064301,23.545159,57.417113
voxconverse_speechbrain,14.805727,57.740555,14.814383,56.462588


In [None]:
import pandas as pd
d = get_results('jer', False)
np.save(f'{ROOT}/jer_False.npy', d)
df = pd.DataFrame(d, index=indexes)
df

Unnamed: 0,ahc,vbx,ahc_vbx,spectrul
ami_brno,23.849321,30.040257,22.918917,24.847648
ami_clova,27.590221,35.263031,25.38364,24.783689
ami_speechbrain,26.458441,29.921606,22.270399,24.755405
aishell_brno,18.922774,26.64072,15.962324,37.450323
aishell_clova,20.113635,45.441097,14.677727,42.081272
aishell_speechbrain,21.954604,23.170361,18.623527,38.354064
voxconverse_brno,18.491692,59.155198,18.90091,59.101058
voxconverse_clova,23.148184,68.111908,27.563112,59.196431
voxconverse_speechbrain,17.954066,56.90699,17.793737,58.262182
