### Review results
This notebook will evaluate predictions against ground truth segmentation masks; contains utils for computing the Dice score and visalizing results.

In [20]:
import numpy as np
import os
import sys
from shared_utils import *
from tqdm import tqdm
from matplotlib import pyplot as plt
sys.path.append("../dauphin/image_segmentation/datasets")
from utils import custom_preproc_op
import seaborn as sns
import pandas as pd
import yaml


In [21]:
# Set color palette
colors = ['#1147A1','#5DC1EE','#89C579','#DE6400']
sns.set_palette(sns.color_palette(colors))
sns.color_palette()


### Set data to review - should only need to modify this cell for custom datasets

In [22]:
task = 'acdc' # Task name, should match the task name where data is stored
split = 'test' # Data split you want to evaluate (train, val, test)
data_path = '../data/'+task+'/'+split # Path to where image data is stored
config_file = '../data/'+task+'/config.yaml' # Path to config file for this dataset
gt_seg_path = '../data/'+task+'/'+split # Path to where ground truth segmentation masks are stored, typically the same as data_path
pred_seg_paths = ['../logs/'+task+'_logs/UNet/gt_train_labels/augment_4/with_consistency_loss/without_uncertainty_thresh/csv_acdc_5-keys_all-slice-per-key_seed-1/seed_1/final_preds/'+split] # Path to where predicted segmentation masks are stored (list); can include more than one path in list to evaluate multiple prediction dirs
img_size_r  = 224 # Should be the same as the param used to train seg networks
img_size_c  = 224 # Should be the same as the param used to train seg networks
print('Done.')

Done.


### Load data

In [23]:
input_data = []
gt_data = []
pred_data = {pred_name:[] for pred_name in pred_seg_paths} 

# Get all successful keys, which have an image, gt, and prediction available
successful_keys = [('_').join(p.split('/')[-1].split('_')[:-1]) for p in glob.glob(os.path.join(data_path,'*_image*'))]
print('Number of image keys available:',len(successful_keys))   
for pred_path in pred_seg_paths:
    all_pred_keys = [('_').join(p.split('/')[-1].split('_')[:-1]) for p in glob.glob(os.path.join(pred_path,'*_seg*'))]
    successful_keys = [i for i in successful_keys if i in all_pred_keys] 
print('Number of keys that also have predictions saved:',len(successful_keys))
  
# Save input and gt data
print('Retreiving image and gt segs')
for key in tqdm(successful_keys):
    input_data += [np.load(os.path.join(data_path,key+'_image.npy'))]
    try:
        gt_data += [np.load(os.path.join(gt_seg_path,key+'_seg.npy'))]
    except:
        gt_data += [np.zeros_like(input_data[-1])]
            
# Save pred data
print('Retreiving predictions')
for pred_path in pred_seg_paths:
    for key in tqdm(successful_keys):
        pred_data[pred_path] += [np.argmax(np.load(os.path.join(pred_path,key+'_seg.npy')),-1)]

# Load config yaml
with open(config_file, "r") as cf:
    config = yaml.safe_load(cf)

print('Lengths of input, ground truth, and all predicted data:',len(input_data),len(gt_data),[len(data) for data in pred_data.values()])

print('Done.')

Number of image keys available: 30
Number of keys that also have predictions saved: 30
Retreiving image and gt segs


100%|██████████| 30/30 [00:00<00:00, 1212.85it/s]


Retreiving predictions


100%|██████████| 30/30 [00:00<00:00, 49.27it/s]

Lengths of input, ground truth, and all predicted data: 30 30 [30]
Done.





### Transform data

In [None]:
# Resize images, gt masks, and predictions to the same size
for key_ind in tqdm(range(len(input_data))):
    input_data[key_ind] = custom_preproc_op(input_data[key_ind], img_size_r, img_size_c)
    gt_data[key_ind] = custom_preproc_op(gt_data[key_ind], img_size_r, img_size_c, order=0)
    for pred_name in pred_seg_paths: pred_data[pred_name][key_ind] = custom_preproc_op(pred_data[pred_name][key_ind], img_size_r, img_size_c, order=0)

print('Done.')

### Compute quantitative metrics

In [None]:
# Compute macro 3d dice coefficient
label_dict = {v: k for k, v in config['label_mapping'].items()}
dice_3d = {(w,label_dict[c]):[] for w in pred_data.keys() for c in range(1,config['num_classes'])}

for weak_seg_type,weak_seg_masks in pred_data.items():
    for class_ind in range(1,config['num_classes']):
        for ind, weak_seg_mask in enumerate(weak_seg_masks):
            dice_3d[weak_seg_type,label_dict[class_ind]] += [compute_dice(gt_data[ind]==class_ind, weak_seg_mask==class_ind)] 
                    
print('Done.')

In [None]:
# Print all DICE scores
print('Segmentation scores:')
for k, v in dice_3d.items():
    print('\nPredictions:',k[0])
    print('Class:',k[1])
    print('\t 3D mean, median dice:',np.around(np.mean(v),3),np.around(np.median(v),3))
    print('\t 3D std dice:',np.around(np.std(v),3))
    print('\t 3D min, max dice:',np.around(np.min(v),3),np.around(np.max(v),3))

print('\nDone.')

In [None]:
# Plot Dice histogram

for task_class, dice_scores in dice_3d.items():
    
    dice3d_df = pd.DataFrame.from_dict({'Dice':dice_scores})
    
    fig,ax = plt.subplots(1,1,figsize=(7,5))
    
    dice_hist = sns.histplot(data=dice3d_df, x='Dice')
    dice_hist.set_xlabel("Dice",fontsize=14)
    dice_hist.set_ylabel("Count",fontsize=14)
    dice_hist.set_xlim(0,1)
    dice_hist.set_title(k)
    dice_hist.tick_params(labelsize=14)
    
    plt.show()
    plt.close()


### Plot some example results that achieve min, median, and max Dice scores

In [None]:

for task_class, all_dice_scores in dice_3d.items():
    print('\nPredictions:',task_class[0])
    print('Class:',task_class[1])
    
    sorted_dice = sorted(all_dice_scores)
    sorted_inds = [ind for _, ind in sorted(zip(all_dice_scores,range(len(all_dice_scores))))]

    num_per = 3
    min_ind = sorted_inds[0:num_per]
    max_ind = sorted_inds[-num_per:]
    med_ind = sorted_inds[len(sorted_inds)//2-num_per//2:len(sorted_inds)//2+(num_per-num_per//2)+1]

    for img_ind, img_id in zip(min_ind+max_ind+med_ind,['Min Dice']*len(min_ind) + ['Max Dice']*len(max_ind) + ['Median Dice']*len(med_ind)):

        pred = pred_data[task_class[0]][img_ind] == config['label_mapping'][task_class[1]]
        gt = gt_data[img_ind] == config['label_mapping'][task_class[1]]
        tps = np.logical_and(pred==1, gt==1)
        fns = np.logical_and(pred==0, gt==1)
        fps = np.logical_and(pred==1, gt==0)

        img_key = successful_keys[img_ind]
        plot_classes(input_data[img_ind],tps,fps,fns,size=10,title=(' ').join([img_id,img_key,'Dice='+str(np.around(all_dice_scores[img_ind],3))]))


