In [2]:
from data.local_datasets import build_dataset, TS_ASR_Dataset, TS_ASR_Random_Dataset, DataCollator, get_text_norm, TS_ASR_HEAT_Dataset
import lhotse
import matplotlib.pyplot as plt
import numpy as np
from IPython.display import Audio
from tqdm import tqdm
from txt_norm import get_text_norm
from transformers.models.whisper import WhisperFeatureExtractor

  from .autonotebook import tqdm as notebook_tqdm


In [3]:
class dotdict(dict):
    """dot.notation access to dictionary attributes"""
    __getattr__ = dict.get
    __setattr__ = dict.__setitem__
    __delattr__ = dict.__delitem__

In [4]:
data_args = {
    'train_cutsets': ['/export/fs06/xhe69/TS-ASR-Whisper/data/manifests/ami-sdm_train_sc_cutset_30s.jsonl.gz'],
    # 'train_cutsets': ['/export/fs06/xhe69/TS-ASR-Whisper/data/manifests/ami-sdm_train_sc_cutset.jsonl.gz'],
    'dev_cutsets': ['/export/fs06/xhe69/TS-ASR-Whisper/data/manifests/ami-sdm_dev_sc_cutset.jsonl.gz'],
    'eval_cutsets': ['/export/fs06/xhe69/TS-ASR-Whisper/data/manifests/ami-sdm_test_sc_cutset.jsonl.gz'],
    'do_augment': False,
    'dataset_weights': None,
    'use_timestamps': True,
    'musan_noises': None,
    'train_text_norm': "whisper_nsf",
    'empty_transcripts_ratio': 0.0,
    'train_with_diar_outputs': None,
    'audio_path_prefix': None,
    'audio_path_prefix_replacement': None,
    'vad_from_alignments': False,
    'random_sentence_l_crop_p': 0.0,
    'random_sentence_r_crop_p': 0.0,
    'max_l_crop': 0,
    'max_r_crop': 0,
    'cache_features_for_dev': False,
    'eval_text_norm': 'whisper_nsf',
    'dev_diar_cutsets': None,
    'use_heat_diar': True,
    'oracle_heat_assignment_method': 'keepchannel',
    'num_heat_channels': 2,
}

data_args = dotdict(data_args)


In [5]:
train_cutsets = [lhotse.load_manifest(cutset) for cutset in data_args.train_cutsets]

In [6]:
train_dataset = TS_ASR_HEAT_Dataset(train_cutsets, do_augment=data_args.do_augment,
                                        dataset_weights=data_args.dataset_weights,
                                        use_timestamps=data_args.use_timestamps,
                                        musan_noises=data_args.musan_noises,
                                        text_norm=get_text_norm(data_args.train_text_norm),
                                        empty_transcript_ratio=data_args.empty_transcripts_ratio,
                                        train_with_diar_outputs=data_args.train_with_diar_outputs,
                                        audio_path_prefix=data_args.audio_path_prefix,
                                        audio_path_prefix_replacement=data_args.audio_path_prefix_replacement,
                                        vad_from_alignments=data_args.vad_from_alignments,
                                        random_sentence_l_crop_p=data_args.random_sentence_l_crop_p,
                                        random_sentence_r_crop_p=data_args.random_sentence_r_crop_p,
                                        max_l_crop=data_args.max_l_crop,
                                        max_r_crop=data_args.max_r_crop,
                                        oracle_heat_assignment_method=data_args.oracle_heat_assignment_method,
                                        )

In [7]:
train_dataset[9735]['all_supervisions']

[SupervisionSegment(id='EN2009d-1868-0', recording_id='EN2009d', start=0.0, duration=0.44, channel=[0], text='and', language='English', speaker='FEE083', gender='F', custom={'end_': 0.98, 'text_': 'and then'}, alignment=None),
 SupervisionSegment(id='EN2009d-1152-0', recording_id='EN2009d', start=0.84, duration=0.08, channel=[0], text='i', language='English', speaker='FEE096', gender='F', custom={'end_': 28.58, 'text_': "i just wanna know it works the way you expect it to work this is not a a gesture of mistrust this is just experience that um if anything can screw up it will i i'm trying to choose my language carefully because we're being recorded but you'll hear more choice language at the point when we've all done a lot of work and and then we discover we can't use the session because of some thing we didn't think about"}, alignment=None),
 SupervisionSegment(id='EN2009d-2222-0', recording_id='EN2009d', start=5.37, duration=0.1, channel=[0], text='yeah', language='English', speaker=

In [8]:
def plot_diarization(diarization_array, frame_rate=1, speaker_labels=None, title="Speaker Diarization"):
    """
    Visualize diarization output as a heatmap.

    Parameters:
    - diarization_array: np.ndarray of shape (num_speakers, num_frames)
    - frame_rate: int or float, number of frames per second (for x-axis time labeling)
    - speaker_labels: list of str, custom labels for speakers (optional)
    - title: str, plot title
    """
    num_speakers, num_frames = diarization_array.shape

    plt.figure(figsize=(15, num_speakers * 0.5 + 1))
    plt.imshow(diarization_array, aspect='auto', interpolation='none', cmap='Greys')
    plt.colorbar(label="Activity (1=active, 0=inactive)")

    # X-axis: time in seconds
    x_ticks = np.linspace(0, num_frames - 1, 10, dtype=int)
    x_labels = (x_ticks / frame_rate).round(2)
    plt.xticks(ticks=x_ticks, labels=x_labels)
    plt.xlabel("Time (s)")

    # Y-axis: speakers
    if speaker_labels is None:
        speaker_labels = [f"Speaker {i}" for i in range(num_speakers)]
    plt.yticks(ticks=np.arange(num_speakers), labels=speaker_labels)

    plt.title(title)
    # plt.ylabel("Speakers")
    plt.tight_layout()
    plt.show()

In [9]:
train_dataset.oracle_heat_assignment_method

'keepchannel'

In [None]:
plot_idx = 154 #should be even, 5842 is example of 3 overlap
#186 good plot

plot_diarization(train_dataset[plot_idx]['heat_output'], title='Heat Output', speaker_labels=['Stream 1', 'Stream 2'])
plot_diarization(train_dataset[plot_idx]['vad_mask'], title='STNO Mask Stream 1', speaker_labels=['Silence', 'Target', 'NonTarget', 'Overlap'])
plot_diarization(train_dataset[plot_idx+1]['vad_mask'], title='STNO Mask Stream 2', speaker_labels=['Silence', 'Target', 'NonTarget', 'Overlap'])

print(train_dataset[plot_idx]['all_supervisions'])
print()
print(train_dataset[plot_idx]['heat_assignment'][0])
print()
print(train_dataset[plot_idx]['heat_assignment'][1])
print()
print(train_dataset[plot_idx]['transcript'])
print()
print(train_dataset[plot_idx+1]['transcript'])
# print(f'Transcript 1: {train_dataset[plot_idx]['transcript']}')
# print(f'Transcript 2: {train_dataset[plot_idx+1]['transcript']}')


In [96]:
overlap_indices = []

# for plot_idx in tqdm(range(0, len(train_dataset), 2)):
#     if 2 in train_dataset[plot_idx]['heat_output'] or 3 in train_dataset[plot_idx]['heat_output']:
#         overlap_indices.append(plot_idx)
#25.74% of ami-sdm samples has of >2 people overlapping but overlapping time is likely low 

0.257396449704142

In [16]:
dec_args = {
    'soft_vad_temp': None,
}
container = {
    'feature_extractor': WhisperFeatureExtractor.from_pretrained('openai/whisper-large-v3-turbo'),
}
dec_args = dotdict(dec_args)
container = dotdict(container)

text_norm = get_text_norm(data_args.eval_text_norm)

dev_dataset = build_dataset(data_args.dev_cutsets, data_args, dec_args, text_norm, container,
                                data_args.dev_diar_cutsets)

loading configuration file preprocessor_config.json from cache at /home/xhe69/.cache/huggingface/hub/models--openai--whisper-large-v3-turbo/snapshots/41f01f3fe87f28c78e2fbf8b568835947dd65ed9/preprocessor_config.json
Feature extractor WhisperFeatureExtractor {
  "chunk_length": 30,
  "feature_extractor_type": "WhisperFeatureExtractor",
  "feature_size": 128,
  "hop_length": 160,
  "n_fft": 400,
  "n_samples": 480000,
  "nb_max_frames": 3000,
  "padding_side": "right",
  "padding_value": 0.0,
  "processor_class": "WhisperProcessor",
  "return_attention_mask": false,
  "sampling_rate": 16000
}

Using LhotseLongFormHeatDataset


In [None]:
val_idx = 4
plot_diarization(dev_dataset[val_idx]['heat_mask'], title='Heat Output', speaker_labels=['Stream 1', 'Stream 2'])
plot_diarization(dev_dataset[val_idx]['vad_mask'], title='STNO Mask Stream 1', speaker_labels=['Silence', 'Target', 'NonTarget', 'Overlap'])
plot_diarization(dev_dataset[val_idx+1]['vad_mask'], title='STNO Mask Stream 2', speaker_labels=['Silence', 'Target', 'NonTarget', 'Overlap'])

In [26]:
dev_dataset[0]['all_supervisions'][0]

SupervisionSegment(id='TS3004d-294', recording_id='TS3004d', start=6.75, duration=0.1299999999999999, channel=[0], text='OH', language='English', speaker='MTD013PM', gender='M', custom={'end_': 6.88, 'text_': 'OH'}, alignment={'word': [AlignmentItem(symbol='OH', start=6.75, duration=0.13, score=None)]})

In [48]:
Event = namedtuple("Event", ["time", "type", "segment_id"])

def compute_supervision_overlaps(supervisions: List[SupervisionSegment]) -> List[Tuple[float, float, int]]:
    """
    Identify overlapping supervision regions and count overlapping speakers.

    Returns:
        List of (start_time, end_time, num_overlapping_speakers)
    """
    events = []
    for seg in supervisions:
        start = seg.start
        end = seg.end_  # duration instead of .end_
        seg_id = seg.id
        events.append(Event(time=start, type='start', segment_id=seg_id))
        events.append(Event(time=end, type='end', segment_id=seg_id))

    # Sort by time, end before start if tied
    events.sort(key=lambda e: (e.time, 0 if e.type == 'end' else 1))

    active_ids = set()
    overlaps = []
    last_time = None

    for event in events:
        current_time = event.time

        if last_time is not None and current_time > last_time:
            if len(active_ids) > 1:
                overlaps.append((last_time, current_time, len(active_ids), active_ids.copy()))

        if event.type == 'start':
            active_ids.add(event.segment_id)
        elif event.type == 'end':
            active_ids.discard(event.segment_id)

        last_time = current_time

    return overlaps

In [None]:
dev_dataset[4]['all_supervisions']

In [53]:
for idx in range(0, len(dev_dataset), 2):
    print(f'Sample #{idx}')
    overlaps = compute_supervision_overlaps(dev_dataset[idx]['all_supervisions'])
    for start, end, count, active_ids in overlaps:
        print(f"Overlap from {start:.2f}s to {end:.2f}s ({end - start:.2f}s) with {count} speakers {active_ids}")

Sample #0
Overlap from 90.71s to 90.93s (0.22s) with 2 speakers {'TS3004d-818', 'TS3004d-304'}
Overlap from 94.76s to 95.15s (0.39s) with 2 speakers {'TS3004d-306', 'TS3004d-820'}
Overlap from 97.32s to 97.79s (0.47s) with 2 speakers {'TS3004d-568', 'TS3004d-820'}
Overlap from 105.40s to 106.50s (1.10s) with 2 speakers {'TS3004d-569', 'TS3004d-307'}
Overlap from 111.62s to 112.43s (0.81s) with 2 speakers {'TS3004d-821', 'TS3004d-570'}
Overlap from 112.43s to 112.83s (0.40s) with 3 speakers {'TS3004d-821', 'TS3004d-308', 'TS3004d-570'}
Overlap from 112.83s to 113.56s (0.73s) with 2 speakers {'TS3004d-821', 'TS3004d-308'}
Overlap from 114.66s to 114.72s (0.06s) with 2 speakers {'TS3004d-309', 'TS3004d-571'}
Overlap from 115.99s to 116.15s (0.16s) with 2 speakers {'TS3004d-309', 'TS3004d-822'}
Overlap from 116.33s to 116.98s (0.65s) with 2 speakers {'TS3004d-572', 'TS3004d-822'}
Overlap from 122.51s to 122.85s (0.34s) with 2 speakers {'TS3004d-310', 'TS3004d-822'}
Overlap from 134.36s to 