In [1]:
import numpy as np
import os
import SimpleITK as sitk
import scipy.spatial
import pandas as pd
from tqdm import tqdm


def get_hausdorff(test_image, result_image):
    """Compute the Hausdorff distance."""

    result_statistics = sitk.StatisticsImageFilter()
    result_statistics.Execute(result_image)

    if result_statistics.GetSum() == 0: 
        hd = np.nan
        return hd

    # Edge detection is done by ORIGINAL - ERODED, keeping the outer boundaries of lesions. Erosion is performed in 3D
    e_test_image = sitk.BinaryErode(test_image, (1, 1, 1))
    e_result_image = sitk.BinaryErode(result_image, (1, 1, 1))

    h_test_image = sitk.Subtract(test_image, e_test_image)
    h_result_image = sitk.Subtract(result_image, e_result_image)

    h_test_indices = np.flip(np.argwhere(sitk.GetArrayFromImage(h_test_image))).tolist()
    h_result_indices = np.flip(
        np.argwhere(sitk.GetArrayFromImage(h_result_image))
    ).tolist()

    test_coordinates = [
        test_image.TransformIndexToPhysicalPoint(x) for x in h_test_indices
    ]
    result_coordinates = [
        test_image.TransformIndexToPhysicalPoint(x) for x in h_result_indices
    ]

    def get_distances_from_a_to_b(a, b):
        kd_tree = scipy.spatial.KDTree(a, leafsize=100)
        return kd_tree.query(b, k=1, eps=0, p=2)[0]

    d_test_to_result = get_distances_from_a_to_b(test_coordinates, result_coordinates)
    d_result_to_test = get_distances_from_a_to_b(result_coordinates, test_coordinates)

    hd = max(np.percentile(d_test_to_result, 95), np.percentile(d_result_to_test, 95))

    return hd


def get_metrics(test_image, label_image):
    """Compute the precision and recall."""
    test_array = sitk.GetArrayFromImage(test_image).flatten()
    result_array = sitk.GetArrayFromImage(label_image).flatten()
    if np.all(test_array == 0) and np.all(result_array == 0):
        return 0, 0, 0, 0, 0, 0, 0

    true_positive = np.sum(test_array * result_array)
    # true_negative = np.sum((1 - test_array) * (1 - result_array))
    false_positive = np.sum((1 - test_array) * result_array)
    false_negative = np.sum(test_array * (1 - result_array))

    sensitivity = (
        true_positive / (true_positive + false_negative)
        if (true_positive + false_negative)
        else 0
    )
    precision = (
        true_positive / (true_positive + false_positive)
        if (true_positive + false_positive)
        else 0
    )
    dice = (
        (2 * true_positive) / (2 * true_positive + false_positive + false_negative)
        if (2 * true_positive + false_positive + false_negative)
        else 0
    )
    test_statistics = sitk.StatisticsImageFilter()
    result_statistics = sitk.StatisticsImageFilter()

    test_statistics.Execute(test_image)
    result_statistics.Execute(label_image)

    numerator = abs(test_statistics.GetSum() - result_statistics.GetSum())
    denominator = test_statistics.GetSum() + result_statistics.GetSum()

    vs = 1 - (float(numerator) / denominator) if denominator else 0

    return (sensitivity, precision, dice, vs, true_positive, false_positive, false_negative)


def calculate(prediction_path, ground_truth_path):
    metrics_list = []

    test_list = [x for x in os.listdir(prediction_path) if x.endswith(".nii.gz") or x.endswith(".mha") ]
    for name in tqdm(sorted(test_list)):
        try:
            for i in range(1, 24):
                test_image_path = os.path.join(ground_truth_path, name)
                result_image_path = os.path.join(prediction_path, name)
                test_image = sitk.ReadImage(test_image_path)
                result_image = sitk.ReadImage(result_image_path)
                assert test_image.GetSize() == result_image.GetSize()

                # Copy meta information
                result_image.CopyInformation(test_image)
                test_image = test_image == i
                result_image = result_image == i

                # Compute metrics
                # try:
                #     h95 = get_hausdorff(test_image, result_image)
                # except Exception as e:
                #     # print(f"Error calculating Hausdorff for {name}: {e}")
                #     h95 = np.nan

                (
                    sensitivity,
                    precision,
                    dice,
                    vs,
                    true_positive,
                    false_positive,
                    false_negative,
                ) = get_metrics(test_image, result_image)

                # Append metrics for current image to list
                metrics_list.append(
                    {
                        "name": name+str(i),
                        "sensitivity": sensitivity,
                        "precision": precision,
                        "dice": dice,
                        "vs": vs,
                        "tp": true_positive,
                        "fp": false_positive,
                        "fn": false_negative,
                    }
                )
        except Exception as e:
            print(f"Failed to process {name}: {e}")
            continue

    # Create DataFrame from list of metrics
    info_df = pd.DataFrame(metrics_list)
    return info_df

In [2]:
paths = [
    "/home/songwei/Data/results/Dataset011_AortaSeg/mynet_1000_66_deep_512x160x160/fold_1_multi/validation_best",
]
label_path = "/home/songwei/Data/raw/labelsTr"

import warnings

warnings.filterwarnings("ignore")
results = []
for p in paths:
    new_fold = (
        calculate(
            prediction_path=p,
            ground_truth_path=label_path,
        ),
    )[0]

    print(p)
    print(new_fold.describe())
    results.append(new_fold)

100%|██████████| 10/10 [04:27<00:00, 26.74s/it]

/home/songwei/Data/results/Dataset011_AortaSeg/mynet_1000_66_deep_512x160x160/fold_1_multi/validation_best
       sensitivity   precision        dice          vs             tp  \
count   230.000000  230.000000  230.000000  230.000000     230.000000   
mean      0.744524    0.747869    0.727097    0.863897   18131.647826   
std       0.192232    0.177555    0.157751    0.117066   38201.826831   
min       0.013636    0.007391    0.010304    0.320468       6.000000   
25%       0.668446    0.658034    0.671987    0.816348    1905.000000   
50%       0.777108    0.782818    0.760032    0.894968    5603.000000   
75%       0.886015    0.868689    0.820565    0.943222   12052.500000   
max       0.987621    0.996453    0.928963    0.998844  341505.000000   

                 fp            fn  
count    230.000000    230.000000  
mean    3616.708696   4564.152174  
std     6214.480707   8519.934050  
min       12.000000     33.000000  
25%      563.750000    528.500000  
50%     1473.500000




In [7]:
new_fold

Unnamed: 0,name,sensitivity,precision,dice,vs,tp,fp,fn
0,subject003.mha1,0.650293,0.865294,0.742543,0.858140,30422,4736,16360
1,subject003.mha2,0.959206,0.514590,0.669832,0.698319,2610,2462,111
2,subject003.mha3,0.750103,0.545140,0.631405,0.841757,3629,3028,1209
3,subject003.mha4,0.885591,0.557632,0.684349,0.772760,1432,1136,185
4,subject003.mha5,0.621006,0.740536,0.675524,0.912210,6025,2111,3677
...,...,...,...,...,...,...,...,...
225,subject047.mha19,0.839885,0.851262,0.845536,0.993273,9678,1691,1845
226,subject047.mha20,0.784442,0.734211,0.758496,0.966923,1674,606,460
227,subject047.mha21,0.683525,0.709379,0.696212,0.981439,1838,753,851
228,subject047.mha22,0.754938,0.732630,0.743617,0.985004,8752,3194,2841
