In [None]:
import pandas as pd
import numpy as np
import os
from rapidfuzz import process, fuzz
from collections import defaultdict
from typing import List, Dict

### Helper Function

In [None]:
def map_speakers(transcript, speaker_changes):
    """
    Map Whisper-style transcript speaker labels (e.g., "SPEAKER_00", "SPEAKER_01") 
    to real speaker names (from diarization output) based on maximum temporal overlap.

    The function compares transcript segments to diarization intervals and 
    determines which real speaker each Whisper speaker corresponds to most often.

    Args:
        transcript (list[dict]): 
            List of transcript segments, each containing:
            - 'start' (float): Start time in seconds.
            - 'end' (float): End time in seconds.
            - 'speaker' (str): Whisper-assigned speaker label (e.g., "SPEAKER_00").

        speaker_changes (list[list]): 
            List of diarization speaker changes as `[timestamp, speaker_name]`,
            sorted by timestamp.

    Returns:
        dict:
            Mapping of Whisper speaker labels to real speaker names, e.g.:
            {
                "SPEAKER_00": "john",
                "SPEAKER_01": "mary"
            }

    Example:
        >>> transcript = [{"start": 0.0, "end": 5.0, "speaker": "SPEAKER_00"}]
        >>> speaker_changes = [[0.0, "John"], [10.0, "Mary"]]
        >>> map_speakers(transcript, speaker_changes)
        {'SPEAKER_00': 'John'}
    """

    
    # Convert speaker_changes into intervals [(start, end, speaker_name)]
    intervals = []
    for i, (time_str, speaker_name) in enumerate(speaker_changes):
        start = float(time_str)
        end = float(speaker_changes[i+1][0]) if i+1 < len(speaker_changes) else float('inf')
        intervals.append((start, end, speaker_name))
    
    # Track overlaps: { transcript_speaker -> { real_speaker -> total_overlap } }
    overlaps = defaultdict(lambda: defaultdict(float))
    
    for seg in transcript:
        t_start, t_end, t_speaker = seg["start"], seg["end"], seg["speaker"]
        
        for i_start, i_end, i_speaker in intervals:
            # Find overlap between transcript segment and real speaker interval
            overlap = max(0, min(t_end, i_end) - max(t_start, i_start))
            if overlap > 0:
                overlaps[t_speaker][i_speaker] += overlap
    
    # Pick the best matching real speaker for each transcript speaker
    mapping = {}
    for t_speaker, speaker_dict in overlaps.items():
        best_match = max(speaker_dict.items(), key=lambda x: x[1])[0]
        mapping[t_speaker] = best_match
    
    return mapping


def merge_consecutive_whisper(data):
    """
    Merge consecutive Whisper transcript segments from the same speaker.

    If adjacent segments share the same 'speaker', they are combined into 
    a single segment with extended 'end' time and concatenated text.

    Args:
        data (list[dict]):
            List of transcript segments, each containing:
            - 'start' (float): Segment start time.
            - 'end' (float): Segment end time.
            - 'speaker' (str): Speaker label.
            - 'text' (str): Segment text.
            Optionally may include 'words' (list), which are dropped when merging.

    Returns:
        list[dict]:
            A list of merged segments, with consecutive same-speaker segments combined.

    Example:
        >>> data = [
        ...     {"start": 0, "end": 2, "speaker": "SPEAKER_00", "text": "Hello"},
        ...     {"start": 2, "end": 4, "speaker": "SPEAKER_00", "text": "world"},
        ...     {"start": 4, "end": 6, "speaker": "SPEAKER_01", "text": "Hi"}
        ... ]
        >>> merge_consecutive_whisper(data)
        [{'start': 0, 'end': 4, 'speaker': 'SPEAKER_00', 'text': 'Hello world'},
         {'start': 4, 'end': 6, 'speaker': 'SPEAKER_01', 'text': 'Hi'}]
    """

    if not data:
        return []

    merged = [data[0]]

    for segment in data[1:]:
        last = merged[-1]
        if segment["speaker"] == last["speaker"]:
            # Extend the last segment
            last["end"] = segment["end"]
            last["text"] += " " + segment["text"]
            
            
            #last["words"].extend(segment["words"])
        else:
            # New speaker → start new segment
            del segment['words']
            merged.append(segment)

    return merged


def standardize_speaker_changes(speaker_changes):
    """
    Clean and standardize a list of speaker change events using fuzzy name matching.

    The function:
      1. Converts timestamps to floats and normalizes speaker names (lowercase, no spaces).
      2. Uses fuzzy matching to merge similar speaker names (e.g., "John" ≈ "Jon").
      3. Merges consecutive entries belonging to the same standardized speaker.

    Args:
        speaker_changes (list[list]):
            List of `[timestamp, speaker_name]` pairs.

    Returns:
        list[tuple]:
            A list of standardized speaker change tuples `(start_time, standardized_name)`.

    Example:
        >>> speaker_changes = [
        ...     ["0.0", "John Doe"],
        ...     ["5.0", "john"],
        ...     ["10.0", "Mary"]
        ... ]
        >>> standardize_speaker_changes(speaker_changes)
        [(0.0, 'johndoe'), (10.0, 'mary')]
    """

    df = pd.DataFrame(speaker_changes, columns=['timestamp', 'speaker'])
    df['speaker'] = df['speaker'].apply(lambda x: x.lower().replace(' ', ''))
    df['timestamp'] = df['timestamp'].astype(float)

    # 1️⃣ Standardize speaker names using fuzzy matching
    unique_speakers = df['speaker'].unique()
    standard_names = {}

    for s in unique_speakers:
        # Compare against already mapped speakers
        if standard_names:
            match, score, _ = process.extractOne(s, standard_names.keys(), scorer=fuzz.ratio)
            if score >= 80:  # similarity threshold
                standard_names[s] = standard_names[match]
            else:
                standard_names[s] = s
        else:
            standard_names[s] = s
    df['speaker_std'] = df['speaker'].map(standard_names)

    # 2️⃣ Merge consecutive identical speakers
    merged = []
    prev_speaker = None
    for _, row in df.iterrows():
        if prev_speaker is None or row['speaker_std'] != prev_speaker['speaker_std']:
            # start a new block
            prev_speaker = {
                'start': row['timestamp'],
                'speaker_std': row['speaker_std']
            }
            merged.append(prev_speaker)
        else:
            # same speaker, skip timestamp (keep start)
            continue

    # Convert to DataFrame
    merged_df = pd.DataFrame(merged)
    merged_df['speaker_std'] = merged_df['speaker_std'].fillna('Other')
    speaker_changes = [(x['start'], x['speaker_std']) for (idx, x) in merged_df.iterrows()]
    return speaker_changes


### Setup

In [None]:
whisper_path = 'path_to_whisper_transcriptions'
speaker_changes_path = 'path_to_zoom_diarization'
save_loc = 'save_loc_path'
os.makedirs(save_loc, exist_ok=True)

whisper_files = os.listdir(whisper_path)
speaker_changes_files = os.listdir(speaker_changes_path)
common = list(set(whisper_files) & set(speaker_changes_files))

#### 1. Map Speaker namees accross videos

In [None]:
all_speaker = []
for file in common:
    speaker_changes = np.load(os.path.join(speaker_changes_path, file), allow_pickle=True)
    whisper = np.load(os.path.join(whisper_path, file), allow_pickle=True)
    all_speaker = all_speaker + list(speaker_changes)
all_speaker = [str(x[1]) for x in all_speaker]
all_speaker = [x.split('-')[0].split('|')[0].split(',')[0].strip() for x in all_speaker]

unique_names = []
mapping = {}  # original -> canonical

for name in np.unique(all_speaker):
    if not unique_names:
        unique_names.append(name)
        mapping[str(name)] = str(name)
        continue
    
    # find best match among canonical names
    match, score, _ = process.extractOne(name, unique_names, scorer=fuzz.token_sort_ratio)
    if score >= 80:
        canonical = match
    else:
        canonical = str(name)
        unique_names.append(name)
    mapping[str(name)] = str(canonical)

### 2. Map speakers and Save

In [None]:
df = pd.DataFrame()
for file in common:

    speaker_changes = np.load(os.path.join(speaker_changes_path, file), allow_pickle=True)
    whisper = np.load(os.path.join(whisper_path, file), allow_pickle=True)
    speaker_changes = standardize_speaker_changes(speaker_changes)

    tmp = pd.DataFrame(list(np.unique([x[1] for x in speaker_changes])))
    if len(tmp) > 100:
        print("Skipping: ", file)
        continue
    whisper_transcript = merge_consecutive_whisper(list(whisper))
    transcript_map = map_speakers(whisper_transcript, speaker_changes)

    for x in whisper_transcript:
        x['speaker'] = mapping[transcript_map[x['speaker']]]

    np.save(os.path.join(save_loc, file), whisper_transcript)
        

KeyError: '02:00'