In [2]:
from collections import defaultdict
import json

In [10]:
with open("/export/fs06/xhe69/TS-ASR-Whisper/exp/heat_oraclediar_md_beam5ctc02lp01/test/0/wer/EN2002a/ref.json", "r") as file:
    reference_data = json.load(file)

In [11]:
# Step 1: Create timeline events
events = []
for entry in reference_data:
    events.append((entry['start_time'], 'start', entry['speaker']))
    events.append((entry['end_time'], 'end', entry['speaker']))

# Step 2: Sort events by time
events.sort()

In [None]:
# Step 3: Sweep line algorithm to generate intervals
active_speakers = set()
result_intervals = []

for i in range(len(events) - 1):
    time, event_type, speaker = events[i]

    if event_type == 'start':
        active_speakers.add(speaker)
    elif event_type == 'end':
        active_speakers.discard(speaker)

    next_time = events[i + 1][0]
    if time != next_time and active_speakers:
        result_intervals.append((round(time, 2), round(next_time, 2), sorted(active_speakers)))

# Step 4: Print results in the requested format
for start, end, speakers in result_intervals:
    print(f"{start:.2f} {end:.2f} {', '.join(speakers)}")


In [14]:
recording_ids = ['EN2002a', 'EN2002b', 'EN2002c', 'EN2002d',
                 'ES2004a', 'ES2004b', 'ES2004c', 'ES2004d',
                 'IS1009a', 'IS1009b', 'IS1009c', 'IS1009d',
                 'TS3003a', 'TS3003b', 'TS3003c', 'TS3003d',
                ]

In [8]:
def get_overlap(a_start, a_end, b_start, b_end):
    return max(0, min(a_end, b_end) - max(a_start, b_start))

In [18]:
for rec_id in recording_ids:
    with open(f"/export/fs06/xhe69/TS-ASR-Whisper/exp/heat_oraclediar_md_beam5ctc02lp01/test/0/wer/{rec_id}/ref.json", "r") as file:
        reference_data = json.load(file)
    with open(f"/export/fs06/xhe69/TS-ASR-Whisper/exp/heat_oraclediar_md_beam5ctc02lp01/test/0/wer/{rec_id}/tc_orc_wer_hyp.json") as f:
        predicted_data = json.load(f)

    # For each predicted segment, find majority speaker from reference
    assigned_predictions = []
    
    for pred in predicted_data:
        p_start = pred["start_time"]
        p_end = pred["end_time"]
        speaker_overlap = {}
    
        # Check overlap with each reference segment
        for ref in reference_data:
            r_start = ref["start_time"]
            r_end = ref["end_time"]
            speaker = ref["speaker"]
    
            overlap = get_overlap(p_start, p_end, r_start, r_end)
            if overlap > 0:
                speaker_overlap[speaker] = speaker_overlap.get(speaker, 0) + overlap
    
        # Assign speaker with the most overlap
        if speaker_overlap:
            majority_speaker = max(speaker_overlap, key=speaker_overlap.get)
        else:
            majority_speaker = "UNKNOWN"  # or skip if preferred
    
        assigned_predictions.append({
            "session_id": pred["session_id"],
            "start_time": pred["start_time"],
            "end_time": pred["end_time"],
            "words": pred["words"],
            "speaker": majority_speaker
        })
    
    # Write results to a new JSON file
    with open(f"/export/fs06/xhe69/TS-ASR-Whisper/exp/heat_oraclediar_md_beam5ctc02lp01/test/0/wer/{rec_id}/cp_wer_hyp_simple.json", "w") as f_out:
        json.dump(assigned_predictions, f_out, indent=2)
        
assigned_predictions = []
for rec_id in recording_ids:
    with open(f"/export/fs06/xhe69/TS-ASR-Whisper/exp/heat_oraclediar_md_beam5ctc02lp01/test/0/wer/{rec_id}/cp_wer_hyp_simple.json", "r") as file:
        data = json.load(file)
    assigned_predictions = assigned_predictions + data
    
with open(f"/export/fs06/xhe69/TS-ASR-Whisper/exp/heat_oraclediar_md_beam5ctc02lp01/test/0/wer/all_cp_wer_hyp_simple.json", "w") as f_out:
        json.dump(assigned_predictions, f_out, indent=2)

In [20]:
recording_ids_rttm = {'EN2002a': '6', 
                 'EN2002b': '2', 
                 'EN2002c': '12', 
                 'EN2002d': '8',
                 'ES2004a': '3', 
                 'ES2004b': '7', 
                 'ES2004c': '13', 
                 'ES2004d': '14',
                 'IS1009a': '5', 
                 'IS1009b': '1', 
                 'IS1009c': '11', 
                 'IS1009d': '9',
                 'TS3003a': '15', 
                 'TS3003b': '10', 
                 'TS3003c': '0', 
                 'TS3003d': '4',
                }

In [22]:
for rec_id in recording_ids:
    rttm_segments = []
    with open(f"/export/fs06/xhe69/TS-ASR-Whisper/diar_exp/diarizen_large/ami-sdm_test_sc_cutset/{rec_id}-{recording_ids_rttm[rec_id]}.rttm", "r") as f:
        for line in f:
            parts = line.strip().split()
            if parts[0] != "SPEAKER":
                continue
            start_time = float(parts[3])
            duration = float(parts[4])
            end_time = start_time + duration
            speaker = parts[7]
            rttm_segments.append({
                "start_time": start_time,
                "end_time": end_time,
                "speaker": speaker
            })
    
    with open(f"/export/fs06/xhe69/TS-ASR-Whisper/exp/heat_harddiar_md_beam5ctc02lp01/test/0/wer/{rec_id}/tc_orc_wer_hyp.json") as f:
        predicted_data = json.load(f)

    assigned_predictions = []
    
    for pred in predicted_data:
        p_start = pred["start_time"]
        p_end = pred["end_time"]
        speaker_overlap = {}
    
        # Check overlap with each reference segment
        for ref in rttm_segments:
            r_start = ref["start_time"]
            r_end = ref["end_time"]
            speaker = ref["speaker"]
    
            overlap = get_overlap(p_start, p_end, r_start, r_end)
            if overlap > 0:
                speaker_overlap[speaker] = speaker_overlap.get(speaker, 0) + overlap
    
        # Assign speaker with the most overlap
        if speaker_overlap:
            majority_speaker = max(speaker_overlap, key=speaker_overlap.get)
        else:
            majority_speaker = "UNKNOWN"  # or skip if preferred
    
        assigned_predictions.append({
            "session_id": pred["session_id"],
            "start_time": pred["start_time"],
            "end_time": pred["end_time"],
            "words": pred["words"],
            "speaker": majority_speaker
        })
    
    # Write results to a new JSON file
    with open(f"/export/fs06/xhe69/TS-ASR-Whisper/exp/heat_harddiar_md_beam5ctc02lp01/test/0/wer/{rec_id}/cp_wer_hyp_simple.json", "w") as f_out:
        json.dump(assigned_predictions, f_out, indent=2)

assigned_predictions = []
for rec_id in recording_ids:
    with open(f"/export/fs06/xhe69/TS-ASR-Whisper/exp/heat_harddiar_md_beam5ctc02lp01/test/0/wer/{rec_id}/cp_wer_hyp_simple.json", "r") as file:
        data = json.load(file)
    assigned_predictions = assigned_predictions + data
    
with open(f"/export/fs06/xhe69/TS-ASR-Whisper/exp/heat_harddiar_md_beam5ctc02lp01/test/0/wer/all_cp_wer_hyp_simple.json", "w") as f_out:
        json.dump(assigned_predictions, f_out, indent=2)