In [None]:
import speedup
import tensorflow as tf
import numpy as np
import tensorflow as tf
import random
random.seed(0)
np.random.seed(0)
tf.random.set_seed(0)

In [None]:
def double_conv_block(x, n_filters):
   x = tf.keras.layers.Conv2D(n_filters, (3, 3),  activation='relu', kernel_initializer = "he_normal", padding="same")(x)
   x = tf.keras.layers.BatchNormalization()(x)
   x = tf.keras.layers.Conv2D(n_filters, (3, 3),  activation='relu', kernel_initializer = "he_normal", padding="same")(x)
   x = tf.keras.layers.BatchNormalization()(x)
   return x

In [None]:
def downsample_block(x, n_filters):
   f = double_conv_block(x, n_filters)
   p = tf.keras.layers.MaxPool2D((2, 2))(f)
   return f, p

In [None]:
def upsample_block(x, conv_features, n_filters):
   x = tf.keras.layers.Conv2DTranspose(n_filters, (3, 3), (2, 2), padding="same")(x)
   x = tf.keras.layers.concatenate([x, conv_features])
   x = double_conv_block(x, n_filters)
   return x

In [None]:
imageSize = 512
m = 3

def Generator():
    inputs = tf.keras.Input(shape=(imageSize, imageSize, m))
    f1, p1 = downsample_block(inputs, 64)
    f2, p2 = downsample_block(p1, 128)
    f3, p3 = downsample_block(p2, 256)
    f4, p4 = downsample_block(p3, 512)

    bottleneck = double_conv_block(p4, 1024)

    u6 = upsample_block(bottleneck, f4, 512)
    u7 = upsample_block(u6, f3, 256)
    u8 = upsample_block(u7, f2, 128)
    u9 = upsample_block(u8, f1, 64)
    outputs = tf.keras.layers.Conv2D(3, (1, 1), activation='sigmoid', padding = "same")(u9)

    return tf.keras.Model(inputs, outputs, name="generator")

In [None]:
def Discriminator():
  inp = tf.keras.Input(shape=(imageSize, imageSize, m), name='input_image')
  tar = tf.keras.Input(shape=(imageSize, imageSize, m), name='target_image')
  x = tf.keras.layers.concatenate([inp, tar])

  f1, p1 = downsample_block(x, 32)
  f2, p2 = downsample_block(p1, 64)
  f3, p3 = downsample_block(p2, 128)
  
  outputs = tf.keras.layers.Conv2D(1, (3, 3), activation='sigmoid', kernel_initializer = "he_normal", padding="same")(p3)

  return tf.keras.Model([inp, tar], outputs, name="discriminator")

In [None]:
image_path = '/content/drive/MyDrive/source2'
models_path = '/content/drive/MyDrive/models/'

In [None]:
LAMBDA = 100
bce = tf.keras.losses.BinaryCrossentropy()

def discriminator_loss(disc_real_output, disc_generated_output):
  real_loss = bce(tf.ones_like(disc_real_output), disc_real_output)

  generated_loss = bce(tf.zeros_like(disc_generated_output), disc_generated_output)

  total_disc_loss = real_loss + generated_loss

  return total_disc_loss

def generator_loss(disc_generated_output, gen_output, target):
  gan_loss = bce(tf.ones_like(disc_generated_output), disc_generated_output)

  seg_loss = bce(target, gen_output)

  total_gen_loss = gan_loss + (LAMBDA * seg_loss)

  return total_gen_loss, gan_loss, seg_loss

In [None]:
class GAN(tf.keras.Model):
    def __init__(self, discriminator, generator):
        super(GAN, self).__init__()
        self.discriminator = discriminator
        self.generator = generator

    def compile(self, discriminator_optimizer, generator_optimizer, discriminator_loss, generator_loss):
        super(GAN, self).compile()
        self.discriminator_optimizer = discriminator_optimizer
        self.generator_optimizer = generator_optimizer
        self.discriminator_loss = discriminator_loss
        self.generator_loss = generator_loss

    def call(self, data, training=False):
        pass

    def train_step(self, data):

      input_image, target_image = data

      with tf.GradientTape() as gen_tape, tf.GradientTape() as disc_tape:
        generator_output = self.generator(input_image)
        discriminator_real_output = self.discriminator([input_image, target_image])
        discriminator_generated_output = self.discriminator([input_image, generator_output])
                                                              
        gen_total_loss, gen_gan_loss, gen_seg_loss = self.generator_loss(discriminator_generated_output, generator_output, target_image)
        disc_loss = self.discriminator_loss(discriminator_real_output, discriminator_generated_output)
      
      generator_gradients = gen_tape.gradient(gen_total_loss,
                                          self.generator.trainable_variables)
      discriminator_gradients = disc_tape.gradient(disc_loss,
                                          self.discriminator.trainable_variables)

      self.generator_optimizer.apply_gradients(zip(generator_gradients,
                                          self.generator.trainable_variables))
      self.discriminator_optimizer.apply_gradients(zip(discriminator_gradients,
                                              self.discriminator.trainable_variables))
      return {"generator_loss": gen_total_loss, "discriminator_loss": disc_loss, "gan_loss": gen_gan_loss, "seg_loss": gen_seg_loss}


    def test_step(self, data):

      input_image, target_image = data

      with tf.GradientTape() as gen_tape, tf.GradientTape() as disc_tape:
        generator_output = self.generator(input_image)
        discriminator_real_output = self.discriminator([input_image, target_image])
        discriminator_generated_output = self.discriminator([input_image, generator_output])
                                                              
        gen_total_loss, gen_gan_loss, gen_seg_loss = self.generator_loss(discriminator_generated_output, generator_output, target_image)
        disc_loss = self.discriminator_loss(discriminator_real_output, discriminator_generated_output)
        
      generator_gradients = gen_tape.gradient(gen_total_loss,
                                          self.generator.trainable_variables)
      discriminator_gradients = disc_tape.gradient(disc_loss,
                                          self.discriminator.trainable_variables)

      self.generator_optimizer.apply_gradients(zip(generator_gradients,
                                          self.generator.trainable_variables))
      self.discriminator_optimizer.apply_gradients(zip(discriminator_gradients,
                                              self.discriminator.trainable_variables))
      return {"generator_loss": gen_total_loss, "discriminator_loss": disc_loss}




In [None]:
import itertools
import tensorflow as tf
from speedup import generate_out_images3
import numpy as np
from random import randint, uniform
import imageio
import time


source_num = 2799
dim = 512
stationary_defocus = 0.05


def gen():
    while True:
        layer1_number = randint(0, source_num)
        layer2_number = randint(0, source_num)
        layer3_number = randint(0, source_num)

        src1 = imageio.imread(image_path + '/image' + str(layer1_number).zfill(4) + '.png')
        src2 = imageio.imread(image_path + '/image' + str(layer2_number).zfill(4) + '.png')
        src3 = imageio.imread(image_path + '/image' + str(layer3_number).zfill(4) + '.png')
        src = np.zeros((dim, dim, m), np.double)
        src[:, :, 0] = src1[:, :, 0]
        src[:, :, 1] = src2[:, :, 0]
        src[:, :, 2] = src3[:, :, 0]
        src = src - np.amin(src)
        src = src / np.amax(src)

        w = uniform(0.05, 0.5) 
        
        a_10 = uniform(-1e3, 1e3)
        a_01 = uniform(-1e3, 1e3)
        b_20 = uniform(1, 1.5)
        b_11 = uniform(-0.1, 0.1)
        b_02 = uniform(1, 1.5)
        c_30 = uniform(-1.5e-6, 1.5e-6)
        c_21 = uniform(-2e-6, 2e-6)
        c_12 = uniform(-2e-6, 2e-6)
        c_03 = uniform(-1.5e-6, 1.5e-6)

        out = generate_out_images3(dim, m, w, stationary_defocus, a_10, a_01, b_20, b_11, b_02, c_30, c_21, c_12, c_03, src)[1]

        out = out / np.amax(out)

        src[src > 0] = 1.

        yield (out, src)

generator_optimizer = tf.keras.optimizers.Adam(learning_rate=1e-4)
discriminator_optimizer = tf.keras.optimizers.Adam(learning_rate=1e-5)

tr_dataset = tf.data.Dataset.from_generator(
     gen, (tf.float64, tf.float64), (tf.TensorShape([dim, dim, m]), tf.TensorShape([dim, dim, m])))\
    .batch(batch_size=1).prefetch(buffer_size=8)

val_dataset = tf.data.Dataset.from_generator(
     gen, (tf.float64, tf.float64), (tf.TensorShape([dim, dim, m]), tf.TensorShape([dim, dim, m])))\
    .take(count=64).cache().batch(batch_size=1)

metric = 'val_generator_loss'
save_best_callback = tf.keras.callbacks.ModelCheckpoint(models_path + 'bestmodel_gan.hdf5',
                                                        save_weights_only=True, save_best_only=True, verbose=True, monitor = metric)
csv_logger_callback = tf.keras.callbacks.CSVLogger(models_path + 'log_gan.csv')
#generator_lr_reduce_callback = CustomReduceLRoP(factor=0.5, min_delta=5e-4, patience=5, monitor = 'val_generator_loss', optim_lr = generator_optimizer)
#discriminator_lr_reduce_callback = CustomReduceLRoP(factor=0.5, min_delta=5e-4, patience=5, monitor = 'val_discriminator_loss', optim_lr = discriminator_optimizer)
early_stop_callback = tf.keras.callbacks.EarlyStopping(patience=25, monitor = metric)

discriminator = Discriminator()
generator = Generator()
#generator.load_weights(models_path + 'bestmodel_unet2d_fix_def.hdf5')
model_instance = GAN(discriminator, generator)
model_instance.compile(discriminator_optimizer=discriminator_optimizer,
        generator_optimizer=generator_optimizer,
        discriminator_loss = discriminator_loss,
        generator_loss = generator_loss, )

model_instance.fit(x=tr_dataset, validation_data=val_dataset, verbose=1, validation_steps=64,
                   steps_per_epoch=256, epochs=30,
                   callbacks=[save_best_callback, csv_logger_callback, early_stop_callback])