<a href="https://colab.research.google.com/github/stmeinert/Recolorization_IANN/blob/main/Iizuka_nb.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

Open questions that are not described in the paper:


*   How to transition from Conv2D- to Dense-Layer in Global Features Network?
*   Where is BatchNormalization applied?
*   What interpolation and cropping in Resizing-Layer?



# Imports

In [143]:
import tensorflow as tf
import tensorflow_datasets as tfds
import tqdm
!pip install tensorflow-io
import tensorflow_io as tfio
import time



# Model

In [144]:
class LowLevelFeatNet(tf.keras.layers.Layer):

    def __init__(self): 
        super(LowLevelFeatNet, self).__init__()
        self.net_layers = []
        self.net_layers.append(tf.keras.layers.Conv2D(filters=64, kernel_size=(3,3), strides=(2,2), padding='same'))
        self.net_layers.append(tf.keras.layers.Activation(tf.nn.relu))
        self.net_layers.append(tf.keras.layers.BatchNormalization())
        self.net_layers.append(tf.keras.layers.Conv2D(filters=128, kernel_size=(3,3), strides=(1,1), padding='same'))
        self.net_layers.append(tf.keras.layers.Activation(tf.nn.relu))
        self.net_layers.append(tf.keras.layers.BatchNormalization())
        self.net_layers.append(tf.keras.layers.Conv2D(filters=128, kernel_size=(3,3), strides=(2,2), padding='same'))
        self.net_layers.append(tf.keras.layers.Activation(tf.nn.relu))
        self.net_layers.append(tf.keras.layers.BatchNormalization())
        self.net_layers.append(tf.keras.layers.Conv2D(filters=256, kernel_size=(3,3), strides=(1,1), padding='same'))
        self.net_layers.append(tf.keras.layers.Activation(tf.nn.relu))
        self.net_layers.append(tf.keras.layers.BatchNormalization())
        self.net_layers.append(tf.keras.layers.Conv2D(filters=156, kernel_size=(3,3), strides=(2,2), padding='same'))
        self.net_layers.append(tf.keras.layers.Activation(tf.nn.relu))
        self.net_layers.append(tf.keras.layers.BatchNormalization())
        self.net_layers.append(tf.keras.layers.Conv2D(filters=512, kernel_size=(3,3), strides=(1,1), padding='same'))
        self.net_layers.append(tf.keras.layers.Activation(tf.nn.relu))
        self.net_layers.append(tf.keras.layers.BatchNormalization())


    @tf.function
    def call(self, x, training=False):
        for layer in self.net_layers:
            x = layer(x, training=training)
        return x


In [145]:
class MidLevelFeatNet(tf.keras.layers.Layer):

    def __init__(self): 
        super(MidLevelFeatNet, self).__init__()
        self.net_layers = []
        self.net_layers.append(tf.keras.layers.Conv2D(filters=512, kernel_size=(3,3), strides=(1,1), padding='same'))
        self.net_layers.append(tf.keras.layers.Activation(tf.nn.relu))
        self.net_layers.append(tf.keras.layers.BatchNormalization())
        self.net_layers.append(tf.keras.layers.Conv2D(filters=256, kernel_size=(3,3), strides=(1,1), padding='same'))
        self.net_layers.append(tf.keras.layers.Activation(tf.nn.relu))
        self.net_layers.append(tf.keras.layers.BatchNormalization())


    @tf.function
    def call(self, x, training=False):
        for layer in self.net_layers:
            x = layer(x, training=training)
        return x

In [146]:
class GlobalFeatNet(tf.keras.layers.Layer):

    def __init__(self): 
        super(GlobalFeatNet, self).__init__()
        self.net_layers = []
        self.net_layers.append(tf.keras.layers.Conv2D(filters=512, kernel_size=(3,3), strides=(2,2), padding='same'))
        self.net_layers.append(tf.keras.layers.Activation(tf.nn.relu))
        self.net_layers.append(tf.keras.layers.BatchNormalization())
        self.net_layers.append(tf.keras.layers.Conv2D(filters=512, kernel_size=(3,3), strides=(1,1), padding='same'))
        self.net_layers.append(tf.keras.layers.Activation(tf.nn.relu))
        self.net_layers.append(tf.keras.layers.BatchNormalization())

        self.net_layers.append(tf.keras.layers.Conv2D(filters=512, kernel_size=(3,3), strides=(2,2), padding='same'))
        self.net_layers.append(tf.keras.layers.Activation(tf.nn.relu))
        self.net_layers.append(tf.keras.layers.BatchNormalization())
        self.net_layers.append(tf.keras.layers.Conv2D(filters=512, kernel_size=(3,3), strides=(1,1), padding='same'))
        self.net_layers.append(tf.keras.layers.Activation(tf.nn.relu))
        self.net_layers.append(tf.keras.layers.BatchNormalization())

        # NOTE: Paper does not specify how to transition from Conv2D- to Dense-Layer (Flatten causes number of variables to explode)
        self.net_layers.append(tf.keras.layers.GlobalMaxPooling2D())
        self.net_layers.append(tf.keras.layers.Dense(units=1024))
        self.net_layers.append(tf.keras.layers.Activation(tf.nn.relu))
        self.net_layers.append(tf.keras.layers.BatchNormalization())
        self.net_layers.append(tf.keras.layers.Dense(units=512))
        self.net_layers.append(tf.keras.layers.Activation(tf.nn.relu))
        self.net_layers.append(tf.keras.layers.BatchNormalization())
        self.net_layers.append(tf.keras.layers.Dense(units=256))
        self.net_layers.append(tf.keras.layers.Activation(tf.nn.relu))
        self.net_layers.append(tf.keras.layers.BatchNormalization())

    @tf.function
    def call(self, x, training=False):
        for layer in self.net_layers:
            x = layer(x, training=training)
        return x

In [147]:
class FusionLayer(tf.keras.layers.Layer):

    def __init__(self): 
        super(FusionLayer, self).__init__()

    @tf.function
    def call(self, x, training=False):
        """ Implementation of a similar approach can be found in https://github.com/baldassarreFe/deep-koalarization/blob/master/src/koalarization/fusion_layer.py """
        imgs, embs = x
        reshaped_shape = imgs.shape[:3].concatenate(embs.shape[1])
        embs = tf.repeat(embs, imgs.shape[1] * imgs.shape[2])
        embs = tf.reshape(embs, reshaped_shape)
        return tf.concat([imgs, embs], axis=3)

In [148]:
class ColorizationNet(tf.keras.layers.Layer):

    def __init__(self): 
        super(ColorizationNet, self).__init__()
        self.net_layers = []
        self.net_layers.append(FusionLayer())
        self.net_layers.append(tf.keras.layers.Conv2D(128, kernel_size=(3,3), strides=(1,1), padding='same'))
        self.net_layers.append(tf.keras.layers.Activation(tf.nn.relu))
        self.net_layers.append(tf.keras.layers.BatchNormalization())

        self.net_layers.append(tf.keras.layers.UpSampling2D(size=(2,2), data_format='channels_last', interpolation='nearest'))
        self.net_layers.append(tf.keras.layers.Conv2D(64, kernel_size=(3,3), strides=(1,1), padding='same'))
        self.net_layers.append(tf.keras.layers.Activation(tf.nn.relu))
        self.net_layers.append(tf.keras.layers.BatchNormalization())
        self.net_layers.append(tf.keras.layers.Conv2D(64, kernel_size=(3,3), strides=(1,1), padding='same'))
        self.net_layers.append(tf.keras.layers.Activation(tf.nn.relu))
        self.net_layers.append(tf.keras.layers.BatchNormalization())

        self.net_layers.append(tf.keras.layers.UpSampling2D(size=(2,2), data_format='channels_last', interpolation='nearest'))
        self.net_layers.append(tf.keras.layers.Conv2D(32, kernel_size=(3,3), strides=(1,1), padding='same'))
        self.net_layers.append(tf.keras.layers.Activation(tf.nn.relu))
        self.net_layers.append(tf.keras.layers.BatchNormalization())
        self.net_layers.append(tf.keras.layers.Conv2D(2, kernel_size=(3,3), strides=(1,1), padding='same'))
        self.net_layers.append(tf.keras.layers.Activation(tf.nn.sigmoid))

    @tf.function
    def call(self, x, training=False):
        for layer in self.net_layers:
            x = layer(x, training=training)
        return x

In [149]:
class IizukaRecolorizationModel(tf.keras.Model):

    def __init__(self): 
        super(IizukaRecolorizationModel, self).__init__()

        self.rescale = tf.keras.layers.Resizing(224, 224, interpolation='nearest', crop_to_aspect_ratio=True)
        self.low = LowLevelFeatNet()
        self.mid = MidLevelFeatNet()
        self.glob = GlobalFeatNet()
        self.colorize = ColorizationNet()
        self.upS = tf.keras.layers.UpSampling2D(size=(2,2), data_format='channels_last', interpolation='nearest')

        self.optimizer = tf.keras.optimizers.Adadelta()
        self.loss_function = tf.keras.losses.MeanSquaredError()
        self.metrics_list = [
                        tf.keras.metrics.Mean(name="loss"),
                        # tf.keras.metrics.CategoricalAccuracy(name="acc"),
                        # tf.keras.metrics.TopKCategoricalAccuracy(3,name="top-3-acc") 
                        ]

    @tf.function
    def call(self, x):
        re = self.rescale(x)
        l1 = self.low(re)
        g = self.glob(l1)

        l2 = self.low(x)
        m = self.mid(l2)

        c = self.colorize((m,g))
        return self.upS(c)

    
    def reset_metrics(self):
        
        for metric in self.metrics:
            metric.reset_states()
            
    @tf.function
    def train_step(self, data):
        
        x, targets = data
        
        with tf.GradientTape() as tape:
            predictions = self(x, training=True)
            
            loss = self.loss_function(targets, predictions) + tf.reduce_sum(self.losses)
        
        gradients = tape.gradient(loss, self.trainable_variables)
        self.optimizer.apply_gradients(zip(gradients, self.trainable_variables))
        
        # update loss metric
        self.metrics[0].update_state(loss)
        
        # for all metrics except loss, update states (accuracy etc.)
        for metric in self.metrics[1:]:
            metric.update_state(targets,predictions)

        # Return a dictionary mapping metric names to current value
        return {m.name: m.result() for m in self.metrics}

    @tf.function
    def test_step(self, data):

        x, targets = data
        
        predictions = self(x, training=False)
        
        loss = self.loss_function(targets, predictions) + tf.reduce_sum(self.losses)
        
        self.metrics[0].update_state(loss)
        
        for metric in self.metrics[1:]:
            metric.update_state(targets, predictions)

        return {m.name: m.result() for m in self.metrics}

# Model Summary

In [150]:
# testing first model input
myinput = tf.random.uniform(shape=(1,128,128,1), minval=0, maxval=None, dtype=tf.dtypes.float32, seed=None, name=None)
print(myinput)
mymodel = IizukaRecolorizationModel()
print(mymodel(myinput))
print(mymodel.summary())

tf.Tensor(
[[[[0.0275408 ]
   [0.46375656]
   [0.1423949 ]
   ...
   [0.7775687 ]
   [0.6604465 ]
   [0.9971471 ]]

  [[0.86306024]
   [0.6645349 ]
   [0.99020994]
   ...
   [0.07342672]
   [0.5866457 ]
   [0.52783   ]]

  [[0.36038697]
   [0.5252266 ]
   [0.02579057]
   ...
   [0.64316726]
   [0.12055314]
   [0.8900324 ]]

  ...

  [[0.1611799 ]
   [0.7792616 ]
   [0.42649996]
   ...
   [0.5582876 ]
   [0.23931122]
   [0.02001822]]

  [[0.16111386]
   [0.95230985]
   [0.98579085]
   ...
   [0.42401528]
   [0.82885313]
   [0.7420373 ]]

  [[0.05723846]
   [0.3443383 ]
   [0.21857095]
   ...
   [0.08773398]
   [0.3194325 ]
   [0.667907  ]]]], shape=(1, 128, 128, 1), dtype=float32)
tf.Tensor(
[[[[0.4999587  0.50014246]
   [0.4999587  0.50014246]
   [0.49997485 0.5001073 ]
   ...
   [0.49987736 0.5000309 ]
   [0.49999803 0.49996227]
   [0.49999803 0.49996227]]

  [[0.4999587  0.50014246]
   [0.4999587  0.50014246]
   [0.49997485 0.5001073 ]
   ...
   [0.49987736 0.5000309 ]
   [0.49999803

# Preprocessing

In [151]:

SIZE = (128,128)
BATCH_SIZE = 128

#################################################
# Prepare data
#################################################

@tf.function
def resize(image):
    return tf.image.resize_with_pad(image, target_height=SIZE[0], target_width=SIZE[1], method=tf.image.ResizeMethod.BILINEAR)


@tf.function
def to_lab(image):
    # expects input to be normalized to [0;1]!!
    # output channels are [l,a,b]
    return tfio.experimental.color.rgb_to_lab(image)


@tf.function
def to_grayscale(image):
    # take l channel (size index starts at one^^)
    image = tf.slice(image, begin=[0, 0, 0], size=[-1, -1, 1])
    return image

# @tf.function
def prepare_image_data(image_ds):
    # resize image to desired dimension, replace label with colored image
    image_ds = image_ds.map(lambda x: (resize(x['image']), resize(x['image'])))

    # normalize data to [0;1) for lab encoder
    image_ds = image_ds.map(lambda image, target: ((image/256), (target/256)))

    # convert image and target image to lab color space
    image_ds = image_ds.map(lambda image, target: (to_lab(image), to_lab(target)))

    # throw away L-dimension in target
    # TODO: do slicing with tensorflow so that function can have decorator
    image_ds = image_ds.map(lambda image, target: (image, target[:,:, :-2]))
    # image_ds = image_ds.map(lambda image, target: (image, tf.slice(target, begin=[0, 0, 1], size=[-1, -1, 2])))

    # only take l channel of input tensor
    image_ds = image_ds.map(lambda image, target: (to_grayscale(image), target))

    # l in lab is in [0;100] -> normalize to [0;1]/[-1;1]?
    # ab are in range [-128;127]
    image_ds = image_ds.map(lambda image, target: ((image/50)-1, target))

    image_ds = image_ds.shuffle(1000).batch(BATCH_SIZE)#.prefetch(20)
    return image_ds



# Tensorboard Stuff

In [None]:
# load tensorboard extension
%load_ext tensorboard

# to clear all logs use this line:
!rm -rf ./logs/

train_log_path = f"logs/train"
val_log_path = f"logs/val"
img_test_log_path = f"logs/img_test"

# log writer for training metrics
train_summary_writer = tf.summary.create_file_writer(train_log_path)

# log writer for validation metrics
val_summary_writer = tf.summary.create_file_writer(val_log_path)

img_test_summary_writer = tf.summary.create_file_writer(img_test_log_path)

# show tensorboard
%tensorboard --logdir logs/

# Main

In [None]:

TRAIN_IMAGES = ''
TEST_IMAGES = '10'
VAL_IMAGES = ''

train_ds, val_ds = tfds.load("imagenette", split=(f'train[:{TRAIN_IMAGES}]', f'validation[:{VAL_IMAGES}]'), as_supervised=False)

train_ds = train_ds.apply(prepare_image_data)
val_ds = val_ds.apply(prepare_image_data)


model = IizukaRecolorizationModel()
EPOCHS = 10

for epoch in range(EPOCHS):
    
    print(f"Epoch {epoch}:")
    start = time.time()

    # Training:
    
    for data in tqdm.notebook.tqdm(train_ds,position=0, leave=True):
        metrics = model.train_step(data)

    end = time.time()
    
    # print the metrics
    print(f"Training took {end-start} seconds.")
    print([f"{key}: {value}" for (key, value) in zip(list(metrics.keys()), list(metrics.values()))])
    
    # logging the validation metrics to the log file which is used by tensorboard
    with train_summary_writer.as_default():
        for metric in model.metrics:
            tf.summary.scalar(f"{metric.name}", metric.result(), step=epoch)
    
    # reset all metrics (requires a reset_metrics method in the model)
    model.reset_metrics()
    
    
    # Validation:
    
    for data in tqdm.notebook.tqdm(val_ds,position=0, leave=True):
        metrics = model.test_step(data)
    
    print([f"val_{key}: {value}" for (key, value) in zip(list(metrics.keys()), list(metrics.values()))])
    
    # logging the validation metrics to the log file which is used by tensorboard
    with val_summary_writer.as_default():
        for metric in model.metrics:
            tf.summary.scalar(f"{metric.name}", metric.result(), step=epoch)
    
    # reset all metrics
    model.reset_metrics()

    
    # Test image:
    sample = val_ds.take(1)
    for input, target in sample:
        prediction = model(input)
        # get l channel, target should be in shape (SIZE, SIZE, lab)
        l = tf.slice(target, begin=[0,0,0,0], size=[-1,-1,-1,1])
        prediction = tf.concat([l, prediction], axis=-1) # should be concatenating along last dimension
        prediction = tfio.experimental.color.lab_to_rgb(prediction)


    with img_test_summary_writer.as_default():
        tf.summary.image(name="generated_images", step=epoch, data=prediction, max_outputs=5)


    
    print("\n")