# Comparison of weights

In [1]:
import glob
import nibabel as nib
import numpy as np
import matplotlib.pyplot as plt
from nilearn.image import resample_img
import os
from scipy.ndimage.morphology import binary_erosion

import scipy.stats as stats

import seaborn as sbs

import pickle

In [2]:
# Function that takes dictionary and model and returns Sorensen-Dice coefficients for target-prediction pairs

# Sorensen-Dice coefficient
def dice(img_true, img_pred, smooth=0):       
    intersection = np.sum(img_true * img_pred)
    union = np.sum(img_true) + np.sum(img_pred)
    dice = (2. * intersection + smooth)/(union+smooth)   
    return dice
                    
def calculate_dice(path1, path2):
    file_target, file_prediction = nib.load(path1), nib.load(path2)   
    file_target_trans, file_prediction_trans = resample_img(file_target, file_prediction.affine, file_prediction.shape, 'linear'),resample_img(file_prediction, file_prediction.affine, file_prediction.shape, 'linear')       
    data_target, data_prediction = file_target_trans.get_fdata(), file_prediction_trans.get_fdata()
    return dice(data_target, data_prediction)

In [3]:
# Groups
groups=['HCP','CHIASM']

# Participants
CHIASM_con=['CON1','CON2','CON3','CON4','CON5','CON6','CON7','CON8']
CHIASM_alb=['ALB1','ALB2','ALB3','ALB4','ALB5','ALB6','ALB7','ALB8','ALB9']
HCP_con= ['101107','118730','131823','134223','151425','165436','208226','304727','379657','673455']

# Path to data
data_folder='../../1_Data/'

In [4]:
# Parameters determining prediction of DCNN
weights=['13ep_00025lr_dice','15ep_0003lr_dice','30ep_00025lr_dice','40ep_00015lr_dice','100ep_00005lr_dice']
connectivity_type=['1','2','3']
cutoff_threshold=['0.25','0.5','0.75','1']

In [5]:
# Dictionary with all the DSC for HCP
HCP_qa={}

for connectivity in connectivity_type:
    HCP_qa[connectivity]={}
    for weight in weights:
        HCP_qa[connectivity][weight]={}
        for thr in cutoff_threshold:
            HCP_qa[connectivity][weight][thr]={}
            for sub in HCP_con:
                hand=data_folder+'2_X-mask_manual/HCP/'+sub+'/X-mask_manual.nii.gz'
                pred=data_folder+'5_X-mask_CNN/training_'+weight+'/connectivity_'+connectivity+'/threshold_'+thr+'/HCP/'+sub+'/X-mask_CNN_cropped_to_gt.nii.gz'
                HCP_qa[connectivity][weight][thr][sub]=calculate_dice(hand, pred)

In [6]:
# Dictionary with all the DSC for CHIASM
CHIASM_qa={}

for connectivity in connectivity_type:
    CHIASM_qa[connectivity]={}
    for weight in weights:
        CHIASM_qa[connectivity][weight]={}
        for thr in cutoff_threshold:
            CHIASM_qa[connectivity][weight][thr]={}
            for sub in CHIASM_con:
                hand=data_folder+'2_X-mask_manual/CHIASM/'+sub+'/X-mask_manual.nii.gz'
                pred=data_folder+'5_X-mask_CNN/training_'+weight+'/connectivity_'+connectivity+'/threshold_'+thr+'/CHIASM/'+sub+'/X-mask_CNN_cropped_to_gt.nii.gz'
                CHIASM_qa[connectivity][weight][thr][sub]=calculate_dice(hand, pred)

In [7]:
# Mean DSC for HCP
for i in connectivity_type:
    print('connectivity_%-7s%-20s%-20s%-20s%-20s' % (i,cutoff_threshold[0],cutoff_threshold[1],cutoff_threshold[2],cutoff_threshold[3]))
    for j in range(len(weights)):
        print('%-20s%-20f%-20f%-20f%-20f' % (weights[j],np.mean(list((HCP_qa[i][weights[j]][cutoff_threshold[0]]).values())),np.mean(list((HCP_qa[i][weights[j]][cutoff_threshold[1]]).values())),np.mean(list((HCP_qa[i][weights[j]][cutoff_threshold[2]]).values())),np.mean(list((HCP_qa[i][weights[j]][cutoff_threshold[3]]).values()))))
    print('')

connectivity_1      0.25                0.5                 0.75                1                   
13ep_00025lr_dice   0.766978            0.770998            0.774405            0.810526            
15ep_0003lr_dice    0.758212            0.761249            0.766351            0.785145            
30ep_00025lr_dice   0.748503            0.751528            0.755451            0.794636            
40ep_00015lr_dice   0.776841            0.783669            0.790386            0.721510            
100ep_00005lr_dice  0.766088            0.765429            0.766529            0.758185            

connectivity_2      0.25                0.5                 0.75                1                   
13ep_00025lr_dice   0.766978            0.770998            0.774405            0.810526            
15ep_0003lr_dice    0.758212            0.761249            0.766180            0.785145            
30ep_00025lr_dice   0.748503            0.751528            0.755451            0.794636  

In [8]:
# Std DSC for HCP
for i in connectivity_type:
    print('connectivity_%-7s%-20s%-20s%-20s%-20s' % (i,cutoff_threshold[0],cutoff_threshold[1],cutoff_threshold[2],cutoff_threshold[3]))
    for j in range(len(weights)):
        print('%-20s%-20f%-20f%-20f%-20f' % (weights[j],stats.sem(list((HCP_qa[i][weights[j]][cutoff_threshold[0]]).values())),stats.sem(list((HCP_qa[i][weights[j]][cutoff_threshold[1]]).values())),stats.sem(list((HCP_qa[i][weights[j]][cutoff_threshold[2]]).values())),stats.sem(list((HCP_qa[i][weights[j]][cutoff_threshold[3]]).values()))))
    print('')

connectivity_1      0.25                0.5                 0.75                1                   
13ep_00025lr_dice   0.021901            0.021156            0.020869            0.017743            
15ep_0003lr_dice    0.022031            0.022153            0.020911            0.016619            
30ep_00025lr_dice   0.022135            0.021953            0.021605            0.020032            
40ep_00015lr_dice   0.024265            0.024044            0.023328            0.042282            
100ep_00005lr_dice  0.016882            0.017978            0.018524            0.022553            

connectivity_2      0.25                0.5                 0.75                1                   
13ep_00025lr_dice   0.021901            0.021156            0.020869            0.017743            
15ep_0003lr_dice    0.022031            0.022153            0.020872            0.016619            
30ep_00025lr_dice   0.022135            0.021953            0.021605            0.020032  

In [9]:
# Mean DSC for CHIASM
for i in connectivity_type:
    print('connectivity_%-7s%-20s%-20s%-20s%-20s' % (i,cutoff_threshold[0],cutoff_threshold[1],cutoff_threshold[2],cutoff_threshold[3]))
    for j in range(len(weights)):
        print('%-20s%-20f%-20f%-20f%-20f' % (weights[j],np.mean(list((CHIASM_qa[i][weights[j]][cutoff_threshold[0]]).values())),np.mean(list((CHIASM_qa[i][weights[j]][cutoff_threshold[1]]).values())),np.mean(list((CHIASM_qa[i][weights[j]][cutoff_threshold[2]]).values())),np.mean(list((CHIASM_qa[i][weights[j]][cutoff_threshold[3]]).values()))))
    print('')

connectivity_1      0.25                0.5                 0.75                1                   
13ep_00025lr_dice   0.579935            0.560515            0.545142            0.245325            
15ep_0003lr_dice    0.088544            0.082007            0.077209            0.022426            
30ep_00025lr_dice   0.777995            0.778557            0.774959            0.747723            
40ep_00015lr_dice   0.601812            0.601190            0.601547            0.423171            
100ep_00005lr_dice  0.396440            0.386062            0.378783            0.253640            

connectivity_2      0.25                0.5                 0.75                1                   
13ep_00025lr_dice   0.584541            0.570050            0.553982            0.251205            
15ep_0003lr_dice    0.096688            0.089928            0.082181            0.027490            
30ep_00025lr_dice   0.777906            0.778557            0.774959            0.747948  

In [10]:
# Std DSC for CHIASM
for i in connectivity_type:
    print('connectivity_%-7s%-20s%-20s%-20s%-20s' % (i,cutoff_threshold[0],cutoff_threshold[1],cutoff_threshold[2],cutoff_threshold[3]))
    for j in range(len(weights)):
        print('%-20s%-20f%-20f%-20f%-20f' % (weights[j],stats.sem(list((CHIASM_qa[i][weights[j]][cutoff_threshold[0]]).values())),stats.sem(list((CHIASM_qa[i][weights[j]][cutoff_threshold[1]]).values())),stats.sem(list((CHIASM_qa[i][weights[j]][cutoff_threshold[2]]).values())),stats.sem(list((CHIASM_qa[i][weights[j]][cutoff_threshold[3]]).values()))))
    print('')

connectivity_1      0.25                0.5                 0.75                1                   
13ep_00025lr_dice   0.040836            0.043856            0.045127            0.059391            
15ep_0003lr_dice    0.027205            0.025332            0.023801            0.011964            
30ep_00025lr_dice   0.028792            0.029290            0.029248            0.025780            
40ep_00015lr_dice   0.032322            0.032600            0.031363            0.039250            
100ep_00005lr_dice  0.053125            0.053898            0.053029            0.056425            

connectivity_2      0.25                0.5                 0.75                1                   
13ep_00025lr_dice   0.039879            0.042212            0.043619            0.056282            
15ep_0003lr_dice    0.027966            0.026181            0.023635            0.011662            
30ep_00025lr_dice   0.028829            0.029290            0.029248            0.025722  