In [1]:
import lhotse
from data.local_datasets import build_dataset, TS_ASR_Dataset, TS_ASR_Random_Dataset, DataCollator, get_text_norm, TS_ASR_HEAT_Dataset, LhotseLongFormDataset
import matplotlib.pyplot as plt
import numpy as np
from IPython.display import Audio
from tqdm import tqdm
from txt_norm import get_text_norm
from transformers.models.whisper import WhisperFeatureExtractor
from tqdm import tqdm
from collections import namedtuple
from typing import List, Tuple
from lhotse import SupervisionSegment

from pyannote.metrics.diarization import DiarizationErrorRate
from pyannote.core import Annotation, Segment

In [2]:
class dotdict(dict):
    """dot.notation access to dictionary attributes"""
    __getattr__ = dict.get
    __setattr__ = dict.__setitem__
    __delattr__ = dict.__delitem__

In [None]:
data_args = {
    'train_cutsets': ['/export/fs06/xhe69/TS-ASR-Whisper/data/manifests/ami-sdm_train_sc_cutset_30s.jsonl.gz'],
    # 'train_cutsets': ['/export/fs06/xhe69/TS-ASR-Whisper/data/manifests/ami-sdm_train_sc_cutset.jsonl.gz'],
    # 'dev_cutsets': ['/export/fs06/xhe69/TS-ASR-Whisper/data/manifests/ami-sdm_dev_sc_cutset.jsonl.gz'],
    'eval_cutsets': ['/export/fs06/xhe69/TS-ASR-Whisper/data/manifests/ami-sdm_test_sc_cutset.jsonl.gz'],
    'eval_diar_cutsets': ['/export/fs06/xhe69/TS-ASR-Whisper/diar_exp/diarizen_large/diarized_cutsets/ami-sdm_test_sc_cutset.jsonl.gz'],
    'dev_cutsets':['/export/fs06/xhe69/TS-ASR-Whisper/data/manifests/ami-sdm_test_sc_cutset.jsonl.gz'],
    'dev_diar_cutsets':['/export/fs06/xhe69/TS-ASR-Whisper/diar_exp/diarizen_large/diarized_cutsets/ami-sdm_test_sc_cutset.jsonl.gz'],
    'do_augment': False,
    'dataset_weights': None,
    'use_timestamps': True,
    'musan_noises': None,
    'train_text_norm': "whisper_nsf",
    'empty_transcripts_ratio': 0.0,
    'train_with_diar_outputs': None,
    'audio_path_prefix': None,
    'audio_path_prefix_replacement': None,
    'vad_from_alignments': False,
    'random_sentence_l_crop_p': 0.0,
    'random_sentence_r_crop_p': 0.0,
    'max_l_crop': 0,
    'max_r_crop': 0,
    'cache_features_for_dev': False,
    'eval_text_norm': 'whisper_nsf',
    'dev_diar_cutsets': None,
    'use_heat_diar': False,
    # 'oracle_heat_assignment_method': 'keepchannel',
    # 'num_heat_channels': 2,
    'diar_type': 'hard',
}

data_args = dotdict(data_args)
train_cutsets = [lhotse.load_manifest(cutset) for cutset in data_args.train_cutsets]

In [4]:
def plot_diarization(diarization_array, frame_rate=1, speaker_labels=None, title="Speaker Diarization"):
    """
    Visualize diarization output as a heatmap.

    Parameters:
    - diarization_array: np.ndarray of shape (num_speakers, num_frames)
    - frame_rate: int or float, number of frames per second (for x-axis time labeling)
    - speaker_labels: list of str, custom labels for speakers (optional)
    - title: str, plot title
    """
    num_speakers, num_frames = diarization_array.shape

    plt.figure(figsize=(15, num_speakers * 0.5 + 1))
    plt.imshow(diarization_array, aspect='auto', interpolation='none', cmap='Greys')
    plt.colorbar(label="Activity (1=active, 0=inactive)")

    # X-axis: time in seconds
    x_ticks = np.linspace(0, num_frames - 1, 10, dtype=int)
    x_labels = (x_ticks / frame_rate).round(2)
    plt.xticks(ticks=x_ticks, labels=x_labels)
    plt.xlabel("Time (s)")

    # Y-axis: speakers
    if speaker_labels is None:
        speaker_labels = [f"Speaker {i}" for i in range(num_speakers)]
    plt.yticks(ticks=np.arange(num_speakers), labels=speaker_labels)

    plt.title(title)
    # plt.ylabel("Speakers")
    plt.tight_layout()
    plt.show()

In [None]:
dec_args = {
    'soft_vad_temp': None,
}
container = {
    'feature_extractor': WhisperFeatureExtractor.from_pretrained('openai/whisper-large-v3-turbo'),
}
dec_args = dotdict(dec_args)
container = dotdict(container)

text_norm = get_text_norm(data_args.eval_text_norm)

dev_dataset = build_dataset(data_args.eval_cutsets, data_args, dec_args, text_norm, container,
                                data_args.eval_diar_cutsets)
data_args_oracle = data_args.copy()
data_args_oracle['diar_type'] = 'oracle'
data_args_oracle = dotdict(data_args_oracle)
dev_dataset_oracle = build_dataset(data_args_oracle.eval_cutsets, data_args_oracle, dec_args, text_norm, container,
                                data_args_oracle.eval_diar_cutsets)

loading configuration file preprocessor_config.json from cache at /home/xhe69/.cache/huggingface/hub/models--openai--whisper-large-v3-turbo/snapshots/41f01f3fe87f28c78e2fbf8b568835947dd65ed9/preprocessor_config.json
Feature extractor WhisperFeatureExtractor {
  "chunk_length": 30,
  "feature_extractor_type": "WhisperFeatureExtractor",
  "feature_size": 128,
  "hop_length": 160,
  "n_fft": 400,
  "n_samples": 480000,
  "nb_max_frames": 3000,
  "padding_side": "right",
  "padding_value": 0.0,
  "processor_class": "WhisperProcessor",
  "return_attention_mask": false,
  "sampling_rate": 16000
}

Using LhotseLongFormDataset
Using LhotseLongFormDataset


In [None]:
print(dev_dataset.diar_type)
print(dev_dataset_oracle.diar_type)

True
False


In [9]:
def make_annotation(activity, frame_duration=0.1):
    num_speakers, num_frames = activity.shape
    
    annotation = Annotation()
    for s in range(num_speakers):
        label = f"speaker{s}"
        active = activity[s]
        start_frame = None
        for t, v in enumerate(active):
            if v == 1 and start_frame is None:
                start_frame = t
            elif v == 0 and start_frame is not None:
                start_time = start_frame * frame_duration
                end_time = t * frame_duration
                annotation[Segment(start_time, end_time)] = label
                start_frame = None
        # flush if ends in active
        if start_frame is not None:
            start_time = start_frame * frame_duration
            end_time = num_frames * frame_duration
            annotation[Segment(start_time, end_time)] = label
        
    return annotation


In [10]:
# diarizationErrorRate = DiarizationErrorRate(collar=0.25)
processed_ids = set()
for batch_idx, (batch, batch_oracle) in tqdm(enumerate(zip(dev_dataset, dev_dataset_oracle)), total=len(dev_dataset)):
    diarizationErrorRate = DiarizationErrorRate(collar=0.25)
    speaker_id = batch['transcript'].split(',')[0]
    if speaker_id in processed_ids:
        continue
    processed_ids.add(speaker_id)
    
    hyp_diar = batch['spk_mask']
    target_diar = batch_oracle['spk_mask']
    
    reference = make_annotation(target_diar)
    hypothesis = make_annotation(hyp_diar)

    der = diarizationErrorRate(reference, hypothesis)
    # global_value = abs(diarizationErrorRate)
    print(speaker_id, der)

# overall_der = diarizationErrorRate.compute_metric()
# print(overall_der)

  1%|██▌                                                                                                                                                                             | 1/69 [00:56<1:04:10, 56.62s/it]

TS3003c-0 0.1234163221973015


  7%|████████████▉                                                                                                                                                                     | 5/69 [01:59<27:11, 25.49s/it]

IS1009b-1 0.1135787791460515


 14%|█████████████████████████▋                                                                                                                                                       | 10/69 [02:53<17:44, 18.03s/it]

EN2002b-2 0.15161682116601236


 20%|███████████████████████████████████▉                                                                                                                                             | 14/69 [03:24<11:17, 12.32s/it]

ES2004a-3 0.1769086147422933


 26%|██████████████████████████████████████████████▏                                                                                                                                  | 18/69 [04:31<18:51, 22.18s/it]

TS3003d-4 0.18100728317930712


 32%|████████████████████████████████████████████████████████▍                                                                                                                        | 22/69 [05:02<10:13, 13.05s/it]

IS1009a-5 0.15880582147547104


 39%|█████████████████████████████████████████████████████████████████████▎                                                                                                           | 27/69 [06:01<13:00, 18.58s/it]

EN2002a-6 0.1715271810588181


 45%|███████████████████████████████████████████████████████████████████████████████▌                                                                                                 | 31/69 [07:18<16:33, 26.15s/it]

ES2004b-7 0.11675064279297054


 54%|██████████████████████████████████████████████████████████████████████████████████████████████▉                                                                                  | 37/69 [08:26<10:24, 19.51s/it]

EN2002d-8 1.2057572232348275


 59%|█████████████████████████████████████████████████████████████████████████████████████████████████████████▏                                                                       | 41/69 [09:23<09:27, 20.25s/it]

IS1009d-9 0.9561271237606263


 65%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████▍                                                             | 45/69 [10:17<07:53, 19.72s/it]

TS3003b-10 1.019516022315432


 71%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████▋                                                   | 49/69 [11:13<06:41, 20.05s/it]

IS1009c-11 0.870826845453156


 77%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████▉                                         | 53/69 [12:25<06:38, 24.91s/it]

EN2002c-12 1.225895849244713


 83%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████▏                              | 57/69 [13:35<05:03, 25.31s/it]

ES2004c-13 0.998436551934226


 90%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████                  | 62/69 [14:31<02:09, 18.48s/it]

ES2004d-14 1.8535932266191573


 91%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████▌               | 63/69 [14:35<01:23, 13.90s/it]


{'session_id': 'IS1009b28', 'start_time': Dec
imal('660'), 'end_time': Decimal('690'), 'speaker': '0', 'words': 'of course and it is not to o bad for me be able do that because it is o bad for me be able do that because it is o bad for me be able do that becau
se it is o bad for me be able do that because it is o bad for me be able do that because it is o bad for me be able do that because it is o bad for me be able do that because it is o bad for me be able do that beca
use it is o bad for me be able do that because it is o bad for me be able do that because it is o bad for me be able do that because it is o bad for me be able do that because it is o bad for me be able do that bec
ause it is o bad for me be able do that because it is o bad for me be able do that because it is o bad for me be able do that because it is o bad for me be able do that'}, {'session_id': 'IS1009b28', 'start_time': 
Decimal('660'), 'end_time': Decimal('690'), 'speaker': '1', 'words': 'of course and it is not to o bad for me be able do that because it is o bad for me be able do that because it is o bad for me be able do that be
cause it is o bad for me be able do that because it is o bad for me be able do that because it is o bad for me be able do that because it is o bad for me be able do that because it is o bad for me be able do that b
ecause it is o bad for me be able do that because it is o bad for me be able do that because it is o bad for me be able do that because it is o bad for me be able do that because it is o bad for me be able do that 
because it is o bad for me be able do that because it is o bad for me be able do that because it is o bad for me be able do that because it is o bad for me be able do that'}, {'session_id': 'IS1009b28', 'start_time
': Decimal('660'), 'end_time': Decimal('690'), 'speaker': '2', 'words': 'of course and it is not to o bad for me be able do that because it is o bad for me be able do that because it is o bad for me be able do that
 because it is o bad for me be able do that because it is o bad for me be able do that because it is o bad for me be able do that because it is o bad for me be able do that because it is o bad for me be able do tha
t because it is o bad for me be able do that because it is o bad for me be able do that because it is o bad for me be able do that because it is o bad for me be able do that because it is o bad for me be able do th
at because it is o bad for me be able do that because it is o bad for me be able do that because it is o bad for me be able do that because it is o bad for me be able do that'}, {'session_id': 'IS1009b28', 'start_t
ime': Decimal('660'), 'end_time': Decimal('690'), 'speaker': '3', 'words': 'of course and it is not to o bad for me be able do that because it is o bad for me be able do that because it is o bad for me be able do t
hat because it is o bad for me be able do that because it is o bad for me be able do that because it is o bad for me be able do that because it is o bad for me be able do that because it is o bad for me be able do 
that because it is o bad for me be able do that because it is o bad for me be able do that because it is o bad for me be able do that because it is o bad for me be able do that because it is o bad for me be able do
 that because it is o bad for me be able do that because it is o bad for me be able do that because it is o bad for me be able do that because it is o bad for me be able do that'}