In [1]:
import tensorflow as tf
import tensorflow_datasets as tfds
from skimage.color import rgb2lab, lab2rgb, rgb2gray
from tqdm import tqdm_notebook

print('TensorFlow', tf.__version__)
print('TensorFlow Datasets', tfds.__version__)

TensorFlow 2.0.0
TensorFlow Datasets 1.3.0


In [3]:
def downscale_conv2D(tensor, n_filters, kernel_size=4, strides=2, name=None, use_bn=True):
    _x = tf.keras.layers.Conv2D(filters=n_filters,
                                kernel_size=kernel_size,
                                strides=strides, 
                                padding='same',
                                use_bias=False,
                                name='downscale_block_' + name + '_conv2d', 
                                activation=None)(tensor)
    if use_bn:
        _x = tf.keras.layers.BatchNormalization(name='downscale_block_' + name + '_bn')(_x)
    _x = tf.keras.layers.LeakyReLU(alpha=0.2, name='downscale_block_' + name + '_lrelu')(_x)
    return _x

def upscale_deconv2d(tensor, n_filters, kernel_size=4, strides=2, name=None):
    _x = tf.keras.layers.Conv2DTranspose(filters=n_filters,
                                         kernel_size=kernel_size,
                                         strides=strides, 
                                         padding='same',
                                         use_bias=False,
                                         name='upscale_block_' + name + '_conv2d', 
                                         activation=None)(tensor)
    _x = tf.keras.layers.BatchNormalization(name='upscale_block_' + name + '_bn')(_x)
    _x = tf.keras.layers.ReLU(name='upscale_block_' + name + '_relu')(_x)
    return _x

def build_generator():
    _input = tf.keras.Input(shape=[256, 256, 1], name='image_input')
    x = downscale_conv2D(_input, 64, strides=1, name='0')
    features = [x]
    for i, n_filters in enumerate([64, 128, 256, 512, 512, 512, 512]):
        x = downscale_conv2D(x, n_filters, name=str(i+1))
        features.append(x)

    for i, n_filters in enumerate([512, 512, 512, 256, 128, 64, 64]):
        x = upscale_deconv2d(x, n_filters, name=str(i+1))
        x = tf.keras.layers.Concatenate()([features[-(i+2)], x])
    _output = tf.keras.layers.Conv2D(filters=3, 
                                     kernel_size=1, 
                                     strides=1, 
                                     padding='same',
                                     name='output_conv2d', 
                                     activation='tanh')(x)
    return tf.keras.Model(inputs=[_input], outputs=[_output], name='Generator')

def build_discriminator():
    _input = tf.keras.Input(shape=[256, 256,  4])
    x = downscale_conv2D(_input, 64, strides=2, name='0', use_bn=False)
    x = downscale_conv2D(x, 128, strides=2, name='1')
    x = downscale_conv2D(x, 256, strides=2, name='2')
    x = downscale_conv2D(x, 512, strides=1, name='3')
    _output = tf.keras.layers.Conv2D(filters=1,
                                     kernel_size=1, 
                                     strides=1, 
                                     padding='same', 
                                     name='output_conv2d', 
                                     activation=None)(x)
    return tf.keras.Model(inputs=[_input], outputs=[_output], name='Discriminator')

In [3]:
class Colorizer:
    def __init__(self, config):
        super(Colorizer, self).__init__()
        self.distribute_strategy = config['distribute_strategy']
#         self.build_models()
        self.epochs = config['epochs']
        self.batch_size = config['batch_size']
        self.build_dataset()
        self.d_optimizer = tf.keras.optimizers.Adam(learning_rate=config['d_lr'])
        self.g_optimizer = tf.keras.optimizers.Adam(learning_rate=config['g_lr'])        
        
        
    def build_models(self):        
        with self.distribute_strategy.scope():
            self.generator = build_generator()
            self.discriminator = build_discriminator()
            
    @staticmethod
    def preprocess_input(image):
        def _preprocess_input(image):
            image_n = image.numpy()
            image_gray = rgb2gray(image_n)
            image_lab = rgb2lab(image_n)
            return image_gray, image_lab
        return tf.py_function(_preprocess_input, [image], [tf.float32, tf.float32])

    def build_dataset(self):
#         dataset = tfds.load(name='places365_small', as_supervised=False)
        images = tf.random.uniform(shape=[1000, 256, 256, 3], maxval=255)
        self.dataset = tf.data.Dataset.from_tensor_slices(images)
        self.dataset = self.dataset.map(Colorizer.preprocess_input,
                                        num_parallel_calls=tf.data.experimental.AUTOTUNE)
        self.datset = self.dataset.batch(self.batch_size, drop_remainder=True)
        self.dataset = self.dataset.prefetch(tf.data.experimental.AUTOTUNE)
        return self.dataset
    
    
    def train(self):
        def distributed_train_step(batch):
            

In [4]:
config = {
    'distribute_strategy': tf.distribute.OneDeviceStrategy(device='/cpu:0'),
    'epochs':20,
    'batch_size':16,
    'd_lr':1e-4,
    'g_lr':1e-4
}

colorizer = Colorizer(config)
colorizer.train()

HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))




In [None]:
tf.distribute.OneDeviceStrategy.experimental_run_v2()