In [None]:
import subprocess
from pyannote.database import registry, FileFinder

In [None]:
registry.load_database("/work/proy/AMI-diarization-setup/pyannote/database.yml")
dataset = registry.get_protocol("AMI-SDM.SpeakerDiarization.mini", {"audio": FileFinder()})

In [None]:
import subprocess
from pathlib import Path

#Function to convert to WAV if necessary
def ensure_wav(audio_path):
    # Convert to a Path object (if it's not already one)
    audio_path = Path(audio_path)
    
    if audio_path.suffix != '.wav':  # Check if the file extension is not '.wav'
        # Define the output path
        wav_path = audio_path.with_suffix('.wav')
        # Convert to WAV using ffmpeg
        subprocess.call(['ffmpeg', '-i', str(audio_path), str(wav_path), '-y'])
        return wav_path
    return audio_path

# Iterate through the dataset and ensure all files are WAV
for file in dataset.test():
    audio_path = file['audio']
    wav_audio_path = ensure_wav(audio_path)
    print(f'Processed file path: {wav_audio_path}')


In [None]:
num_speakers = 4 #@param {type:"integer"}

language = 'English' #@param ['any', 'English']

model_size = 'large' #@param ['tiny', 'base', 'small', 'medium', 'large']


model_name = model_size
if language == 'English' and model_size != 'large':
  model_name += '.en'

In [None]:
!pip install -q git+https://github.com/openai/whisper.git > /dev/null
!pip install -q git+https://github.com/pyannote/pyannote-audio > /dev/null

In [None]:
import whisper
import datetime

import subprocess

import torch
import pyannote.audio
from pyannote.audio.pipelines.speaker_verification import PretrainedSpeakerEmbedding
embedding_model = PretrainedSpeakerEmbedding(
    "speechbrain/spkrec-ecapa-voxceleb",
    device=torch.device("cuda:2"))

from pyannote.audio import Audio
from pyannote.core import Segment

import wave
import contextlib

from sklearn.cluster import AgglomerativeClustering
import numpy as np
from sklearn.decomposition import PCA
import matplotlib.pyplot as plt
import os
from pyannote.core import Segment, Annotation  # ADDITION
from pyannote.metrics.diarization import DiarizationErrorRate  # ADDITION
from sklearn.metrics import adjusted_rand_score, normalized_mutual_info_score  # ADDITION
from sklearn.metrics import confusion_matrix  # ADDITION
from scipy.optimize import linear_sum_assignment  # ADDITION


In [None]:
device = torch.device("cuda:2")
model = whisper.load_model(model_size)
model.to(device)

In [None]:
def get_audio_duration(wav_audio_path):
    with contextlib.closing(wave.open(str(wav_audio_path), 'r')) as f:
        frames = f.getnframes()
        rate = f.getframerate()
        duration = frames / float(rate)
    return duration

In [None]:
def transcribe_audio(wav_audio_path):
    result = model.transcribe(str(wav_audio_path))
    return result["segments"]

In [None]:
audio = Audio()
def segment_embedding(segment, wav_audio_path, duration):
    start = segment["start"]
    end = min(duration, segment["end"])  # Ensure the end time is within file bounds
    clip = Segment(start, end)
    waveform, sample_rate = audio.crop(wav_audio_path, clip)
    embedding = embedding_model(waveform[None])  # Extract embedding
    return embedding.squeeze()  # Remove any extra dimensions, if any

In [None]:
def perform_clustering(embeddings, num_speakers=4):
    clustering = AgglomerativeClustering(
        n_clusters=num_speakers,
        metric='euclidean',  # MODIFICATION: Use 'cosine' affinity
        linkage='average'
    ).fit(embeddings)
    return clustering.labels_

# from sklearn.cluster import SpectralClustering

# def perform_clustering(embeddings, num_speakers=4):
#     clustering = SpectralClustering(
#         n_clusters=4,  # Number of speakers
#         affinity='nearest_neighbors',
#         assign_labels='kmeans',
#         random_state=42
#         ).fit(embeddings)
#     return clustering.labels_


In [None]:
def plot_pca(embeddings, labels, num_speakers, title="Speaker Diarization Clusters (PCA Visualization)"):
    # Perform PCA to reduce the dimensionality of embeddings to 2D
    pca = PCA(n_components=2, random_state=42)
    embeddings_2d = pca.fit_transform(embeddings)

    # Plot the clusters
    plt.figure(figsize=(8, 6))

    # Create a color palette for speakers
    unique_labels = np.unique(labels)
    colors = plt.cm.get_cmap('tab10', len(unique_labels))

    # Loop over each label and plot the embeddings
    for idx, label in enumerate(unique_labels):
        # Select points belonging to the current label
        indices = np.where(labels == label)
        label_embeddings = embeddings_2d[indices]

        # Plot points for the current label
        plt.scatter(
            label_embeddings[:, 0],
            label_embeddings[:, 1],
            label=f'Label {label}',
            color=colors(idx % 10),
            alpha=0.7
        )

    # Add title, labels, and legend
    plt.title(title)
    plt.xlabel("Principal Component 1")
    plt.ylabel("Principal Component 2")
    plt.legend(loc='best', title="Labels")
    plt.show()

In [None]:
# Main processing loop for each file in the test set
for file in dataset.test():
    # Ensure WAV file for the current audio
    wav_audio_path = ensure_wav(file['audio'])
    print(f"Processing file: {wav_audio_path}")
    
    # Get the audio duration
    duration = get_audio_duration(wav_audio_path)
    
    # Transcribe the audio file using Whisper
    segments = transcribe_audio(wav_audio_path)

     # Initialize the embeddings array and lists for labels
    embeddings = np.zeros(shape=(len(segments), 192))  # Assuming 192 is the embedding size
    true_labels = []  # ADDITION: To store ground truth labels for each segment

    # Get ground truth annotation
    ground_truth = file['annotation']
    
    # Process each segment and extract embeddings and true labels
    for i, segment in enumerate(segments):
        embedding = segment_embedding(segment, wav_audio_path, duration)
        embeddings[i] = embedding

        # Create a Segment object for the current segment
        seg = Segment(segment['start'], segment['end'])

        # Get the speaker(s) from ground truth that overlap with the segment
        speakers = ground_truth.crop(seg).labels()

        # Handle cases where there might be multiple or no speakers
        if speakers:
            true_labels.append(speakers[0])  # Take the first speaker
        else:
            true_labels.append('Unknown')

    # Replace NaN values with 0 in embeddings (useful for handling missing data)
    embeddings = np.nan_to_num(embeddings)
    
    # Map ground truth labels to integers
    unique_speakers = list(set(true_labels))
    speaker_to_int = {speaker: idx for idx, speaker in enumerate(unique_speakers)}
    true_labels_int = np.array([speaker_to_int[speaker] for speaker in true_labels])

    # Determine the number of unique speakers
    num_speakers = len(unique_speakers)
    print(f"Number of speakers in ground truth: {num_speakers}")
    
    # Perform clustering to assign speaker labels
    labels = perform_clustering(embeddings, num_speakers=num_speakers)  # MODIFICATION
    
    # Compute clustering metrics before mapping
    ari = adjusted_rand_score(true_labels_int, labels)
    nmi = normalized_mutual_info_score(true_labels_int, labels)

    print(f"Adjusted Rand Index (ARI) before mapping: {ari:.4f}")
    print(f"Normalized Mutual Information (NMI) before mapping: {nmi:.4f}")
    
    # Mapping algorithm (Hungarian algorithm)
    confusion = confusion_matrix(true_labels_int, labels)

    # Apply the Hungarian algorithm
    row_ind, col_ind = linear_sum_assignment(-confusion)

    # Create a mapping from predicted labels to true labels
    label_mapping = {col_ind[i]: row_ind[i] for i in range(len(col_ind))}

    # Map the predicted labels to the ground truth labels
    mapped_labels = np.array([label_mapping[label] for label in labels])

    # Compute clustering metrics after mapping
    ari_mapped = adjusted_rand_score(true_labels_int, mapped_labels)
    nmi_mapped = normalized_mutual_info_score(true_labels_int, mapped_labels)

    print(f"Adjusted Rand Index (ARI) after mapping: {ari_mapped:.4f}")
    print(f"Normalized Mutual Information (NMI) after mapping: {nmi_mapped:.4f}")
    
    # Plot PCA visualization with predicted labels
    plot_pca(embeddings, labels, num_speakers=num_speakers, title=f"PCA Clustering for {wav_audio_path.stem} (Predicted Labels)")

    # Plot PCA visualization with mapped labels
    plot_pca(embeddings, mapped_labels, num_speakers=num_speakers, title=f"PCA Clustering for {wav_audio_path.stem} (Mapped Labels)")

    # Plot PCA visualization with ground truth labels
    plot_pca(embeddings, true_labels_int, num_speakers=num_speakers, title=f"PCA Clustering for {wav_audio_path.stem} (Ground Truth Labels)")

    # Assign the speaker labels to each segment using mapped labels
    for i in range(len(segments)):
        segments[i]["speaker"] = f'SPEAKER {mapped_labels[i] + 1}'  # MODIFICATION

    # Save the transcript with speaker labels and timestamps
    transcript_file_path = f"{wav_audio_path.stem}_transcript.txt"
    with open(transcript_file_path, "w") as f:
        for i, segment in enumerate(segments):
            if i == 0 or segments[i - 1]["speaker"] != segment["speaker"]:
                f.write(f"\n{segment['speaker']} {str(datetime.timedelta(seconds=round(segment['start'])))}\n")
            f.write(f"{segment['text']} ")
    
    print(f"Saved transcript and embeddings for {wav_audio_path.stem}")
    
    

#Calculating DER

In [None]:
for file in dataset.test():
    print(file)
    print(file.__dict__)

In [None]:
for file in dataset.test():
    ground_truth = file['annotation']  # This is the ground truth annotation
    print(f"Ground truth annotation for file {file['uri']}:")
    print(ground_truth)

In [None]:
# def merge_segments(segments, labels):
#     merged_segments = []
#     merged_labels = []

#     start, end, current_label = segments[0]['start'], segments[0]['end'], labels[0]
#     for i in range(1, len(segments)):
#         if labels[i] == current_label:
#             end = segments[i]['end']
#         else:
#             merged_segments.append({'start': start, 'end': end})
#             merged_labels.append(current_label)
#             start, end, current_label = segments[i]['start'], segments[i]['end'], labels[i]

#     merged_segments.append({'start': start, 'end': end})
#     merged_labels.append(current_label)

#     return merged_segments, merged_labels

In [None]:
# def convert_predictions_to_annotation(merged_segments, merged_labels):
#     """
#     Convert merged segments and labels to pyannote.core.annotation.Annotation.
#     Args:
#         merged_segments (list): List of merged segment dictionaries with 'start' and 'end' keys.
#         merged_labels (list): Cluster labels for each merged segment.
#     Returns:
#         Annotation: Predicted annotation.
#     """
#     annotation = Annotation()
#     for segment, label in zip(merged_segments, merged_labels):
#         annotation[Segment(segment['start'], segment['end'])] = f"SPEAKER {label + 1}"
#     return annotation


In [None]:
for file in dataset.test():
    ground_truth = file['annotation']  # Ground truth annotation

    # Create predicted annotation using mapped labels
    predicted_annotation = Annotation()

    for segment, label in zip(segments, mapped_labels):
        predicted_annotation[Segment(segment['start'], segment['end'])] = f"SPEAKER {label + 1}"

    # Evaluate DER
    metric = DiarizationErrorRate()
    der = metric(ground_truth, predicted_annotation)
    print(f"DER for {file['uri']}: {der:.2%}")

    # Detailed DER breakdown
    detailed_der = metric(ground_truth, predicted_annotation, detailed=True)
    print(f"False Alarm: {detailed_der['false alarm']:.2%}")
    print(f"Missed Detection: {detailed_der['missed detection']:.2%}")
    print(f"Confusion: {detailed_der['confusion']:.2%}")

In [None]:
from pyannote.core import notebook

notebook.crop = Segment(0,60)  # Focus on a smaller time window for analysis
notebook.plot_annotation(ground_truth, legend=True)
notebook.plot_annotation(predicted_annotation, legend=True)


In [None]:
print(f"Ground truth extent: {ground_truth.get_timeline().extent()}")
print(f"Prediction extent: {predicted_annotation.get_timeline().extent()}")

In [None]:
# Compute the intersection between ground truth and predicted timelines
overlap = ground_truth.get_timeline().crop(predicted_annotation.get_timeline(), mode='intersection')

# Calculate overlap coverage
overlap_duration = overlap.duration()
prediction_duration = predicted_annotation.get_timeline().duration()

coverage = overlap_duration / prediction_duration if prediction_duration > 0 else 0
print(f"Overlap coverage: {coverage:.2%}")
