In [3]:
import speedup
import tensorflow as tf
import numpy as np
import tensorflow as tf
import random
random.seed(0)
np.random.seed(0)
tf.random.set_seed(0)

In [4]:
def double_conv_block(x, n_filters):
   x = tf.keras.layers.Conv3D(n_filters, (3, 3, 3), activation = "relu", kernel_initializer = "he_normal", padding="same")(x)
   x = tf.keras.layers.Conv3D(n_filters, (3, 3, 3), activation = "relu", kernel_initializer = "he_normal", padding="same")(x)
   return x

In [5]:
def downsample_block(x, n_filters):
   f = double_conv_block(x, n_filters)
   p = tf.keras.layers.MaxPool3D((2, 2, 2))(f)
   return f, p

In [6]:
def upsample_block(x, conv_features, n_filters):
   x = tf.keras.layers.Conv3DTranspose(n_filters, (3, 3, 3), (2, 2, 2), padding="same")(x)
   x = tf.keras.layers.Concatenate(axis=-1)([x, conv_features])
   x = double_conv_block(x, n_filters)
   return x

In [7]:
imageSize = 512
m = 8
min_w = 0.1
max_w = 0.5

def constructModel():
    inputs = tf.keras.Input(shape=(imageSize, imageSize, m, 1))
    f1, p1 = downsample_block(inputs, 64)
    f2, p2 = downsample_block(p1, 128)
    f3, p3 = downsample_block(p2, 256)
    

    bottleneck = double_conv_block(p3, 512)

    u6 = upsample_block(bottleneck, f3, 256)
    u7 = upsample_block(u6, f2, 128)
    u8 = upsample_block(u7, f1, 64)
    pre_outputs = tf.keras.layers.Conv3D(1, (1, 1, 1), activation = "sigmoid", padding="same")(u8)
    outputs = tf.keras.layers.Lambda(lambda x: tf.concat([x[:, :, :, 0, :], x[:, :, :, 7, :]], 3))(pre_outputs)

    return tf.keras.Model(inputs, outputs, name="U-Net")

In [30]:
model_instance = constructModel()
model_instance.summary()

Model: "U-Net"
__________________________________________________________________________________________________
 Layer (type)                   Output Shape         Param #     Connected to                     
 input_2 (InputLayer)           [(None, 512, 512, 8  0           []                               
                                , 1)]                                                             
                                                                                                  
 conv3d_15 (Conv3D)             (None, 512, 512, 8,  1792        ['input_2[0][0]']                
                                 64)                                                              
                                                                                                  
 conv3d_16 (Conv3D)             (None, 512, 512, 8,  110656      ['conv3d_15[0][0]']              
                                 64)                                                          

In [10]:
image_path = '/content/drive/MyDrive/source'
models_path = '/content/drive/MyDrive/models/'

In [None]:
import itertools
import tensorflow as tf
from speedup import generate_out_images
import numpy as np
from random import randint, uniform
import imageio
import time

source_num = 2999
dim = 512
stationary_defocus = 0.05

def gen():
    while True:
        layer1_number = randint(0, source_num)
        layer2_number = randint(0, source_num)

        src1 = imageio.imread(image_path + '/src_' + str(layer1_number).zfill(4) + '.png')
        src2 = imageio.imread(image_path + '/src_' + str(layer2_number).zfill(4) + '.png')

        src = np.zeros((dim, dim, m), np.double)

        src[:, :, 0] = src1[:, :, 0]
        src[:, :, 7] = src2[:, :, 0]
        
        src = src / np.amax(src)

        w_rand = uniform(min_w, max_w)

        out = generate_out_images(dim, m, w_rand, stationary_defocus, src)[1]

        out = out / np.amax(out)


        src_ext = np.zeros((dim, dim, 2), np.double)
        src_ext[:, :, 0] = src[:, :, 0]
        src_ext[:, :, 1] = src[:, :, 7]

        out_ext = np.zeros((dim, dim, m, 1), np.double)
        out_ext[:, :, :, 0] = out

        yield (out_ext, src_ext)


tr_dataset = tf.data.Dataset.from_generator(
     gen, (tf.float64, tf.float64), (tf.TensorShape([dim, dim, m, 1]), tf.TensorShape([dim, dim, 2])))\
    .batch(batch_size=1).prefetch(buffer_size=8)

val_dataset = tf.data.Dataset.from_generator(
     gen, (tf.float64, tf.float64), (tf.TensorShape([dim, dim, m, 1]), tf.TensorShape([dim, dim, 2])))\
    .take(count=64).cache().batch(batch_size=1)


opt = tf.keras.optimizers.Adam(learning_rate=1e-5)

save_best_callback = tf.keras.callbacks.ModelCheckpoint(models_path + 'bestmodel_unet4.hdf5',
                                                        save_weights_only=True,save_best_only=True, verbose=True)
csv_logger_callback = tf.keras.callbacks.CSVLogger(models_path + 'log_unet4.csv')
lr_reduce_callback = tf.keras.callbacks.ReduceLROnPlateau(factor=0.5, min_delta=5e-4, patience=5)
early_stop_callback = tf.keras.callbacks.EarlyStopping(patience=25)

model_instance = constructModel()
model_instance.compile(loss='mse', optimizer=opt, metrics=['mae', "mse"])
model_instance.fit(x=tr_dataset, validation_data=val_dataset, verbose=1, validation_steps=32,
                   steps_per_epoch=128, epochs=100,
                   callbacks=[save_best_callback, csv_logger_callback, lr_reduce_callback, early_stop_callback])