In [None]:
import os
from os.path import join as pjoin
import pdb

import plotly.graph_objs as go
import plotly as py
import ipywidgets as widgets
py.offline.init_notebook_mode(connected=True)

import numpy as np
from sklearn.metrics import confusion_matrix
import pyift.pyift as ift

In [None]:
result_dir = "../exps/SAAD"
gt_dir = "../bases/ATLAS-304/3T/regs/nonrigid/labels/primary_stroke"

tpr_dict = {0: [], 1: []}
precision_dict = {0: [], 1: []}
fdr_dict = {0: [], 1: []}
svoxel_vols_dict = {0: [], 1: []}
dist_to_border_dict = {0: [], 1: []}

img_ids_list = sorted(os.listdir(result_dir))

for idx, img_id in enumerate(img_ids_list):
    print(f"[{idx}/{len(img_ids_list) - 1}] Image ID: {img_id}")
    
    svoxels_path = pjoin(result_dir, img_id, "svoxels.nii.gz")
    result_path = pjoin(result_dir, img_id, "result.nii.gz")
    gt_path = pjoin(gt_dir, img_id + ".nii.gz")
    
    gt_img = ift.ReadImageByExt(gt_path)
    sagittal_slice = gt_img.xsize // 2
    
    gt_img_flip = ift.FlipImage(gt_img, ift.IFT_AXIS_X)
    gt_mask = gt_img.AsNumPy() + gt_img_flip.AsNumPy()
    gt_mask[:, :, sagittal_slice:] = 0
    gt_mask = gt_mask.astype(np.bool)
    
    svoxels_data = ift.ReadImageByExt(svoxels_path).AsNumPy()
    svoxels_data[:, :, sagittal_slice:] = 0
    result_data = ift.ReadImageByExt(result_path).AsNumPy()
    result_data[:, :, sagittal_slice:] = 0
    
    target_svoxels = np.unique(svoxels_data * gt_mask)[1:]  # ignore the 0-label    
    target_svoxels_data = np.isin(svoxels_data, target_svoxels) * svoxels_data

    for svoxel in target_svoxels:
        print(f"\tsvoxel = {svoxel}")
        target_svoxel_mask = target_svoxels_data == svoxel
    
        tn, fp, fn, tp = confusion_matrix(gt_mask.ravel(), target_svoxel_mask.ravel()).ravel()
        
        tpr = tp / (tp + fn)  # true positive rate
        precision = tp / (tp + fp)
        fdr = fp / (fp + tp)  # false discovery rate
        svoxel_vol = tp + fp
        was_detected = np.any(result_data == svoxel)

        tpr_dict[was_detected].append(tpr)
        precision_dict[was_detected].append(precision)
        fdr_dict[was_detected].append(fdr)
        svoxel_vols_dict[was_detected].append(svoxel_vol)  


In [None]:
data = [
    go.Parcoords(
        line = dict(color = '#33a067'),
        dimensions = list([
            dict(
                label = 'SVoxel Volume',
                values = svoxel_vols_dict[1],
                tickformat = ".2f"),
            dict(
                label = 'Perc. SVoxel on GT (Precision)',
                values = precision_dict[1],
                range = [0.0, 1.0],
                tickformat = ".2f"),
#             dict(
#                 label = 'Perc. SVoxel out of GT (FDR)',
#                 values = fdr_dict[1],
#                 range = [0.0, 1.0],
#                 tickformat = ".2f"),
            dict(
                label = 'Perc. Intersection SVoxel and GT (TPR)',
                values = tpr_dict[1],
                range = [0.0, 1.0],
                tickformat = ".2f"),
        ])
    )
]

layout = go.Layout(
    title = "Detected Supervoxels on GT"
)

fig = go.Figure(data = data, layout = layout)
py.offline.iplot(fig)

In [None]:
data = [
    go.Parcoords(
        line = dict(color = '#ee6e73'),
        dimensions = list([
            dict(
                label = 'SVoxel Volume',
                values = svoxel_vols_dict[0],
                tickformat = ".2f"),
            dict(
                label = 'Perc. SVoxel on GT (Precision)',
                values = precision_dict[0],
                tickformat = ".2f"),
#             dict(
#                 label = 'Perc. SVoxel out of GT (FDR)',
#                 values = fdr_dict[1],
#                 range = [0.0, 1.0],
#                 tickformat = ".2f"),
            dict(
                label = 'Perc. Intersection SVoxel and GT (TPR)',
                values = tpr_dict[0],
                tickformat = ".2f"),
        ])
    )
]

layout = go.Layout(
    title = "Non-Detected Supervoxels on GT"
)

fig = go.Figure(data = data, layout = layout)
py.offline.iplot(fig)