In [240]:
import os
import sys

import numpy as np
import matplotlib.pyplot as plt
import sigpy

plt.set_cmap('gray')
plt.rcParams['figure.figsize'] = (6, 4)

import subtle.subtle_io as suio

from keras.layers import Input, Conv3D, Activation
import keras.models
from keras.optimizers import Adam


os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
os.environ["CUDA_VISIBLE_DEVICES"] = "0"

def show_img(img, title='', axis=False, vmin=None, vmax=None):
    imshow_args = {}
    
    if vmin:
        imshow_args['vmin'] = vmin
    if vmax:
        imshow_args['vmax'] = vmax
    
    im_axis = 'on' if axis else 'off'
    plt.axis(im_axis)
    plt.imshow(img, **imshow_args)
    plt.title(title, fontsize=15)
    plt.show()

    
def get_patches(img_vol, block_size=64):
    patches = sigpy.block.array_to_blocks(img_vol, [block_size]*3, [block_size]*3)
    n_patches = patches.shape[0] * patches.shape[1] * patches.shape[2]
    patches = np.reshape(patches, (n_patches, block_size, block_size, block_size))
    return patches

def filter_patches(patches, pixel_percent=0.1, block_size=64):
    filt_patches = []
    get_nz = lambda plist: np.sum([len(np.nonzero(sl)[0]) for sl in plist])
    sel_idx = []
    
    for i, patch in enumerate(patches):
        percent = get_nz(patch.reshape((block_size, block_size**2))) / (block_size ** 3)
        if percent >= pixel_percent:
            filt_patches.append(patch)
            sel_idx.append(i)
    return np.array(filt_patches), np.array(sel_idx)

def conv_block(input, filters, strides=2):
    d_out = Conv3D(filters=filters, kernel_size=3, strides=strides, padding='same')(input)
    d_out = ReLU()(d_out)
    return d_out

<Figure size 432x288 with 0 Axes>

In [161]:
fpath_h5 = '/home/srivathsa/projects/studies/gad/stanford/preprocess/data/Patient_0101.h5'
study_data = suio.load_h5_file(fpath_h5)

In [241]:
ims_zero, ims_low, ims_full = study_data.transpose(1, 0, 2, 3)

zero_patches, good_idx = filter_patches(get_patches(ims_zero))
low_patches = get_patches(ims_low)[good_idx]

In [242]:
input = Input(shape=(64, 64, 64, 2))

d_out = conv_block(input, 32)
d_out = conv_block(d_out, 64)
d_out = conv_block(d_out, 128)

output = conv_block(d_out, 1, strides=1)
output = Activation('sigmoid')(output)

model = keras.models.Model(inputs=input, outputs=output)
model.summary()

model.compile(loss='mse', optimizer=Adam())

_________________________________________________________________
Layer (type)                 Output Shape              Param #   
input_8 (InputLayer)         (None, 64, 64, 64, 2)     0         
_________________________________________________________________
conv3d_22 (Conv3D)           (None, 32, 32, 32, 32)    1760      
_________________________________________________________________
re_lu_22 (ReLU)              (None, 32, 32, 32, 32)    0         
_________________________________________________________________
conv3d_23 (Conv3D)           (None, 16, 16, 16, 64)    55360     
_________________________________________________________________
re_lu_23 (ReLU)              (None, 16, 16, 16, 64)    0         
_________________________________________________________________
conv3d_24 (Conv3D)           (None, 8, 8, 8, 128)      221312    
_________________________________________________________________
re_lu_24 (ReLU)              (None, 8, 8, 8, 128)      0         
__________

In [243]:
inp_patches = np.array([zero_patches, low_patches]).transpose((1, 2, 3, 4, 0))
out_patches = np.ones((inp_patches.shape[0], 8, 8, 8, 1))

In [244]:
model.fit(inp_patches, out_patches, batch_size=16, epochs=2)

Epoch 1/2
Epoch 2/2


<keras.callbacks.History at 0x7fe39f415240>