In [2]:
# 3D general patch from T2

# background info
# For T2, we used 1. mean \in (0.2*max(image), 0.6*max(img)) 2. diff 

# guideline
# - use the axial axis for the 3d volume
# - check the first and last slice from the volume
# - apply mean and diff 
# - use AND for the two slices
# - zoomed

# questions:
# - do we save enough patch using the strategy above?  Wait for the current training
# - do we need to filter all slice within the volume?  No
# - do the criteria for T2 apply for the T1 cases? Yes

# func: 
# input: numpy array after pre-processing 
# output: hdf5 files

# procedures:
# - select one case (101_Id_013) => with motion, 
# - save the slice location for each patch (axial axis)
# - calculate the max for each slice(image-based)
# - filter the first and the last patch
# - apply AND
# - saved as hdf5

In [3]:
import pydicom
import os
import sys
sys.path.append("/home/subtle/Long/SubtleMR/src/utils")
from co_registration import extract_info, register_im
from others import pad_array
import numpy as np
import matplotlib.pyplot as plt
from skimage.util.shape import view_as_windows
import h5py

In [10]:
# extract patch
def filter_3d(norm_low, norm_high, patch_size, step, thresh_low_ratio, thresh_high_ratio,
              thresh_diff_ratio, channel, verbose):
    """
    generate a set of 3d patch that could be used for training
    
    args:
    -------
    norm_low: the original volume for the lowres (after pre-processing)
    norm_high: the original volume for the highres (after pre-processing)
    patch_size: the patch size for the 3d patch volume (int)
    step: steps (int)
    thresh_low_ratio: a ratio used to filter out the background area
    thresh_diff_ratio: a ratio used to filter out the large intensity difference
    verbose: verbose mode
    
    return:
    -------
    saved_patch_ins: 3d low-res patches for training
    saved_patch_outs: 3d high-res patches for training
    """
    # calculate norm_max
    norm_slice_max = np.max(norm_low, axis=tuple(x for x in (0, 1, 2) if x != channel))
    
    # set up the patch
    norm_patch_ins = view_as_windows(norm_low, window_shape=patch_size, step=step)
    norm_patch_outs = view_as_windows(norm_high, window_shape=patch_size, step=step)
    nrow, ncol, nchan = norm_patch_ins.shape[:3]

    saved_patch_ins, saved_patch_outs = [], []
    # go thru each patch volume
    for i in range(nrow):
        for j in range(ncol):
            for k in range(nchan):
                # set up the taxis 
                taxis = [i, j, k]
                
                # select two index: subidx1 and subidx2
                subidx1, subidx2 = 0, patch_size[channel] - 1
                
                # three cases based on the channel selection
                sub_patch_ins1, sub_patch_outs1 = divid_channel(norm_patch_ins, 
                                                                norm_patch_outs, channel, 
                                                                i, j, k, subidx1)
                sub_patch_ins2, sub_patch_outs2 = divid_channel(norm_patch_ins, 
                                                                norm_patch_outs, channel, 
                                                                i, j, k, subidx2)
                
                # check the two slices
                pass1, val11, lcriteria11, hcriteria11, val12, criteria12 = filter_2d(sub_patch_ins1, sub_patch_outs1, 
                                  norm_slice_max[step[channel]*taxis[channel]+subidx1], 
                                  thresh_low_ratio, thresh_high_ratio, thresh_diff_ratio, verbose)
                pass2, val21, lcriteria21, hcriteria21, val22, criteria22 = filter_2d(sub_patch_ins2, sub_patch_outs2, 
                                  norm_slice_max[step[channel]*taxis[channel]+subidx2], 
                                  thresh_low_ratio, thresh_high_ratio, thresh_diff_ratio, verbose)
                            
                if pass1 and pass2:
                    saved_patch_ins.append(norm_patch_ins[i, j, k])
                    saved_patch_outs.append(norm_patch_outs[i, j, k])
                #else:
                    #print (val11, lcriteria11, hcriteria11, val12, criteria12)
                    #print (val21, lcriteria21, hcriteria21, val22, criteria22)
                    #disp_3d(norm_patch_ins[i, j, k], norm_patch_outs[i, j, k])
                    
    return np.array(saved_patch_ins), np.array(saved_patch_outs)


def divid_channel(norm_patch_ins, norm_patch_outs, channel, i, j, k, subidx1):
    if channel == 0:
        sub_patch_ins = norm_patch_ins[i,j,k,subidx1]
        sub_patch_outs = norm_patch_outs[i,j,k,subidx1]
    elif channel == 1:
        sub_patch_ins = norm_patch_ins[i,j,k,subidx1]
        sub_patch_outs = norm_patch_outs[i,j,k,subidx1]
    elif channel == 2:
        sub_patch_ins = norm_patch_ins[i,j,k,subidx1]
        sub_patch_outs = norm_patch_outs[i,j,k,subidx1]
    else :
        raise NotImplementedError('Channel [{:s}] not recognized.'.format(str(channel)))
    return sub_patch_ins, sub_patch_outs 


def disp_3d(volume_ins, volume_outs):
    """
    a helper function to display the all 3 central slice in the 3 axis
    * the lowres image
    * the highres image
    * the absolute difference between the lowres image and the highres image
    
    ## an assumption here is that the volume is a cubic (n*n*n)
    args:
    -----
    volume_ins: a 3d numpy array (low-res)
    volume_outs: a 3d numpy array (high-res)
    
    return:
    -----
    None
    """
    nrow, ncol, nchan = volume_ins.shape
    # an assumption for using stack
    assert nrow == ncol and ncol == nchan
    assert volume_ins.shape == volume_outs.shape
    
    # set up the rows
    lowres_val = np.hstack((volume_ins[nrow//2], volume_ins[:,ncol//2,:], volume_ins[...,nchan//2]))
    highres_val = np.hstack((volume_outs[nrow//2], volume_outs[:,ncol//2,:], volume_outs[...,nchan//2]))
    diff = np.abs(volume_ins - volume_outs)
    diff_val = np.hstack((diff[nrow//2], diff[:,ncol//2,:], diff[...,nchan//2]))
    
    # display
    #plt.imshow(np.vstack((lowres_val, highres_val, diff_val)), clim=[0,5], cmap='gray')
    #plt.show()
    #plt.close()
    
    
def filter_2d(patch_ins, patch_outs, pixel_max, thresh_low_ratio, thresh_high_ratio, thresh_diff_ratio, verbose):
    """
    used to extract the patches that could be used in the training sets.
    (tested on the 2D T2 axial => thresh_low_ratio=0.3, thresh_diff_ratio=0.05)
    
    args:
    --------
    patch_ins : the numpy array for the target low-res patch (2d)
    patch_outs: the numpy array for the target high-res patch (2d)
    pixel_max: the max value for the current image(2d)
    thresh_low_ratio: a ratio used to filter out the background area
    thresh_diff_ratio: a ratio used to filter out the area with intensity difference
    verbose: verbose mode
    
    return:
    -------
    boolean var: whether the current patch should be filtered out or kept
    """
    patch_val = np.mean(patch_ins)
    patch_diff = np.mean(np.abs(patch_ins - patch_outs))
    if patch_val < thresh_high_ratio * pixel_max and patch_val > thresh_low_ratio * pixel_max and patch_diff < thresh_diff_ratio * pixel_max:
        return True, patch_val, thresh_low_ratio * pixel_max, thresh_high_ratio * pixel_max, patch_diff, thresh_diff_ratio * pixel_max
    else:
        return False, patch_val, thresh_low_ratio * pixel_max, thresh_high_ratio * pixel_max, patch_diff, thresh_diff_ratio * pixel_max

In [16]:
shared_path = '/home/subtle/Data/Long/data/hoag_T1_zoomed/S1'
saved_path = "/home/subtle/Long/"
cnt = 0
for element in os.listdir(shared_path):
    file = h5py.File(os.path.join(shared_path, 'cor_T_Id0017.h5'), 'r')
    norm_low = file["input"][:][...,0]
    norm_high = file["output"][:][...,0]
    pixel_mean = file["mean"][:][0]
    print (norm_low.shape, np.mean(norm_low), pixel_mean)
    saved_patch_ins, saved_patch_outs = filter_3d(norm_low, norm_high, (32,32,32), (8,16,16), 0.3, 0.6, 0.05, 1, False)
    cnt += saved_patch_ins.shape[0]
    with h5py.File(os.path.join(saved_path, 'p32_3d_S1_lr_hr' + element), 'w') as hf:
        hf.create_dataset("input", data=saved_patch_ins[..., np.newaxis], compression="gzip")
        hf.create_dataset("output", data=saved_patch_outs[...,np.newaxis], compression="gzip")
        hf.create_dataset("nslice", data=[saved_patch_ins.shape[0]], compression='gzip')
        hf.create_dataset("mean", data=[pixel_mean], compression='gzip')
    
print ("cnt=", cnt)

(224, 512, 512) 1.002882 115.22112
(224, 512, 512) 1.002882 115.22112
(224, 512, 512) 1.002882 115.22112
(224, 512, 512) 1.002882 115.22112
(224, 512, 512) 1.002882 115.22112
(224, 512, 512) 1.002882 115.22112
(224, 512, 512) 1.002882 115.22112
(224, 512, 512) 1.002882 115.22112
(224, 512, 512) 1.002882 115.22112
(224, 512, 512) 1.002882 115.22112
(224, 512, 512) 1.002882 115.22112
(224, 512, 512) 1.002882 115.22112
(224, 512, 512) 1.002882 115.22112
(224, 512, 512) 1.002882 115.22112
(224, 512, 512) 1.002882 115.22112
(224, 512, 512) 1.002882 115.22112
(224, 512, 512) 1.002882 115.22112
(224, 512, 512) 1.002882 115.22112
(224, 512, 512) 1.002882 115.22112
(224, 512, 512) 1.002882 115.22112
(224, 512, 512) 1.002882 115.22112
(224, 512, 512) 1.002882 115.22112
cnt= 57200
