In [None]:
import speedup
import tensorflow as tf
import numpy as np
import tensorflow as tf
import random
random.seed(0)
np.random.seed(0)
tf.random.set_seed(0)

In [None]:
def conv2d_bn(x, filters, num_row, num_col, padding = 'same', strides = (1, 1), activation = 'relu', name = None):

    x = tf.keras.layers.Conv2D(filters, (num_row, num_col), strides = strides, padding = padding, use_bias = False)(x)
    x = tf.keras.layers.BatchNormalization(axis = 3, scale = False)(x)

    if(activation == None):
        return x

    x = tf.keras.layers.Activation(activation, name = name)(x)

    return x

In [None]:
def trans_conv2d_bn(x, filters, num_row, num_col, padding = 'same', strides = (2, 2), name = None):

    x = tf.keras.layers.Conv2DTranspose(filters, (num_row, num_col), strides = strides, padding = padding)(x)
    x = tf.keras.layers.BatchNormalization(axis = 3, scale = False)(x)
    
    return x

In [None]:
def DCBlock(U, inp, alpha = 1.67):

    W = alpha * U


    conv3x3_1 = conv2d_bn(inp, int(W*0.167), 3, 3,
                        activation='relu', padding='same')

    conv5x5_1 = conv2d_bn(conv3x3_1, int(W*0.333), 3, 3,
                        activation='relu', padding='same')

    conv7x7_1 = conv2d_bn(conv5x5_1, int(W*0.5), 3, 3,
                        activation='relu', padding='same')

    out1 = tf.keras.layers.concatenate([conv3x3_1, conv5x5_1, conv7x7_1], axis=3)
    out1 = tf.keras.layers.BatchNormalization(axis=3)(out1)
    

    conv3x3_2 = conv2d_bn(inp, int(W*0.167), 3, 3,
                        activation='relu', padding='same')

    conv5x5_2 = conv2d_bn(conv3x3_2, int(W*0.333), 3, 3,
                        activation='relu', padding='same')

    conv7x7_2 = conv2d_bn(conv5x5_2, int(W*0.5), 3, 3,
                        activation='relu', padding='same')
    
    out2 = tf.keras.layers.concatenate([conv3x3_2, conv5x5_2, conv7x7_2], axis=3)
    out2 = tf.keras.layers.BatchNormalization(axis=3)(out2)


    out = tf.keras.layers.Add()([out1, out2])
    out = tf.keras.layers.Activation('relu')(out)
    out = tf.keras.layers.BatchNormalization(axis=3)(out)

    return out

In [None]:
def ResPath(filters, length, inp):

    shortcut = inp
    shortcut = conv2d_bn(shortcut, filters, 1, 1,
                         activation=None, padding='same')

    out = conv2d_bn(inp, filters, 3, 3, activation='relu', padding='same')

    out = tf.keras.layers.Add()([shortcut, out])
    out = tf.keras.layers.Activation('relu')(out)
    out = tf.keras.layers.BatchNormalization(axis=3)(out)

    for i in range(length-1):

        shortcut = out
        shortcut = conv2d_bn(shortcut, filters, 1, 1,
                             activation=None, padding='same')

        out = conv2d_bn(out, filters, 3, 3, activation='relu', padding='same')

        out = tf.keras.layers.Add()([shortcut, out])
        out = tf.keras.layers.Activation('relu')(out)
        out = tf.keras.layers.BatchNormalization(axis=3)(out)

    return out

In [None]:
imageSize = 512
m = 3

def constructModel():
    inputs = tf.keras.layers.Input((imageSize, imageSize, m))

    dcblock1 = DCBlock(32, inputs)
    pool1 = tf.keras.layers.MaxPooling2D(pool_size=(2, 2))(dcblock1)
    dcblock1 = ResPath(32, 4, dcblock1)

    dcblock2 = DCBlock(32*2, pool1)
    pool2 = tf.keras.layers.MaxPooling2D(pool_size=(2, 2))(dcblock2)
    dcblock2 = ResPath(32*2, 3, dcblock2)

    dcblock3 = DCBlock(32*4, pool2)
    pool3 = tf.keras.layers.MaxPooling2D(pool_size=(2, 2))(dcblock3)
    dcblock3 = ResPath(32*4, 2, dcblock3)

    dcblock4 = DCBlock(32*8, pool3)
    pool4 = tf.keras.layers.MaxPooling2D(pool_size=(2, 2))(dcblock4)
    dcblock4 = ResPath(32*8, 1, dcblock4)

    dcblock5 = DCBlock(32*16, pool4)

    up6 = tf.keras.layers.concatenate([tf.keras.layers.Conv2DTranspose(
        32*8, (2, 2), strides=(2, 2), padding='same')(dcblock5), dcblock4], axis=3)
    dcblock6 = DCBlock(32*8, up6)

    up7 = tf.keras.layers.concatenate([tf.keras.layers.Conv2DTranspose(
        32*4, (2, 2), strides=(2, 2), padding='same')(dcblock6), dcblock3], axis=3)
    dcblock7 = DCBlock(32*4, up7)

    up8 = tf.keras.layers.concatenate([tf.keras.layers.Conv2DTranspose(
        32*2, (2, 2), strides=(2, 2), padding='same')(dcblock7), dcblock2], axis=3)
    dcblock8 = DCBlock(32*2, up8)

    up9 = tf.keras.layers.concatenate([tf.keras.layers.Conv2DTranspose(32, (2, 2), strides=(
        2, 2), padding='same')(dcblock8), dcblock1], axis=3)
    dcblock9 = DCBlock(32, up9)

    conv10 = conv2d_bn(dcblock9, 3, 1, 1, activation='sigmoid')
    
    model = tf.keras.Model(inputs=[inputs], outputs=[conv10])

    return model

In [None]:
model_instance = constructModel()
model_instance.summary()

Model: "model_1"
__________________________________________________________________________________________________
 Layer (type)                   Output Shape         Param #     Connected to                     
 input_2 (InputLayer)           [(None, 512, 512, 3  0           []                               
                                )]                                                                
                                                                                                  
 conv2d_75 (Conv2D)             (None, 512, 512, 8)  216         ['input_2[0][0]']                
                                                                                                  
 conv2d_78 (Conv2D)             (None, 512, 512, 8)  216         ['input_2[0][0]']                
                                                                                                  
 batch_normalization_112 (Batch  (None, 512, 512, 8)  24         ['conv2d_75[0][0]']        

In [None]:
image_path = '/content/drive/MyDrive/source2'
models_path = '/content/drive/MyDrive/models/'

In [None]:
import itertools
import tensorflow as tf
from speedup import generate_out_images3
import numpy as np
from random import randint, uniform
import imageio
import time


source_num = 2799
dim = 512
stationary_defocus = 0.05


def gen():
    while True:
        layer1_number = randint(0, source_num)
        layer2_number = randint(0, source_num)
        layer3_number = randint(0, source_num)

        src1 = imageio.imread(image_path + '/image' + str(layer1_number).zfill(4) + '.png')
        src2 = imageio.imread(image_path + '/image' + str(layer2_number).zfill(4) + '.png')
        src3 = imageio.imread(image_path + '/image' + str(layer3_number).zfill(4) + '.png')
        src = np.zeros((dim, dim, m), np.double)
        src[:, :, 0] = src1[:, :, 0]
        src[:, :, 1] = src2[:, :, 0]
        src[:, :, 2] = src3[:, :, 0]
        src = src - np.amin(src)
        src = src / np.amax(src)

        w = uniform(0.05, 0.5) 
        
        a_10 = uniform(-1e3, 1e3)
        a_01 = uniform(-1e3, 1e3)
        b_20 = uniform(1, 1.5)
        b_11 = uniform(-0.1, 0.1)
        b_02 = uniform(1, 1.5)
        c_30 = uniform(-1.5e-6, 1.5e-6)
        c_21 = uniform(-2e-6, 2e-6)
        c_12 = uniform(-2e-6, 2e-6)
        c_03 = uniform(-1.5e-6, 1.5e-6)

        out = generate_out_images3(dim, m, w, stationary_defocus, a_10, a_01, b_20, b_11, b_02, c_30, c_21, c_12, c_03, src)[1]

        out = out / np.amax(out)

        src[src > 0] = 1.

        yield (out, src)


tr_dataset = tf.data.Dataset.from_generator(
     gen, (tf.float64, tf.float64), (tf.TensorShape([dim, dim, m]), tf.TensorShape([dim, dim, m])))\
    .batch(batch_size=2).prefetch(buffer_size=8)

val_dataset = tf.data.Dataset.from_generator(
     gen, (tf.float64, tf.float64), (tf.TensorShape([dim, dim, m]), tf.TensorShape([dim, dim, m])))\
    .take(count=128).cache().batch(batch_size=2)


opt = tf.keras.optimizers.Adam()

save_best_callback = tf.keras.callbacks.ModelCheckpoint(models_path + 'bestmodel_dc_unet_3params_rate.hdf5',
                                                        save_weights_only=True,save_best_only=True, verbose=True)
csv_logger_callback = tf.keras.callbacks.CSVLogger(models_path + 'log_dc_unet_3params_rate.csv')
lr_reduce_callback = tf.keras.callbacks.ReduceLROnPlateau(factor=0.5, min_delta=5e-4, patience=5)
early_stop_callback = tf.keras.callbacks.EarlyStopping(patience=25)

model_instance = constructModel()
model_instance.compile(loss='binary_crossentropy', optimizer=opt, metrics=['binary_crossentropy', 'mse'])
model_instance.fit(x=tr_dataset, validation_data=val_dataset, verbose=1, validation_steps=64,
                   steps_per_epoch=256, epochs=200,
                   callbacks=[save_best_callback, csv_logger_callback, lr_reduce_callback, early_stop_callback])