In [None]:
from pathlib import Path

import pandas as pd

from psifx.io import json, rttm, vtt
import numpy as np

In [None]:
root = Path("/home/guillaume/Datasets/UNIL/CH.102")
transcription_path = root / "Transcriptions/CH.102.combined.vtt"
diarization_path = root / "Diarizations/CH.102.combined.rttm"
identification_path = root / "Identifications/CH.102.combined.json"
enhanced_transcription_path = root / "Transcriptions/CH.102.combined.enhanced.vtt"

In [None]:
transcription = vtt.VTTReader.read(transcription_path)
transcription = pd.DataFrame.from_records(transcription)
transcription

In [None]:
diarization = rttm.RTTMReader.read(diarization_path)
diarization = pd.DataFrame.from_records(diarization)
diarization["end"] = diarization["start"] + diarization["duration"]
diarization

In [None]:
identification = json.JSONReader.read(identification_path)
mapping = identification["mapping"]
identification

In [None]:
for transcription_index in range(transcription.shape[0]):
    transcription_row = transcription.iloc[transcription_index]
    matching_diarization_index = None
    highest_iou_index, highest_iou = None, 0.0
    for diarization_index in range(diarization.shape[0]):
        diarization_row = diarization.iloc[diarization_index]
        intersection_start = max(transcription_row["start"], diarization_row["start"])
        intersection_end = min(transcription_row["end"], diarization_row["end"])
        union_start = min(transcription_row["start"], diarization_row["start"])
        union_end = max(transcription_row["end"], diarization_row["end"])
        intersection_duration = max(0.0, intersection_end - intersection_start)
        union_duration = max(0.0, union_end - union_start)
        iou = intersection_duration / union_duration
        if iou > highest_iou:
            highest_iou_index, highest_iou = diarization_index, iou
    matching_diarization_index = highest_iou_index
    # if matching_diarization_index is None:
    #     transcription_center = transcription_row[["start", "end"]].mean()
    #     lowest_distance_index, lowest_distance = None, np.inf
    #     for diarization_index in range(diarization.shape[0]):
    #         diarization_row = diarization.iloc[diarization_index]
    #         diarization_center = diarization_row[["start", "end"]].mean()
    #         distance = np.abs(transcription_center - diarization_center)
    #         if distance < lowest_distance:
    #             lowest_distance_index, lowest_distance = diarization_index, distance
    #     matching_diarization_index = lowest_distance_index
    if matching_diarization_index is not None:
        speaker_name = mapping[diarization.iloc[matching_diarization_index]["speaker_name"]]
    else:
        speaker_name = "NA"
    transcription.loc[transcription_index, "speaker"] = speaker_name
transcription

In [None]:
segments = []
for index in range(transcription.shape[0]):
    segment = {
        "start": transcription.iloc[index]["start"],
        "end": transcription.iloc[index]["end"],
        "speaker": transcription.iloc[index]["speaker"],
        "text": transcription.iloc[index]["text"],
    }
    segments.append(segment)
vtt.VTTWriter.write(
    path="/home/guillaume/Datasets/UNIL/CH.102/Transcriptions/CH.102.combined.enhanced.vtt",
    segments=segments,
    verbose=True,
    overwrite=True,
)