## Load Libraries

In [3]:
import numpy as np
import pandas as pd
import soundfile as sf
from tqdm import tqdm
from IPython.display import Audio
from pyannote.core import Annotation, Segment
from pyannote.metrics.diarization import DiarizationErrorRate
import re


## Work on CAM++ Diarization file

In [4]:
# Read can_plus csv file
can_plus = pd.read_csv('Diarization results/can_plus_diarized_output.csv')
sortformer_diarization = pd.read_csv('Diarization results/sortformer_diarization.csv')

# Add 'audio_keep' column to can_plus if 'audio_id' in can_plus exists in sortformer_diarization
can_plus['audio_keep'] = can_plus['audio_id'].isin(sortformer_diarization['audio_id'])

# Copy 'transcript' from sortformer_diarization to can_plus where 'audio_id' matches
can_plus['transcript'] = can_plus['audio_id'].map(sortformer_diarization.set_index('audio_id')['transcript'])

# Filter can_plus to keep only rows where 'audio_keep' is True
can_plus = can_plus[can_plus['audio_keep']].reset_index(drop=True)





In [5]:
# Remove the {'text': ...} wrapper if present in each pred_segment
for i, row in can_plus.iterrows():
    seg = row['pred_segment']
    # If seg is a string, try to eval to dict/list
    if isinstance(seg, str):
        try:
            seg = eval(seg)
        except Exception:
            seg = seg
    # If seg is a dict with 'text' key, extract it
    if isinstance(seg, dict) and 'text' in seg:
        segments_list = seg['text']
    elif isinstance(seg, list):
        segments_list = seg
    else:
        segments_list = []
    formatted_segments = []
    for s in segments_list:
        # If segment is a list/tuple of length 3, convert to tuple
        if isinstance(s, (list, tuple)) and len(s) == 3:
            start, end, speaker = s
            # Convert numpy types to float/int/str
            try:
                start = float(start)
                end = float(end)
                speaker = str(speaker)
            except Exception:
                pass
            formatted_segments.append((start, end, speaker))
        # If segment is a string, try to split
        elif isinstance(s, str):
            parts = s.split()
            if len(parts) >= 3:
                formatted_segments.append((float(parts[0]), float(parts[1]), parts[2]))
    formatted_segments.sort(key=lambda x: x[0])
    can_plus.at[i, 'pred_segment'] = formatted_segments

In [6]:
def convert_time_to_seconds(timestamp):
    # Split the timestamp into minutes, seconds, and milliseconds
    minutes, seconds, milliseconds = map(float, timestamp.split(':'))
    # Convert the time to seconds (including fractional part from milliseconds)
    total_seconds = minutes * 60 + seconds + milliseconds / 1000
    return total_seconds


def extract_segments(transcript):
    # Regular expression to match the timestamp and speaker tag
    timestamp_pattern = r'(\d{2}:\d{2}:\d{2})'
    speaker_pattern = r'\[([^\]]+)\]'

    lines = transcript.strip().splitlines()
    segments = []

    start_time = None
    speaker_tag = None

    for i in range(len(lines)):
        if re.match(timestamp_pattern, lines[i]):  # Line is a timestamp
            if start_time and speaker_tag:
                # If we have both start and speaker, the current timestamp is the end time
                end_time = convert_time_to_seconds(lines[i])
                segments.append((start_time, end_time, speaker_tag))
                start_time = None
                speaker_tag = None
            # Set the new start time, converting to seconds
            start_time = convert_time_to_seconds(lines[i])
        elif re.match(speaker_pattern, lines[i]):  # Line contains a speaker tag
            speaker_tag = re.findall(speaker_pattern, lines[i])[0]

    return segments

In [7]:
#ensure new line before speaker tags
can_plus['transcript'] = can_plus['transcript'].apply(lambda x: str(x).replace('[', '\r\n['))
can_plus['ref_segments'] = can_plus['transcript'].apply(lambda x: extract_segments(x))

# Save the DataFrame to a CSV file
can_plus.to_csv('/home/kelechi/Dialect-Classification/Diarization results/CAM_Plus_plus_diarization.csv', index=False)

## Read other model files

In [8]:
# Read csv files
assemblyai = pd.read_csv('Diarization results/assemblyai_diarization_der_0.1272_30.csv')
deepgram = pd.read_csv('Diarization results/deepgram_diarization_der_0.1421_30.csv')
sortformer = pd.read_csv('Diarization results/sortformer_diarization.csv')
pyannote = pd.read_csv('Diarization results/pyannote_diarization_der_0.2130_30.csv')
soniox = pd.read_csv('Diarization results/soniox_diarization_der_0.2005_30.csv')
reverb = pd.read_csv('Diarization results/reverb_diarization_der_0.2687_30.csv')
cam = pd.read_csv('Diarization results/CAM_Plus_plus_diarization.csv')

# Select only rows [0:31] of reverb
reverb = reverb.iloc[0:31]
# Filter out rows where 'audio_id' is nan in reverb
reverb = reverb[reverb['audio_id'].notna()]

#Rename 'pred_segment' to 'pred_segments' in cam
cam.rename(columns={'pred_segment': 'pred_segments'}, inplace=True)

## Edit speaker tagging in pred_segments

In [9]:
import ast

def parse_segments(cell):
    # If already a list of tuples, return as is
    if isinstance(cell, list) and all(isinstance(x, tuple) for x in cell):
        return cell
    # If it's a list of single characters, join and eval
    if isinstance(cell, list):
        cell = ''.join(cell)
    # If it's a string, eval
    if isinstance(cell, str):
        try:
            return ast.literal_eval(cell)
        except Exception:
            return []
    return []

def map_speaker_labels(segments):
    mapped = []
    for seg in segments:
        if len(seg) == 3:
            start, end, speaker = seg
            if speaker == '0':
                speaker = 'Speaker A'
            elif speaker == '1':
                speaker = 'Speaker B'
            mapped.append((start, end, speaker))
        else:
            mapped.append(seg)
    return mapped

# Apply both functions to the column
cam['pred_segments'] = cam['pred_segments'].apply(parse_segments).apply(map_speaker_labels)

## Check if audio_id is same for all files

In [10]:
# Verify that 'audio_id' columns are the same across all dataframes, if successful print a message
def verify_audio_ids(*dfs):
    audio_ids = [set(df['audio_id']) for df in dfs]
    if not all(audio_ids[0] == audio_id for audio_id in audio_ids):
        raise ValueError("Audio IDs do not match across all dataframes.")
        # Print all mismatched audio IDs
    else:
        print("All audio IDs match across the dataframes.")
verify_audio_ids(assemblyai, deepgram, sortformer, pyannote, soniox, reverb)

All audio IDs match across the dataframes.


## DER Matrics

In [11]:
def create_pyannote_annotation(segments_list):
    annotation = Annotation()
    for start, end, speaker_tag in segments_list:
        segment = Segment(start, end)
        annotation[segment] = speaker_tag
    return annotation

der_metric = DiarizationErrorRate()

## Calculate absolute DER for all domains

In [12]:
def compute_der_for_dataset(df, ref_col='ref_segments', pred_col='pred_segments'):
    results = []
    der_metric = DiarizationErrorRate()
    for i, row in tqdm(df.iterrows(), total=len(df), desc="Processing"):
        if not (isinstance(row[ref_col], (str, list)) and isinstance(row[pred_col], (str, list))):
            continue
        ref_annotation = create_pyannote_annotation(eval(row[ref_col]) if isinstance(row[ref_col], str) else row[ref_col])
        pred_annotation = create_pyannote_annotation(eval(row[pred_col]) if isinstance(row[pred_col], str) else row[pred_col])
        der = der_metric(ref_annotation, pred_annotation)
        results.append({'audio_id': row['audio_id'], 'DER': der})
    abs_der = abs(der_metric)
    print(f"Absolute DER for dataset: {100 * abs_der:.2f}%")
    return pd.DataFrame(results), abs_der

datasets = {
    'assemblyai': assemblyai,
    'deepgram': deepgram,
    'sortformer': sortformer,
    'pyannote': pyannote,
    'soniox': soniox,
    'reverb': reverb,
    'cam': cam,
}

der_results_all = {}
abs_ders_all = {}

for name, df in datasets.items():
    print(f"\nProcessing {name} (ALL DOMAIN)...")
    der_df, abs_der = compute_der_for_dataset(df)
    der_results_all[name] = der_df
    abs_ders_all[name] = abs_der

abs_der_df_all = pd.DataFrame.from_dict(abs_ders_all, orient='index', columns=['Absolute DER (All Domain)']).reset_index().rename(columns={'index': 'model'})



Processing assemblyai (ALL DOMAIN)...


Processing: 100%|██████████| 30/30 [00:00<00:00, 132.64it/s]


Absolute DER for dataset: 12.72%

Processing deepgram (ALL DOMAIN)...


Processing: 100%|██████████| 30/30 [00:00<00:00, 128.55it/s]


Absolute DER for dataset: 14.21%

Processing sortformer (ALL DOMAIN)...


Processing: 100%|██████████| 30/30 [00:00<00:00, 38.42it/s]


Absolute DER for dataset: 26.82%

Processing pyannote (ALL DOMAIN)...


Processing: 100%|██████████| 30/30 [00:00<00:00, 147.51it/s]


Absolute DER for dataset: 21.30%

Processing soniox (ALL DOMAIN)...


Processing: 100%|██████████| 30/30 [00:00<00:00, 180.19it/s]


Absolute DER for dataset: 20.05%

Processing reverb (ALL DOMAIN)...


Processing: 100%|██████████| 30/30 [00:00<00:00, 148.82it/s]


Absolute DER for dataset: 20.23%

Processing cam (ALL DOMAIN)...


Processing: 100%|██████████| 30/30 [00:00<00:00, 113.57it/s]

Absolute DER for dataset: 19.58%





## Absolute DER for Medical Domain Datasets

In [13]:
def filter_by_domain(df, domain='OSCE-Doctor-Patient'):
    return df[df['domain'] == domain].reset_index(drop=True)

filtered_datasets_medical = {name: filter_by_domain(df) for name, df in datasets.items()}

der_results_medical = {}
abs_ders_medical = {}
medical_audio = []

for name, df in filtered_datasets_medical.items():
    print(f"\nProcessing {name} (MEDICAL DOMAIN)...")
    der_df, abs_der = compute_der_for_dataset(df)
    der_df['model'] = name
    medical_audio.append(der_df)
    der_results_medical[name] = der_df
    abs_ders_medical[name] = abs_der

medical_audio = pd.concat(medical_audio, ignore_index=True)
abs_der_df_medical = pd.DataFrame.from_dict(abs_ders_medical, orient='index', columns=['Medical Absolute DER']).reset_index().rename(columns={'index': 'model'})



Processing assemblyai (MEDICAL DOMAIN)...


Processing: 100%|██████████| 9/9 [00:00<00:00, 62.15it/s]


Absolute DER for dataset: 25.68%

Processing deepgram (MEDICAL DOMAIN)...


Processing: 100%|██████████| 9/9 [00:00<00:00, 71.46it/s]


Absolute DER for dataset: 29.35%

Processing sortformer (MEDICAL DOMAIN)...


Processing: 100%|██████████| 9/9 [00:00<00:00, 33.47it/s]


Absolute DER for dataset: 39.64%

Processing pyannote (MEDICAL DOMAIN)...


Processing: 100%|██████████| 9/9 [00:00<00:00, 78.39it/s]


Absolute DER for dataset: 31.46%

Processing soniox (MEDICAL DOMAIN)...


Processing: 100%|██████████| 9/9 [00:00<00:00, 87.07it/s]


Absolute DER for dataset: 42.16%

Processing reverb (MEDICAL DOMAIN)...


Processing: 100%|██████████| 9/9 [00:00<00:00, 78.83it/s]


Absolute DER for dataset: 31.46%

Processing cam (MEDICAL DOMAIN)...


Processing: 100%|██████████| 9/9 [00:00<00:00, 66.90it/s]

Absolute DER for dataset: 34.64%





## Absolute DER for Non-Medical Domain Files

In [15]:
# Filter all datasets for 'Chit-Chat-NG' domain
def filter_by_domain(df, domain='OSCE-Doctor-Patient'):
    return df[df['domain'] != domain].reset_index(drop=True)

filtered_datasets_non_medical = {name: filter_by_domain(df) for name, df in datasets.items()}

der_results_non_medical = {}
abs_ders_non_medical = {}
non_medical_audio = []

for name, df in filtered_datasets_non_medical.items():
    print(f"\nProcessing {name} (NON-MEDICAL DOMAIN)...")
    der_df, abs_der = compute_der_for_dataset(df)
    der_df['model'] = name
    non_medical_audio.append(der_df)
    der_results_non_medical[name] = der_df
    abs_ders_non_medical[name] = abs_der

non_medical_audio = pd.concat(non_medical_audio, ignore_index=True)
abs_der_df_non_medical = pd.DataFrame.from_dict(abs_ders_non_medical, orient='index', columns=['Non-Medical Absolute DER']).reset_index().rename(columns={'index': 'model'})



Processing assemblyai (NON-MEDICAL DOMAIN)...


Processing: 100%|██████████| 21/21 [00:00<00:00, 235.51it/s]


Absolute DER for dataset: 9.91%

Processing deepgram (NON-MEDICAL DOMAIN)...


Processing: 100%|██████████| 21/21 [00:00<00:00, 206.38it/s]


Absolute DER for dataset: 10.92%

Processing sortformer (NON-MEDICAL DOMAIN)...


Processing: 100%|██████████| 21/21 [00:00<00:00, 41.14it/s]


Absolute DER for dataset: 24.04%

Processing pyannote (NON-MEDICAL DOMAIN)...


Processing: 100%|██████████| 21/21 [00:00<00:00, 230.91it/s]


Absolute DER for dataset: 19.09%

Processing soniox (NON-MEDICAL DOMAIN)...


Processing: 100%|██████████| 21/21 [00:00<00:00, 327.16it/s]


Absolute DER for dataset: 15.24%

Processing reverb (NON-MEDICAL DOMAIN)...


Processing: 100%|██████████| 21/21 [00:00<00:00, 237.66it/s]


Absolute DER for dataset: 17.68%

Processing cam (NON-MEDICAL DOMAIN)...


Processing: 100%|██████████| 21/21 [00:00<00:00, 159.48it/s]

Absolute DER for dataset: 16.30%





## View all doamin, medical doamain and general domain files

In [16]:
# --- DISPLAY ---
# print("\nAbsolute DERs for ALL DOMAIN datasets:")
# print(abs_der_df_all)

# print("\nAbsolute DERs for MEDICAL DOMAIN datasets:")
# print(abs_der_df_medical)

# print("\nAbsolute DERs for NON-MEDICAL DOMAIN datasets:")
# print(abs_der_df_non_medical)

# Print a joint DataFrame with all absolute DERs, including all domains
all_abs_der_df = pd.merge(abs_der_df_all, abs_der_df_medical, on='model', how='outer')
all_abs_der_df = pd.merge(all_abs_der_df, abs_der_df_non_medical, on='model', how='outer')
print("\nAll Absolute DERs across domains:")
print(all_abs_der_df)

# Save all absolute DERs to a CSV file
all_abs_der_df.to_csv('Diarization results/absolute_der_all_domains_and_models.csv', index=False)


All Absolute DERs across domains:
        model  Absolute DER (All Domain)  Medical Absolute DER  \
0  assemblyai                   0.127220              0.256756   
1         cam                   0.195766              0.346428   
2    deepgram                   0.142089              0.293510   
3    pyannote                   0.213007              0.314621   
4      reverb                   0.202347              0.314621   
5      soniox                   0.200471              0.421604   
6  sortformer                   0.268240              0.396392   

   Non-Medical Absolute DER  
0                  0.099077  
1                  0.163033  
2                  0.109191  
3                  0.190930  
4                  0.176771  
5                  0.152428  
6                  0.240398  
