In [2]:
import sys
import os
import tensorflow as tf
import numpy as np
import argparse
import SimpleITK as sitk
import random
import keras
import keras.backend as K
import time

Using TensorFlow backend.


In [17]:
img = sitk.ReadImage(r"E:\kits19\data\case_00000\segmentation.nii.gz")
arry = sitk.GetArrayFromImage(img)


print(arry[:,:,0].shape)

(602, 602)


In [5]:
args = None

def ParseArgs():
    parser = argparse.ArgumentParser()
    parser.add_argument("trainingdatafile", help="Input Dataset file for training")
    parser.add_argument("modelfile", help="Output trained model file in HDF5 format (*.hdf5).")
    parser.add_argument("-t","--testfile", help="Input Dataset file for validation")
    parser.add_argument("-e", "--epochs", help="Number of epochs", default=1000, type=int)
    parser.add_argument("-b", "--batchsize", help="Batch size", default=10, type=int)
    parser.add_argument("-l", "--learningrate", help="Learning rate", default=1e-3, type=float)
    parser.add_argument("--nobn", help="Do not use batch normalization layer", action='store_true')
    parser.add_argument("--nodropout", help="Do not use dropout layer", action='store_true')
    parser.add_argument("--noaugmentation", help="Do not use training data augmentation", action='store_true')
    parser.add_argument("--magnification", help="Magnification coefficient for data augmentation", default=10, type=int)
    parser.add_argument("--latestfile", help="The filename of the latest weights.")
    parser.add_argument("--bestfile", help="The filename of the best weights.")
    parser.add_argument("--weightinterval", help="The interval between checkpoint for weight saving.", type=int)
    parser.add_argument("--weightfile", help="The filename of the trained weight parameters file for fine tuning or resuming.")
    parser.add_argument("--premodel", help="The filename of the previously trained model")
    parser.add_argument("--initialepoch", help="Epoch at which to start training for resuming a previous training", default=0, type=int)
    #parser.add_argument("--idlist", help="The filename of ID list for splitting input datasets into training and validation datasets.")
    #parser.add_argument("--split", help="Fraction of the training data to be used as validation data.", default=0.0, type=float)
    parser.add_argument("--logdir", help="Log directory", default='log')
    parser.add_argument("-g", "--gpuid", help="ID of GPU to be used for segmentation. [default=0]", default=0, type=int)
    parser.add_argument("--history")

    args = parser.parse_args()
    return args


def createParentPath(filepath):
    head, _ = os.path.split(filepath)
    if len(head) != 0:
        os.makedirs(head, exist_ok = True)







def GetInputShapes(filenamepair):
    image = ImportImage(filenamepair[0])
    label = ImportImage(filenamepair[1])
    return (image.shape, label.shape)


def GetMinimumValue(image):
    minmax = sitk.MinimumMaximumImageFilter()
    minmax.Execute(image)
    return minmax.GetMinimum()


def Affine(t, r, scale, shear, c):
    a = sitk.AffineTransform(2)
    a.SetCenter(c)
    a.Scale(scale)
    a.Rotate(0,1,r)
    a.Shear(0,1,shear[0])
    a.Shear(1,0,shear[1])
    a.Translate(t)
    return a


def Transforming(image, bspline, affine, interpolator, minval):
    # B-spline transformation
    transformed_b = sitk.Resample(image, bspline, interpolator, minval)

    # Affine transformation
    transformed_a = sitk.Resample(transformed_b, affine, interpolator, minval)

    return transformed_a


def ImportImageTransformed(imagefile, labelfile):
    sigma = 4
    translationrange = 5 # [mm]
    rotrange = 5 # [deg]
    shearrange = 1/16 
    scalerange = 0.05

    image = sitk.ReadImage(imagefile)
    label = sitk.ReadImage(labelfile)

    # B-spline parameters
    bspline = sitk.BSplineTransformInitializer(image, [5,5])
    p = bspline.GetParameters()
    numbsplineparams = len(p)
    coeff = np.random.normal(0, sigma, numbsplineparams)
    bspline.SetParameters(coeff)

    # Affine parameters
    translation = np.random.uniform(-translationrange, translationrange, 2)
    rotation = np.radians(np.random.uniform(-rotrange, rotrange))
    shear = np.random.uniform(-shearrange, shearrange, 2)
    scale = np.random.uniform(1-scalerange, 1+scalerange)
    center = np.array(image.GetSize()) * np.array(image.GetSpacing()) / 2
    affine = Affine(translation, rotation, scale, shear, center)

    minval = GetMinimumValue(image)

    transformed_image = Transforming(image, bspline, affine, sitk.sitkLinear, minval)
    transformed_label = Transforming(label, bspline, affine, sitk.sitkNearestNeighbor, 0)

    imagearry = sitk.GetArrayFromImage(transformed_image)
    imagearry = imagearry[..., np.newaxis]
    labelarry = sitk.GetArrayFromImage(transformed_label)

    return imagearry, labelarry

In [6]:
def ImportBatchArray(datalist, batch_size = 32, apply_augmentation = False):
    while True:
        indices = list(range(len(datalist)))
        random.shuffle(indices)
        

        if apply_augmentation:
            for i in range(0, len(indices), batch_size):
                imagelabellist = [ ImportImageTransformed(datalist[idx][0], datalist[idx][1]) for idx in indices[i:i+batch_size] ]
                print("apply_augmentation")
                imagelist, onehotlabellist = zip(*imagelabellist)
                print("patch shape1 :",imagelist.shape)
                yield (np.array(imagelist), np.array(onehotlabellist))
        else:
            for i in range(0, len(indices), batch_size):
                imagelist = np.array([ ImportImage(datalist[idx][0]) for idx in indices[i:i+batch_size] ])
                
                onehotlabellist = np.array([ keras.utils.to_categorical(ImportImage(datalist[idx][1]),num_classes=3) for idx in indices[i:i+batch_size] ])
                
                yield (imagelist, onehotlabellist)

1. パスを同じ腎臓ごとにまとめる
2. それぞれのまとまりの中で、3つずつに分ける
3. 分けたものを3chの配列にする関数
4. batcharrayにch数で場合分けして、バッチする

In [53]:
def ImportImage(filename):
    image = sitk.ReadImage(filename)
    imagearry = sitk.GetArrayFromImage(image)
    if image.GetNumberOfComponentsPerPixel() == 1:
        imagearry = imagearry[..., np.newaxis]
    return imagearry



def ImportImage3ch(pList):#[["case_00000/image0_00.mha", "case_00000/image0_01.mha", "case_00000/image0_02.mha"],\
                          # ["case_00000/image0_01.mha", "case_00000/image0_02.mha", "case_00000/image0_03.mha"]...]
    
    check = False
    for x in pList:
        img = sitk.ReadImage(x)
        imgArray = sitk.GetArrayFromImage(img)
        
        if not check:
            check = True
            stackedArray = imgArray
            

        else:
            stackedArray = np.dstack([stackedArray, imgArray])
            

    return stackedArray

In [67]:
def ReadSliceDataList(filename):
    datalist = []
    with open(filename) as f:
        for line in f:
            labelfile, imagefile = line.strip().split('\t')
            datalist.append((imagefile, labelfile))

    return datalist

def ReadSliceDataList3ch(filename):
    datalist = []
    with open(filename) as f:
        for line in f:
            labelfile, imagefile = line.strip().split('\t')
            datalist.append((imagefile, labelfile))
            
            
    pathDicImg = {}#{~/case_00000/image0 : ~/case_00000/image0_00.mha}
    pathDicLab = {}#{~/case_00000/image0 : ~/case_00000/label0_00.mha}
    pathList = []#3枚ごとにまとめられたリスト(image,label)


    #パスを同じ腎臓ごとにまとめる
    for path in datalist:

        dicPathI, filePathI = os.path.split(path[0])
        dicPathL, filePathL = os.path.split(path[1])
        
        fI,nameI = filePathI.split("_")
        fL,nameL = filePathL.split("_")
        fPathI = os.path.join(dicPathI, fI)
        fPathL = os.path.join(dicPathL, fL)

        if fPathI not in pathDicImg:
            pathDicImg[fPathI] = []

        pathDicImg[fPathI].append(path[0])
        
        if fPathL not in pathDicLab:
            pathDicLab[fPathL] = []

        pathDicLab[fPathL].append(path[1])

    #同じ腎臓の中で、あるスライスと前後2枚をくっつける(path)

    for (keyI, valueI), (keyL, valueL) in zip(pathDicImg.items(), pathDicLab.items()):
        for x in range(1,len(valueI)-1):
            pathList.append((valueI[x-1:x+2], valueL[x-1:x+2]))
   
    return pathList

In [106]:
def ReadSliceDataList3ch_1ch(filename):
    datalist = []
    with open(filename) as f:
        for line in f:
            labelfile, imagefile = line.strip().split('\t')
            datalist.append((imagefile, labelfile))
            
            
    pathDicImg = {}#{~/case_00000/image0 : ~/case_00000/image0_00.mha}
    labellist = {}
    pathList = []#3枚ごとにまとめられたリスト(image,label)

    #パスを同じ腎臓ごとにまとめる
    for path in datalist:
        dicPathI, filePathI = os.path.split(path[0])
        fI,nameI = filePathI.split("_")
        fPathI = os.path.join(dicPathI, fI)
        

        if fPathI not in pathDicImg:
            pathDicImg[fPathI] = []
            labellist[fPathI] = []

        pathDicImg[fPathI].append(path[0])

        labellist[fPathI].append(path[1])

    #同じ腎臓の中で、あるスライスと前後2枚をくっつける(path)

    for (keyI, valueI),(labkey, labvalue) in zip(pathDicImg.items(), labellist.items()):
        valueI = sorted(valueI)
        labvalue = sorted(labvalue)
        for x in range(1,len(valueI)-1):
            pathList.append((valueI[x-1:x+2], labvalue[x]))
   
    return pathList

In [109]:
pathList = ReadSliceDataList3ch_1ch("/Users/tanimotoryou/Documents/Documents/lab/program/2Dkidney/slice/margeTraining_8_2.txt")
for x, y in pathList:
    _, a = os.path.split(x[1])
    _, b = os.path.split(y)
    _, a = a.split("_")
    _, b = b.split("_")

In [102]:
pathList = ReadSliceDataList3ch_1ch("slice/0.txt")
print(pathList)

for p in pathList:
    array = ImportImage3ch(p[0])
    print(array.shape)

[(['slice/0/image0_03.mha', 'slice/0/image0_03.mha', 'slice/0/image0_03.mha'], 'slice/0/label0_03.mha'), (['slice/0/image0_03.mha', 'slice/0/image0_03.mha', 'slice/0/image0_03.mha'], 'slice/0/label0_03.mha'), (['slice/0/image0_03.mha', 'slice/0/image0_03.mha', 'slice/0/image0_02.mha'], 'slice/0/label0_03.mha'), (['slice/0/image0_03.mha', 'slice/0/image0_02.mha', 'slice/0/image0_02.mha'], 'slice/0/label0_02.mha'), (['slice/0/image0_02.mha', 'slice/0/image0_02.mha', 'slice/0/image0_02.mha'], 'slice/0/label0_02.mha'), (['slice/0/image0_02.mha', 'slice/0/image0_02.mha', 'slice/0/image0_02.mha'], 'slice/0/label0_02.mha'), (['slice/0/image0_02.mha', 'slice/0/image0_02.mha', 'slice/0/image0_01.mha'], 'slice/0/label0_02.mha'), (['slice/0/image0_02.mha', 'slice/0/image0_01.mha', 'slice/0/image0_01.mha'], 'slice/0/label0_01.mha'), (['slice/0/image0_01.mha', 'slice/0/image0_01.mha', 'slice/0/image0_01.mha'], 'slice/0/label0_01.mha'), (['slice/0/image0_01.mha', 'slice/0/image0_01.mha', 'slice/0/im

In [160]:
def To3chDataList(dataPath):
    pathDic = {}#{~/case_00000/image0 : ~/case_00000/image0_00.mha}
    pathList = []#3枚ごとにまとめられたリスト
    
    #パスを同じ腎臓ごとにまとめる
    for path in dataPath:

        dicPath, filePath = os.path.split(path)
        f,name = filePath.split("_")
        fPath = os.path.join(dicPath, f)

        if fPath not in pathDic:
            pathDic[fPath] = []

        pathDic[fPath].append(path)
    print(pathDic)
    #同じ腎臓の中で、あるスライスと前後2枚をくっつける(path)

    for key, value in pathDic.items():
        for x in range(1,len(value)-1):
            pathList.append(value[x-1:x+2])
    
    return pathList

In [65]:
datalist = ReadSliceDataList("/Users/tanimotoryou/Documents/Documents/lab/program/2Dkidney/slice/2.txt")

for d in datalist:
    pathList = To3chDataList(d[0])
print(pathList)
for p in pathList:
    array = ImportImage3ch(p)
    print(array.shape)

NameError: name 'To3chDataList' is not defined

In [None]:
def main(_):
    
<<<<<<< HEAD
    experiment = Experiment(api_key="IowbTppLPOohqhcDtzxw76Cot")#comet_mlに保存
    comet_callback = experiment.get_keras_callback()
=======
    
>>>>>>> 48cc187a15e8e5fb05b6fbaffbbd81c63ea69a6e
                        
    config = tf.ConfigProto()
    config.gpu_options.allow_growth = True
    config.allow_soft_placement = True
    sess = tf.Session(config=config)
    tf.keras.backend.set_session(sess)

    trainingdatalist = ReadSliceDataList(args.trainingdatafile)
    testdatalist = None
    if args.testfile is not None:
        testdatalist = ReadSliceDataList(args.testfile)
        testdatalist = random.sample(testdatalist, int(len(testdatalist)*0.1))

    (imageshape, labelshape) = GetInputShapes(trainingdatalist[0])
    nclasses = 3 # Number of classes
    print(imageshape)
    print(labelshape)

    with tf.device('/device:GPU:{}'.format(args.gpuid)):
        x = tf.keras.layers.Input(shape=imageshape, name="x")
        segmentation = ConstructModel(x, nclasses, not args.nobn, not args.nodropout)
        model = tf.keras.models.Model(x, segmentation)
        model.summary()

        optimizer = tf.keras.optimizers.Adam(lr=args.learningrate)

        model.compile(loss=penalty_categorical, optimizer=optimizer, metrics=[kidney_dice, cancer_dice])

    createParentPath(args.modelfile)
    with open(args.modelfile, 'w') as f:
        f.write(model.to_yaml())

    if args.weightfile is None:
        initial_epoch = 0
    else:
        model.load_weights(args.weightfile)
        initial_epoch = args.initialepoch

    if args.latestfile is None:
        latestfile = args.logdir + '/latestweights.hdf5'
    else:
        latestfile = args.latestfile
        createParentPath(latestfile)

    tb_cbk = tf.keras.callbacks.TensorBoard(log_dir=args.logdir)
    latest_cbk = LatestWeightSaver(latestfile)
    callbacks = [tb_cbk, latest_cbk]
    if testdatalist is not None:
        if args.bestfile is None:
            bestfile = args.logdir + '/bestweights.hdf5'
        else:
            bestfile = args.bestfile
            createParentPath(bestfile)
        chkp_cbk = tf.keras.callbacks.ModelCheckpoint(filepath=bestfile, save_best_only = True, save_weights_only = True)
        callbacks.append(chkp_cbk)
    if args.weightinterval is not None:
        periodic_cbk = PeriodicWeightSaver(logdir=args.logdir, interval=args.weightinterval)
        callbacks.append(periodic_cbk)

    steps_per_epoch = len(trainingdatalist) / args.batchsize 
    print ("Batch size: {}".format(args.batchsize))
    print ("Number of Epochs: {}".format(args.epochs))
    print ("Number of Steps/epoch: {}".format(steps_per_epoch))

    with tf.device('/device:GPU:{}'.format(args.gpuid)):
        if testdatalist is not None:
            historys = model.fit_generator(ImportBatchArray(trainingdatalist, batch_size = args.batchsize, apply_augmentation = False),
                    steps_per_epoch = int(steps_per_epoch), epochs = args.epochs,
                    callbacks=callbacks,
                    validation_data = ImportBatchArray(testdatalist, batch_size = args.batchsize),
                    validation_steps = len(testdatalist),
                    initial_epoch = int(initial_epoch))
        else:
            historys = model.fit_generator(ImportBatchArray(trainingdatalist, batch_size = args.batchsize, apply_augmentation = False),
                    steps_per_epoch = int(steps_per_epoch), epochs = args.epochs,
                    callbacks=callbacks,
                    initial_epoch = int(initial_epoch))

    
    loss = historys.history['dice']
    val_loss = historys.history['val_dice']
    epochs = len(loss)
    
    history_file = open(args.history,"a")

    for x in range(epochs):
        print("{}\t{}".format(loss[x],val_loss[x]),file = history_file)
    print("\n",file=history_file)
        
    history_file.close()

In [None]:
if __name__ == '__main__':
    args = ParseArgs()
    t1 = time.time()

    tf.app.run(main=main, argv=[sys.argv[0]])

    t2 = time.time()
    caluculateTime(t1, t2)