Use this notebook for generating VAD results on a selected dataset

## Basic Imports

In [1]:
from pyannote.audio import Pipeline
from pyannote.core import SlidingWindowFeature, Annotation, Segment
from pyannote.metrics.detection import DetectionAccuracy, DetectionErrorRate
from pyannote.audio.pipelines import VoiceActivityDetection
from pyannote.audio import Model
from sklearn.metrics import roc_auc_score
from silero_vad import load_silero_vad, read_audio, get_speech_timestamps

import numpy as np
from dataclasses import dataclass
import os
import torch
from enum import Enum
import pandas as pd

import yaml

  from .autonotebook import tqdm as notebook_tqdm
INFO:speechbrain.utils.quirks:Applied quirks (see `speechbrain.utils.quirks`): [allow_tf32, disable_jit_profiling]
INFO:speechbrain.utils.quirks:Excluded quirks specified by the `SB_DISABLE_QUIRKS` environment (comma-separated list): []


## Testing pre-trained models on DIHARD III

To use pyannote, sign up for a auth token with hugging face first. Save the auth token in a config.yaml file to use it (below cell does that to load the auth token)

Follow this [link](https://huggingface.co/pyannote/segmentation-3.0#requirements) for more info.

In [2]:
# For pyannote/segmentation-3.0 usage
with open("../configs/config.yaml", 'r') as file:
    config = yaml.safe_load(file)

AUTH_TOKEN = config.get('AUTH_TOKEN')

In [3]:
# filepath for DIHARD 3 eval data, change this
DIHARD_FLAC_FILEPATH = "../../original_dihard3_dataset/third_dihard_challenge_eval/data/flac/"
DIHARD_LAB_FILEPATH = "../../original_dihard3_dataset/third_dihard_challenge_eval/data/sad/"
DIHARD_ENHANCED_FILEPATH = "../../speech_enhanced_dihard3/MP-SENet_dihard3/eval"

## `VADResults` class

Generic VAD wrapper class to calculate metrics. Utilizes pyannote's underlying utils for the calculation of metrics. As long as other VAD models (e.g. Silero) handles generating speech timestamps and converting into an annotation object, the `VADResults` class should be able to generate relevant metrics.

In [4]:
class VADType(str, Enum):
    SILERO = "silero"
    PYANNOTE = "pyannote"

def load_lab_file(lab_file_path):
    lab_annotation = Annotation()
    with open(lab_file_path, "r") as lab_file:
        for line in lab_file:
            start, end, label = line.strip().split()
            if label == 'speech':  # Only marking speech segments
                lab_annotation[Segment(float(start), float(end))] = "speech"
    return lab_annotation

def annotation_to_frame_labels(annotation, total_duration, frame_duration):
    num_frames = int(total_duration / frame_duration)
    labels = np.zeros(num_frames)

    for segment in annotation.get_timeline():
        start_frame = int(segment.start / frame_duration)
        end_frame = int(segment.end / frame_duration)
        labels[start_frame:end_frame] = 1

    return labels

@dataclass
class VADResults:
    sound_file_path: str
    label_file_path: str

    # metrics
    detection_accuracy: float = 0.0
    detection_error_rate_value: float = 0.0
    missed_detection_rate: float = 0.0
    false_alarm_rate: float = 0.0
    roc_auc: float = 0.0

    vad_result: SlidingWindowFeature = None

    frame_duration = 0.01

    def load_audio_and_vad(self, vad_model: VADType):
        if vad_model == VADType.PYANNOTE:
            model = Model.from_pretrained(
                "pyannote/segmentation-3.0", use_auth_token=AUTH_TOKEN
            )
            pipeline = VoiceActivityDetection(segmentation=model)
            HYPER_PARAMETERS = {
                # remove speech regions shorter than that many seconds.
                "min_duration_on": 0.0,
                # fill non-speech regions shorter than that many seconds.
                "min_duration_off": 0.0,
            }
            pipeline.instantiate(HYPER_PARAMETERS)
            if torch.cuda.is_available():
                pipeline.to(torch.device('cuda'))
            self.vad_result = pipeline(self.sound_file_path)
        elif vad_model == VADType.SILERO:
            model = load_silero_vad()
            wav = read_audio(self.sound_file_path)
            speech_timestamps = get_speech_timestamps(
                wav,
                model,
                return_seconds=True,  # Return speech timestamps in seconds (default is samples)
            )
            annotation = Annotation()
            for ts in speech_timestamps:
                # Create a segment for the start and end times
                segment = Segment(ts['start'], ts['end'])
                # Add the segment to the annotation with a label (e.g., "speech")
                annotation[segment] = "speech"
                self.vad_result = annotation

    def calcMetrics(self):
        # Step 2: Load ground truth labels from the .lab file
        ground_truth = load_lab_file(self.label_file_path)
        # print("Ground truth loaded:", ground_truth)

        detection_accuracy = DetectionAccuracy()

        # Step 3: Initialize detection error rate metric
        detection_error_rate = DetectionErrorRate()

        # Step 4: Compute Detection Error Rate (DER)
        self.detection_accuracy = detection_accuracy(ground_truth, self.vad_result)
        # print("Detection accuracy:", self.detection_accuracy)

        self.detection_error_rate_value = detection_error_rate(
            ground_truth, self.vad_result
        )
        # print("Detection error rate:", self.detection_error_rate_value)


        # Step 5: Compute missed detection and false alarm rates
        detailed_metrics = detection_error_rate.compute_components(
            ground_truth, self.vad_result
        )
        # print("Detailed metrics:", detailed_metrics)

        missed_detection_duration = detailed_metrics["miss"]
        false_alarm_duration = detailed_metrics["false alarm"]
        total_reference_duration = detailed_metrics["total"]

        self.missed_detection_rate = (
            (missed_detection_duration / total_reference_duration) * 100
            if total_reference_duration != 0
            else 0
        )
        self.false_alarm_rate = (
            (false_alarm_duration / total_reference_duration) * 100
            if total_reference_duration != 0
            else 0
        )

        # Step 6: Compute ROC-AUC score
        total_duration = self.vad_result.get_timeline().extent().duration
        ground_truth_labels = annotation_to_frame_labels(
            ground_truth, total_duration, self.frame_duration
        )
        vad_labels = annotation_to_frame_labels(
            self.vad_result, total_duration, self.frame_duration
        )
        self.roc_auc = roc_auc_score(ground_truth_labels, vad_labels)

In [5]:
# Get list of test files and corresponding label files
flac_files = [f for f in os.listdir(DIHARD_FLAC_FILEPATH) if f.endswith(".flac")]
lab_files = [f for f in os.listdir(DIHARD_LAB_FILEPATH) if f.endswith(".lab")]
enhanced_vad_wav_files = [f for f in os.listdir(DIHARD_ENHANCED_FILEPATH) if f.endswith(".wav")]

# Ensure matching files
flac_files.sort()
lab_files.sort()
enhanced_vad_wav_files.sort()

In [None]:
# Define lists to store results for different models
results_pyannote_original = []
results_pyannote_enhanced = []
results_silero_original = []
results_silero_enhanced = []

# Define the VAD models
vad_models = ["pyannote", "silero"]

for (flac_file, enhanced_vad_file, lab_file) in zip(flac_files, enhanced_vad_wav_files, lab_files):
    
    # File paths
    flac_path = os.path.join(DIHARD_FLAC_FILEPATH, flac_file)
    enhanced_vad_path = os.path.join(DIHARD_ENHANCED_FILEPATH, enhanced_vad_file)
    lab_path = os.path.join(DIHARD_LAB_FILEPATH, lab_file)

    for vad_model in vad_models:
        # Process original VAD
        vad_result = VADResults(sound_file_path=flac_path, label_file_path=lab_path)
        vad_result.load_audio_and_vad(vad_model=vad_model)
        vad_result.calcMetrics()

        result = {
            'flac_file': flac_file,
            'vad_model': vad_model,
            'detection_accuracy': vad_result.detection_accuracy,
            'detection_error_rate': vad_result.detection_error_rate_value,
            'missed_detection_rate': vad_result.missed_detection_rate,
            'false_alarm_rate': vad_result.false_alarm_rate,
            'roc_auc': vad_result.roc_auc
        }

        if vad_model == "pyannote":
            results_pyannote_original.append(result)
        else:
            results_silero_original.append(result)

        print(f"Appended VAD result for {flac_file} using {vad_model}")

        # Process enhanced VAD
        enhanced_vad_result = VADResults(sound_file_path=enhanced_vad_path, label_file_path=lab_path)
        enhanced_vad_result.load_audio_and_vad(vad_model=vad_model)
        enhanced_vad_result.calcMetrics()

        enhanced_result = {
            'flac_file': enhanced_vad_file,
            'vad_model': vad_model,
            'detection_accuracy': enhanced_vad_result.detection_accuracy,
            'detection_error_rate': enhanced_vad_result.detection_error_rate_value,
            'missed_detection_rate': enhanced_vad_result.missed_detection_rate,
            'false_alarm_rate': enhanced_vad_result.false_alarm_rate,
            'roc_auc': enhanced_vad_result.roc_auc
        }

        if vad_model == "pyannote":
            results_pyannote_enhanced.append(enhanced_result)
        else:
            results_silero_enhanced.append(enhanced_result)

        print(f"Appended Enhanced VAD result for {enhanced_vad_file} using {vad_model}")

# Convert lists to DataFrames and save as CSV files
df_pyannote_original = pd.DataFrame(results_pyannote_original)
df_pyannote_enhanced = pd.DataFrame(results_pyannote_enhanced)
df_silero_original = pd.DataFrame(results_silero_original)
df_silero_enhanced = pd.DataFrame(results_silero_enhanced)

df_pyannote_original.to_csv("pyannote_original_vad_results.csv", index=False)
df_pyannote_enhanced.to_csv("pyannote_enhanced_vad_results.csv", index=False)
df_silero_original.to_csv("silero_original_vad_results.csv", index=False)
df_silero_enhanced.to_csv("silero_enhanced_vad_results.csv", index=False)

print("All 4 CSV files have been saved successfully.")
