<a href="https://colab.research.google.com/github/stmeinert/Recolorization_IANN/blob/main/Iizuka_nb.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

Open questions that are not described in the paper:


*   How to transition from Conv2D- to Dense-Layer in Global Features Network?
*   Where is BatchNormalization applied?
*   What interpolation and cropping in Resizing-Layer?



In [30]:
import tensorflow as tf
import tensorflow_datasets as tfds
import tqdm

In [31]:
class LowLevelFeatNet(tf.keras.layers.Layer):

    def __init__(self): 
        super(LowLevelFeatNet, self).__init__()
        self.net_layers = []
        self.net_layers.append(tf.keras.layers.Conv2D(filters=64, kernel_size=(3,3), strides=(2,2), padding='same'))
        self.net_layers.append(tf.keras.layers.Activation(tf.nn.relu))
        self.net_layers.append(tf.keras.layers.BatchNormalization())
        self.net_layers.append(tf.keras.layers.Conv2D(filters=128, kernel_size=(3,3), strides=(1,1), padding='same'))
        self.net_layers.append(tf.keras.layers.Activation(tf.nn.relu))
        self.net_layers.append(tf.keras.layers.BatchNormalization())
        self.net_layers.append(tf.keras.layers.Conv2D(filters=128, kernel_size=(3,3), strides=(2,2), padding='same'))
        self.net_layers.append(tf.keras.layers.Activation(tf.nn.relu))
        self.net_layers.append(tf.keras.layers.BatchNormalization())
        self.net_layers.append(tf.keras.layers.Conv2D(filters=256, kernel_size=(3,3), strides=(1,1), padding='same'))
        self.net_layers.append(tf.keras.layers.Activation(tf.nn.relu))
        self.net_layers.append(tf.keras.layers.BatchNormalization())
        self.net_layers.append(tf.keras.layers.Conv2D(filters=156, kernel_size=(3,3), strides=(2,2), padding='same'))
        self.net_layers.append(tf.keras.layers.Activation(tf.nn.relu))
        self.net_layers.append(tf.keras.layers.BatchNormalization())
        self.net_layers.append(tf.keras.layers.Conv2D(filters=512, kernel_size=(3,3), strides=(1,1), padding='same'))
        self.net_layers.append(tf.keras.layers.Activation(tf.nn.relu))
        self.net_layers.append(tf.keras.layers.BatchNormalization())


    @tf.function
    def call(self, x, training=False):
        for layer in self.net_layers:
            x = layer(x, training=training)
        return x


In [32]:
class MidLevelFeatNet(tf.keras.layers.Layer):

    def __init__(self): 
        super(MidLevelFeatNet, self).__init__()
        self.net_layers = []
        self.net_layers.append(tf.keras.layers.Conv2D(filters=512, kernel_size=(3,3), strides=(1,1), padding='same'))
        self.net_layers.append(tf.keras.layers.Activation(tf.nn.relu))
        self.net_layers.append(tf.keras.layers.BatchNormalization())
        self.net_layers.append(tf.keras.layers.Conv2D(filters=256, kernel_size=(3,3), strides=(1,1), padding='same'))
        self.net_layers.append(tf.keras.layers.Activation(tf.nn.relu))
        self.net_layers.append(tf.keras.layers.BatchNormalization())


    @tf.function
    def call(self, x, training=False):
        for layer in self.net_layers:
            x = layer(x, training=training)
        return x

In [33]:
class GlobalFeatNet(tf.keras.layers.Layer):

    def __init__(self): 
        super(GlobalFeatNet, self).__init__()
        self.net_layers = []
        self.net_layers.append(tf.keras.layers.Conv2D(filters=512, kernel_size=(3,3), strides=(2,2), padding='same'))
        self.net_layers.append(tf.keras.layers.Activation(tf.nn.relu))
        self.net_layers.append(tf.keras.layers.BatchNormalization())
        self.net_layers.append(tf.keras.layers.Conv2D(filters=512, kernel_size=(3,3), strides=(1,1), padding='same'))
        self.net_layers.append(tf.keras.layers.Activation(tf.nn.relu))
        self.net_layers.append(tf.keras.layers.BatchNormalization())

        self.net_layers.append(tf.keras.layers.Conv2D(filters=512, kernel_size=(3,3), strides=(2,2), padding='same'))
        self.net_layers.append(tf.keras.layers.Activation(tf.nn.relu))
        self.net_layers.append(tf.keras.layers.BatchNormalization())
        self.net_layers.append(tf.keras.layers.Conv2D(filters=512, kernel_size=(3,3), strides=(1,1), padding='same'))
        self.net_layers.append(tf.keras.layers.Activation(tf.nn.relu))
        self.net_layers.append(tf.keras.layers.BatchNormalization())

        # NOTE: Paper does not specify how to transition from Conv2D- to Dense-Layer (Flatten causes number of variables to explode)
        self.net_layers.append(tf.keras.layers.GlobalMaxPooling2D())
        self.net_layers.append(tf.keras.layers.Dense(units=1024))
        self.net_layers.append(tf.keras.layers.Activation(tf.nn.relu))
        self.net_layers.append(tf.keras.layers.BatchNormalization())
        self.net_layers.append(tf.keras.layers.Dense(units=512))
        self.net_layers.append(tf.keras.layers.Activation(tf.nn.relu))
        self.net_layers.append(tf.keras.layers.BatchNormalization())
        self.net_layers.append(tf.keras.layers.Dense(units=256))
        self.net_layers.append(tf.keras.layers.Activation(tf.nn.relu))
        self.net_layers.append(tf.keras.layers.BatchNormalization())

    @tf.function
    def call(self, x, training=False):
        for layer in self.net_layers:
            x = layer(x, training=training)
        return x

In [34]:
class FusionLayer(tf.keras.layers.Layer):

    def __init__(self): 
        super(FusionLayer, self).__init__()

    @tf.function
    def call(self, x, training=False):
        """ Implementation of a similar approach can be found in https://github.com/baldassarreFe/deep-koalarization/blob/master/src/koalarization/fusion_layer.py """
        imgs, embs = x
        reshaped_shape = imgs.shape[:3].concatenate(embs.shape[1])
        embs = tf.repeat(embs, imgs.shape[1] * imgs.shape[2])
        embs = tf.reshape(embs, reshaped_shape)
        return tf.concat([imgs, embs], axis=3)

In [35]:
class ColorizationNet(tf.keras.layers.Layer):

    def __init__(self): 
        super(ColorizationNet, self).__init__()
        self.net_layers = []
        self.net_layers.append(FusionLayer())
        self.net_layers.append(tf.keras.layers.Conv2D(128, kernel_size=(3,3), strides=(1,1), padding='same'))
        self.net_layers.append(tf.keras.layers.Activation(tf.nn.relu))
        self.net_layers.append(tf.keras.layers.BatchNormalization())

        self.net_layers.append(tf.keras.layers.UpSampling2D(size=(2,2), data_format='channels_last', interpolation='nearest'))
        self.net_layers.append(tf.keras.layers.Conv2D(64, kernel_size=(3,3), strides=(1,1), padding='same'))
        self.net_layers.append(tf.keras.layers.Activation(tf.nn.relu))
        self.net_layers.append(tf.keras.layers.BatchNormalization())
        self.net_layers.append(tf.keras.layers.Conv2D(64, kernel_size=(3,3), strides=(1,1), padding='same'))
        self.net_layers.append(tf.keras.layers.Activation(tf.nn.relu))
        self.net_layers.append(tf.keras.layers.BatchNormalization())

        self.net_layers.append(tf.keras.layers.UpSampling2D(size=(2,2), data_format='channels_last', interpolation='nearest'))
        self.net_layers.append(tf.keras.layers.Conv2D(32, kernel_size=(3,3), strides=(1,1), padding='same'))
        self.net_layers.append(tf.keras.layers.Activation(tf.nn.relu))
        self.net_layers.append(tf.keras.layers.BatchNormalization())
        self.net_layers.append(tf.keras.layers.Conv2D(2, kernel_size=(3,3), strides=(1,1), padding='same'))
        self.net_layers.append(tf.keras.layers.Activation(tf.nn.sigmoid))

    @tf.function
    def call(self, x, training=False):
        for layer in self.net_layers:
            x = layer(x, training=training)
        return x

In [36]:
class IizukaRecolorizationModel(tf.keras.Model):

    def __init__(self): 
        super(IizukaRecolorizationModel, self).__init__()

        self.rescale = tf.keras.layers.Resizing(224, 224, interpolation='nearest', crop_to_aspect_ratio=True)
        self.low = LowLevelFeatNet()
        self.mid = MidLevelFeatNet()
        self.glob = GlobalFeatNet()
        self.colorize = ColorizationNet()
        self.upS = tf.keras.layers.UpSampling2D(size=(2,2), data_format='channels_last', interpolation='nearest')

        self.optimizer = tf.keras.optimizers.Adadelta
        self.loss_function = tf.keras.losses.MeanSquaredError
        self.metrics_list = [
                        tf.keras.metrics.Mean(name="loss"),
                        # tf.keras.metrics.CategoricalAccuracy(name="acc"),
                        # tf.keras.metrics.TopKCategoricalAccuracy(3,name="top-3-acc") 
                        ]

    @tf.function
    def call(self, x):
        re = self.rescale(x)
        l1 = self.low(re)
        g = self.glob(l1)

        l2 = self.low(x)
        m = self.mid(l2)

        c = self.colorize((m,g))
        return self.upS(c)

    
    def reset_metrics(self):
        
        for metric in self.metrics:
            metric.reset_states()
            
    @tf.function
    def train_step(self, data):
        
        x, targets = data
        
        with tf.GradientTape() as tape:
            predictions = self(x, training=True)
            
            loss = self.loss_function(targets, predictions) + tf.reduce_sum(self.losses)
        
        gradients = tape.gradient(loss, self.trainable_variables)
        self.optimizer.apply_gradients(zip(gradients, self.trainable_variables))
        
        # update loss metric
        self.metrics[0].update_state(loss)
        
        # for all metrics except loss, update states (accuracy etc.)
        for metric in self.metrics[1:]:
            metric.update_state(targets,predictions)

        # Return a dictionary mapping metric names to current value
        return {m.name: m.result() for m in self.metrics}

    @tf.function
    def test_step(self, data):

        x, targets = data
        
        predictions = self(x, training=False)
        
        loss = self.loss_function(targets, predictions) + tf.reduce_sum(self.losses)
        
        self.metrics[0].update_state(loss)
        
        for metric in self.metrics[1:]:
            metric.update_state(targets, predictions)

        return {m.name: m.result() for m in self.metrics}

In [37]:
# testing first model input
myinput = tf.random.uniform(shape=(1,256,256,1), minval=0, maxval=None, dtype=tf.dtypes.float32, seed=None, name=None)
print(myinput)
mymodel = IizukaRecolorizationModel()
print(mymodel(myinput))
print(mymodel.summary())

tf.Tensor(
[[[[0.9834081 ]
   [0.07496083]
   [0.24160588]
   ...
   [0.32004106]
   [0.30351353]
   [0.365389  ]]

  [[0.1136843 ]
   [0.86566937]
   [0.02430964]
   ...
   [0.76642203]
   [0.6408502 ]
   [0.7014291 ]]

  [[0.913852  ]
   [0.14544249]
   [0.12423933]
   ...
   [0.7304437 ]
   [0.9731121 ]
   [0.1085248 ]]

  ...

  [[0.6241597 ]
   [0.9112793 ]
   [0.71760976]
   ...
   [0.99834406]
   [0.1007278 ]
   [0.3495505 ]]

  [[0.27207565]
   [0.7413845 ]
   [0.6636585 ]
   ...
   [0.87155867]
   [0.05006909]
   [0.85351336]]

  [[0.33585382]
   [0.47191715]
   [0.70817494]
   ...
   [0.08633208]
   [0.1073488 ]
   [0.9850364 ]]]], shape=(1, 256, 256, 1), dtype=float32)
tf.Tensor(
[[[[0.49994278 0.50001585]
   [0.49994278 0.50001585]
   [0.49985743 0.5000104 ]
   ...
   [0.49987265 0.49994752]
   [0.49992558 0.49995065]
   [0.49992558 0.49995065]]

  [[0.49994278 0.50001585]
   [0.49994278 0.50001585]
   [0.49985743 0.5000104 ]
   ...
   [0.49987265 0.49994752]
   [0.49992558