In [1]:
from skimage.measure import regionprops, label
import matplotlib.pyplot as plt
import matplotlib
from natsort import natsorted
from glob import glob
from PIL import Image
import os
import numpy as np
from numba import jit
from scipy.optimize import linear_sum_assignment
from scipy.ndimage import convolve, mean
from tqdm import trange

def mask_ious(masks_true, masks_pred):
    """ return best-matched masks """
    iou = _intersection_over_union(masks_true, masks_pred)[1:,1:]
    n_min = min(iou.shape[0], iou.shape[1])
    costs = -(iou >= 0.5).astype(float) - iou / (2*n_min)
    true_ind, pred_ind = linear_sum_assignment(costs)
    iout = np.zeros(masks_true.max())
    iout[true_ind] = iou[true_ind,pred_ind]
    preds = np.zeros(masks_true.max(), 'int')
    preds[true_ind] = pred_ind+1
    return iout, preds


def aggregated_jaccard_index(masks_true, masks_pred):
    """ AJI = intersection of all matched masks / union of all masks 
    
    Parameters
    ------------
    
    masks_true: list of ND-arrays (int) or ND-array (int) 
        where 0=NO masks; 1,2... are mask labels
    masks_pred: list of ND-arrays (int) or ND-array (int) 
        ND-array (int) where 0=NO masks; 1,2... are mask labels
    Returns
    ------------
    aji : aggregated jaccard index for each set of masks
    """

    aji = np.zeros(len(masks_true))
    for n in range(len(masks_true)):
        iout, preds = mask_ious(masks_true[n], masks_pred[n])
        inds = np.arange(0, masks_true[n].max(), 1, int)
        overlap = _label_overlap(masks_true[n], masks_pred[n])
        union = np.logical_or(masks_true[n]>0, masks_pred[n]>0).sum()
        overlap = overlap[inds[preds>0]+1, preds[preds>0].astype(int)]
        aji[n] = overlap.sum() / union
    return aji 

def dice_score(true_mask, bn_mask, eps=0.00001):
    true_mask = 1*(true_mask>0)
    bn_mask   = 1*(bn_mask>0)

    # Computing intersection and union masks
    inter_mask = np.dot(bn_mask.flatten(), true_mask.flatten())
    union_mask = np.sum(bn_mask) + np.sum(true_mask) + eps

    # Computing the Dice coefficient
    return (2 * inter_mask + eps) / union_mask

def eval_metrics(masks_true, masks_pred, threshold=[0.5, 0.75, 0.9]):
    """ average precision estimation: AP = TP / (TP + FP + FN)
    This function is based heavily on the *fast* stardist matching functions
    (https://github.com/mpicbg-csbd/stardist/blob/master/stardist/matching.py)
    Parameters
    ------------
    
    masks_true: list of ND-arrays (int) or ND-array (int) 
        where 0=NO masks; 1,2... are mask labels
    masks_pred: list of ND-arrays (int) or ND-array (int) 
        ND-array (int) where 0=NO masks; 1,2... are mask labels
    Returns
    ------------
    ap: array [len(masks_true) x len(threshold)]
        average precision at thresholds
    tp: array [len(masks_true) x len(threshold)]
        number of true positives at thresholds
    fp: array [len(masks_true) x len(threshold)]
        number of false positives at thresholds
    fn: array [len(masks_true) x len(threshold)]
        number of false negatives at thresholds
    """
    not_list = False
    if not isinstance(masks_true, list):
        masks_true = [masks_true]
        masks_pred = [masks_pred]
        not_list = True
    if not isinstance(threshold, list) and not isinstance(threshold, np.ndarray):
        threshold = [threshold]
    
    if len(masks_true) != len(masks_pred):
        raise ValueError('metrics.average_precision requires len(masks_true)==len(masks_pred)')

    f1  = np.zeros((len(masks_true), len(threshold)), np.float32)
    tp  = np.zeros((len(masks_true), len(threshold)), np.float32)
    fp  = np.zeros((len(masks_true), len(threshold)), np.float32)
    fn  = np.zeros((len(masks_true), len(threshold)), np.float32)

    accuracy  = np.zeros((len(masks_true), len(threshold)), np.float32)
    precision = np.zeros((len(masks_true), len(threshold)), np.float32)
    recall    = np.zeros((len(masks_true), len(threshold)), np.float32)

    n_true = np.array(list(map(np.max, masks_true)))
    n_pred = np.array(list(map(np.max, masks_pred)))
    dice = []
    miou = []

    for n in range(len(masks_true)):
        tmp_dice = dice_score(masks_true[n], masks_pred[n])

        #_,mt = np.reshape(np.unique(masks_true[n], return_index=True), masks_pred[n].shape)
        if n_pred[n] > 0:
            iou = _intersection_over_union(masks_true[n], masks_pred[n])[1:, 1:]
            iou_arranged = arrange_iou(iou)
            miou.append(np.sum(iou_arranged)/ np.max(masks_true[n]))
            for k,th in enumerate(threshold):
                tp[n,k] = _true_positive(iou, th)
        fp[n] = n_pred[n] - tp[n]
        fn[n] = n_true[n] - tp[n]

        # Computing F1 score
        f1[n] = 2 * tp[n] / (2 * tp[n] + fp[n] + fn[n])
        # Computing the precision
        precision[n] =  tp[n] / (tp[n] + fp[n])

        # Computing the recall
        recall[n]    =  tp[n] / (tp[n] + fn[n])

        # Computing accuracy
        accuracy[n] = tp[n] / (tp[n] + fp[n] + fn[n])  

        dice.append(tmp_dice)
    dice = np.mean(dice)
    miou = np.mean(miou) 
    if not_list:
        miou, dice, f1, accuracy, precision, recall = miou[0], dice[0], f1[0], accuracy[0], precision[0], recall[0]#tp[0], fp[0], fn[0]
    return miou, dice, f1, accuracy, precision, recall #tp, fp, fn

@jit(nopython=True)
def _label_overlap(x, y):
    """ fast function to get pixel overlaps between masks in x and y 
    
    Parameters
    ------------
    x: ND-array, int
        where 0=NO masks; 1,2... are mask labels
    y: ND-array, int
        where 0=NO masks; 1,2... are mask labels
    Returns
    ------------
    overlap: ND-array, int
        matrix of pixel overlaps of size [x.max()+1, y.max()+1]
    
    """
    # put label arrays into standard form then flatten them 
#     x = (utils.format_labels(x)).ravel()
#     y = (utils.format_labels(y)).ravel()
    x = x.ravel()
    y = y.ravel()
    
    # preallocate a 'contact map' matrix
    overlap = np.zeros((1+x.max(),1+y.max()), dtype=np.uint)
    
    # loop over the labels in x and add to the corresponding
    # overlap entry. If label A in x and label B in y share P
    # pixels, then the resulting overlap is P
    # len(x)=len(y), the number of pixels in the whole image 
    for i in range(len(x)):
        overlap[x[i],y[i]] += 1
    return overlap

def _intersection_over_union(masks_true, masks_pred):
    """ intersection over union of all mask pairs
    
    Parameters
    ------------
    
    masks_true: ND-array, int 
        ground truth masks, where 0=NO masks; 1,2... are mask labels
    masks_pred: ND-array, int
        predicted masks, where 0=NO masks; 1,2... are mask labels
    Returns
    ------------
    iou: ND-array, float
        matrix of IOU pairs of size [x.max()+1, y.max()+1]
    
    ------------
    How it works:
        The overlap matrix is a lookup table of the area of intersection
        between each set of labels (true and predicted). The true labels
        are taken to be along axis 0, and the predicted labels are taken 
        to be along axis 1. The sum of the overlaps along axis 0 is thus
        an array giving the total overlap of the true labels with each of
        the predicted labels, and likewise the sum over axis 1 is the
        total overlap of the predicted labels with each of the true labels.
        Because the label 0 (background) is included, this sum is guaranteed
        to reconstruct the total area of each label. Adding this row and
        column vectors gives a 2D array with the areas of every label pair
        added together. This is equivalent to the union of the label areas
        except for the duplicated overlap area, so the overlap matrix is
        subtracted to find the union matrix. 
    """
    overlap = _label_overlap(masks_true, masks_pred)
    n_pixels_pred = np.sum(overlap, axis=0, keepdims=True)
    n_pixels_true = np.sum(overlap, axis=1, keepdims=True)
    iou = overlap / (n_pixels_pred + n_pixels_true - overlap)
    iou[np.isnan(iou)] = 0.0
    return iou

def _true_positive(iou, th):
    """ true positive at threshold th
    
    Parameters
    ------------
    iou: float, ND-array
        array of IOU pairs
    th: float
        threshold on IOU for positive label
    Returns
    ------------
    tp: float
        number of true positives at threshold
        
    ------------
    How it works:
        (1) Find minimum number of masks
        (2) Define cost matrix; for a given threshold, each element is negative
            the higher the IoU is (perfect IoU is 1, worst is 0). The second term
            gets more negative with higher IoU, but less negative with greater
            n_min (but that's a constant...)
        (3) Solve the linear sum assignment problem. The costs array defines the cost
            of matching a true label with a predicted label, so the problem is to 
            find the set of pairings that minimizes this cost. The scipy.optimize
            function gives the ordered lists of corresponding true and predicted labels. 
        (4) Extract the IoUs fro these parings and then threshold to get a boolean array
            whose sum is the number of true positives that is returned. 
    """
    n_min = min(iou.shape[0], iou.shape[1])
    costs = -(iou >= th).astype(float) - iou / (2*n_min)
    true_ind, pred_ind = linear_sum_assignment(costs)
    match_ok = iou[true_ind, pred_ind] >= th
    tp = match_ok.sum()
    # print('tp : ', tp)
    return tp

def arrange_iou(iou):
    n_min = min(iou.shape[0], iou.shape[1])
    costs = -(iou >= 0.1).astype(float) - iou / (2*n_min)
    true_ind, pred_ind = linear_sum_assignment(costs)
    return iou[true_ind, pred_ind]

In [2]:
methods = ['cellpose', 'stardist','LSTM_bionet','XAI_bionet', 'LSTM_attunet','XAI_attunet', 'LSTM_unet','XAI_unet']
show_methods = ['Cellpose', 'Stardist','BioNet-LSTM','BioNet(XAI)', 'AttUNet-LSTM','AttUNet(XAI)', 'UNet-LSTM','UNet(XAI)']

my_dict_f1   = {"methods":[],"IoU threshold":[],"F1 score":[], "F1 score (std)":[],"Accuracy score":[], "Accuracy score (std)":[],"Precision score":[], "Precision score (std)":[],"Recall score":[], "Recall score (std)":[]};
my_dict_miou = {"methods":[],"mIoU (mean)":[], "mIoU (std)":[], "Dice (mean)":[], "Dice (std)":[]};

thresholds = [0.5, 0.55, 0.6, 0.65, 0.7, 0.75, 0.8, 0.85, 0.9, 0.95, 1]


for m in trange(len(methods)):
    # print(show_methods[m])
    mIoU_list, dice_list, f1_list, std_f1_list, acc_list, std_acc_list, pres_list, std_pres_list, rec_list, std_rec_list = [], [], [], [], [], [], [], [], [], []
    for fold in range(1,6):
        gt_masks = natsorted(glob(os.path.join('dataset', 'microglial_cells_test_sequences_1024', '*', '*_binary_mask.tif')))
        xai_unet_preds = natsorted(glob(os.path.join(methods[m], 'results', 'FOLD-'+str(fold), '*', '*.tif')))

        masks_true, masks_pred = [], []
        for i in range(len(gt_masks)):

            masks_true.append(label(np.array(Image.open(gt_masks[i]))))
            masks_pred.append(np.array(Image.open(xai_unet_preds[i])))
            
        mIoU, dice, f1, accuracy, precision, recall = eval_metrics(masks_true, masks_pred, threshold=thresholds)

        mIoU_list.append(mIoU)
        dice_list.append(dice)
        f1_list.append(np.mean(f1, axis=0))
        acc_list.append(np.mean(accuracy, axis=0))
        pres_list.append(np.mean(precision, axis=0))
        rec_list.append(np.mean(recall, axis=0))

    f1_list_tmp = np.mean(f1_list, axis=0)
    std_f1_list = np.std(f1_list, axis=0)

    acc_list_tmp = np.mean(acc_list, axis=0)
    std_acc_list = np.std(acc_list, axis=0)

    pres_list_tmp = np.mean(pres_list, axis=0)
    std_pres_list = np.std(pres_list, axis=0)

    rec_list_tmp = np.mean(rec_list, axis=0)
    std_rec_list = np.std(rec_list, axis=0)

    for k in range(len(f1_list_tmp)):
        my_dict_f1["methods"].append(show_methods[m])
        my_dict_f1["IoU threshold"].append(thresholds[k])

        my_dict_f1["F1 score"].append(round(f1_list_tmp[k]*100,2))
        my_dict_f1["F1 score (std)"].append(round(std_f1_list[k]*100,2))

        my_dict_f1["Accuracy score"].append(round(acc_list_tmp[k]*100,2))
        my_dict_f1["Accuracy score (std)"].append(round(std_acc_list[k]*100,2))

        my_dict_f1["Precision score"].append(round(pres_list_tmp[k]*100,2))
        my_dict_f1["Precision score (std)"].append(round(std_pres_list[k]*100,2))

        my_dict_f1["Recall score"].append(round(rec_list_tmp[k]*100,2))
        my_dict_f1["Recall score (std)"].append(round(std_rec_list[k]*100,2))
    
    my_dict_miou["methods"].append(show_methods[m])
    my_dict_miou["mIoU (mean)"].append(round(np.mean(mIoU_list)*100,2))
    my_dict_miou["mIoU (std)"].append(round(np.std(mIoU_list)*100,2))
    my_dict_miou["Dice (mean)"].append(round(np.mean(dice_list)*100,2))
    my_dict_miou["Dice (std)"].append(round(np.std(dice_list)*100,2))


100%|██████████| 8/8 [06:53<00:00, 51.72s/it]


In [3]:
import pandas as pd

df_metrics_f1 = pd.DataFrame.from_dict(my_dict_f1)
df_metrics_f1.to_csv('f1_metrics.csv')


df_metrics_miou = pd.DataFrame.from_dict(my_dict_miou)
df_metrics_miou.to_csv('miou_metrics.csv')

In [4]:
import pandas as pd

df_metrics_f1    = pd.read_csv('f1_metrics.csv')
df_metrics_miou  = pd.read_csv('miou_metrics.csv')
df_metrics_exe_time = pd.read_csv('execution_time_metrics.csv')

In [5]:
df_metrics_f1.loc[df_metrics_f1['IoU threshold'] == 0.5]

Unnamed: 0.1,Unnamed: 0,methods,IoU threshold,F1 score,F1 score (std),Accuracy score,Accuracy score (std),Precision score,Precision score (std),Recall score,Recall score (std)
0,0,Cellpose,0.5,83.35,1.41,72.22,1.99,91.3,1.75,77.09,1.24
11,11,Stardist,0.5,85.82,0.93,76.04,1.31,94.8,0.15,78.95,1.47
22,22,BioNet-LSTM,0.5,80.25,5.41,68.44,7.01,90.83,2.88,72.73,7.44
33,33,BioNet(XAI),0.5,81.43,4.13,70.09,5.33,88.67,3.67,75.99,5.95
44,44,AttUNet-LSTM,0.5,86.12,2.52,76.73,3.54,92.2,3.2,81.33,2.72
55,55,AttUNet(XAI),0.5,86.53,1.64,77.0,2.47,89.0,2.63,84.53,1.55
66,66,UNet-LSTM,0.5,86.9,1.8,77.72,2.53,93.47,1.61,81.66,2.2
77,77,UNet(XAI),0.5,85.44,2.12,75.72,2.92,87.95,2.87,83.56,2.46


In [6]:
df_metrics_miou

Unnamed: 0.1,Unnamed: 0,methods,mIoU (mean),mIoU (std),Dice (mean),Dice (std)
0,0,Cellpose,63.85,1.39,84.51,1.16
1,1,Stardist,66.21,0.96,87.86,0.52
2,2,BioNet-LSTM,63.94,6.48,92.2,2.78
3,3,BioNet(XAI),65.94,5.31,91.24,2.66
4,4,AttUNet-LSTM,71.17,2.77,91.87,2.6
5,5,AttUNet(XAI),73.55,1.41,93.77,0.34
6,6,UNet-LSTM,72.23,1.95,94.04,0.28
7,7,UNet(XAI),72.29,2.6,93.18,1.18


In [7]:
df_metrics_exe_time

Unnamed: 0.1,Unnamed: 0,methods,execution time (mean),execution time (std)
0,0,Cellpose,1.0328,0.0238
1,1,Stardist,0.2224,0.0173
2,2,BioNet-LSTM,0.3062,0.0062
3,3,BioNet(XAI),0.2038,0.0042
4,4,Att-UNet-LSTM,0.3692,0.0062
5,5,AttUNet(XAI),0.1573,0.0021
6,6,UNet-LSTM,0.3692,0.0062
7,7,UNet(XAI),0.1367,0.0069


In [8]:
df1 = df_metrics_miou.loc[df_metrics_miou['methods'] == 'Cellpose']
df2 = df_metrics_miou.loc[df_metrics_miou['methods'] == 'Stardist']

df3 = df_metrics_miou.loc[df_metrics_miou['methods'] == 'BioNet-LSTM']
df4 = df_metrics_miou.loc[df_metrics_miou['methods'] == 'BioNet(XAI)']

df5 = df_metrics_miou.loc[df_metrics_miou['methods'] == 'AttUNet-LSTM']
df6 = df_metrics_miou.loc[df_metrics_miou['methods'] == 'AttUNet(XAI)']

df7 = df_metrics_miou.loc[df_metrics_miou['methods'] == 'UNet-LSTM']
df8 = df_metrics_miou.loc[df_metrics_miou['methods'] == 'UNet(XAI)']

In [9]:
fdf1 = df_metrics_exe_time.loc[df_metrics_exe_time['methods'] == 'Cellpose']
fdf2 = df_metrics_exe_time.loc[df_metrics_exe_time['methods'] == 'Stardist']

fdf3 = df_metrics_exe_time.loc[df_metrics_exe_time['methods'] == 'BioNet-LSTM']
fdf4 = df_metrics_exe_time.loc[df_metrics_exe_time['methods'] == 'BioNet(XAI)']

fdf5 = df_metrics_exe_time.loc[df_metrics_exe_time['methods'] == 'Att-UNet-LSTM']
fdf6 = df_metrics_exe_time.loc[df_metrics_exe_time['methods'] == 'AttUNet(XAI)']

fdf7 = df_metrics_exe_time.loc[df_metrics_exe_time['methods'] == 'UNet-LSTM']
fdf8 = df_metrics_exe_time.loc[df_metrics_exe_time['methods'] == 'UNet(XAI)']

In [14]:
import plotly.graph_objects as go

colors = ["+",] * 8
colors[0], colors[1] = '', ''

y=[float(df1['mIoU (mean)']), float(df2['mIoU (mean)']), 
    float(df3['mIoU (mean)']), float(df4['mIoU (mean)']),
    float(df5['mIoU (mean)']), float(df6['mIoU (mean)']),
    float(df7['mIoU (mean)']), float(df8['mIoU (mean)'])]

fig = go.Figure(data=[go.Bar(
    x=['Cellpose', 'Stardist', 'BioNet-LSTM', 'BioNet(XAI)', 'AttUNet-LSTM', 'AttUNet(XAI)', 'UNet-LSTM', 'UNet(XAI)'],
    y =y,
    text=y,
    textposition='outside',
    error_y=dict(type='data', array=[
        float(df1['mIoU (std)']), float(df2['mIoU (std)']),
        float(df3['mIoU (std)']), float(df4['mIoU (std)']),
        float(df5['mIoU (std)']), float(df6['mIoU (std)']),
        float(df7['mIoU (std)']), float(df8['mIoU (std)'])]),
    marker_pattern_shape=colors # marker color can be a single color value or an iterable
)])
fig.update_yaxes(range=(0, 100))
fig.update_layout( yaxis_title='mIoU',width=600,
                   height=500)

fig.write_image("./evaluation_figures/mIoU.pdf")
fig.show()

In [11]:
import plotly.graph_objects as go

colors = ["+",] * 8
colors[0], colors[1] = '', ''
y=[float(fdf1['execution time (mean)']), float(fdf2['execution time (mean)']), 
       float(fdf3['execution time (mean)']), float(fdf4['execution time (mean)']),
       float(fdf5['execution time (mean)']), float(fdf6['execution time (mean)']),
       float(fdf7['execution time (mean)']), float(fdf8['execution time (mean)'])]
fig = go.Figure(data=[go.Bar(
    x=['Cellpose', 'Stardist', 'BioNet-LSTM', 'BioNet(XAI)', 'AttUNet-LSTM', 'AttUNet(XAI)', 'UNet-LSTM', 'UNet(XAI)'],
    y =y,
    text=y,
    textposition='outside',
    error_y=dict(type='data', array=[
        float(fdf1['execution time (std)']), float(fdf2['execution time (std)']),
        float(fdf3['execution time (std)']), float(fdf4['execution time (std)']),
        float(fdf5['execution time (std)']), float(fdf6['execution time (std)']),
        float(fdf7['execution time (std)']), float(fdf8['execution time (std)'])]),
    marker_pattern_shape=colors # marker color can be a single color value or an iterable
)])
# fig.update_yaxes(range=(0, 1.2))
fig.update_layout( yaxis_title='Execution time per image (seconds)',width=600,
                   height=500)

fig.write_image("./evaluation_figures/exe_time.pdf")
fig.show()

In [12]:
import plotly.graph_objects as go

colors = ["+",] * 6

fig = go.Figure(data=[go.Bar(
    x=['BioNet-LSTM', 'BioNet(XAI)', 'AttUNet-LSTM', 'AttUNet(XAI)', 'UNet-LSTM', 'UNet(XAI)'],
    y=[60234946,8420065, 54168494, 7571033, 50325506, 7023385],
    text=[60234946,8420065, 54168494, 7571033, 50325506, 7023385],
    textposition='outside',
    marker_pattern_shape=colors # marker color can be a single color value or an iterable
)])
fig.update_layout( yaxis_title='Number of parameters',width=600, height=500)

fig.write_image('./evaluation_figures/LSTM_vs_XAI_params.pdf')
fig.show()

In [13]:

cellpose_idx = df_metrics_f1['methods']=='Cellpose'
stardist_idx = df_metrics_f1['methods']=='Stardist'

lstm_bio_idx = df_metrics_f1['methods']=='LSTM-BioNet'
xai_bio_idx = df_metrics_f1['methods'] =='BioNet(XAI)'

lstm_attunet_idx = df_metrics_f1['methods']=='AttUNet-LSTM'
xai_attunet_idx = df_metrics_f1['methods'] =='AttUNet(XAI)'

lstm_unet_idx = df_metrics_f1['methods']=='UNet-LSTM'
xai_unet_idx = df_metrics_f1['methods'] =='UNet(XAI)'

fig = go.Figure()

fig.add_trace(go.Scatter(x=list(df_metrics_f1[cellpose_idx]['IoU threshold']), y=list(df_metrics_f1[cellpose_idx]['Accuracy score']), name='CellPose', line = dict(width=4)
            , error_y=dict(type='data',array=list(df_metrics_f1[cellpose_idx]['Accuracy score (std)']),visible=True)))

fig.add_trace(go.Scatter(x=list(df_metrics_f1[stardist_idx]['IoU threshold']), y=list(df_metrics_f1[stardist_idx]['Accuracy score']), name='StarDist', line = dict(width=4)
            , error_y=dict(type='data',array=list(df_metrics_f1[stardist_idx]['Accuracy score (std)']),visible=True)))

fig.add_trace(go.Scatter(x=list(df_metrics_f1[xai_attunet_idx]['IoU threshold']), y=list(df_metrics_f1[xai_attunet_idx]['Accuracy score']), name='Ours:Att-Unet(XAI)', line = dict(width=4, dash='dash')
               , error_y=dict(type='data',array=list(df_metrics_f1[xai_attunet_idx]['Accuracy score (std)']),visible=True)))

fig.update_yaxes(range=(0, 100))
fig.update_xaxes(range=(0.49, 1))
fig.update_layout( # title='The influence of the IoU threshold on F1 score',
                   xaxis_title='IoU threshold',
                   yaxis_title='Accuracy score',
                   autosize=False,
                   width=1200,
                   height=500,
                   legend=dict(
                   yanchor="bottom",
                   y=0.03,
                   xanchor="left",
                   x=0.03
))
fig.write_image("./evaluation_figures/accuracy_score.pdf")
fig.show()