In [3]:
import numpy as np
import tensorflow as tf
from keras.models import Model
from tensorflow.keras.layers import Conv3DTranspose
from keras.layers import (Input, concatenate, Conv2D,
                          MaxPooling2D, Conv2DTranspose, 
                          Dropout, Conv3D, MaxPooling3D, UpSampling3D)

In [5]:
def compute_level_output_shape(filters, depth, pool_size, image_shape):
    if depth != 0:
        output_image_shape = np.divide(image_shape, np.multiply(pool_size, depth)).tolist()
    else:
        output_image_shape = image_shape
    
    return tuple([None, filters] + [int(x) for x in output_image_shape])


def get_upconv(depth, nb_filters, pool_size, image_shape, kernel_size=(2, 2, 2), strides=(2, 2, 2),
               deconvolution=False):
    if deconvolution:
        return Conv3DTranspose(
            filters=nb_filters,
            kernel_size=kernel_size,
            output_shape=compute_level_output_shape(
                            filters=nb_filters,
                            depth=depth,
                            pool_size=pool_size,
                            image_shape=image_shape),
            strides=strides,
            input_shape=compute_level_output_shape(
                            filters=nb_filters,
                            depth=depth+1,
                            pool_size=pool_size,
                            image_shape=image_shape))
    else:
        return UpSampling3D(size=pool_size)

In [6]:
def Unet(image_shape, lr=1e-04, decay=1e-08, sw=None,initializer = 'glorot_uniform', nb_classes=2):
    inputs = Input(shape = image_shape)
    conv1 = Conv2D(64, (3,3), activation = 'relu', padding = 'same', kernel_initializer = initializer)(inputs)
    conv1 = Conv2D(64, (3,3), activation = 'relu', padding = 'same', kernel_initializer = initializer)(conv1)
    pool1 = MaxPooling2D(pool_size=(2, 2))(conv1)
    
    conv2 = Conv2D(128, (3,3), activation = 'relu', padding = 'same', kernel_initializer = initializer)(pool1)
    conv2 = Conv2D(128, (3,3), activation = 'relu', padding = 'same', kernel_initializer = initializer)(conv2)
    pool2 = MaxPooling2D(pool_size=(2, 2))(conv2)
    
    conv3 = Conv2D(256, (3,3), activation = 'relu', padding = 'same', kernel_initializer = initializer)(pool2)
    conv3 = Conv2D(256, (3,3), activation = 'relu', padding = 'same', kernel_initializer = initializer)(conv3)
    pool3 = MaxPooling2D(pool_size=(2, 2))(conv3)
    
    conv4 = Conv2D(512, (3,3), activation = 'relu', padding = 'same', kernel_initializer = initializer)(pool3)
    conv4 = Conv2D(512, (3,3), activation = 'relu', padding = 'same', kernel_initializer = initializer)(conv4)
    drop4 = Dropout(0.5)(conv4)
    pool4 = MaxPooling2D(pool_size=(2, 2))(drop4)
    
    conv5 = Conv2D(1024, (3,3), activation = 'relu', padding = 'same', kernel_initializer = initializer)(pool4)
    conv5 = Conv2D(1024, (3,3), activation = 'relu', padding = 'same', kernel_initializer = initializer)(conv5)
    drop5 = Dropout(0.5)(conv5)
    
    up6 = concatenate([Conv2DTranspose(512, (2, 2), strides=(2, 2), padding='same', kernel_initializer = initializer)(drop5), drop4], axis=3)
    conv6 = Conv2D(512, (3,3), activation = 'relu', padding = 'same', kernel_initializer = initializer)(up6)
    conv6 = Conv2D(512, (3,3), activation = 'relu', padding = 'same', kernel_initializer = initializer)(conv6)

    up7 = concatenate([Conv2DTranspose(256, (2, 2), strides=(2, 2), padding='same')(conv6), conv3], axis=3)    
    conv7 = Conv2D(256, (3,3), activation = 'relu', padding = 'same', kernel_initializer = initializer)(up7)
    conv7 = Conv2D(256, (3,3), activation = 'relu', padding = 'same', kernel_initializer = initializer)(conv7)
    
    up8 = concatenate([Conv2DTranspose(128, (2, 2), strides=(2, 2), padding='same', kernel_initializer = initializer)(conv7), conv2], axis=3)    
    conv8 = Conv2D(128, (3,3), activation = 'relu', padding = 'same', kernel_initializer = initializer)(up8)
    conv8 = Conv2D(128, (3,3), activation = 'relu', padding = 'same', kernel_initializer = initializer)(conv8)
    
    up9 = concatenate([Conv2DTranspose(64, (2, 2), strides=(2, 2), padding='same', kernel_initializer = initializer)(conv8), conv1], axis=3)
    conv9 = Conv2D(64, (3,3), activation = 'relu', padding = 'same', kernel_initializer = initializer)(up9)
    conv9 = Conv2D(64, (3,3), activation = 'relu', padding = 'same', kernel_initializer = initializer)(conv9)

    conv10 = Conv2D(nb_classes, (1,1), activation = 'relu')(conv9)    
    
    model = Model(inputs=[inputs], outputs=[conv10])

    return model



def Unet3D(input_shape, ds=1, pool_size=(2, 2, 2), n_labels=2,
                  initial_learning_rate=1e-4, deconvolution=False):
    inputs = Input(input_shape)
    conv1 = Conv3D(int(64/ds), (3, 3, 3), activation='relu', padding='same')(inputs)
    conv1 = Conv3D(int(64/ds), (3, 3, 3), activation='relu', padding='same')(conv1)
    pool1 = MaxPooling3D(pool_size=pool_size)(conv1)

    conv2 = Conv3D(int(128/ds), (3, 3, 3), activation='relu', padding='same')(pool1)
    conv2 = Conv3D(int(128/ds), (3, 3, 3), activation='relu', padding='same')(conv2)
    pool2 = MaxPooling3D(pool_size=pool_size)(conv2)

    conv3 = Conv3D(int(256/ds), (3, 3, 3), activation='relu', padding='same')(pool2)
    conv3 = Conv3D(int(256/ds), (3, 3, 3), activation='relu', padding='same')(conv3)
    pool3 = MaxPooling3D(pool_size=pool_size)(conv3)

    conv4 = Conv3D(int(512/ds), (3, 3, 3), activation='relu', padding='same')(pool3)
    conv4 = Conv3D(int(512/ds), (3, 3, 3), activation='relu', padding='same')(conv4)
    drop4 = Dropout(0.5)(conv4)
    
    up5 = get_upconv(pool_size=pool_size, deconvolution=deconvolution, depth=2,
                     nb_filters=int(256/ds), image_shape=input_shape)(drop4)
    up5 = concatenate([up5, conv3], axis=4)
    conv5 = Conv3D(int(256/ds), (3, 3, 3), activation='relu', padding='same')(up5)
    conv5 = Conv3D(int(256/ds), (3, 3, 3), activation='relu', padding='same')(conv5)

    up6 = get_upconv(pool_size=pool_size, deconvolution=deconvolution, depth=1,
                     nb_filters=int(256/ds), image_shape=input_shape)(conv5)
    up6 = concatenate([up6, conv2], axis=4)
    conv6 = Conv3D(int(128/ds), (3, 3, 3), activation='relu', padding='same')(up6)
    conv6 = Conv3D(int(128/ds), (3, 3, 3), activation='relu', padding='same')(conv6)

    up7 = get_upconv(pool_size=pool_size, deconvolution=deconvolution, depth=0,
                     nb_filters=int(128/ds), image_shape=input_shape)(conv6)
    up7 = concatenate([up7, conv1], axis=4)
    conv7 = Conv3D(int(64/ds), (3, 3, 3), activation='relu', padding='same')(up7)
    conv7 = Conv3D(int(64/ds), (3, 3, 3), activation='relu', padding='same')(conv7)

    conv8 = Conv3D(n_labels, (1, 1, 1))(conv7)
    model = Model(inputs=inputs, outputs=conv8)

    return model