In [60]:
import tensorflow as tf

from tensorflow.keras import Input, Model
from tensorflow.keras.optimizers import Adam
from tensorflow.keras.layers import Conv2D, MaxPooling2D, UpSampling2D, concatenate

import numpy as np

In [29]:
def compose_f(functions, x):
    out = x
    
    for f in functions:
        out = f(out)
    
    return out

In [7]:
def conv_layer(n_filters, size, activation = 'relu', padding = 'same', kernel_initializer = 'he_normal'):
    return Conv2D(n_filters, size, activation = activation, padding = padding, kernel_initializer = kernel_initializer)

In [10]:
def pool_layer(pool_size=(2, 2)):
    return MaxPooling2D(pool_size=pool_size)

In [35]:
def contracting_layer(n_filters, x):
    x = conv_layer(n_filters, 3)(conv_layer(n_filters, 3)(x))
    skip_conn = x  # save to concat in corresponding skip connection later
    
    x = pool_layer()(x)
    return x, skip_conn

In [37]:
def contracting_path(x, start_layers=64, n_layers=4):
    current_n_layers = start_layers
    skip_conns = []
    
    for _ in range(n_layers):
        x, skip_conn = contracting_layer(current_n_layers, x)
        
        skip_conns.append(skip_conn)
        current_n_layers *= 2
    
    # todo: dropout ?
    return x, skip_conns

In [38]:
def upconv_layer(x, n_filters, size, up_sampling_size, activation = 'relu', padding = 'same', kernel_initializer = 'he_normal'):
    x = UpSampling2D(size = up_sampling_size)(x)
    x = conv_layer(n_filters, size)(x)
    return x

In [41]:
def expanding_layer(n_filters, x, skip_conn):
    x = upconv_layer(x, n_filters, 2, (2,2))  # up-convolution
    x = concatenate([skip_conn, x], axis=3)  # concat in 3rd dim
    x = conv_layer(n_filters, 3)(x)
    x = conv_layer(n_filters, 3)(x)
    return x

In [44]:
def expanding_path(x, skip_conns, start_layers=512):
    current_n_layers = start_layers
    
    for skip_conn in reversed(skip_conns):
        x = expanding_layer(current_n_layers, x, skip_conn)
        
        current_n_layers /= 2
    
    return x

In [52]:
def _unet(input_size = (256, 256, 1)):
    x = Input(input_size)
    inputs = x
    
    x, skip_conns = contracting_path(x)
    x = conv_layer(1024, 3)(conv_layer(1024, 3)(x))  # todo: dropout ?
    
    x = expanding_path(x, skip_conns)
    
    x = conv_layer(2, 3)(x)  # output segmentation map
    x = conv_layer(1, 1, activation = 'sigmoid')(x)
    
    outputs = x
    return inputs, outputs

In [69]:
def unet(optimizer = Adam(lr = 1e-4), loss = 'binary_crossentropy', metrics = ['accuracy']):
    inputs, outputs = _unet()
    
    model = Model(inputs = inputs, outputs = outputs, name='simple UNet')
    model.compile(optimizer=optimizer, loss=loss, metrics=metrics)
    model.summary()
    
    return model

In [70]:
u = unet()

Model: "simple UNet"
__________________________________________________________________________________________________
Layer (type)                    Output Shape         Param #     Connected to                     
input_8 (InputLayer)            [(None, 256, 256, 1) 0                                            
__________________________________________________________________________________________________
conv2d_141 (Conv2D)             (None, 256, 256, 64) 640         input_8[0][0]                    
__________________________________________________________________________________________________
conv2d_140 (Conv2D)             (None, 256, 256, 64) 36928       conv2d_141[0][0]                 
__________________________________________________________________________________________________
max_pooling2d_28 (MaxPooling2D) (None, 128, 128, 64) 0           conv2d_140[0][0]                 
________________________________________________________________________________________