In [1]:
import tensorflow as tf
print('TensorFlow', tf.__version__)

TensorFlow 2.0.0


In [2]:
def downscale_conv2D(tensor, n_filters, kernel_size=4, strides=2, name=None, use_bn=True):
    _x = tf.keras.layers.Conv2D(filters=n_filters,
                                kernel_size=kernel_size,
                                strides=strides, 
                                padding='same',
                                use_bias=False,
                                name='downscale_block_' + name + '_conv2d', 
                                activation=None)(tensor)
    if use_bn:
        _x = tf.keras.layers.BatchNormalization(name='downscale_block_' + name + '_bn')(_x)
    _x = tf.keras.layers.LeakyReLU(alpha=0.2, name='downscale_block_' + name + '_lrelu')(_x)
    return _x

def upscale_deconv2d(tensor, n_filters, kernel_size=4, strides=2, name=None):
    _x = tf.keras.layers.Conv2DTranspose(filters=n_filters,
                                         kernel_size=kernel_size,
                                         strides=strides, 
                                         padding='same',
                                         use_bias=False,
                                         name='upscale_block_' + name + '_conv2d', 
                                         activation=None)(tensor)
    _x = tf.keras.layers.BatchNormalization(name='upscale_block_' + name + '_bn')(_x)
    _x = tf.keras.layers.ReLU(name='upscale_block_' + name + '_relu')(_x)
    return _x

def build_generator():
    _input = tf.keras.Input(shape=[256, 256, 1], name='image_input')
    x = downscale_conv2D(_input, 64, strides=1, name='0')
    features = [x]
    for i, n_filters in enumerate([64, 128, 256, 512, 512, 512, 512]):
        x = downscale_conv2D(x, n_filters, name=str(i+1))
        features.append(x)

    for i, n_filters in enumerate([512, 512, 512, 256, 128, 64, 64]):
        x = upscale_deconv2d(x, n_filters, name=str(i+1))
        x = tf.keras.layers.Concatenate()([features[-(i+2)], x])
    _output = tf.keras.layers.Conv2D(filters=3, 
                                     kernel_size=1, 
                                     strides=1, 
                                     padding='same',
                                     name='output_conv2d', 
                                     activation='tanh')(x)
    return tf.keras.Model(inputs=[_input], outputs=[_output], name='Generator')

def build_discriminator():
    _input = tf.keras.Input(shape=[256, 256,  4])
    x = downscale_conv2D(_input, 64, strides=2, name='0', use_bn=False)
    x = downscale_conv2D(x, 128, strides=2, name='1')
    x = downscale_conv2D(x, 256, strides=2, name='2')
    x = downscale_conv2D(x, 512, strides=1, name='3')
    _output = tf.keras.layers.Conv2D(filters=1,
                                     kernel_size=1, 
                                     strides=1, 
                                     padding='same', 
                                     name='output_conv2d', 
                                     activation=None)(x)
    return tf.keras.Model(inputs=[_input], outputs=[_output], name='Discriminator')

In [3]:
generator = build_generator()
discriminator = build_discriminator()