In [1]:
import tensorflow as tf
import numpy as np
import os
import sys
import scipy
import scipy.signal
import glob

  _np_qint8 = np.dtype([("qint8", np.int8, 1)])
  _np_quint8 = np.dtype([("quint8", np.uint8, 1)])
  _np_qint16 = np.dtype([("qint16", np.int16, 1)])
  _np_quint16 = np.dtype([("quint16", np.uint16, 1)])
  _np_qint32 = np.dtype([("qint32", np.int32, 1)])
  np_resource = np.dtype([("resource", np.ubyte, 1)])


In [6]:
import UNet

sys.path.append('../Preprocess')
import CTPPreprocess as preprocess
import CalcParaMaps as paramaps

sys.path.append('../../')
import helper

In [31]:
import argparse
parser = argparse.ArgumentParser(description = 'ctp noise2noise netwok')
parser.add_argument('--imgFile', type=str, default='/home/dwu/trainData/Noise2Noise/train/ctp/simul/data/imgs_100000.npy')
parser.add_argument('--refFile', type=str, default='/home/dwu/trainData/Noise2Noise/train/ctp/simul/data/imgs_-1.npy')
parser.add_argument('--paraFile', type=str, default='/home/dwu/trainData/Noise2Noise/train/ctp/simul/data/paras_tikh_0.3.npz')
parser.add_argument('--aifFile', type=str, default='/home/dwu/trainData/Noise2Noise/train/ctp/simul/data/aif0.npy')
parser.add_argument('--nTest', type=int, default=-1)

# paths
parser.add_argument('--checkPoint', type=str, default=None)
parser.add_argument('--outFile', type=str, default=None)

# general network training
parser.add_argument('--device', type=int, default=0)

parser.add_argument('--imgNormIn', type=float, default=0.15)
parser.add_argument('--imgOffsetIn', type=float, default=-1)

parser.add_argument('--imgNormOut', type=float, default=0.025)
parser.add_argument('--imgOffsetOut', type=float, default=0)

_StoreAction(option_strings=['--imgOffsetOut'], dest='imgOffsetOut', nargs=None, const=None, default=0, type=<class 'float'>, choices=None, help=None, metavar=None)

In [32]:
tf.reset_default_graph()
net = UNet.UNet()
parser = net.AddArgsToArgParser(parser)

In [33]:
if sys.argv[0] != 'TestNetwork.py':
    from IPython import display
    import matplotlib.pyplot as plt
    %matplotlib inline
    args = parser.parse_args(['--device', '0',
                              '--imgFile', '/home/dwu/trainData/Noise2Noise/train/ctp/simul/data/imgs_200000.npy',
                              '--checkPoint', '/home/dwu/trainData/Noise2Noise/train/ctp/simul/beta_0_N0_200000/24',
                              '--nTest', '-1',
                              '--outFile', '/home/dwu/trainData/Noise2Noise/train/ctp/simul/beta_0_N0_200000/tmp/iodines_24.npy'])
else:
    args = parser.parse_args(sys.argv[1:])

for k in args.__dict__:
    print (k, args.__dict__[k], sep=': ', flush=True)

imgFile: /home/dwu/trainData/Noise2Noise/train/ctp/simul/data/imgs_200000.npy
refFile: /home/dwu/trainData/Noise2Noise/train/ctp/simul/data/imgs_-1.npy
paraFile: /home/dwu/trainData/Noise2Noise/train/ctp/simul/data/paras_tikh_0.3.npz
aifFile: /home/dwu/trainData/Noise2Noise/train/ctp/simul/data/aif0.npy
nTest: -1
checkPoint: /home/dwu/trainData/Noise2Noise/train/ctp/simul/beta_0_N0_200000/24
outFile: /home/dwu/trainData/Noise2Noise/train/ctp/simul/beta_0_N0_200000/tmp/iodines_24.npy
device: 0
imgNormIn: 0.15
imgOffsetIn: -1
imgNormOut: 0.025
imgOffsetOut: 0
scope: unet2d
imgshapeIn: [256, 256, 1]
imgshapeOut: [256, 256, 1]
nFilters: 32
filterSz: [3, 3, 3]
depth: 4
model: unet
bn: 0
beta: 0
biasKernelSz: 37
biasKernelStd: 6


In [34]:
tf.reset_default_graph()
net = UNet.UNet()
net.FromParser(args)
net.imgshapeIn[-1] = net.imgshapeIn[-1] + 1
net.BuildModel()

loader = tf.train.Saver()
if not os.path.exists(os.path.dirname(args.outFile)):
    os.makedirs(os.path.dirname(args.outFile))

In [35]:
# load the image data
imgs = np.load(args.imgFile) - 1
refs = np.load(args.refFile) - 1

In [36]:
# load param files
with np.load(args.paraFile) as f:
    cbf0 = f['cbf']
    cbv0 = f['cbv']
    mtt0 = f['mtt']
    masks = f['mask'][..., np.newaxis]
    cbfFac = f['cbfFac']
aif0 = np.load(args.aifFile) / 1000

maskVessels = np.where(np.max(imgs, -1) > 0.1, 1, 0)[...,np.newaxis]
maskVessels *= masks
for i in range(maskVessels.shape[0]):
    maskVessels[i,...,0] = scipy.ndimage.morphology.binary_dilation(maskVessels[i,...,0])
masks *= (1-maskVessels)

imgs *= np.tile(masks, (1,1,1,imgs.shape[-1]))
refs *= np.tile(masks, (1,1,1,imgs.shape[-1]))

In [37]:
def TestSequence(sess, net, imgs, args, iSlices = None):
    if iSlices is None:
        iSlices = [np.random.randint(imgs.shape[0])]
    elif iSlices == -1:
        iSlices = list(range(imgs.shape[0]))
    print (iSlices)
    
    imgNormIn = args.imgNormIn
    imgOffsetIn = args.imgOffsetIn
    
    imgs = imgs[iSlices, ...]
    recons = []
    for i in range(imgs.shape[-1]):
        print (i, end=',')
        inputImg1 = np.concatenate((imgs[..., [i]], imgs[..., [0]]), -1)
        inputImg2 = np.concatenate((imgs[..., [i]], imgs[..., [1]]), -1)
        
        recon1 = sess.run(net.recon, {net.x: inputImg1 / imgNormIn + imgOffsetIn})
        recon2 = sess.run(net.recon, {net.x: inputImg2 / imgNormIn + imgOffsetIn})
        
        recon = (recon1 + recon2) / 2 - args.imgOffsetOut
        recons.append(recon)
    
    recons = np.concatenate(recons, -1)

    return recons, iSlices

In [38]:
sess = tf.Session(config = tf.ConfigProto(gpu_options = tf.GPUOptions(visible_device_list='%s'%args.device, 
                                                                      allow_growth=True)))
sess.run(tf.global_variables_initializer())
loader.restore(sess, args.checkPoint)

INFO:tensorflow:Restoring parameters from /home/dwu/trainData/Noise2Noise/train/ctp/simul/beta_0_N0_200000/24


In [39]:
# save intermediate results
print ('Generating results')
if args.nTest > 0:
    imgs = imgs[-args.nTest:, ...]
    masks = masks[-args.nTest:, ...]

reconTest, _ = TestSequence(sess, net, imgs, args, -1)
maskFrame = np.tile(masks, (1,1,1,reconTest.shape[-1]))

np.save(os.path.join(args.outFile), 
        np.copy(np.transpose((reconTest * maskFrame).astype(np.float32), (0,3,1,2)), 'C'))

Generating results
[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63, 64, 65, 66, 67, 68, 69]
0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18,19,20,21,22,23,24,25,26,27,28,29,30,31,32,33,34,35,36,37,38,39,40,41,42,43,44,45,46,47,48,49,