In [1]:
!pip install tensorflow-gpu==2.0.0
import tensorflow as tf
import tensorflow_datasets as tfds
from skimage.color import rgb2lab, lab2rgb, rgb2gray
from tqdm import tqdm_notebook

print('TensorFlow', tf.__version__)
print('TensorFlow Datasets', tfds.__version__)

Collecting tensorflow-gpu==2.0.0
[?25l  Downloading https://files.pythonhosted.org/packages/25/44/47f0722aea081697143fbcf5d2aa60d1aee4aaacb5869aee2b568974777b/tensorflow_gpu-2.0.0-cp36-cp36m-manylinux2010_x86_64.whl (380.8MB)
[K     |████████████████████████████████| 380.8MB 40kB/s 
Collecting tensorboard<2.1.0,>=2.0.0
[?25l  Downloading https://files.pythonhosted.org/packages/9b/a6/e8ffa4e2ddb216449d34cfcb825ebb38206bee5c4553d69e7bc8bc2c5d64/tensorboard-2.0.0-py3-none-any.whl (3.8MB)
[K     |████████████████████████████████| 3.8MB 33.9MB/s 
Collecting tensorflow-estimator<2.1.0,>=2.0.0
[?25l  Downloading https://files.pythonhosted.org/packages/fc/08/8b927337b7019c374719145d1dceba21a8bb909b93b1ad6f8fb7d22c1ca1/tensorflow_estimator-2.0.1-py2.py3-none-any.whl (449kB)
[K     |████████████████████████████████| 450kB 58.3MB/s 
[31mERROR: tensorflow 1.15.0 has requirement tensorboard<1.16.0,>=1.15.0, but you'll have tensorboard 2.0.0 which is incompatible.[0m
[31mERROR: tensorflow 1

In [0]:
def downscale_conv2D(tensor, n_filters, kernel_size=4, strides=2, name=None, use_bn=True):
    _x = tf.keras.layers.Conv2D(filters=n_filters,
                                kernel_size=kernel_size,
                                strides=strides, 
                                padding='same',
                                use_bias=False,
                                name='downscale_block_' + name + '_conv2d', 
                                activation=None)(tensor)
    if use_bn:
        _x = tf.keras.layers.BatchNormalization(name='downscale_block_' + name + '_bn')(_x)
    _x = tf.keras.layers.LeakyReLU(alpha=0.2, name='downscale_block_' + name + '_lrelu')(_x)
    return _x

def upscale_deconv2d(tensor, n_filters, kernel_size=4, strides=2, name=None):
    _x = tf.keras.layers.Conv2DTranspose(filters=n_filters,
                                         kernel_size=kernel_size,
                                         strides=strides, 
                                         padding='same',
                                         use_bias=False,
                                         name='upscale_block_' + name + '_conv2d', 
                                         activation=None)(tensor)
    _x = tf.keras.layers.BatchNormalization(name='upscale_block_' + name + '_bn')(_x)
    _x = tf.keras.layers.ReLU(name='upscale_block_' + name + '_relu')(_x)
    return _x

def build_generator():
    _input = tf.keras.Input(shape=[256, 256, 1], name='image_input')
    x = downscale_conv2D(_input, 64, strides=1, name='0')
    features = [x]
    for i, n_filters in enumerate([64, 128, 256, 512, 512, 512, 512]):
        x = downscale_conv2D(x, n_filters, name=str(i+1))
        features.append(x)

    for i, n_filters in enumerate([512, 512, 512, 256, 128, 64, 64]):
        x = upscale_deconv2d(x, n_filters, name=str(i+1))
        x = tf.keras.layers.Concatenate()([features[-(i+2)], x])
    _output = tf.keras.layers.Conv2D(filters=3, 
                                     kernel_size=1, 
                                     strides=1, 
                                     padding='same',
                                     name='output_conv2d', 
                                     activation='tanh')(x)
    return tf.keras.Model(inputs=[_input], outputs=[_output], name='Generator')

def build_discriminator():
    _input = tf.keras.Input(shape=[256, 256,  4])
    x = downscale_conv2D(_input, 64, strides=2, name='0', use_bn=False)
    x = downscale_conv2D(x, 128, strides=2, name='1')
    x = downscale_conv2D(x, 256, strides=2, name='2')
    x = downscale_conv2D(x, 512, strides=1, name='3')
    _output = tf.keras.layers.Conv2D(filters=1,
                                     kernel_size=1, 
                                     strides=1, 
                                     padding='same', 
                                     name='output_conv2d', 
                                     activation=None)(x)
    return tf.keras.Model(inputs=[_input], outputs=[_output], name='Discriminator')

In [0]:
class Colorizer:
    def __init__(self, config):
        super(Colorizer, self).__init__()
        self.distribute_strategy = config['distribute_strategy']
#         self.build_models()
        self.epochs = config['epochs']
        self.batch_size = config['batch_size']
        self.build_dataset()
        self.d_optimizer = tf.keras.optimizers.Adam(learning_rate=config['d_lr'])
        self.g_optimizer = tf.keras.optimizers.Adam(learning_rate=config['g_lr'])        
        self.bce_smooth = tf.keras.losses.BinaryCrossentropy(from_logits=True, label_smoothing=0.1)
        self.bce = tf.keras.losses.BinaryCrossentropy(from_logits=True)

        
    def build_models(self):        
        with self.distribute_strategy.scope():
            self.generator = build_generator()
            self.discriminator = build_discriminator()
            
    @staticmethod
    def preprocess_input(image):
        def _preprocess_input(image):
            image_n = image.numpy()
            image_gray = rgb2gray(image_n)
            image_lab = rgb2lab(image_n)
            return image_gray, image_lab
        return tf.py_function(_preprocess_input, [image], [tf.float32, tf.float32])

    def build_dataset(self):
#         dataset = tfds.load(name='places365_small', as_supervised=False)
        images = tf.random.uniform(shape=[1000, 256, 256, 3], maxval=255)
        self.dataset = tf.data.Dataset.from_tensor_slices(images)
        self.dataset = self.dataset.map(Colorizer.preprocess_input,
                                        num_parallel_calls=tf.data.experimental.AUTOTUNE)
        self.datset = self.dataset.batch(self.batch_size, drop_remainder=True)
        self.dataset = self.dataset.prefetch(tf.data.experimental.AUTOTUNE)

    def loss_G(self, fake_logits):
        loss = self.bce(tf.ones_like(fake_logits), fake_logits)
        return loss

    def loss_D_real(self, real_logits):
        '''Discriminator loss, real images'''
        real_loss = self.bce_smooth(tf.ones_like(real_logits), real_logits)
        return real_loss

    def loss_D_fake(self, fake_logits):
        fake_loss = bce_loss_fn(tf.zeros_like(fake_logits), fake_logits)
        return fake_loss
    
    def train(self):
        def train_step(grayscale_image, lab_image):
            real_input = tf.concat([grayscale_image, lab_image], axis=-1)
            with tf.GradientTape() as r_tape:
                real_logits = self.discriminator(real_input, training=True)
                d_real_loss = self.loss_D_real(real_logits)
            d_r_gradients = r_tape.gradient(d_real_loss, self.discriminator.trainable_variables)
            self.d_optimizer.apply_gradients(zip(d_r_gradients, self.discriminator.trainable_variables))

            with tf.GradientTape() as g_tape, tf.GradientTape() as d_tape:
                fake_image = self.generator(grayscale_image, training=True)
                fake_input = tf.concat([grayscale_image, fake_image], axis=-1)
                fake_logits = self.discriminator(real_input, training=True)

                d_fake_loss = self.loss_D_fake(fake_logits)
                g_loss = self.loss_G(fake_logits)
            d_f_gradients = d_tape.gradient(d_fake_loss, self.discriminator.trainable_variables)
            g_gradients = g_tape.gradient(g_loss, self.generator.trainable_variables)
            self.d_optimizer.apply_gradients(zip(d_f_gradients, self.discriminator.trainable_variables))
            self.g_optimizer.apply_gradients(zip(g_gradients, self.generator.trainable_variables))
            return d_real_loss, d_fake_loss, g_loss

        def distributed_train_step(grayscale_image, lab_image):
            per_replica_loss =  self.strategy.experimental_run_v2(fn=train_step, args=(grayscale_image, lab_image))
 '''
 To Do 
1. Reduce losses
2. Complete Training loop
3. Log metrics to Tensorboard

In [0]:
config = {
    'distribute_strategy': tf.distribute.OneDeviceStrategy(device='/gpu:0'),
    'epochs':20,
    'batch_size':16,
    'd_lr':1e-4,
    'g_lr':1e-4
}

colorizer = Colorizer(config)
colorizer.train()