# Fine-tuning the pyannote.audio diarization model

In this notebook, we evaluate the performance of the pyannote.audio 3.1.1 diarization audio on our Khanty dataset of elicitation sessions. It reaches 38.9% DER. We tune hyperparameters and train a new segmentation to improve the diarization result. We reach 26.3% DER and 28.9% JER on the test dataset with the fine-tuned diarization model. This notebook is adapted from one of the recipes provided by the pyannote.audio library https://github.com/pyannote/pyannote-audio/blob/develop/tutorials/adapting_pretrained_pipeline.ipynb (see for more detailed code explanation)

In [None]:
# Installing libraries
!pip install -qq pyannote.audio==3.1.1
!pip install -qq ipython==7.34.0

[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m208.7/208.7 kB[0m [31m3.8 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m43.2/43.2 kB[0m [31m1.8 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m2.0/2.0 MB[0m [31m11.9 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m79.5/79.5 kB[0m [31m9.4 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m58.5/58.5 kB[0m [31m6.6 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m48.1/48.1 kB[0m [31m5.5 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m51.4/51.4 kB[0m [31m6.6 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m119.1/119.1 kB[0m [31m2.2 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━

In [None]:
from google.colab import drive

drive.mount('/content/gdrive')

Mounted at /content/gdrive


In [None]:
#Loading the .yaml protocol containing paths to .lst, .rttm and .uem train, dev and test files
from pyannote.database import registry, FileFinder

registry.load_database("/content/gdrive/MyDrive/diarizationcorpora/database.yml")
dataset = registry.get_protocol("MyDatabase.SpeakerDiarization.MyProtocol", {"audio": FileFinder()})

'MyDatabase.SpeakerDiarization.MyProtocol' found in /content/gdrive/MyDrive/diarizationcorpora/database.yml does not define the 'scope' of speaker labels (file, database, or global). Setting it to 'file'.


In [None]:
from huggingface_hub import notebook_login
notebook_login()

VBox(children=(HTML(value='<center> <img\nsrc=https://huggingface.co/front/assets/huggingface_logo-noborder.sv…

In [None]:
#Loading the pretrained diarization model
from pyannote.audio import Pipeline
pretrained_pipeline = Pipeline.from_pretrained("pyannote/speaker-diarization-3.1", use_auth_token='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')

In [None]:
#Evaluating the performane of the pretrained diarization model on our test data
from pyannote.audio import Audio
from pyannote.metrics.diarization import DiarizationErrorRate
from pyannote.database.util import load_rttm

metric = DiarizationErrorRate()
der_metric = []
for file in dataset.test():
    # apply pretrained pipeline
    io = Audio(mono='downmix', sample_rate=16000)
    waveform, sample_rate = io('/content/gdrive/MyDrive/diarizationcorpora/audio_final/' + file['uri'] + '.wav')

    diarization = pretrained_pipeline({"waveform": waveform, "sample_rate": sample_rate})
    _, groundtruth = load_rttm('/content/gdrive/MyDrive/diarizationcorpora/' + file['uri'] + '.rttm').popitem()

    der = metric(groundtruth, diarization)
    der_metric.append(der)

der_test = sum(der_metric) / len(der_metric)
print(f"The pretrained pipeline reaches a Diarization Error Rate (DER) of {100 * der_test:.1f}% on {dataset.name} test set.")

The pretrained pipeline reaches a Diarization Error Rate (DER) of 38.9% on MyDatabase.SpeakerDiarization.MyProtocol test set.


In [None]:
# Loading the pretrained segmentation model
from pyannote.audio import Model
model = Model.from_pretrained("pyannote/segmentation-3.0", use_auth_token=True)

In [None]:
from pyannote.audio.tasks import Segmentation
task = Segmentation(
    dataset,
    duration=model.specifications.duration,
    max_num_speakers=len(model.specifications.classes),
    batch_size=32,
    num_workers=2,
    loss="bce",
    vad_loss="bce")
model.task = task
model.setup(stage="fit")

In [None]:
#Training new segmentation model
from types import MethodType
from torch.optim import Adam
from pytorch_lightning.callbacks import (
    EarlyStopping,
    ModelCheckpoint,
    RichProgressBar,
)

def configure_optimizers(self):
    return Adam(self.parameters(), lr=1e-4)

model.configure_optimizers = MethodType(configure_optimizers, model)

monitor, direction = task.val_monitor
checkpoint = ModelCheckpoint(
    monitor=monitor,
    mode=direction,
    save_top_k=1,
    every_n_epochs=1,
    save_last=False,
    save_weights_only=False,
    filename="{epoch}",
    verbose=False,
)
early_stopping = EarlyStopping(
    monitor=monitor,
    mode=direction,
    min_delta=0.0,
    patience=10,
    strict=True,
    verbose=False,
)

callbacks = [RichProgressBar(), checkpoint, early_stopping]

from pytorch_lightning import Trainer
trainer = Trainer(devices = 1, accelerator="gpu",
                  callbacks=callbacks,
                  max_epochs=5,
                  gradient_clip_val=0.5)
trainer.fit(model)

In [None]:
finetuned_model = checkpoint.best_model_path

from google.colab import files
files.download(finetuned_model)

In [None]:
pretrained_hyperparameters = pretrained_pipeline.parameters(instantiated=True)
pretrained_hyperparameters

{'segmentation': {'min_duration_off': 0.0},
 'clustering': {'method': 'centroid',
  'min_cluster_size': 12,
  'threshold': 0.7045654963945799}}

In [None]:
#Tuning the segmentation threshold
from pyannote.audio.pipelines import SpeakerDiarization
from pyannote.pipeline import Optimizer

pipeline = SpeakerDiarization(
    segmentation=finetuned_model,
    clustering="OracleClustering",
)
# as reported in the technical report, min_duration_off can safely be set to 0.0
pipeline.freeze({"segmentation": {"min_duration_off": 0.0}})

optimizer = Optimizer(pipeline)
dev_set = list(dataset.development())

iterations = optimizer.tune_iter(dev_set, show_progress=False)
best_loss = 1.0
for i, iteration in enumerate(iterations):
    print(f"Best segmentation threshold so far: {iteration['params']['segmentation']['threshold']}")
    if i > 20: break   # 50 iterations should give slightly better results

Best segmentation threshold so far: 0.6219445152738883
Best segmentation threshold so far: 0.6219445152738883
Best segmentation threshold so far: 0.7059725947191872
Best segmentation threshold so far: 0.7059725947191872
Best segmentation threshold so far: 0.7059725947191872
Best segmentation threshold so far: 0.7059725947191872
Best segmentation threshold so far: 0.7059725947191872
Best segmentation threshold so far: 0.7059725947191872
Best segmentation threshold so far: 0.7059725947191872
Best segmentation threshold so far: 0.7059725947191872
Best segmentation threshold so far: 0.7059725947191872
Best segmentation threshold so far: 0.7059725947191872
Best segmentation threshold so far: 0.7059725947191872
Best segmentation threshold so far: 0.7059725947191872
Best segmentation threshold so far: 0.7059725947191872
Best segmentation threshold so far: 0.7059725947191872
Best segmentation threshold so far: 0.7059725947191872
Best segmentation threshold so far: 0.7059725947191872
Best segme

In [None]:
best_segmentation_threshold = optimizer.best_params["segmentation"]["threshold"]

In [None]:
best_segmentation_threshold

0.7059725947191872

In [None]:
#Tuning the clustering threshold
pipeline = SpeakerDiarization(
    segmentation=finetuned_model,
    embedding=pretrained_pipeline.embedding,
    embedding_exclude_overlap=pretrained_pipeline.embedding_exclude_overlap,
    clustering=pretrained_pipeline.klustering,
)

pipeline.freeze({
    "segmentation": {
        "threshold": best_segmentation_threshold,
        "min_duration_off": 0.0,
    },
    "clustering": {
        "method": "centroid",
        "min_cluster_size": 15,
    },
})

optimizer = Optimizer(pipeline)
iterations = optimizer.tune_iter(dev_set, show_progress=True)
best_loss = 1.0
for i, iteration in enumerate(iterations):
    print(f"Best clustering threshold so far: {iteration['params']['clustering']['threshold']}")
    if i > 20: break  # 50 iterations should give slightly better results


In [None]:
best_clustering_threshold = 0.593157817904834

In [None]:
#Testing the performance of the fine-tuned diarization model (DER)
from pyannote.audio.pipelines import SpeakerDiarization
from pyannote.metrics.diarization import DiarizationErrorRate
from pyannote.audio import Audio
from pyannote.database.util import load_rttm
finetuned_pipeline = SpeakerDiarization(
    segmentation=finetuned_model,
    embedding=pretrained_pipeline.embedding,
    embedding_exclude_overlap=pretrained_pipeline.embedding_exclude_overlap,
    clustering=pretrained_pipeline.klustering,
)

finetuned_pipeline.instantiate({
    "segmentation": {
        "threshold": 0.7059725947191872,
        "min_duration_off": 0.0,
    },
    "clustering": {
        "method": "centroid",
        "min_cluster_size": 15,
        "threshold": 0.593157817904834,
    },
})

metric = DiarizationErrorRate()
finetuned_pipeline = finetuned_pipeline.to(torch.device('cuda:0'))
der_metric = []
for file in dataset.test():
    io = Audio(mono='downmix', sample_rate=16000)
    waveform, sample_rate = io('/content/gdrive/MyDrive/diarizationcorpora/audio_final/' + file['uri'] + '.wav')

    diarization = finetuned_pipeline({"waveform": waveform, "sample_rate": sample_rate})
    _, groundtruth = load_rttm('/content/gdrive/MyDrive/diarizationcorpora/' + file['uri'] + '.rttm').popitem()

    der = metric(groundtruth, diarization)
    print(der)
    der_metric.append(der)

der_test = sum(der_metric) / len(der_metric)
print(f"The pretrained pipeline reaches a Diarization Error Rate (DER) of {100 * der_test:.1f}% on {dataset.name} test set.")

0.302837686064196
0.2424699739069177
0.26000320976463054
0.27901151689785364
0.24322383634936465
0.25237569660388465
The pretrained pipeline reaches a Diarization Error Rate (DER) of 26.3% on MyDatabase.SpeakerDiarization.MyProtocol test set.


In [None]:
#Testing the performance of the fine-tuned diarization model (JER)
from pyannote.audio.pipelines import SpeakerDiarization
from pyannote.metrics.diarization import DiarizationErrorRate
from pyannote.metrics.diarization import JaccardErrorRate
from pyannote.audio import Audio
from pyannote.database.util import load_rttm
import torch
finetuned_pipeline = SpeakerDiarization(
    segmentation=finetuned_model,
    embedding=pretrained_pipeline.embedding,
    embedding_exclude_overlap=pretrained_pipeline.embedding_exclude_overlap,
    clustering=pretrained_pipeline.klustering,
)

finetuned_pipeline.instantiate({
    "segmentation": {
        "threshold": 0.7059725947191872,
        "min_duration_off": 0.0,
    },
    "clustering": {
        "method": "centroid",
        "min_cluster_size": 15,
        "threshold": 0.593157817904834,
    },
})

metric = JaccardErrorRate()
finetuned_pipeline = finetuned_pipeline.to(torch.device('cuda:0'))
jer_metric = []
for file in dataset.test():
    io = Audio(mono='downmix', sample_rate=16000)
    waveform, sample_rate = io('/content/gdrive/MyDrive/diarizationcorpora/audio_final/' + file['uri'] + '.wav')

    diarization = finetuned_pipeline({"waveform": waveform, "sample_rate": sample_rate})
    _, groundtruth = load_rttm('/content/gdrive/MyDrive/diarizationcorpora/' + file['uri'] + '.rttm').popitem()

    jer = metric(groundtruth, diarization)
    print(jer)
    jer_metric.append(jer)

jer_test = sum(jer_metric) / len(jer_metric)
print(f"The pretrained pipeline reaches a Jaccard Error Rate (DER) of {100 * jer_test:.1f}% on {dataset.name} test set.")

0.31782997060578594
0.25501198746748016
0.29619764755978145
0.2980510361107672
0.2575446188582464
0.31003656409448366
The pretrained pipeline reaches a Jaccard Error Rate (DER) of 28.9% on MyDatabase.SpeakerDiarization.MyProtocol test set.
