In [1]:
import numpy as np
from skimage import measure
import pandas as pd
import os

In [2]:
def find_partial_overlap(pred, gt):
    """
    find the fractional volume of pred in gt
    """
    pred_volume = np.sum(pred)
    gt_volume = np.sum(gt)
    intersection = np.sum(pred & gt)
    return intersection / (pred_volume+1e-10)

def overlap_info(pred_region, gt_region, gt_lesion, overlap_info, pred_labels):
    for pred_label in pred_labels:
        if pred_label == 0:
            continue
        pred = pred_region == pred_label
        overlap = find_partial_overlap(pred, gt_region)
        if pred_label in overlap_info.keys():
            overlap_info[pred_label].append({'gt_lesion': gt_lesion, 'overlap' : overlap})
        else:
            overlap_info[pred_label] = [{'gt_lesion': gt_lesion, 'overlap' : overlap}]
    
    return overlap_info

def find_overlaps(pred, gt, labeled=False):
    if not labeled:
        pred = measure.label(pred, background=0)
        gt = measure.label(gt, background=0)
    pred_labels = np.unique(pred)
    
    overlaps = {}
    for pred_lb in pred_labels:
        if pred_lb == 0:
            continue
        overlaps[pred_lb] = []
    region_props = pd.DataFrame(measure.regionprops_table(gt, properties=('label','bbox',)))
    for i, row in region_props.iterrows():
        gt_lesion = row['label']
        xmin = row['bbox-0']
        ymin = row['bbox-1']
        zmin = row['bbox-2']
        xmax = row['bbox-3']
        ymax = row['bbox-4']
        zmax = row['bbox-5']
        gt_region = gt[xmin:xmax, ymin:ymax, zmin:zmax]
        gt_region = (gt_region == gt_lesion)
        pred_region = pred[xmin:xmax, ymin:ymax, zmin:zmax]
        overlaps = overlap_info(pred_region, gt_region, gt_lesion, overlaps, pred_labels)
    return overlaps

In [3]:
import nibabel as nib
def load_nifti_volume(file_path):
    nifti_image = nib.load(file_path)
    volume = nifti_image.get_fdata()
    return volume


In [155]:
gt = load_nifti_volume('../stacks/svuh/ground_truth/subID129_0.nii.gz')
pred = load_nifti_volume('../stacks/svuh/beams_siamese_unet/subID129_0.nii.gz')

In [158]:
pred_lb = measure.label(pred, background=0)
gt_lb = measure.label(gt, background=0)
ov_info = find_overlaps(pred_lb, gt_lb, labeled=True)
ov_info

1
2
pred volume 109
gt volume 278
intersection volume 107


{1: [{'gt_lesion': 1, 'overlap': 0.0}, {'gt_lesion': 2, 'overlap': 0.0}],
 2: [{'gt_lesion': 1, 'overlap': 0.0}, {'gt_lesion': 2, 'overlap': 0.0}],
 3: [{'gt_lesion': 1, 'overlap': 0.0}, {'gt_lesion': 2, 'overlap': 0.0}],
 4: [{'gt_lesion': 1, 'overlap': 0.0}, {'gt_lesion': 2, 'overlap': 0.0}],
 5: [{'gt_lesion': 1, 'overlap': 0.0}, {'gt_lesion': 2, 'overlap': 0.0}],
 6: [{'gt_lesion': 1, 'overlap': 0.0}, {'gt_lesion': 2, 'overlap': 0.0}],
 7: [{'gt_lesion': 1, 'overlap': 0.0}, {'gt_lesion': 2, 'overlap': 0.0}],
 8: [{'gt_lesion': 1, 'overlap': 0.0},
  {'gt_lesion': 2, 'overlap': 0.9816513761458884}],
 9: [{'gt_lesion': 1, 'overlap': 0.0}, {'gt_lesion': 2, 'overlap': 0.0}],
 10: [{'gt_lesion': 1, 'overlap': 0.0}, {'gt_lesion': 2, 'overlap': 0.0}],
 11: [{'gt_lesion': 1, 'overlap': 0.0}, {'gt_lesion': 2, 'overlap': 0.0}],
 12: [{'gt_lesion': 1, 'overlap': 0.0}, {'gt_lesion': 2, 'overlap': 0.0}],
 13: [{'gt_lesion': 1, 'overlap': 0.0}, {'gt_lesion': 2, 'overlap': 0.0}],
 14: [{'gt_lesion

In [4]:
def update_prediction_labels(pred_labels, overlap_dict, max_gt_label):
    updated_labels = pred_labels.copy()
    last_label = max_gt_label

    for pred_label, overlaps in overlap_dict.items():
        if pred_label == 0:
            continue  # skip background
        
        # Find the ground truth lesion with the highest overlap
        max_overlap = -1
        best_gt_lesion = None
        for overlap_info in overlaps:
            if overlap_info['overlap'] > max_overlap:
                max_overlap = overlap_info['overlap']
                best_gt_lesion = overlap_info['gt_lesion']
        
        if max_overlap > 0:
            # Update the prediction labels with the best matching ground truth label
            updated_labels[pred_labels == pred_label] = best_gt_lesion
        else:
            # Assign a unique label for zero overlap lesions
            last_label += 1
            updated_labels[pred_labels == pred_label] = last_label
    
    return updated_labels

In [159]:
pred_labels = np.unique(pred_lb)
updated_pred = update_prediction_labels(measure.label(pred_lb, background=0), ov_info, pred_labels[-1] )

In [160]:
np.unique(updated_pred)

array([ 0,  2, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27])

In [148]:
ov_info2 = find_overlaps(updated_pred, gt, labeled=True)

1
2
pred volume 109
gt volume 278
intersection volume 107


{2: [{'gt_lesion': 1, 'overlap': 0.0},
  {'gt_lesion': 2, 'overlap': 0.9816513761458884}],
 15: [{'gt_lesion': 1, 'overlap': 0.0}, {'gt_lesion': 2, 'overlap': 0.0}],
 16: [{'gt_lesion': 1, 'overlap': 0.0}, {'gt_lesion': 2, 'overlap': 0.0}],
 17: [{'gt_lesion': 1, 'overlap': 0.0}, {'gt_lesion': 2, 'overlap': 0.0}],
 18: [{'gt_lesion': 1, 'overlap': 0.0}, {'gt_lesion': 2, 'overlap': 0.0}],
 19: [{'gt_lesion': 1, 'overlap': 0.0}, {'gt_lesion': 2, 'overlap': 0.0}],
 20: [{'gt_lesion': 1, 'overlap': 0.0}, {'gt_lesion': 2, 'overlap': 0.0}],
 21: [{'gt_lesion': 1, 'overlap': 0.0}, {'gt_lesion': 2, 'overlap': 0.0}],
 22: [{'gt_lesion': 1, 'overlap': 0.0}, {'gt_lesion': 2, 'overlap': 0.0}],
 23: [{'gt_lesion': 1, 'overlap': 0.0}, {'gt_lesion': 2, 'overlap': 0.0}],
 24: [{'gt_lesion': 1, 'overlap': 0.0}, {'gt_lesion': 2, 'overlap': 0.0}],
 25: [{'gt_lesion': 1, 'overlap': 0.0}, {'gt_lesion': 2, 'overlap': 0.0}],
 26: [{'gt_lesion': 1, 'overlap': 0.0}, {'gt_lesion': 2, 'overlap': 0.0}],
 27: [{'g

In [5]:
def calculate_metrics(y_true_labels, y_pred_labels):
    # Get unique labels in ground truth and prediction
    true_labels = np.unique(y_true_labels)
    pred_labels = np.unique(y_pred_labels)
    # Initialize counts
    true_positive = 0
    false_positive = 0
    false_negative = 0
    
    # Calculate true positives
    for label in true_labels:
        if label == 0:
            continue
        if label in pred_labels:
            true_positive += 1
        else:
            false_negative += 1
    
    # Calculate false positives
    for label in pred_labels:
        if label == 0:
            continue
        if label not in true_labels:
            false_positive += 1
    
    return true_positive, false_positive, false_negative

In [6]:
def find_metrics(model, dataset, target_dir, eps=1e-10):
    csv = os.path.join(target_dir, dataset, 'ground_truth.csv')
    df = pd.read_csv(csv)
    results = []
    for i, row in df.iterrows():
        filename = row['file']
        gt = load_nifti_volume(os.path.join(target_dir, dataset, 'ground_truth', filename))
        pred = load_nifti_volume(os.path.join(target_dir, dataset, model, filename))
        pred_lb = measure.label(pred, background=0)
        gt_lb = measure.label(gt, background=0)
        ov_info = find_overlaps(pred_lb, gt_lb, labeled=True)
        updated_pred = update_prediction_labels(measure.label(pred_lb, background=0), ov_info, pred_lb.max() )
        tp, fp, fn = calculate_metrics(gt_lb, updated_pred)
        if tp == 0:
            if fp == 0 and fn == 0:
                print(f'Perfect prediction for {filename}')
                dice = 1
                precision = 1
                recall = 1
            else:
                dice = 0
                precision = 0
                recall = 0
        else:
            dice = 2*tp / (2*tp + fp + fn + eps)
            precision = tp / (tp + fp + eps)
            recall = tp / (tp + fn + eps)
        results.append({'filename': filename, 'dice': dice, 'precision': precision, 'recall': recall,
                        'correct': tp, 'false_positive': fp, 'false_negative': fn})
    return pd.DataFrame(results)


In [7]:
target_dir = '../stacks/'
model = 'beams_siamese_unet'
dataset = 'svuh'
beams_su = find_metrics(model, dataset, target_dir)

UnboundLocalError: local variable 'dice' referenced before assignment

In [8]:
model = 'siamese_unet'
su = find_metrics(model, dataset, target_dir)

In [9]:
model = "att_pfpn"
pfpn = find_metrics(model, dataset, target_dir)

In [10]:
model = "beams_att_pfpn"
beams_pfpn = find_metrics(model, dataset, target_dir)

In [11]:
model = "nnUNet"
nnunet = find_metrics(model, dataset, target_dir)

In [12]:
model = "xbound22"
xbound22 = find_metrics(model, dataset, target_dir)

In [13]:
model = "vitseg_r18"
vit = find_metrics(model, dataset, target_dir)

In [15]:
model = "pfpn"
n_pfpn = find_metrics(model, dataset, target_dir)

In [16]:
model = "beams_pfpn"
n_beams_pfpn = find_metrics(model, dataset, target_dir)

In [38]:
target_dir = '../stacks/'
model = "beams_siamese_unet_bcef05"
dataset = 'svuh'
beams_siamese_p = find_metrics(model, dataset, target_dir)

In [7]:
target_dir = '../stacks/'
model = "vitseg_r50_backbone_enc_dec_2"
dataset = 'svuh'
vit_r50 = find_metrics(model, dataset, target_dir)

In [184]:
beams_su.to_csv('beams_su.csv', index=False)
su.to_csv('su.csv', index=False)
pfpn.to_csv('pfpn.csv', index=False)
beams_pfpn.to_csv('beams_pfpn.csv', index=False)
nnunet.to_csv('nnunet.csv', index=False)
xbound22.to_csv('xbound22.csv', index=False)
vit.to_csv('vit.csv', index=False)


In [8]:
import pandas as pd

def summarize(results):
    for model, df in results.items():
        print("-"*10)
        print(model)
        print("Dice mean: ", df['dice'].mean())
        print("Precision mean: ", df['precision'].mean())
        print("Recall mean: ", df['recall'].mean())
        print("Correct: ", df['correct'].sum())
        print("False Positive: ", df['false_positive'].sum())
        print("False Negative: ", df['false_negative'].sum())
        print("-"*10)

def summarize_df(results):
    data = []
    for model, df in results.items():
        data.append({'model': model, 'dice': df['dice'].mean(), 'precision': df['precision'].mean(), 'recall': df['recall'].mean(),
                    'correct': df['correct'].sum(), 'false_positive': df['false_positive'].sum(), 'false_negative': df['false_negative'].sum()})
    return pd.DataFrame(data)

In [9]:
# results = [("nnUNet", nnunet),
#            ("XboundFormer", xbound22),
#            ("ViT", vit),
#            ("att pfpn", pfpn),
#            ("pfpn", n_pfpn),
#            ("Siamese U-Net", su),
#            ("BEAMS att PFPN", beams_pfpn),
#            ("BEAMS PFPN", n_beams_pfpn),
#            ("BEAMS Siamese U-Net", beams_su),
#            ("BEAMS Siamese U-Net P", beams_siamese_p)]

results = {'vit': vit_r50}
summarize(dict(results))
           

----------
vit
Dice mean:  0.44666083440941334
Precision mean:  0.4222563650087221
Recall mean:  0.5818518518145306
Correct:  65
False Positive:  158
False Negative:  33
----------


In [195]:
res_df = summarize_df(dict(results))

In [10]:
res_df = summarize_df(dict(results))
ltx = res_df.style.to_latex()
print(ltx)

\begin{tabular}{llrrrrrr}
 & model & dice & precision & recall & correct & false_positive & false_negative \\
0 & vit & 0.446661 & 0.422256 & 0.581852 & 65 & 158 & 33 \\
\end{tabular}



In [8]:
target_dir = '../'
model = 'siamc-dfl'
dataset = 'ablation_stacks'
dfl = find_metrics(model, dataset, target_dir)

In [9]:
target_dir = '../'
model = 'siamc-dice'
dataset = 'ablation_stacks'
dc = find_metrics(model, dataset, target_dir)

In [10]:
target_dir = '../'
model = 'siamc-f05'
dataset = 'ablation_stacks'
f05 = find_metrics(model, dataset, target_dir)

In [11]:
target_dir = '../'
model = 'siamc-f2'
dataset = 'ablation_stacks'
f2 = find_metrics(model, dataset, target_dir)

In [16]:
results = [("DiceLoss", dc),
           ("DiceFocalLoss", dfl),
           ("FBetaLoss-2", f2),
           ("FBetaLoss-0.5", f05),
           ]

summarize(dict(results))

----------
DiceLoss
Dice mean:  0.35083611489812083
Precision mean:  0.2842400743071825
Recall mean:  0.6660493826740203
Correct:  76
False Positive:  451
False Negative:  22
----------
----------
DiceFocalLoss
Dice mean:  0.23836665077550126
Precision mean:  0.16098570542053717
Recall mean:  0.7228395061260641
Correct:  81
False Positive:  781
False Negative:  17
----------
----------
FBetaLoss-2
Dice mean:  0.13212098417061058
Precision mean:  0.07624578785951742
Recall mean:  0.7420987653848333
Correct:  85
False Positive:  1429
False Negative:  13
----------
----------
FBetaLoss-0.5
Dice mean:  0.4700853226095146
Precision mean:  0.426500843470907
Recall mean:  0.6309876542804093
Correct:  71
False Positive:  185
False Negative:  27
----------


In [17]:
res_df = summarize_df(dict(results))
ltx = res_df.style.to_latex()
print(ltx)

\begin{tabular}{llrrrrrr}
 & model & dice & precision & recall & correct & false_positive & false_negative \\
0 & DiceLoss & 0.350836 & 0.284240 & 0.666049 & 76 & 451 & 22 \\
1 & DiceFocalLoss & 0.238367 & 0.160986 & 0.722840 & 81 & 781 & 17 \\
2 & FBetaLoss-2 & 0.132121 & 0.076246 & 0.742099 & 85 & 1429 & 13 \\
3 & FBetaLoss-0.5 & 0.470085 & 0.426501 & 0.630988 & 71 & 185 & 27 \\
\end{tabular}



In [19]:
target_dir = '../stacks/'
model = 'nnUNet'
dataset = 'ms2'
ms2nnunet = find_metrics(model, dataset, target_dir)

In [20]:
target_dir = '../stacks/'
model = 'beams_siamese_unet'
dataset = 'ms2'
ms2_beams_siamese = find_metrics(model, dataset, target_dir)

In [21]:
target_dir = '../stacks/'
model = 'pfpn'
dataset = 'ms2'
ms2_pfpn = find_metrics(model, dataset, target_dir)

In [22]:
target_dir = '../stacks/'
model = 'beams_pfpn'
dataset = 'ms2'
ms2_beams_pfpn = find_metrics(model, dataset, target_dir)

In [23]:
target_dir = '../stacks/'
model = 'siamese_unet'
dataset = 'ms2'
ms2_siamese = find_metrics(model, dataset, target_dir)

In [24]:
target_dir = '../stacks/'
model = 'xbound22'
dataset = 'ms2'
ms2_xbound = find_metrics(model, dataset, target_dir)

In [25]:
target_dir = '../stacks/'
model = 'att_pfpn'
dataset = 'ms2'
ms2_att_pfpn = find_metrics(model, dataset, target_dir)

In [26]:
target_dir = '../stacks/'
model = 'beams_att_pfpn'
dataset = 'ms2'
ms2_beams_att_pfpn = find_metrics(model, dataset, target_dir)

In [28]:
target_dir = '../stacks/'
model = "beams_siamese_unet_f05p"
dataset = 'ms2'
ms2_beams_siamese_f05p = find_metrics(model, dataset, target_dir)

In [30]:
target_dir = '../stacks/'
model = "beams_siamese_unet_no_p"
dataset = 'ms2'
ms2_beams_siamese_nop = find_metrics(model, dataset, target_dir)

In [40]:
target_dir = '../stacks/'
model = "beams_siamese_unet_bcef05"
dataset = 'ms2'
ms2_beams_siamese_p = find_metrics(model, dataset, target_dir)

In [11]:
target_dir = '../stacks/'
model = "vitseg_r50_backbone_enc_dec_2"
dataset = 'ms2'
ms2_vit_r50 = find_metrics(model, dataset, target_dir)

In [25]:
ms2_xbound.to_csv('ms2_xbound.csv', index=False)
ms2_siamese.to_csv('ms2_siamese.csv', index=False)
ms2_beams_siamese.to_csv('ms2_beams_siamese.csv', index=False)
ms2_pfpn.to_csv('ms2_pfpn.csv', index=False)
ms2_beams_pfpn.to_csv('ms2_beams_pfpn.csv', index=False)
ms2nnunet.to_csv('ms2_nnunet.csv', index=False)
ms2_att_pfpn.to_csv('ms2_att_pfpn.csv', index=False)
ms2_beams_att_pfpn.to_csv('ms2_beams_att_pfpn.csv', index=False)



In [12]:
# results = [("nnUNet", ms2nnunet),
#            ("XboundFormer", ms2_xbound),
#            ("pfpn", ms2_pfpn),
#            ("Siamese U-Net", ms2_siamese),
           
#            ("BEAMS Siamese U-Net P", ms2_beams_siamese_p),
#            ("BEAMS PFPN", ms2_beams_pfpn),
#            ]
results = {'vit': ms2_vit_r50}


summarize(dict(results))

----------
vit
Dice mean:  0.2210874140447432
Precision mean:  0.34435515871170735
Recall mean:  0.21779801291145096
Correct:  50
False Positive:  51
False Negative:  95
----------


In [13]:
res_df = summarize_df(dict(results))
ltx = res_df.style.to_latex()
print(ltx)

\begin{tabular}{llrrrrrr}
 & model & dice & precision & recall & correct & false_positive & false_negative \\
0 & vit & 0.221087 & 0.344355 & 0.217798 & 50 & 51 & 95 \\
\end{tabular}

