In [1]:
'''
Calculate the ssim and rmse of the mayo results
'''

'\nCalculate the ssim and rmse of the mayo results\n'

In [2]:
import skimage.metrics
import glob
import os
import SimpleITK as sitk
import pandas as pd
import pathlib
import numpy as np

In [3]:
def calc_windowed_ssim(ref, img, vmin=-160, vmax=240):
    ref = (ref.astype(np.float32) - vmin) / (vmax - vmin)
    ref[ref < 0] = 0
    ref[ref > 1] = 1
    
    img = (img.astype(np.float32) - vmin) / (vmax - vmin)
    img[img < 0] = 0
    img[img > 1] = 1
    
    ssims = []
    for r, m in zip(ref, img):
        ssims.append(skimage.metrics.structural_similarity(r, m))
        
    return np.array(ssims)

def calc_rmse(ref, img):
    rmses = []
    for r, m in zip(ref, img):
        rmses.append(np.sqrt(np.mean((r - m)**2)))
    
    return np.array(rmses)

In [4]:
def calc_std_difference(ref, img, patch_size = [32, 32], hu_range = [-160, 240], max_ref_std = 25):
    iys = np.arange(0, ref.shape[1], patch_size[0])
    iys[iys > ref.shape[1] - patch_size[0]] = ref.shape[1] - patch_size[0]
    ixs = np.arange(0, ref.shape[2], patch_size[1])
    ixs[ixs > ref.shape[2] - patch_size[1]] = ref.shape[2] - patch_size[1]
    
    ref_patches = []
    img_patches = []
    
    for iy in iys:
        for ix in ixs:
            ref_patches.append(ref[:, iy:iy+patch_size[0], ix:ix+patch_size[1]])
            img_patches.append(img[:, iy:iy+patch_size[0], ix:ix+patch_size[1]])
    
    ref_patches = np.concatenate(ref_patches)
    img_patches = np.concatenate(img_patches)
    
    stds_ref = np.std(ref_patches, axis = (1,2))
    mean_ref = np.mean(ref_patches, axis = (1,2))
    stds_img = np.std(img_patches, axis = (1,2))
    
    inds = np.where((mean_ref > hu_range[0]) & (mean_ref < hu_range[1]) & (stds_ref < max_ref_std))
    
    return stds_ref[inds], stds_img[inds], ref_patches[inds]

In [5]:
def get_attrib_from_name(filename):
    filename = pathlib.Path(filename)
    
    basename = filename.stem
    method = filename.parents[0].name
    dose_rate = filename.parents[1].name
    
    name = basename.split('_')[0]
    if dose_rate == 'quarter':
        dose = '4x (Mayo)'
    else:
        dose = '%sx'%(dose_rate.split('_')[-1])
    if method == 'ensemble':
        method = 'Ensemble'
    elif method == 'fbp':
        method = 'FBP'
    elif method == 'l2' and 'all' in basename:
        method = 'L2 (universal)'
    elif method == 'l2' and not 'all' in basename:
        method = 'L2 (matched)'
    elif method == 'wgan' and 'all' in basename:
        method = 'WGAN (universal)'
    elif method == 'wgan' and not 'all' in basename:
        method = 'WGAN (matched)'
    else:
        raise ValueError('Unexpected method: %s'%method)
    
    return {'Image': name, 
            'Tag': '%s/%s'%(dose, method),
            'filename': filename}

In [6]:
result_dir = '/home/dwu/trainData/deep_denoiser_ensemble/test/mayo_2d_3_layer_mean/'
names = ['L291', 'L143', 'L067']
margin = 96

manifest = []
for name in names:
    print (name, flush=True)
    result_list = [f for f in glob.glob(os.path.join(result_dir, '*/*', '%s*.nii'%name)) if 'full' not in f]
    
    # load the reference image
    ref = sitk.GetArrayFromImage(sitk.ReadImage(os.path.join(result_dir, 'full/fbp/%s.nii')%name))[:, margin:-margin, margin:-margin]
    print (len(result_list), end=': ', flush=True)
    metric_manifests = []
    for k, result_name in enumerate(result_list):
        print (k, end = ',', flush=True)
        
        attrib = get_attrib_from_name(result_name)
        img = sitk.GetArrayFromImage(sitk.ReadImage(result_name))[:, margin:-margin, margin:-margin]
        
        rmses = calc_rmse(ref, img)
        ssims = calc_windowed_ssim(ref, img)
        
        df = pd.DataFrame({'Image': [attrib['Image']] * len(img), 
                           'Slice': range(len(img)), 
                           attrib['Tag'] + '/RMSE': rmses, 
                           attrib['Tag'] + '/SSIM': ssims})
        metric_manifests.append(df)
        
    print ('')
    
    # merge the manifests
    df = metric_manifests[0]
    for m in metric_manifests[1:]:
        df = pd.merge(df, m, on = ['Image', 'Slice'])
    manifest.append(df)

manifest = pd.concat(manifest, ignore_index = True)
manifest.to_csv('./mayo2d_ssim_rmse.csv', index=False)

L291
30: 0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18,19,20,21,22,23,24,25,26,27,28,29,
L143
30: 0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18,19,20,21,22,23,24,25,26,27,28,29,
L067
30: 0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18,19,20,21,22,23,24,25,26,27,28,29,
