<a href="https://colab.research.google.com/github/stmeinert/Recolorization_IANN/blob/main/Iizuka_nb.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

Open questions that are not described in the paper:


*   How to transition from Conv2D- to Dense-Layer in Global Features Network?
*   Where is BatchNormalization applied?



In [1]:
import tensorflow as tf
import tensorflow_datasets as tfds
import tqdm

In [11]:
class LowLevelFeatNet(tf.keras.layers.Layer):

    def __init__(self): 
        super(LowLevelFeatNet, self).__init__()
        self.net_layers = []
        self.net_layers.append(tf.keras.layers.Conv2D(filters=64, kernel_size=(3,3), strides=(2,2), padding='same'))
        self.net_layers.append(tf.keras.layers.Activation(tf.nn.relu))
        self.net_layers.append(tf.keras.layers.BatchNormalization())
        self.net_layers.append(tf.keras.layers.Conv2D(filters=128, kernel_size=(3,3), strides=(1,1), padding='same'))
        self.net_layers.append(tf.keras.layers.Activation(tf.nn.relu))
        self.net_layers.append(tf.keras.layers.BatchNormalization())
        self.net_layers.append(tf.keras.layers.Conv2D(filters=128, kernel_size=(3,3), strides=(2,2), padding='same'))
        self.net_layers.append(tf.keras.layers.Activation(tf.nn.relu))
        self.net_layers.append(tf.keras.layers.BatchNormalization())
        self.net_layers.append(tf.keras.layers.Conv2D(filters=256, kernel_size=(3,3), strides=(1,1), padding='same'))
        self.net_layers.append(tf.keras.layers.Activation(tf.nn.relu))
        self.net_layers.append(tf.keras.layers.BatchNormalization())
        self.net_layers.append(tf.keras.layers.Conv2D(filters=156, kernel_size=(3,3), strides=(2,2), padding='same'))
        self.net_layers.append(tf.keras.layers.Activation(tf.nn.relu))
        self.net_layers.append(tf.keras.layers.BatchNormalization())
        self.net_layers.append(tf.keras.layers.Conv2D(filters=512, kernel_size=(3,3), strides=(1,1), padding='same'))
        self.net_layers.append(tf.keras.layers.Activation(tf.nn.relu))
        self.net_layers.append(tf.keras.layers.BatchNormalization())


    @tf.function
    def call(self, x, training=False):
        for layer in self.net_layers:
            x = layer(x, training=training)
        return x


In [12]:
class MidLevelFeatNet(tf.keras.layers.Layer):

    def __init__(self): 
        super(MidLevelFeatNet, self).__init__()
        self.net_layers = []
        self.net_layers.append(tf.keras.layers.Conv2D(filters=512, kernel_size=(3,3), strides=(1,1), padding='same'))
        self.net_layers.append(tf.keras.layers.Activation(tf.nn.relu))
        self.net_layers.append(tf.keras.layers.BatchNormalization())
        self.net_layers.append(tf.keras.layers.Conv2D(filters=256, kernel_size=(3,3), strides=(1,1), padding='same'))
        self.net_layers.append(tf.keras.layers.Activation(tf.nn.relu))
        self.net_layers.append(tf.keras.layers.BatchNormalization())


    @tf.function
    def call(self, x, training=False):
        for layer in self.net_layers:
            x = layer(x, training=training)
        return x

In [13]:
class GlobalFeatNet(tf.keras.layers.Layer):

    def __init__(self): 
        super(GlobalFeatNet, self).__init__()
        self.net_layers = []
        self.net_layers.append(tf.keras.layers.Conv2D(filters=512, kernel_size=(3,3), strides=(2,2), padding='same'))
        self.net_layers.append(tf.keras.layers.Activation(tf.nn.relu))
        self.net_layers.append(tf.keras.layers.BatchNormalization())
        self.net_layers.append(tf.keras.layers.Conv2D(filters=512, kernel_size=(3,3), strides=(1,1), padding='same'))
        self.net_layers.append(tf.keras.layers.Activation(tf.nn.relu))
        self.net_layers.append(tf.keras.layers.BatchNormalization())
        self.net_layers.append(tf.keras.layers.Conv2D(filters=512, kernel_size=(3,3), strides=(2,2), padding='same'))
        self.net_layers.append(tf.keras.layers.Activation(tf.nn.relu))
        self.net_layers.append(tf.keras.layers.BatchNormalization())
        self.net_layers.append(tf.keras.layers.Conv2D(filters=512, kernel_size=(3,3), strides=(1,1), padding='same'))
        self.net_layers.append(tf.keras.layers.Activation(tf.nn.relu))
        self.net_layers.append(tf.keras.layers.BatchNormalization())
        # NOTE: Paper does not specify how to transition from Conv2D- to Dense-Layer (chose Flatten for simplicity)
        self.net_layers.append(tf.keras.layers.Flatten())
        self.net_layers.append(tf.keras.layers.Dense(units=1024))
        self.net_layers.append(tf.keras.layers.Activation(tf.nn.relu))
        self.net_layers.append(tf.keras.layers.BatchNormalization())
        self.net_layers.append(tf.keras.layers.Dense(units=512))
        self.net_layers.append(tf.keras.layers.Activation(tf.nn.relu))
        self.net_layers.append(tf.keras.layers.BatchNormalization())
        self.net_layers.append(tf.keras.layers.Dense(units=256))
        self.net_layers.append(tf.keras.layers.Activation(tf.nn.relu))
        self.net_layers.append(tf.keras.layers.BatchNormalization())

    @tf.function
    def call(self, x, training=False):
        for layer in self.net_layers:
            x = layer(x, training=training)
        return x

In [15]:
class FusionLayer(tf.keras.layers.Layer):

    def __init__(self): 
        super(FusionLayer, self).__init__()
        

    @tf.function
    def call(self, x, training=False):
        """ Implementation of a similar approach can be found in https://github.com/baldassarreFe/deep-koalarization/blob/master/src/koalarization/fusion_layer.py """
        imgs, embs = x
        reshaped_shape = imgs.shape[:3].concatenate(embs.shape[1])
        embs = tf.repeat(embs, imgs.shape[1] * imgs.shape[2])
        embs = tf.reshape(embs, reshaped_shape)
        return tf.concat([imgs, embs], axis=3)

In [14]:
class ColorizationNet(tf.keras.layers.Layer):

    def __init__(self): 
        super(ColorizationNet, self).__init__()
        self.net_layers = []

    @tf.function
    def call(self, x, training=False):
        for layer in self.net_layers:
            x = layer(x, training=training)
        return x

In [19]:
class IizukaRecolorizationModel(tf.keras.Model):

    def __init__(self): 
        super(IizukaRecolorizationModel, self).__init__()

        self.low_1 = LowLevelFeatNet()
        self.low_2 = LowLevelFeatNet()
        self.mid = MidLevelFeatNet()
        self.glob = GlobalFeatNet()
        self.colorize = ColorizationNet()
        self.fusion = FusionLayer()

        self.optimizer = tf.keras.optimizers.Adadelta
        self.loss_function = tf.keras.losses.MeanSquaredError
        self.metrics_list = [
                        tf.keras.metrics.Mean(name="loss"),
                        tf.keras.metrics.CategoricalAccuracy(name="acc"),
                        tf.keras.metrics.TopKCategoricalAccuracy(3,name="top-3-acc") 
                        ]

    @tf.function
    def call(self, x):
        pass

    
    def reset_metrics(self):
        
        for metric in self.metrics:
            metric.reset_states()
            
    @tf.function
    def train_step(self, data):
        
        x, targets = data
        
        with tf.GradientTape() as tape:
            predictions = self(x, training=True)
            
            loss = self.loss_function(targets, predictions) + tf.reduce_sum(self.losses)
        
        gradients = tape.gradient(loss, self.trainable_variables)
        self.optimizer.apply_gradients(zip(gradients, self.trainable_variables))
        
        # update loss metric
        self.metrics[0].update_state(loss)
        
        # for all metrics except loss, update states (accuracy etc.)
        for metric in self.metrics[1:]:
            metric.update_state(targets,predictions)

        # Return a dictionary mapping metric names to current value
        return {m.name: m.result() for m in self.metrics}

    @tf.function
    def test_step(self, data):

        x, targets = data
        
        predictions = self(x, training=False)
        
        loss = self.loss_function(targets, predictions) + tf.reduce_sum(self.losses)
        
        self.metrics[0].update_state(loss)
        
        for metric in self.metrics[1:]:
            metric.update_state(targets, predictions)

        return {m.name: m.result() for m in self.metrics}

In [17]:

# Dummy datasets
ds = tfds.load("fashion_mnist", as_supervised=True)

train_ds = ds["train"]
val_ds = ds["test"]

data_augmentation_model = tf.keras.Sequential([
    tf.keras.layers.RandomFlip("horizontal"),
    tf.keras.layers.Resizing(32,32),
    tf.keras.layers.RandomCrop(28,28)
])

def augment(x):
    return data_augmentation_model(x)

train_ds = train_ds.map(lambda x,y: (augment(x)/255, tf.one_hot(y, 10, dtype=tf.float32)),\
                        num_parallel_calls=tf.data.AUTOTUNE).shuffle(5000).batch(32).prefetch(tf.data.AUTOTUNE)

val_ds = val_ds.map(lambda x,y: (x/255, tf.one_hot(y, 10, dtype=tf.float32)),\
                    num_parallel_calls=tf.data.AUTOTUNE).shuffle(5000).batch(32).prefetch(tf.data.AUTOTUNE)


[1mDownloading and preparing dataset fashion_mnist/3.0.1 (download: 29.45 MiB, generated: 36.42 MiB, total: 65.87 MiB) to /root/tensorflow_datasets/fashion_mnist/3.0.1...[0m


Dl Completed...: 0 url [00:00, ? url/s]

Dl Size...: 0 MiB [00:00, ? MiB/s]

Extraction completed...: 0 file [00:00, ? file/s]






0 examples [00:00, ? examples/s]

Shuffling and writing examples to /root/tensorflow_datasets/fashion_mnist/3.0.1.incomplete4J411F/fashion_mnist-train.tfrecord


  0%|          | 0/60000 [00:00<?, ? examples/s]

0 examples [00:00, ? examples/s]

Shuffling and writing examples to /root/tensorflow_datasets/fashion_mnist/3.0.1.incomplete4J411F/fashion_mnist-test.tfrecord


  0%|          | 0/10000 [00:00<?, ? examples/s]

[1mDataset fashion_mnist downloaded and prepared to /root/tensorflow_datasets/fashion_mnist/3.0.1. Subsequent calls will reuse this data.[0m


In [None]:
""" This is main """

# load tensorboard extension
%load_ext tensorboard

# to clear all logs use this line:
!rm -rf ./logs/

train_log_path = f"logs/train"
val_log_path = f"logs/val"
# log writer for training metrics
train_summary_writer = tf.summary.create_file_writer(train_log_path)
# log writer for validation metrics
val_summary_writer = tf.summary.create_file_writer(val_log_path)

# Init model and params
model = IizukaRecolorizationModel()
EPOCHS = 5


for epoch in range(EPOCHS):
    
    print(f"Epoch {epoch}:")
    
    # Training:
    
    for data in tqdm.notebook.tqdm(train_ds,position=0, leave=True):
        metrics = model.train_step(data)
    
    # print the metrics
    print([f"{key}: {value}" for (key, value) in zip(list(metrics.keys()), list(metrics.values()))])
    
    # logging the validation metrics to the log file which is used by tensorboard
    with train_summary_writer.as_default():
        for metric in model.metrics:
            tf.summary.scalar(f"{metric.name}", metric.result(), step=epoch)
    
    # reset all metrics (requires a reset_metrics method in the model)
    model.reset_metrics()
    
    
    # Validation:
    
    for data in val_ds:
        metrics = model.test_step(data)
    
    print([f"val_{key}: {value}" for (key, value) in zip(list(metrics.keys()), list(metrics.values()))])
    
    # logging the validation metrics to the log file which is used by tensorboard
    with val_summary_writer.as_default():
        for metric in model.metrics:
            tf.summary.scalar(f"{metric.name}", metric.result(), step=epoch)
    
    # reset all metrics
    model.reset_metrics()
    
    print("\n")