In [1]:
import os

In [None]:
repo_folder_path = '~/STCarotidSeg4D'
os.chdir(repo_folder_path)

In [None]:
import pandas as pd
import matplotlib.pyplot as plt
import numpy as np
from pathlib import Path

from dataclasses import dataclass

import re

##### Model Paths

In [None]:
model_output_folder = '~/output'
data_folder = '~/data'

In [5]:
test_images = {
    'magnitude': '{data_folder}/TestDataCorrected/2dtime_PCMRI/images',
    'velocities': '{data_folder}/TestDataCorrected/Velocities/images/'
}

test_labels = {
    'magnitude': '{data_folder}/TestDataCorrected/2dtime_PCMRI/labels/',
    'velocities': '{data_folder}/TestDataCorrected/Velocities/labels/'
}

##### Evaluation Folder Paths

In [23]:
evaluation_paths = {
    "magnitude": {
        "unetr_interp16": f"{model_output_folder}/UNETR/magnitude/Interp_16/evaluation/data.xlsx",
        "unetr_interp32": f"{model_output_folder}/UNETR/magnitude/Interp_32/evaluation/data.xlsx",
        "unetr_pad32_crop": f"{model_output_folder}/UNETR/magnitude/Pad_32/evaluation_crop/data.xlsx",
        "unetr_pad32_average": f"{model_output_folder}/UNETR/magnitude/Pad_32/evaluation_avg/data.xlsx",
        "unet3d_interp16": f"{model_output_folder}/UNet3D/magnitude/Interp_16/evaluation/data.xlsx",
        "unet3d_interp32": f"{model_output_folder}/UNet3D/magnitude/Interp_32/evaluation/data.xlsx",
        "unet3d_pad32_average": f"{model_output_folder}/UNet3D/magnitude/Pad_32/evaluation_avg/data.xlsx",
        "unet3d_pad32_crop": f"{model_output_folder}/UNet3D/magnitude/Pad_32/evaluation_crop/data.xlsx",
        "spatio_temporal_interp16": f"{model_output_folder}/SpatioTemporalTransformer/magnitude/Interp_16/evaluation/data.xlsx",
        "spatio_temporal_interp32": f"{model_output_folder}/SpatioTemporalTransformer/magnitude/Interp_32/evaluation/data.xlsx",
        "spatio_temporal_pad32_crop": f"{model_output_folder}/SpatioTemporalTransformer/magnitude/Pad_32/evaluation_crop/data.xlsx",
        "spatio_temporal_pad32_average": f"{model_output_folder}/SpatioTemporalTransformer/magnitude/Pad_32/evaluation_avg/data.xlsx",
        "unet2d": f"{model_output_folder}/UNet2D/magnitude/None_16/evaluation/data.xlsx",
        "nnunet_2d": f"{model_output_folder}/nnUNet_results/Dataset003_MagnitudesCorrected/nnUNetTrainer__nnUNetPlans__2d/evaluation/data.xlsx",
        "nnunet_3d_fullres": f"{model_output_folder}/nnUNet_results/Dataset003_MagnitudesCorrected/nnUNetTrainer__nnUNetPlans__3d_fullres/evaluation/data.xlsx",
    },
    "velocities": {
        "unetr_interp16": f"{model_output_folder}/UNETR/velocities/Interp_16/evaluation/data.xlsx",
        "unetr_interp32": f"{model_output_folder}/UNETR/velocities/Interp_32/evaluation/data.xlsx",
        "unetr_pad32_crop": f"{model_output_folder}/UNETR/velocities/Pad_32/evaluation_crop/data.xlsx",
        "unetr_pad32_average": f"{model_output_folder}/UNETR/velocities/Pad_32/evaluation_avg/data.xlsx",
        "unet3d_interp16": f"{model_output_folder}/UNet3D/velocities/Interp_16/evaluation/data.xlsx",
        "unet3d_interp16_ensemble": f"{model_output_folder}/UNet3D/velocities/Interp_16/evaluation_ensemble/data.xlsx",
        "unet3d_interp32": f"{model_output_folder}/UNet3D/velocities/Interp_32/evaluation/data.xlsx",
        "unet3d_interp32_ensemble": f"{model_output_folder}/UNet3D/velocities/Interp_32/evaluation_ensemble/data.xlsx",
        "unet3d_pad32_average": f"{model_output_folder}/UNet3D/velocities/Pad_32/evaluation_avg/data.xlsx",
        "unet3d_pad32_average_ensemble": f"{model_output_folder}/UNet3D/velocities/Pad_32/evaluation_avg_ensemble/data.xlsx",
        "unet3d_pad32_crop": f"{model_output_folder}/UNet3D/velocities/Pad_32/evaluation_crop/data.xlsx",
        "spatio_temporal_interp16": f"{model_output_folder}/SpatioTemporalTransformer/velocities/Interp_16/evaluation/data.xlsx",
        "spatio_temporal_interp32": f"{model_output_folder}/SpatioTemporalTransformer/velocities/Interp_32/evaluation/data.xlsx",
        "spatio_temporal_pad32_crop": f"{model_output_folder}/SpatioTemporalTransformer/velocities/Pad_32/evaluation_crop/data.xlsx",
        "spatio_temporal_pad32_average": f"{model_output_folder}/SpatioTemporalTransformer/velocities/Pad_32/evaluation_avg/data.xlsx",
        "unet2d": f"{model_output_folder}/UNet2D/velocities/None_16/evaluation/data.xlsx",
        "nnunet_2d": f"{model_output_folder}/nnUNet_results/Dataset004_VelocitiesCorrected/nnUNetTrainer__nnUNetPlans__2d/evaluation/data.xlsx",
        "nnunet_3d_fullres": f"{model_output_folder}/nnUNet_results/Dataset004_VelocitiesCorrected/nnUNetTrainer__nnUNetPlans__3d_fullres/evaluation/data.xlsx",
    },
}

In [21]:
@dataclass
class ModelInfo:
    name: str
    transform: str
    temporal_dimension: int | None

    def __init__(
        self, name: str, transform: str, temporal_dimension: int | None = None
    ):
        self.name = name
        self.transform = transform
        self.temporal_dimension = temporal_dimension

    @property
    def description(self) -> str:
        if self.transform:
            return f"{self.name} ({self.transform}, T'={self.temporal_dimension})"
        else:
            return self.name


all_models = {
    "spatio_temporal_interp16": ModelInfo(
        "SpatioTemporalTransformer", "Interpolate", 16
    ),
    "spatio_temporal_interp32": ModelInfo(
        "SpatioTemporalTransformer", "Interpolate", 32
    ),
    "spatio_temporal_pad32_average": ModelInfo(
        "SpatioTemporalTransformer", "Pad Average", 32
    ),
    "spatio_temporal_pad32_crop": ModelInfo(
        "SpatioTemporalTransformer", "Pad Crop", 32
    ),
    "unetr_interp16": ModelInfo("UNETR", "Interpolate", 16),
    "unetr_interp16_ensemble": ModelInfo("UNETR", "Interpolate", 16),
    "unetr_interp32": ModelInfo("UNETR", "Interpolate", 32),
    "unetr_pad32_average": ModelInfo("UNETR", "Pad Average", 32),
    "unetr_pad32_crop": ModelInfo("UNETR", "Pad Crop", 32),
    "unet3d_interp16": ModelInfo("U-Net 3D", "Interpolate", 16),
    "unet3d_interp16_ensemble": ModelInfo("U-Net 3D", "Interpolate", 16),
    "unet3d_interp32": ModelInfo("U-Net 3D", "Interpolate", 32),
    "unet3d_interp32_ensemble": ModelInfo("U-Net 3D", "Interpolate", 32),
    "unet3d_pad32_average": ModelInfo("U-Net 3D", "Pad Average", 32),
    "unet3d_pad32_average_ensemble": ModelInfo("U-Net 3D", "Pad Average", 32),
    "unet3d_pad32_crop": ModelInfo("U-Net 3D", "Pad Crop", 32),
    "unet2d": ModelInfo("U-Net 2D", "None", None),
    "nnunet_2d": ModelInfo("U-Net 2D", "nnU-Net", None),
    "nnunet_3d_fullres": ModelInfo("U-Net 3D", "nnU-Net", 14),
}

#### Bland-Altman Plots (Validation)

In [7]:
@dataclass
class Metric:
    name: str
    label: str
    units: str
    ground_truth_key: str
    prediction_key: str

    def __init__(self, *, name: str, label: str, units: str, ground_truth_key:str, prediction_key: str):
        self.name = name
        self.label = label
        self.units = units
        self.ground_truth_key = ground_truth_key
        self.prediction_key = prediction_key

In [8]:
# Metrics
lumen_diameter = Metric(
    name='lumen_diameter',
    label='max($d_{lumen}$)',
    units='[mm]',
    ground_truth_key='gt_lumen_diameter',
    prediction_key='pred_lumen_diameter'
)
total_flow = Metric(
    name='total_flow',
    label='Q',
    units='[mL/min]',
    ground_truth_key='gt_flow_rate',
    prediction_key='pred_flow_rate'
)

max_velocity = Metric(
    name='max_velocity',
    label='$v_{max}$',
    units='[m/s]',
    ground_truth_key='gt_max_velocity',
    prediction_key='pred_max_velocity'
)

In [16]:
def bland_altman_plot(
    gt_values: list[float],
    pred_values: list[float],
    save_path,
    metric: Metric,
    comparison: str = "Model",
):
    y = np.array(gt_values) - np.array(pred_values)
    x = np.vstack((gt_values, pred_values)).mean(axis=0)
    loa_upper = y.mean() + 1.96 * y.std()
    loa_lower = y.mean() - 1.96 * y.std()

    plt.style.use("default")
    plt.axhline(loa_upper, color="green", linestyle="--")
    plt.axhline(loa_lower, color="green", linestyle="--")

    plt.scatter(x, y, alpha=0.7)
    plt.axhline(y.mean(), color="red", linestyle="--")
    plt.xlabel(f"Mean {metric.label} {metric.units}")
    plt.ylabel(f"Ground truth {metric.label} - Model {metric.label} {metric.units}")
    plt.title(f"Bland-Altman Plot Ground Truth vs {comparison}")
    plt.savefig(save_path)
    plt.close()


def bland_altman_plot_planes(
    gt_values,
    pred_values,
    save_path,
    metric: Metric,
    comparison: str = "Model",
    category=None,
    category_name="Planes",
):
    y = np.array(gt_values) - np.array(pred_values)
    x = np.vstack((gt_values, pred_values)).mean(axis=0)
    loa_upper = y.mean() + 1.96 * y.std()
    loa_lower = y.mean() - 1.96 * y.std()

    plt.style.use("default")
    plt.axhline(loa_upper, color="green", linestyle="--")
    plt.axhline(loa_lower, color="green", linestyle="--")

    scatter = plt.scatter(x, y, c=category, alpha=0.7)
    plt.legend(scatter.legend_elements()[0], set(category), title=category_name)
    plt.axhline(y.mean(), color="red", linestyle="--")
    plt.xlabel(f"Mean {metric.label} {metric.units}")
    plt.ylabel(f"Ground truth {metric.label} - Model {metric.label} {metric.units}")
    plt.title(f"Bland-Altman Plot Ground Truth vs {comparison}")
    plt.savefig(save_path)
    plt.close()


def create_bland_altman_plot(
    evaluations: dict, dataset: str, model_variant: str, metric: Metric
):
    model_info: ModelInfo = all_models[model_variant]
    df = pd.read_excel(evaluations[dataset][model_variant])
    data = df[[metric.ground_truth_key, metric.prediction_key]]
    path_to_save = (
        Path(evaluations[dataset][model_variant]).parent
        / f"bland_altman_{metric.name}.svg"
    )
    bland_altman_plot(
        gt_values=data[metric.ground_truth_key],
        pred_values=data[metric.prediction_key],
        save_path=path_to_save,
        metric=metric,
        comparison=model_info.description,
    )


def create_bland_altman_plot_planes(
    evaluations: dict, dataset: str, model_variant: str, metric: Metric
):
    model_info: ModelInfo = all_models[model_variant]
    df = pd.read_excel(evaluations[dataset][model_variant])
    planes = list(
        map(lambda x: int(re.findall(r"slice(\d)", x)[0]), list(df["sample"].values))
    )
    df["planes"] = planes
    data = df[[metric.ground_truth_key, metric.prediction_key, "planes"]]
    path_to_save = (
        Path(evaluations[dataset][model_variant]).parent
        / f"bland_altman_{metric.name}_planes.svg"
    )
    bland_altman_plot_planes(
        gt_values=data[metric.ground_truth_key],
        pred_values=data[metric.prediction_key],
        save_path=path_to_save,
        metric=metric,
        comparison=model_info.description,
        category=data["planes"],
        category_name="Plane",
    )


def create_bland_altman_plot_time_steps(
    evaluations: dict, dataset: str, model_variant: str, metric: Metric
):
    model_info: ModelInfo = all_models[model_variant]
    df = pd.read_excel(evaluations[dataset][model_variant])
    data = df[[metric.ground_truth_key, metric.prediction_key, "time_step"]]
    path_to_save = (
        Path(evaluations[dataset][model_variant]).parent
        / f"bland_altman_{metric.name}_time_step.svg"
    )
    bland_altman_plot_planes(
        gt_values=data[metric.ground_truth_key],
        pred_values=data[metric.prediction_key],
        save_path=path_to_save,
        metric=metric,
        comparison=model_info.description,
        category=list(data["time_step"] + 1),
        category_name="Time Step",
    )

In [None]:
# Lumen diameter plots
datasets = ['magnitude', 'velocities']


for dataset in datasets:
    for model_variation in evaluation_paths[dataset].keys():
        try:
            create_bland_altman_plot(
                evaluations=evaluation_paths,
                dataset=dataset,
                model_variant=model_variation,
                metric=lumen_diameter
            )
            create_bland_altman_plot_planes(
                evaluations=evaluation_paths,
                dataset=dataset,
                model_variant=model_variation,
                metric=lumen_diameter
            )
        except Exception as e:
            print(e.args)
            continue

In [None]:
# Flow rate plots
datasets = ['magnitude', 'velocities']


for dataset in datasets:
    for model_variation in evaluation_paths[dataset].keys():
        try:
            create_bland_altman_plot(
                evaluations=evaluation_paths,
                dataset=dataset,
                variation=model_variation,
                metric=total_flow
            )
            create_bland_altman_plot_planes(
                evaluations=evaluation_paths,
                dataset=dataset,
                variation=model_variation,
                metric=total_flow
            )
        except Exception as e:
            print(e.args)
            continue

In [None]:
# Max velocity plots
datasets = ['magnitude', 'velocities']


for dataset in datasets:
    for model_variation in evaluation_paths[dataset].keys():
        try:
            create_bland_altman_plot(
                evaluations=evaluation_paths,
                dataset=dataset,
                variation=model_variation,
                metric=max_velocity
            )
            create_bland_altman_plot_planes(
                evaluations=evaluation_paths,
                dataset=dataset,
                variation=model_variation,
                metric=max_velocity
            )
        except Exception as e:
            print(e.args)
            continue

#### Bland-Altman Plots (Test Data)

In [None]:
# plots for the best models on the test data
model_configurations = [('velocities', 'unet3d_interp16_ensemble'), ('velocities', 'unet3d_pad32_average_ensemble')]
for dataset, variation in model_configurations:
    try:
        create_bland_altman_plot(
            evaluations=evaluation_paths,
            dataset=dataset,
            variation=variation,
            metric=lumen_diameter
        )
        create_bland_altman_plot_planes(
            evaluations=evaluation_paths,
            dataset=dataset,
            variation=variation,
            metric=lumen_diameter
        )
        create_bland_altman_plot(
            evaluations=evaluation_paths,
            dataset=dataset,
            variation=variation,
            metric=total_flow
        )
        create_bland_altman_plot_planes(
            evaluations=evaluation_paths,
            dataset=dataset,
            variation=variation,
            metric=total_flow
        )
        create_bland_altman_plot(
            evaluations=evaluation_paths,
            dataset=dataset,
            variation=variation,
            metric=max_velocity
        )
        create_bland_altman_plot_planes(
            evaluations=evaluation_paths,
            dataset=dataset,
            variation=variation,
            metric=max_velocity
        )
    except Exception as e:
        print(e.args)