In [1]:
import nibabel as nib
import os
import numpy as np
import matplotlib.pyplot as plt
import SimpleITK as sitk
import pandas as pd
import random
import seaborn as sns

### 0 Preparation
#### 0.1 Select out scans for test
* select AD/HC dev and test scans
* round the age
* if there are multiple scans at an age, only keep the one cloest to the rounded age
* restrict age range: 60 ~ 80 (with good template quality)
* select subjects with multiple scans
#### 0.2 Select reference scan
* randomly select a source (reference) scan for each subject, and the others are destination scans
<br>    if there are 3 or more scans for a subejct, randomly select a middle one as the reference

In [148]:
def filter_csv(disease_condn='HC', age_min=60, age_max=80, csv_path='./FS_scans_metadata_with_split.csv', save_path='./'):
    
    disease_code = 0 if disease_condn == 'HC' else 1
    df = pd.read_csv(csv_path)

    ids_set = df['ID'][df['Partition'].isin(['dev', 'test'])][df['HC0/AD1'] == disease_code]
    df = df[df['ID'].isin(ids_set)]

    df['age_rounded'] = df['age'].round().astype('int')

    for i in ids_set:
        df.loc[df['ID'] == i, 'subject'] = i[: 8]

    for s in df['subject']:
        age_set = df['age_rounded'][df['subject'] == s]
        age_idx = age_set.index

        for a in range(len(age_set) - 1):
            if age_set[age_idx[a]] == age_set[age_idx[a + 1]]:
                if abs(df['age'][age_idx[a]] - df['age_rounded'][age_idx[a]]) < abs(df['age'][age_idx[a + 1]] - df['age_rounded'][age_idx[a + 1]]):
                    df = df.drop(index=age_idx[a + 1])
                else:
                    df = df.drop(index=age_idx[a])

    for i in df['ID']:
        df.loc[df['ID'] == i, 'uid'] = i.split('_')[0] + '_' + i.split('_')[2]

    df = df[(df['age_rounded'] >= age_min) & (df['age_rounded'] <= age_max)]

    cnt_result = df['subject'].value_counts()
    for s in df['subject']:
        df.loc[df['subject'] == s, 'subject_cnt'] = cnt_result[s]
    df['subject_cnt'] = df['subject_cnt'].astype('int')
    df = df[df['subject_cnt'] > 1]

    num_of_scans = len(df['subject'])
    num_of_subjects = len(df['subject'].value_counts())
    print(f'num of {disease_condn} subjects: ', num_of_subjects)
    print(f'num of {disease_condn} scans: ', num_of_scans)

    random.seed(90)
    subject_list = df['subject'].value_counts().index
    for s in subject_list:
        age_list = df['age_rounded'][df['subject'] == s]
        if len(age_list) == 2:
            rdm_src = int(np.array(random.sample(list(age_list), 1)))
        elif len(age_list) == 3:
            rdm_src = int(np.array(list(age_list)[1]))
        else:
            rdm_src = int(np.array(random.sample(list(age_list)[1: -1], 1)))
            
        df.loc[(df['subject'] == s) & (df['age_rounded'] == rdm_src), 'tag'] = 'src'

    df.loc[df['tag'] != 'src', 'tag'] = 'dst'

    save_name = os.path.join(save_path, 'FS_scans_test_{}.csv'.format(disease_condn))
    pd.DataFrame(df).to_csv(save_name, index=False)

    return df

In [None]:
# HC: 33 subjects, 70 scans
# AD: 7 subjects, 16 scans
csv_path = './FS_scans_metadata_with_split.csv'
df = filter_csv(disease_condn='AD', age_min=60, age_max=80, csv_path=csv_path, save_path='./0418/')

In [152]:
csv_path = './0418/FS_scans_test_HC.csv'
df = pd.read_csv(csv_path)

In [194]:
def construct_subject_list(csv_path):

    df = pd.read_csv(csv_path)
    subject_list = df['subject'].value_counts().index
    test_subject_list = []
    for s in subject_list:
        src_age = df['age_rounded'][df['subject'] == s][df['tag'] == 'src'].tolist()[0]
        src_subject = df['uid'][df['subject'] == s][df['tag'] == 'src'].tolist()[0]

        dst_age = df['age_rounded'][df['subject'] == s][df['tag'] == 'dst']
        for d in dst_age:
            dst_subject = df['uid'][df['subject'] == s][df['age_rounded'] == d].tolist()[0]
            test_subject_list.append([s, src_subject, src_age, dst_subject, d])
    
    return test_subject_list

### 1 Test a single subject


In [27]:
def load_nii(path):
    img = nib.load(path)
    return(img.get_fdata())


def mha2nii(mha_path, nii_path):
    img = sitk.ReadImage(mha_path)
    sitk.WriteImage(img, nii_path)


def np2nii(data, save_path, save_name):
    img = nib.Nifti1Image(data, np.eye(4))

    img.header.get_xyzt_units()
    img.to_filename(os.path.join(save_path, save_name))

In [None]:
d_path = '/users/yuqiz/downloads/test_0310/0418/demo_inv_0425/T71toOAS30671_0_vel.nii.gz'
d_inv_path = '/users/yuqiz/downloads/test_0310/0418/demo_inv_0425/T71toOAS30671_0_inv_vel.nii.gz'

d = load_nii(d_path)
d_inv = load_nii(d_inv_path)

print(np.sum(d+d_inv))

In [20]:
mha_path = '/users/yuqiz/downloads/test_0310/0418/demo_inv_0425/OAS30671_71to68_0_vel_25.mha'
mha_inv_path = '/users/yuqiz/downloads/test_0310/0418/demo_inv_0425/OAS30671_71to68_0_inv_vel_25.mha'

mha2nii(mha_path, mha_path[: -4] + '.nii.gz')
mha2nii(mha_inv_path, mha_inv_path[: -4] + '.nii.gz')

In [22]:
# similarity metrics
from skimage.metrics import structural_similarity as SSIM
from skimage.metrics import peak_signal_noise_ratio as PSNR

def mae_cal(im1, im2):
    # mean absolute error
    return np.mean(np.abs(im1 - im2))


def ssim_cal(im1, im2):
    # structural similarity index
    return SSIM(im1, im2)


def ncc_cal(im1, im2):
    # normalized cross-correlation
    mean1 = np.mean(im1)
    mean2 = np.mean(im2)
    numerator = np.sum((im1 - mean1) * (im2 - mean2))
    denominator = np.sqrt(np.sum((im1 - mean1) ** 2) * np.sum((im2 - mean2) ** 2))

    return numerator / denominator


def psnr_cal(im1, im2):
    # peak signal-to-noise ratio
    if PSNR(im1, im2) == float('inf'):
        psnr = 100.0
    else:
        psnr = PSNR(im1, im2)

    return psnr


def nfn_cal(im1, im2):
    # normalized Frobenius norm
    return np.sqrt(np.mean((im1 - im2) ** 2))


def DSC_loss(seg1, seg2):
    smooth = 1e-5
    m1 = seg1.flatten()
    m2 = seg2.flatten()

    return (2. * (m1 * m2).sum() + smooth) / (m1.sum() + m2.sum() + smooth)


def label_dice_cal(seg1, seg2):
    # Dice score
    label_list = np.unique(seg1)
    label_list = label_list[label_list != 0]
    dice_list = []
    for label in label_list:
        dice = DSC_loss(seg1 == label, seg2 == label)
        dice_list.append(dice)

    return np.mean(dice_list)


In [26]:
# calculate similarity metrics:
i = ['OAS30671', 'OAS30671_d1122', 71, 'OAS30671_d0267', 68]
img1 = f'/users/yuqiz/downloads/test_0310/0319/all_nii/{i[3]}.nii.gz'
seg1 = f'/users/yuqiz/downloads/test_0310/0319/segmentations/subjects/{i[3]}_seg.nii.gz'
# pre_OAS30671_71to68_0_inv_25_seg.nii
img2 = f'/users/yuqiz/downloads/test_0310/0418/demo_inv_0425/pre_{i[0]}_{i[2]}to{i[-1]}_0.nii.gz'
seg2 = f'/users/yuqiz/downloads/test_0310/0418/demo_inv_0425/pre_{i[0]}_{i[2]}to{i[-1]}_0_seg.nii.gz'
img1 = load_nii(img1)
img2 = load_nii(img2)
seg1 = load_nii(seg1)
seg2 = load_nii(seg2)

mae = mae_cal(img1, img2)
ssim = ssim_cal(img1, img2)
psnr = psnr_cal(img1, img2)
ncc = ncc_cal(img1, img2)
nfn = nfn_cal(img1, img2)
dsc = label_dice_cal(seg1, seg2)

print(mae, ssim, psnr, ncc, nfn, dsc)

0.016515874258181473 0.9615557624659089 28.617788044300543 0.9903289388614493 0.03707751316565434 0.9321995409511787


### 2 Cohort-level analysis

In [None]:
def save_list(list_to_save, save_name, save_path):
    file = open(f'{save_path}{save_name}.txt', 'w')
    for s in list_to_save:
        file.write(str(s))
        file.write('\n')
    file.close()


def read_list(list_name, list_path):
    file = open(f'{list_path}{list_name}.txt', 'r')
    X = file.readlines()
    for i in range(len(X)):
        X[i] = X[i].strip()
        X[i] = X[i].strip("[]")
        X[i] = X[i].split(",")
        X[i] = [X[i][j].strip(" '") for j in range(len(X[i]))]
    file.close()

    return X


def extract_voxels_in_seg(seg, labels, sumup=True):
    voxels = []
    for label in labels:
        voxels.append(np.count_nonzero(seg == label))
    if sumup:
        voxels = sum(voxels)
    return voxels

# cohorts: HC/AD × train, template, test subject(ground truth), test subject(prediction)
# train: npz ['seg']
# template, ground truth, prediction: {}_seg.nii.gz

def calc_vox_subject(subject_list, structure_label, disease='HC', option='GT'):

    ages = []
    voxels = []
    cohorts = [f'Test/{disease}/{option}'] * len(subject_list)

    for i in subject_list:
        if option == 'GT':
            seg = f'/users/yuqiz/downloads/test_0310/0319/segmentations/subjects/{i[3]}_seg.nii.gz'
        elif option == 'pre':
            seg = f'/users/yuqiz/downloads/test_0310/0319/segmentations/pre_{disease}/pre_{i[0]}_{i[2]}to{i[-1]}_seg.nii.gz'
        seg = load_nii(seg)
        voxels.append(extract_voxels_in_seg(seg, structure_label))
        ages.append(i[-1])

    return ages, voxels, cohorts


def calc_vox_template(age_range, structure_label, disease='HC', disease_condn=0):
    
    ages = []
    voxels = []
    cohorts = [f'Template/{disease}'] * (age_range[1] - age_range[0] + 1)

    for i in range(age_range[0], age_range[1] + 1):
        seg = f'/users/yuqiz/downloads/test_0310/corrected_nii/age_{i}disease_{disease_condn}_SynthSeg.nii.gz'
        seg = load_nii(seg)
        voxels.append(extract_voxels_in_seg(seg, structure_label))
        ages.append(str(i))

    return ages, voxels, cohorts


def plot_given_structure(ages, voxels, cohorts, columns=['age', 'Number of Voxels', 'Cohorts'], save_path_name=None):
    data = {'age': ages, 'Number of Voxels': voxels, 'Cohorts': cohorts}
    df = pd.DataFrame(data, columns=columns)

    sns.lineplot(data=df, x='age', y='Number of Voxels', hue='Cohorts', style='Cohorts', markers=True)
    plt.ticklabel_format(axis='y', style='sci', scilimits=(0, 0))
    if save_path_name:
        plt.savefig(save_path_name, bbox_inches='tight', dpi=1200)
    plt.show()


def draw_volumetric_trends(structure_label, subject_list_HC, subject_list_AD, age_range=[60, 80], save_path_name=None):

    # Template/HC
    ages, voxels, cohorts = calc_vox_template(age_range, structure_label, disease='HC', disease_condn=0)
    # Template/AD
    ages1, voxels1, cohorts1 = calc_vox_template(age_range, structure_label, disease='AD', disease_condn=1)

    # Test/HC/GT
    ages2, voxels2, cohorts2 = calc_vox_subject(subject_list_HC, structure_label, disease='HC', option='GT')
    # Test/AD/GT
    ages3, voxels3, cohorts3 = calc_vox_subject(subject_list_AD, structure_label, disease='AD', option='GT')
    # Test/HC/pre
    ages4, voxels4, cohorts4 = calc_vox_subject(subject_list_HC, structure_label, disease='HC', option='pre')
    # Test/AD/pre
    ages5, voxels5, cohorts5 = calc_vox_subject(subject_list_AD, structure_label, disease='AD', option='pre')

    ages.extend(ages1)
    ages.extend(ages2)
    ages.extend(ages3)
    ages.extend(ages4)
    ages.extend(ages5)

    voxels.extend(voxels1)
    voxels.extend(voxels2)
    voxels.extend(voxels3)
    voxels.extend(voxels4)
    voxels.extend(voxels5)

    cohorts.extend(cohorts1)
    cohorts.extend(cohorts2)
    cohorts.extend(cohorts3)
    cohorts.extend(cohorts4)
    cohorts.extend(cohorts5)

    plot_given_structure(ages, voxels, cohorts, columns=['age', 'Number of Voxels', 'Cohorts'], save_path_name=save_path_name)

In [None]:
save_path = '/users/yuqiz/downloads/test_0310/0319/'
subject_list_HC = read_list('test_subject_list_HC', save_path)
subject_list_AD = read_list('test_subject_list_AD', save_path)
structure_label = [4, 14, 15, 43]
age_range=[60, 80]
save_path_name = '/users/yuqiz/downloads/test_0310/0319/ventricle_volumetric_trends.jpg'
draw_volumetric_trends(structure_label, subject_list_HC, subject_list_AD, age_range, save_path_name=save_path_name)