In [None]:
import os

import numpy as np
import pandas as pd


def iou(onehot_preds, onehot_labels):
    ious = []
    for i in range(onehot_preds.shape[1]):
        intersection = (onehot_preds[:, i] * onehot_labels[:, i]).sum()
        union = onehot_preds[:, i].sum() + onehot_labels[:, i].sum() - intersection
        iou = intersection / union if union > 0 else 0.0
        ious.append(iou)
    return np.array(ious).mean()


def compute_numerics(
    delineation_output,
    sample_rate=250,
):
    """
    Extract onset and offset indices for P, QRS, and T wave segments.

    Args:
        delineation_output (np.ndarray): Array containing labels {0, 1, 2, 3}.

    Returns:
        dict: Onset and offset indices for P, QRS, and T segments.
    """
    # Step 1: get onset and offset indices
    segments = {
        'P_onset': [], 'P_offset': [],
        'QRS_onset': [], 'QRS_offset': [],
        'T_onset': [], 'T_offset': []
    }

    transitions = {
        1: ('P_onset', 'P_offset'),
        2: ('QRS_onset', 'QRS_offset'),
        3: ('T_onset', 'T_offset')
    }

    diff_output = np.diff(delineation_output)
    transition_indices = np.where(diff_output != 0)[0]

    for idx in transition_indices:
        prev_label, curr_label = delineation_output[idx], delineation_output[idx + 1]
        if curr_label in transitions:
            segments[transitions[curr_label][0]].append(idx + 1)
        if prev_label in transitions:
            segments[transitions[prev_label][1]].append(idx)

    # Step 2: get intervals
    intervals = {
        "PR": [],
        "QRS": [],
        "QT": [],
    }
    # 2-1: PR interval
    if len(segments['P_onset']) == 1:
        onset = segments['P_onset'][0]
        next_onset = np.inf
        qrs_onset = next((idx for idx in segments['QRS_onset'] if onset < idx < next_onset), None)
        if qrs_onset is not None:
            intervals['PR'].append((onset, qrs_onset))
    elif len(segments['P_onset']) > 1:
        for i in range(len(segments['P_onset']) - 1):
            onset = segments['P_onset'][i]
            next_onset = segments['P_onset'][i + 1]
            # get the QRS onset between two P wave onsets
            qrs_onset = next((idx for idx in segments['QRS_onset'] if onset < idx < next_onset), None)
            if qrs_onset is not None:
                intervals['PR'].append((onset, qrs_onset))

    # 2-2: QRS interval
    if len(segments['QRS_onset']) == 1:
        onset = segments['QRS_onset'][0]
        next_onset = np.inf
        qrs_offset = next((idx for idx in segments['QRS_offset'] if onset < idx < next_onset), None)
        if qrs_offset is not None:
            intervals['QRS'].append((onset, qrs_offset))
    elif len(segments['QRS_onset']) > 1:
        for i in range(len(segments['QRS_onset']) - 1):
            onset = segments['QRS_onset'][i]
            next_onset = segments['QRS_onset'][i + 1]
            # get the QRS offset between two QRS wave onsets
            qrs_offset = next((idx for idx in segments['QRS_offset'] if onset < idx < next_onset), None)
            if qrs_offset is not None:
                intervals['QRS'].append((onset, qrs_offset))

    # 2-3: QT interval
    if len(segments['QRS_onset']) == 1:
        onset = segments['QRS_onset'][0]
        next_onset = np.inf
        t_offset = next((idx for idx in segments['T_offset'] if onset < idx < next_onset), None)
        if t_offset is not None:
            intervals['QT'].append((onset, t_offset))
    elif len(segments['QRS_onset']) > 1:
        for i in range(len(segments['QRS_onset']) - 1):
            onset = segments['QRS_onset'][i]
            next_onset = segments['QRS_onset'][i + 1]
            # get the T wave offset between two QRS wave onsets
            t_offset = next((idx for idx in segments['T_offset'] if onset < idx < next_onset), None)
            if t_offset is not None:
                intervals['QT'].append((onset, t_offset))

    # 2-4: get numerics
    if len(intervals['PR']) > 0:
        pr_interval = np.median(
            [(end - start) / sample_rate for start, end in intervals['PR']]
        ) * 1000  # convert to ms
    else:
        pr_interval = 0

    if len(intervals['QRS']) > 0:
        qrs_interval = np.median(
            [(end - start) / sample_rate for start, end in intervals['QRS']]
        ) * 1000
    else:
        qrs_interval = 0

    if len(intervals['QT']) > 0:
        qt_interval = np.median(
            [(end - start) / sample_rate for start, end in intervals['QT']]
        ) * 1000
    else:
        qt_interval = 0

    return {
        "PR_interval": pr_interval,
        "QRS_duration": qrs_interval,
        "QT_interval": qt_interval,
    }


In [None]:
exp_dir = "../exps"
setting = "cross_domain"
dataset = "merged"
models = [
    "vit_tiny",
]

algos = [
    "scratch",
    "mean_teacher",
    "fixmatch",
    "cps",
    "reco",
    "stpp",
]

performance_table = pd.DataFrame(
    columns=[
        "mIoU",
        "MAE (Avg.)",
        "MAE (PR)",
        "MAE (QRS)",
        "MAE (QT)",
    ]
)

In [None]:
for model_name in models:
    for algo in algos:
        run_dir = os.path.join(exp_dir, model_name, algo, setting, dataset)
        
        # load predictions and labels
        probs = np.load(os.path.join(run_dir, "test_outputs.npy"))
        onehot_preds = (probs == probs.max(axis=1, keepdims=True)).astype(np.float32)
        onehot_masks = np.load(os.path.join(run_dir, "test_labels.npy"))

        miou = iou(onehot_preds, onehot_masks)
        performance_table.loc[f"{model_name}-{algo}", "mIoU"] = f"{miou * 100:.1f}"

        numerics_preds = {
            "PR_interval": [],
            "QRS_duration": [],
            "QT_interval": [],
        }
        numerics_masks = {
            "PR_interval": [],
            "QRS_duration": [],
            "QT_interval": [],
        }
        
        for onehot_pred, onehot_mask in zip(onehot_preds, onehot_masks):
            pred = onehot_pred.argmax(axis=0)
            mask = onehot_mask.argmax(axis=0)

            pred_numerics = compute_numerics(pred)
            true_numerics = compute_numerics(mask)

            for key in numerics_preds.keys():
                numerics_preds[key].append(pred_numerics[key])
                numerics_masks[key].append(true_numerics[key])

        mae_pr = np.mean(np.abs(np.array(numerics_preds["PR_interval"]) - np.array(numerics_masks["PR_interval"])))
        mae_qrs = np.mean(np.abs(np.array(numerics_preds["QRS_duration"]) - np.array(numerics_masks["QRS_duration"])))
        mae_qt = np.mean(np.abs(np.array(numerics_preds["QT_interval"]) - np.array(numerics_masks["QT_interval"])))
        mae_avg = np.mean([mae_pr, mae_qrs, mae_qt])
        performance_table.loc[f"{model_name}-{algo}", "MAE (Avg.)"] = f"{mae_avg:.1f}"
        performance_table.loc[f"{model_name}-{algo}", "MAE (PR)"] = f"{mae_pr:.1f}"
        performance_table.loc[f"{model_name}-{algo}", "MAE (QRS)"] = f"{mae_qrs:.1f}"
        performance_table.loc[f"{model_name}-{algo}", "MAE (QT)"] = f"{mae_qt:.1f}"

performance_table
