In [1]:
'''
Calculate the texture difference of the mayo results (currently using only patch-std)
'''

'\nCalculate the texture difference of the mayo results (currently using only patch-std)\n'

In [2]:
import glob
import os
import SimpleITK as sitk
import pandas as pd
import pathlib
import numpy as np

In [3]:
def calc_std_difference(ref, img, patch_size = [32, 32], hu_range = [0, 240], max_ref_std = 20):
    iys = np.arange(0, ref.shape[1], patch_size[0])
    iys[iys > ref.shape[1] - patch_size[0]] = ref.shape[1] - patch_size[0]
    ixs = np.arange(0, ref.shape[2], patch_size[1])
    ixs[ixs > ref.shape[2] - patch_size[1]] = ref.shape[2] - patch_size[1]
    
    ref_patches = []
    img_patches = []
    
    for iy in iys:
        for ix in ixs:
            ref_patches.append(ref[:, iy:iy+patch_size[0], ix:ix+patch_size[1]])
            img_patches.append(img[:, iy:iy+patch_size[0], ix:ix+patch_size[1]])
    
    ref_patches = np.concatenate(ref_patches)
    img_patches = np.concatenate(img_patches)
    
    stds_ref = np.std(ref_patches, axis = (1,2))
    mean_ref = np.mean(ref_patches, axis = (1,2))
    stds_img = np.std(img_patches, axis = (1,2))
    
    inds = np.where((mean_ref > hu_range[0]) & (mean_ref < hu_range[1]) & (stds_ref < max_ref_std))
    
    stds_img = stds_img[inds]
    stds_ref = stds_ref[inds]
    ref_patches = ref_patches[inds]
    inds = np.argsort(stds_ref)
    
    return stds_img[inds] - stds_ref[inds], ref_patches[inds]

In [7]:
def get_attrib_from_name(filename):
    filename = pathlib.Path(filename)
    
    basename = filename.stem
    method = filename.parents[0].name
    dose_rate = filename.parents[1].name
    
    name = basename.split('_')[0]
    if dose_rate == 'quarter':
        dose = '4x (Mayo)'
    else:
        dose = '%sx'%(dose_rate.split('_')[-1])
    if method == 'ensemble':
        method = 'Ensemble'
    elif method == 'average':
        method = 'Average'
    elif method == 'fbp':
        method = 'FBP'
    elif method == 'l2' and 'all' in basename:
        method = 'L2 (universal)'
    elif method == 'l2' and not 'all' in basename:
        method = 'L2 (matched)'
    elif method == 'wgan' and 'all' in basename:
        method = 'WGAN (universal)'
    elif method == 'wgan' and not 'all' in basename:
        method = 'WGAN (matched)'
    else:
        raise ValueError('Unexpected method: %s'%method)
    
    return {'Image': name, 
            'Tag': '%s/%s'%(dose, method),
            'filename': filename}

In [9]:
result_dir = '/home/local/PARTNERS/dw640/mnt/women_health_internal/dufan.wu/deep_denoiser_ensemble/test/mayo_2d_3_layer_mean'
names = ['L291', 'L143', 'L067']
margin = 96

# names = ['L143']

manifest = []
for name in names:
    print (name, flush=True)
    result_list = [f for f in glob.glob(os.path.join(result_dir, '*/*', '%s*.nii'%name)) if 'full' not in f]
    
    # load the reference image
    ref = sitk.GetArrayFromImage(sitk.ReadImage(os.path.join(result_dir, 'full/fbp/%s.nii')%name))[:, margin:-margin, margin:-margin]
    print (len(result_list), end=': ', flush=True)
    metric_manifests = []
    for k, result_name in enumerate(result_list):
        print (k, end = ',', flush=True)
        
        attrib = get_attrib_from_name(result_name)
        img = sitk.GetArrayFromImage(sitk.ReadImage(result_name))[:, margin:-margin, margin:-margin]
        
        std_dist, ref_patches = calc_std_difference(ref, img)
        std_dist = std_dist[:100]
        
        df = pd.DataFrame({'Image': [attrib['Image']] * len(std_dist), 
                           'Patch': range(len(std_dist)), 
                           attrib['Tag'] + '/std': std_dist})
        metric_manifests.append(df)
#         break
        
    print ('')
    
    # merge the manifests
    df = metric_manifests[0]
    for m in metric_manifests[1:]:
        df = pd.merge(df, m, on = ['Image', 'Patch'])
    manifest.append(df)
#     break
        
manifest = pd.concat(manifest, ignore_index = True)
manifest.to_csv('./mayo2d_texture_avg.csv', index=False)

L291
29: 0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18,19,20,21,22,23,24,25,26,27,28,
L143
29: 0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18,19,20,21,22,23,24,25,26,27,28,
L067
29: 0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18,19,20,21,22,23,24,25,26,27,28,


In [6]:
# ref_patches.shape

In [7]:
# import matplotlib.pyplot as plt
# patch_img = np.zeros([320, 320])
# for i in range(100):
#     ix = i % 10
#     iy = i // 10
    
#     patch_img[iy*32:(iy+1)*32, ix*32:(ix+1)*32] = ref_patches[i] - np.mean(ref_patches[i])
# plt.imshow(patch_img, 'gray', vmin=-160, vmax=240)