In [65]:
import tensorflow as tf
import zipfile
import glob
import os
import cv2
import scanf
import SimpleITK as sitk
import sys
import numpy as np
import nibabel as nib
import matplotlib.pyplot as plt
import torch
import random
import torch.nn as nn
import torch.nn.functional as F

from tensorflow.keras.models import *
from tensorflow.keras.layers import *

In [190]:
dir = 'synthetic_data/seg'
OUT_DIR = 'synthetic_data/'
files = os.listdir(dir)
training_images = [os.path.join(dir, f) for f in files]
print(files)

['10.nii.gz', '1.nii.gz']


In [67]:
def preprocess(f):
    niftiImage = nib.load(f).get_fdata(caching='unchanged')
    niftiImage[tf.newaxis, ...]
    inImage = niftiImage.astype(np.float32)
    inImage = (inImage - inImage.mean()) / inImage.std()
    return inImage

In [161]:
XT = 'validation_data/imageData.nii'
YT = 'validation_data/segmentationData.nii'

# Model

In [162]:
def unet_model(output_channels, f=2):
    inputs = tf.keras.layers.Input(shape=(48, 48, 48, 1))
    
    # Downsampling through the model
    d1 = tf.keras.layers.Conv3D(f, 3, padding='same', activation='relu')(inputs)
    d1 = tf.keras.layers.Conv3D(f, 3, padding='same', activation='relu')(d1)

    d2 = tf.keras.layers.MaxPooling3D()(d1)
    d2 = tf.keras.layers.Conv3D(2*f, 3, padding='same', activation='relu')(d2)
    d2 = tf.keras.layers.Conv3D(2*f, 3, padding='same', activation='relu')(d2)
    
    d3 = tf.keras.layers.MaxPooling3D()(d2)
    d3 = tf.keras.layers.Conv3D(4*f, 3, padding='same', activation='relu')(d3)
    d3 = tf.keras.layers.Conv3D(4*f, 3, padding='same', activation='relu')(d3)
    # Upsampling and establishing the skip connections
    u2 = tf.keras.layers.UpSampling3D()(d3)
    u2 = tf.keras.layers.concatenate([u2, d2])
    u2 = tf.keras.layers.Conv3D(2*f, 3, padding='same', activation='relu')(u2)
    u2 = tf.keras.layers.Conv3D(2*f, 3, padding='same', activation='relu')(u2)

    u1 = tf.keras.layers.UpSampling3D()(u2)
    u1 = tf.keras.layers.concatenate([u1, d1])
    u1 = tf.keras.layers.Conv3D(f, 3, padding='same', activation='relu')(u1)
    u1 = tf.keras.layers.Conv3D(f, 3, padding='same', activation='relu')(u1)

    # This is the last layer of the model.
    outputs = tf.keras.layers.Conv3D(1, 1,activation='sigmoid')(u1)

    return tf.keras.Model(inputs=inputs, outputs=outputs)

In [213]:
def save_as_nifti(data, folder, name, affine=np.eye(4)):
    hdr = nib.Nifti1Header()
    hdr.set_data_dtype(np.float32)
    img = nib.Nifti1Image(data, affine, hdr)
    if not os.path.exists(folder):
        os.mkdir(folder)
    nib.save(img, os.path.join(folder, name))

In [163]:
DATA = 'synthetic_data/raw'
SEG = 'synthetic_data/seg'

In [170]:
def get_patch(img, seg, size=48):
    def rand_seg(d, l):
        lower = 0
        upper = d - l
        idx = random.randint(lower, upper)
        idx = random.randint(lower, upper)
        return idx, idx + l
    d1, d2, d3 = img.shape
    s1, s2, s3 = [rand_seg(d, size) for d in [d1, d2, d3]]
    l1, u1 = s1
    l2, u2 = s2
    l3, u3 = s3
    imgp, segp = img[l1:u1, l2:u2, l3:u3], seg[l1:u1, l2:u2, l3:u3]
    imgp = imgp[..., tf.newaxis]
    segp = segp[..., tf.newaxis]
    return (imgp - imgp.mean()) / imgp.std(), segp

In [171]:
def data_gen(data_dir, seg_dir, images_per_batch=1, patches_per_img=4, patch_size=48):
    data_files = os.listdir(data_dir)
    cases = {}
    for f in data_files:
        case_num = scanf.scanf('%d.nii.gz', f)[0]
        cases[case_num] = [os.path.join(data_dir, f)]

    seg_files = os.listdir(seg_dir)
    for f in seg_files:
        case_num = scanf.scanf('%d.nii.gz', f)[0]
        cases[case_num].append(os.path.join(seg_dir, f))
        assert len(cases[case_num]) == 2
    
    data = list(cases.values())
    
    while True:
        batch = random.choices(data, k=images_per_batch)
       
        xb, yb = [], []
        for x, y in batch:
            x = nib.load(x).get_fdata()
            y = nib.load(y).get_fdata()
            for i in range(patches_per_img):         
                xp, yp = get_patch(x, y, size=patch_size)
                xb.append(xp)
                yb.append(yp)
        xb = np.array(xb).astype('float32')
        yb = np.array(yb)
        yield xb, yb

In [214]:
if __name__ == "__main__":
    xt = nib.load(XT).get_fdata()
    yt = nib.load(YT).get_fdata()
    optimizer = tf.keras.optimizers.Adam(lr=0.001)
    g = data_gen(DATA, SEG)
    model = unet_model(1, 2)
    model.summary()
    for i in range(2):
        with tf.GradientTape(persistent=True) as tape:
            xp, yp = next(g)
            xp_out = model(xp, training=True)
            xp_loss = tf.keras.losses.binary_crossentropy(yp, xp_out)
        gradients = tape.gradient(xp_loss, model.trainable_weights)
        optimizer.apply_gradients(zip(gradients, model.trainable_weights))
        print("STEP=%05d, loss=%.5f" % (
            i,
            np.mean(xp_loss.numpy())
        )
    )
        if i % 1 == 0:
            xp_save, yp = get_patch(xt, yt)
            xp = xp_save.astype('float32')
            pred = model.predict(xp[tf.newaxis, ...])
            pred = tf.squeeze(pred, [0])
            print(pred.shape)
            print(pred.shape)
            save_as_nifti(xp_save, os.path.join(OUT_DIR, 'pred'), '%d_x.nii.gz'% i)
            save_as_nifti(pred, os.path.join(OUT_DIR, 'pred'), '%d_pred.nii.gz'% i)
            save_as_nifti(np.round(pred), os.path.join(OUT_DIR, 'pred'), '%d_pred_rounded.nii.gz'% i)
            save_as_nifti(yp, os.path.join(OUT_DIR, 'pred'), '%d_y.nii.gz'% i)
    



Model: "functional_131"
__________________________________________________________________________________________________
Layer (type)                    Output Shape         Param #     Connected to                     
input_70 (InputLayer)           [(None, 48, 48, 48,  0                                            
__________________________________________________________________________________________________
conv3d_743 (Conv3D)             (None, 48, 48, 48, 2 56          input_70[0][0]                   
__________________________________________________________________________________________________
conv3d_744 (Conv3D)             (None, 48, 48, 48, 2 110         conv3d_743[0][0]                 
__________________________________________________________________________________________________
max_pooling3d_138 (MaxPooling3D (None, 24, 24, 24, 2 0           conv3d_744[0][0]                 
_____________________________________________________________________________________