In [1]:
import tensorflow as tf
import os

In [2]:
class InstanceNormalization(tf.keras.layers.Layer):
  """Instance Normalization Layer (https://arxiv.org/abs/1607.08022)."""

  def __init__(self, epsilon=1e-5):
    super(InstanceNormalization, self).__init__()
    self.epsilon = epsilon

  def build(self, input_shape):
    self.scale = self.add_weight(
        name='scale',
        shape=input_shape[-1:],
        initializer=tf.random_normal_initializer(1., 0.02),
        trainable=True)

    self.offset = self.add_weight(
        name='offset',
        shape=input_shape[-1:],
        initializer='zeros',
        trainable=True)

  def call(self, x):
    mean, variance = tf.nn.moments(x, axes=[1, 2], keepdims=True)
    inv = tf.math.rsqrt(variance + self.epsilon)
    normalized = (x - mean) * inv
    return self.scale * normalized + self.offset

In [3]:
def downsample(filters, size, norm_type='instancenorm',apply_norm=True,  last=False):
    #Random Initialization of parameters
    initializer = tf.random_normal_initializer(0., 0.02)

    result = tf.keras.Sequential()
    result.add(
        tf.keras.layers.Conv2D(filters, size, strides=2, padding='same',
                                kernel_initializer=initializer, use_bias=False))

    if apply_norm:
        if norm_type.lower() == 'batchnorm':
            result.add(tf.keras.layers.BatchNormalization())
        elif norm_type.lower() == 'instancenorm':
            result.add(InstanceNormalization())
        result.add(tf.keras.layers.LeakyReLU())
    
    if last:
        result.add(tf.keras.layers.Concatenate((1,1,1,512)))
    
    return result



def upsample(filters, size, norm_type='batchnorm', apply_dropout=False):
    initializer = tf.random_normal_initializer(0., 0.02)

    result = tf.keras.Sequential()
    result.add(
    tf.keras.layers.Conv2DTranspose(filters, size, strides=2,
                                    padding='same',
                                    kernel_initializer=initializer,
                                    use_bias=False))

    result.add(tf.keras.layers.BatchNormalization())

    if norm_type.lower() == 'batchnorm':
        result.add(tf.keras.layers.BatchNormalization())
    elif norm_type.lower() == 'instancenorm':
        result.add(InstanceNormalization())
    result.add(tf.keras.layers.LeakyReLU())

    return result

In [4]:
def unetGenerator256():
  norm_type = 'instancenorm'
  inputs = tf.keras.layers.Input(shape=[256, 256, 6])

  down_stack = [
      
    downsample(256,   4, apply_norm=False),  # (bs, 512, 512, 64)
    downsample(512,  4, norm_type),  # (bs, 256, 256, 128)
    downsample(512,  4, norm_type),  # (bs, 128, 128, 256
    downsample(512,  4, norm_type),  # (bs, 64, 64, 512)
    downsample(512,  4, norm_type),  # (bs, 32, 32, 1024)
    downsample(512,  4, norm_type),  # (bs, 16, 16, 2048)
  
  ]

  up_stack = [

    upsample(512, 4, apply_dropout=True),  # (bs, 4, 4, 1024)
    upsample(512, 4, apply_dropout=True),  # (bs, 8, 8, 1024)
    upsample(512, 4),  # (bs, 16, 16, 1024)
    upsample(512, 4),  # (bs, 32, 32, 512)
    upsample(256, 4),  # (bs, 32, 32, 512)
    upsample(128, 4),  # (bs, 64, 64, 256)
    upsample(64, 4),  # (bs, 128, 128, 128)
      
  ]

  initializer = tf.random_normal_initializer(0., 0.02)
  last = tf.keras.layers.Conv2DTranspose(3, 4,
                                         strides=2,
                                         padding='same',
                                         kernel_initializer=initializer,
                                         activation='tanh')  # (bs, 256, 256, 3)

  x = inputs

  # Downsampling through the model
  skips = []
  for down in down_stack:
    x = down(x)
#     if down==down_stack[-1]:
#         x = tf.keras.layers.Concatenate()
    skips.append(x)

  skips = reversed(skips[:-1])

  # Upsampling and establishing the skip connections
  for up, skip in zip(up_stack, skips):
    x = up(x)
    x = tf.keras.layers.Concatenate()([x, skip])
#     x = tf.keras.layers.Concatenate()([x, datainp])

  x = last(x)

  return tf.keras.Model(inputs=inputs, outputs=x)
    

In [5]:
def unetGenerator512():
  norm_type = 'instancenorm'
  inputs = tf.keras.layers.Input(shape=[512, 512, 6])
  conds  = tf.keras.layers.Input(shape=[1,1,4])

  down_stack = [
      
    downsample(256,   4, apply_norm=False),  # (bs, 512, 512, 64)
    downsample(512,  4, norm_type),  # (bs, 256, 256, 128)
    downsample(512,  4, norm_type),  # (bs, 128, 128, 256
    downsample(512,  4, norm_type),  # (bs, 64, 64, 512)
    downsample(512,  4, norm_type),  # (bs, 32, 32, 1024)
    downsample(1024,  4, norm_type),  # (bs, 16, 16, 2048)
  
  ]

  up_stack = [

    upsample(512, 4, apply_dropout=True),  # (bs, 4, 4, 1024)
    upsample(512, 4, apply_dropout=True),  # (bs, 8, 8, 1024)
    upsample(512, 4),  # (bs, 16, 16, 1024)
    upsample(512, 4),  # (bs, 32, 32, 512)
    upsample(256, 4),  # (bs, 32, 32, 512)
    upsample(128, 4),  # (bs, 64, 64, 256)
    upsample(64, 4),  # (bs, 128, 128, 128)
      
  ]

  initializer = tf.random_normal_initializer(0., 0.02)
  last = tf.keras.layers.Conv2DTranspose(3, 4,
                                         strides=2,
                                         padding='same',
                                         kernel_initializer=initializer,
                                         activation='tanh')  # (bs, 256, 256, 3)

  x = inputs
  # Downsampling through the model
  skips = []
  for down in down_stack:
    x = down(x)
     if down==down_stack[-1]:
         x = tf.keras.layers.Concatenate(x,conds)
    skips.append(x)

  skips = reversed(skips[:-1])

  # Upsampling and establishing the skip connections
  for up, skip in zip(up_stack, skips):
    x = up(x)
    if up==up_stack[0]:
        x = tf.keras.layers.Concatenate(x,conds)
    x = tf.keras.layers.Concatenate()([x, skip])
#     x = tf.keras.layers.Concatenate()([x, datainp])

  x = last(x)

  return tf.keras.Model(inputs=inputs, outputs=x)

In [6]:
generator = unetGenerator512()

In [None]:
import os
# with tf.device('/CPU:0'):
os.environ["PATH"] += os.pathsep + 'C:/Program Files/Graphviz/bin/'
generator = unetGenerator512()
tf.keras.utils.plot_model(generator, show_shapes=True, dpi=64)
dot_img_file = 'model_1.png'
tf.keras.utils.plot_model(generator, to_file=dot_img_file, show_shapes=True,dpi=256)

In [6]:
def patchDiscriminator():
    norm_type = 'instancenorm'
    initializer = tf.random_normal_initializer(0., 0.02)

    inp = tf.keras.layers.Input(shape=[256, 256, 6], name='input_image')
    x = inp

    
    tar = tf.keras.layers.Input(shape=[256, 256, 3], name='target_image')
    x = tf.keras.layers.concatenate([inp, tar])  # (bs, 256, 256, channels*2)

    down1 = downsample(64,   4, norm_type, False)(x)  # (bs, 128, 128, 64)
    down2 = downsample(128,  4, norm_type)(down1)  # (bs, 64, 64, 128)
    down3 = downsample(256,  4, norm_type)(down2)  # (bs, 32, 32, 256)

    zero_pad1 = tf.keras.layers.ZeroPadding2D()(down3)  # (bs, 34, 34, 256)
    
    conv = tf.keras.layers.Conv2D(512, 4, strides=1,
                                kernel_initializer=initializer,
                                use_bias=False)(zero_pad1)  # (bs, 31, 31, 512)

    if norm_type.lower() == 'batchnorm':
        norm1 = tf.keras.layers.BatchNormalization()(conv)
    elif norm_type.lower() == 'instancenorm':
        norm1 = InstanceNormalization()(conv)

    leaky_relu = tf.keras.layers.LeakyReLU()(norm1)

    zero_pad2 = tf.keras.layers.ZeroPadding2D()(leaky_relu)  # (bs, 33, 33, 512)

    last = tf.keras.layers.Conv2D(1, 4, strides=1,
                                kernel_initializer=initializer)(zero_pad2)  # (bs, 30, 30, 1)

    return tf.keras.Model(inputs=[inp, tar], outputs=last)


In [7]:
def patchDiscriminator512():
    norm_type = 'instancenorm'
    initializer = tf.random_normal_initializer(0., 0.02)

    inp = tf.keras.layers.Input(shape=[512, 512, 6], name='input_image')
    x = inp

    
    tar = tf.keras.layers.Input(shape=[512, 512, 3], name='target_image')
    x = tf.keras.layers.concatenate([inp, tar])  # (bs, 256, 256, channels*2)

    down1 = downsample(64,   4, norm_type, False)(x)  # (bs, 128, 128, 64)
    down2 = downsample(128,  4, norm_type)(down1)  # (bs, 64, 64, 128)
    down3 = downsample(256,  4, norm_type)(down2)  # (bs, 32, 32, 256)
    down4 = downsample(512,  4, norm_type)(down3)  # (bs, 32, 32, 256)

    zero_pad1 = tf.keras.layers.ZeroPadding2D()(down4)  # (bs, 34, 34, 256)
    
    conv = tf.keras.layers.Conv2D(1024, 4, strides=1,
                                kernel_initializer=initializer,
                                use_bias=False)(zero_pad1)  # (bs, 31, 31, 512)

    if norm_type.lower() == 'batchnorm':
        norm1 = tf.keras.layers.BatchNormalization()(conv)
    elif norm_type.lower() == 'instancenorm':
        norm1 = InstanceNormalization()(conv)

    leaky_relu = tf.keras.layers.LeakyReLU()(norm1)

    zero_pad2 = tf.keras.layers.ZeroPadding2D()(leaky_relu)  # (bs, 33, 33, 512)

    last = tf.keras.layers.Conv2D(1, 4, strides=1,
                                kernel_initializer=initializer)(zero_pad2)  # (bs, 30, 30, 1)

    return tf.keras.Model(inputs=[inp, tar], outputs=last)

In [None]:
import os
# with tf.device('/CPU:0'):
os.environ["PATH"] += os.pathsep + 'C:/Program Files/Graphviz/bin/'
discriminator = patchDiscriminator512()
tf.keras.utils.plot_model(discriminator, show_shapes=True, dpi=256)

In [8]:
def classicDiscriminator():
    norm_type = 'instancenorm'
    initializer = tf.random_normal_initializer(0., 0.02)

    inp = tf.keras.layers.Input(shape=[256, 256, 3], name='input_image')
    x = inp
    
    down1 = downsample(64,   4, norm_type, False)(x)  # (bs, 128, 128, 64)
    down2 = downsample(128,  4, norm_type)(down1)  # (bs, 64, 64, 128)
    down3 = downsample(256,  4, norm_type)(down2)  # (bs, 32, 32, 256)
    down4 = downsample(256,  4, norm_type)(down3)
    down5 = downsample(1,    4, norm_type)(down4)
    
    down6 = tf.keras.layers.Flatten()(down5)
    last  = tf.keras.layers.Dense(1)(down6)
    
    return tf.keras.Model(inputs=inp, outputs=last)

In [9]:
def classicDiscriminator512():
    norm_type = 'instancenorm'
    initializer = tf.random_normal_initializer(0., 0.02)

    inp = tf.keras.layers.Input(shape=[512, 512, 3], name='input_image')
    x = inp
    
    down1 = downsample(64,   4, norm_type, False)(x)  # (bs, 128, 128, 64)
    down2 = downsample(128,  4, norm_type)(down1)  # (bs, 64, 64, 128)
    down3 = downsample(256,  4, norm_type)(down2)  # (bs, 32, 32, 256)
    down4 = downsample(512,  4, norm_type)(down3)
    down5 = downsample(512,  4, norm_type)(down4)
    down6 = downsample(1,    4, norm_type)(down5)
    
    down7 = tf.keras.layers.Flatten()(down6)
    last  = tf.keras.layers.Dense(1)(down7)
    
    return tf.keras.Model(inputs=inp, outputs=last)

In [1]:
def classicGenerator():
    return null