<a href="https://colab.research.google.com/github/stmeinert/Recolorization_IANN/blob/main/Iizuka_nb.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

Open questions that are not described in the paper:


*   How to transition from Conv2D- to Dense-Layer in Global Features Network?
*   Where is BatchNormalization applied?
*   What interpolation and cropping in Resizing-Layer?
*   What learning rate for AdaDelta?


# Util:

In [None]:
!rm -rf /drive/MyDrive/saved_model/model
!rm -rf ./logs/

# Imports

In [None]:
import tensorflow as tf
from tensorflow import keras
import tensorflow_datasets as tfds
import tqdm
!pip install tensorflow-io
import tensorflow_io as tfio
import time
import os 
import pickle
import keras.backend as K
import numpy as np
import zipfile

tf.keras.backend.clear_session()
BATCH_SIZE = 32

# Model

In [None]:
class LowLevelFeatNet(tf.keras.layers.Layer):

    def __init__(self, **kwargs): 
        super(LowLevelFeatNet, self).__init__(**kwargs)
        self.net_layers = []
        self.net_layers.append(tf.keras.layers.Conv2D(filters=64, kernel_size=(3,3), strides=(2,2), padding='same'))
        self.net_layers.append(tf.keras.layers.Activation(tf.nn.relu))
        self.net_layers.append(tf.keras.layers.BatchNormalization())
        self.net_layers.append(tf.keras.layers.Conv2D(filters=128, kernel_size=(3,3), strides=(1,1), padding='same'))
        self.net_layers.append(tf.keras.layers.Activation(tf.nn.relu))
        self.net_layers.append(tf.keras.layers.BatchNormalization())

        self.net_layers.append(tf.keras.layers.Conv2D(filters=128, kernel_size=(3,3), strides=(2,2), padding='same'))
        self.net_layers.append(tf.keras.layers.Activation(tf.nn.relu))
        self.net_layers.append(tf.keras.layers.BatchNormalization())
        self.net_layers.append(tf.keras.layers.Conv2D(filters=256, kernel_size=(3,3), strides=(1,1), padding='same'))
        self.net_layers.append(tf.keras.layers.Activation(tf.nn.relu))
        self.net_layers.append(tf.keras.layers.BatchNormalization())
        
        self.net_layers.append(tf.keras.layers.Conv2D(filters=256, kernel_size=(3,3), strides=(2,2), padding='same'))
        self.net_layers.append(tf.keras.layers.Activation(tf.nn.relu))
        self.net_layers.append(tf.keras.layers.BatchNormalization())
        self.net_layers.append(tf.keras.layers.Conv2D(filters=512, kernel_size=(3,3), strides=(1,1), padding='same'))
        self.net_layers.append(tf.keras.layers.Activation(tf.nn.relu))
        self.net_layers.append(tf.keras.layers.BatchNormalization())


    @tf.function
    def call(self, x, training=False):
        for layer in self.net_layers:
            x = layer(x, training=training)
        return x

    def get_config(self):
        config = super(LowLevelFeatNet, self).get_config()
        # config.update({
        #     "net_layers" : self.net_layers
        # })
        return config

    @classmethod
    def from_config(cls, config):
        return cls(**config)

In [None]:
class MidLevelFeatNet(tf.keras.layers.Layer):

    def __init__(self, **kwargs): 
        super(MidLevelFeatNet, self).__init__(**kwargs)
        self.net_layers = []
        self.net_layers.append(tf.keras.layers.Conv2D(filters=512, kernel_size=(3,3), strides=(1,1), padding='same'))
        self.net_layers.append(tf.keras.layers.Activation(tf.nn.relu))
        self.net_layers.append(tf.keras.layers.BatchNormalization())
        self.net_layers.append(tf.keras.layers.Conv2D(filters=256, kernel_size=(3,3), strides=(1,1), padding='same'))
        self.net_layers.append(tf.keras.layers.Activation(tf.nn.relu))
        self.net_layers.append(tf.keras.layers.BatchNormalization())


    @tf.function
    def call(self, x, training=False):
        for layer in self.net_layers:
            x = layer(x, training=training)
        return x

    def get_config(self):
        config = super(MidLevelFeatNet, self).get_config()
        # config.update({
        #     "net_layers" : self.net_layers
        # })
        return config
        
    @classmethod
    def from_config(cls, config):
        return cls(**config)

In [None]:
class GlobalFeatNet(tf.keras.layers.Layer):

    def __init__(self, **kwargs): 
        super(GlobalFeatNet, self).__init__(**kwargs)
        self.net_layers = []
        self.net_layers.append(tf.keras.layers.Conv2D(filters=512, kernel_size=(3,3), strides=(2,2), padding='same'))
        self.net_layers.append(tf.keras.layers.Activation(tf.nn.relu))
        self.net_layers.append(tf.keras.layers.BatchNormalization())
        self.net_layers.append(tf.keras.layers.Conv2D(filters=512, kernel_size=(3,3), strides=(1,1), padding='same'))
        self.net_layers.append(tf.keras.layers.Activation(tf.nn.relu))
        self.net_layers.append(tf.keras.layers.BatchNormalization())

        self.net_layers.append(tf.keras.layers.Conv2D(filters=512, kernel_size=(3,3), strides=(2,2), padding='same'))
        self.net_layers.append(tf.keras.layers.Activation(tf.nn.relu))
        self.net_layers.append(tf.keras.layers.BatchNormalization())
        self.net_layers.append(tf.keras.layers.Conv2D(filters=512, kernel_size=(3,3), strides=(1,1), padding='same'))
        self.net_layers.append(tf.keras.layers.Activation(tf.nn.relu))
        self.net_layers.append(tf.keras.layers.BatchNormalization())

        # NOTE: Paper does not specify how to transition from Conv2D- to Dense-Layer (Flatten causes number of variables to explode)
        self.net_layers.append(tf.keras.layers.GlobalMaxPooling2D())
        self.net_layers.append(tf.keras.layers.Dense(units=1024))
        self.net_layers.append(tf.keras.layers.Activation(tf.nn.relu))
        self.net_layers.append(tf.keras.layers.BatchNormalization())
        self.net_layers.append(tf.keras.layers.Dense(units=512))
        self.net_layers.append(tf.keras.layers.Activation(tf.nn.relu))
        self.net_layers.append(tf.keras.layers.BatchNormalization())
        self.net_layers.append(tf.keras.layers.Dense(units=256))
        self.net_layers.append(tf.keras.layers.Activation(tf.nn.relu))
        self.net_layers.append(tf.keras.layers.BatchNormalization())

    @tf.function
    def call(self, x, training=False):
        for layer in self.net_layers:
            x = layer(x, training=training)
        return x

    def get_config(self):
        config = super(GlobalFeatNet, self).get_config()
        # config.update({
        #     "net_layers" : self.net_layers
        # })
        return config
        
    @classmethod
    def from_config(cls, config):
        return cls(**config)

In [None]:
class FusionLayer(tf.keras.layers.Layer):

    def __init__(self, batch_size, **kwargs): 
        super(FusionLayer, self).__init__(**kwargs)
        self.batch_size = tf.constant(batch_size)

    @tf.function
    def call(self, x, training=False):
        """ Implementation of a similar approach can be found in https://github.com/baldassarreFe/deep-koalarization/blob/master/src/koalarization/fusion_layer.py """
        imgs, embs = x
        reshaped_shape = tf.stack([tf.constant(self.batch_size), tf.constant(imgs.shape[1]), tf.constant(imgs.shape[2]), tf.constant(embs.shape[1])])
        # reshaped_shape = imgs.shape[:3].concatenate(embs.shape[1])
        embs = tf.repeat(embs, imgs.shape[1] * imgs.shape[2])
        embs = tf.reshape(embs, reshaped_shape)
        return tf.concat([imgs, embs], axis=3)

    def get_config(self):
        config = super(FusionLayer, self).get_config()
        # config.update({
        #     "batch_size" : self.batch_size
        # })
        return config
        
    @classmethod
    def from_config(cls, config):
        return cls(**config)

In [None]:
class ColorizationNet(tf.keras.layers.Layer):

    def __init__(self, batch_size, **kwargs): 
        super(ColorizationNet, self).__init__(**kwargs)
        self.net_layers = []
        self.net_layers.append(FusionLayer(batch_size))
        self.net_layers.append(tf.keras.layers.Conv2D(128, kernel_size=(3,3), strides=(1,1), padding='same'))
        self.net_layers.append(tf.keras.layers.Activation(tf.nn.relu))
        self.net_layers.append(tf.keras.layers.BatchNormalization())

        self.net_layers.append(tf.keras.layers.UpSampling2D(size=(2,2), data_format='channels_last', interpolation='nearest'))
        self.net_layers.append(tf.keras.layers.Conv2D(64, kernel_size=(3,3), strides=(1,1), padding='same'))
        self.net_layers.append(tf.keras.layers.Activation(tf.nn.relu))
        self.net_layers.append(tf.keras.layers.BatchNormalization())
        self.net_layers.append(tf.keras.layers.Conv2D(64, kernel_size=(3,3), strides=(1,1), padding='same'))
        self.net_layers.append(tf.keras.layers.Activation(tf.nn.relu))
        self.net_layers.append(tf.keras.layers.BatchNormalization())

        self.net_layers.append(tf.keras.layers.UpSampling2D(size=(2,2), data_format='channels_last', interpolation='nearest'))
        self.net_layers.append(tf.keras.layers.Conv2D(32, kernel_size=(3,3), strides=(1,1), padding='same'))
        self.net_layers.append(tf.keras.layers.Activation(tf.nn.relu))
        self.net_layers.append(tf.keras.layers.BatchNormalization())
        self.net_layers.append(tf.keras.layers.Conv2D(2, kernel_size=(3,3), strides=(1,1), padding='same'))
        self.net_layers.append(tf.keras.layers.Activation(tf.nn.sigmoid))

    @tf.function
    def call(self, x, training=False):
        for layer in self.net_layers:
            x = layer(x, training=training)
        return x

    def get_config(self):
        config = super(ColorizationNet, self).get_config()
        # config.update({
        #     "net_layers" : self.net_layers
        # })
        return config
        
    @classmethod
    def from_config(cls, config):
        return cls(**config)

In [None]:
class IizukaRecolorizationModel(tf.keras.Model):

    def __init__(self, batch_size, **kwargs): 
        super(IizukaRecolorizationModel, self).__init__(**kwargs)

        self.rescale = tf.keras.layers.Resizing(224, 224, interpolation='nearest', crop_to_aspect_ratio=True)
        self.low = LowLevelFeatNet()
        self.mid = MidLevelFeatNet()
        self.glob = GlobalFeatNet()
        self.colorize = ColorizationNet(batch_size)
        self.upS = tf.keras.layers.UpSampling2D(size=(2,2), data_format='channels_last', interpolation='nearest')

        self.optimizer = tf.keras.optimizers.Adadelta(learning_rate=1.0)
        self.loss_function = tf.keras.losses.MeanSquaredError()
        self.metrics_list = [
                        tf.keras.metrics.Mean(name="loss"),
                        # tf.keras.metrics.CategoricalAccuracy(name="acc"),
                        # tf.keras.metrics.TopKCategoricalAccuracy(3,name="top-3-acc") 
                        ]

    @tf.function
    def call(self, x, training=False):
        re = self.rescale(x, training=training)
        l1 = self.low(re, training=training)
        g = self.glob(l1, training=training)

        l2 = self.low(x, training=training)
        m = self.mid(l2, training=training)

        c = self.colorize((m,g), training=training)
        out = self.upS(c, training=training)

        # bring the a-b-values from range [0,1] to [-128, 127]
        out = out * 255.0
        out = out - 128.0
        return out

    @tf.function
    def reset_metrics(self):
        
        for metric in self.metrics:
            metric.reset_states()
            
    @tf.function
    def train_step(self, data):
        
        x, targets = data

        # throw away L-dimension in target
        # TODO: do slicing with tensorflow so that function can have decorator
        targets = targets[:,:,:,-2:]
        
        with tf.GradientTape() as tape:
            predictions = self(x, training=True)
            
            loss = self.loss_function(targets, predictions)# + tf.reduce_sum(self.losses)
        
        gradients = tape.gradient(loss, self.trainable_variables)
        self.optimizer.apply_gradients(zip(gradients, self.trainable_variables))
        
        # update loss metric
        self.metrics[0].update_state(loss)
        
        # for all metrics except loss, update states (accuracy etc.)
        for metric in self.metrics[1:]:
            metric.update_state(targets,predictions)

        # Return a dictionary mapping metric names to current value
        return {m.name: m.result() for m in self.metrics}

    @tf.function
    def test_step(self, data):

        x, targets = data

        # throw away L-dimension in target
        # TODO: do slicing with tensorflow so that function can have decorator
        targets = targets[:,:,:,-2:]
        
        predictions = self(x, training=False)
        
        loss = self.loss_function(targets, predictions)# + tf.reduce_sum(self.losses)
        
        self.metrics[0].update_state(loss)
        
        for metric in self.metrics[1:]:
            metric.update_state(targets, predictions)

        return {m.name: m.result() for m in self.metrics}

    def get_config(self):
        config = super(IizukaRecolorizationModel, self).get_config()
        # config.update({
        #     "rescale" : self.rescale,
        #     "low" : self.low,
        #     "mid" : self.mid,
        #     "glob" : self.glob,
        #     "colorize" : self.colorize,
        #     "uS" : self.upS,
        #     "optimizer" : self.optimizer,
        #     "loss_function" : self.loss_function,
        #     "metrics_list" : self.metrics_list
        # })
        return config
        
    @classmethod
    def from_config(cls, config):
        return cls(**config)

# Model Summary

In [None]:
# testing first model input
myinput = tf.random.uniform(shape=(2,128,128,1), minval=0, maxval=None, dtype=tf.dtypes.float32, seed=None, name=None)
mymodel = IizukaRecolorizationModel(2)
mymodel(myinput)

print(mymodel.summary())

# Preprocessing

In [None]:

DS_NAME = "celeb_data_set_preprocessed_part_0_3"
# DS_NAME = "celeb_data_set_preprocessed_part_0_15_tiny_64_30000"

ZIP_DS_PATH = '/content/drive/MyDrive/' + DS_NAME + '.zip'
EXTRACT_DS_PATH = '/content/current/Dataset'

SIZE = (128,128)

#################################################
# Prepare data
#################################################

@tf.function
def resize(image):
    return tf.image.resize_with_pad(image, target_height=SIZE[0], target_width=SIZE[1], method=tf.image.ResizeMethod.BILINEAR)


@tf.function
def to_lab(image):
    # expects input to be normalized to [0;1]!!
    # output channels are [l,a,b]
    return tfio.experimental.color.rgb_to_lab(image)


@tf.function
def to_grayscale(image):
    # take l channel (size index starts at one^^)
    image = tf.slice(image, begin=[0, 0, 0], size=[-1, -1, 1])
    return image

@tf.function
def prepare_image_data(image_ds):
    # resize image to desired dimension, replace label with colored image
    image_ds = image_ds.map(lambda x: (resize(x['image']), resize(x['image'])))

    # normalize data to [0;1) for lab encoder
    image_ds = image_ds.map(lambda image, target: ((image/256), (target/256)))

    # convert image and target image to lab color space
    image_ds = image_ds.map(lambda image, target: (to_lab(image), to_lab(target)))

    # only take l channel of input tensor
    image_ds = image_ds.map(lambda image, target: (to_grayscale(image), target))

    # l in lab is in [0;100] -> normalize to [0;1]/[-1;1]?
    # ab are in range [-128;127]
    image_ds = image_ds.map(lambda image, target: ((image/50)-1, target))

    image_ds = image_ds.shuffle(1000).batch(BATCH_SIZE)#.prefetch(20)
    return image_ds

def prepare_validation_data(image_ds):
    """
    Same as for train and test data, but don't shuffle so you can the progress over same image in tensorboard
    """
    # resize image to desired dimension, replace label with colored image
    image_ds = image_ds.map(lambda x: (resize(x['image']), resize(x['image'])))

    # normalize data to [0;1) for lab encoder
    image_ds = image_ds.map(lambda image, target: ((image/256), (target/256)))

    # convert image and target image to lab color space
    image_ds = image_ds.map(lambda image, target: (to_lab(image), to_lab(target)))

    # only take l channel of input tensor
    image_ds = image_ds.map(lambda image, target: (to_grayscale(image), target))

    # l in lab is in [0;100] -> normalize to [-1;1]
    # ab are in range [-128;127]
    image_ds = image_ds.map(lambda image, target: ((image/50)-1, target))

    image_ds = image_ds.batch(BATCH_SIZE).prefetch(20)
    return image_ds

def unzip_and_load_ds():
    path = os.path.join(os.getcwd(), EXTRACT_DS_PATH, 'content', DS_NAME)

    # only extract again if path does not exist!
    if not os.path.exists(path):
      with zipfile.ZipFile(ZIP_DS_PATH, 'r') as zip_ref:
          zip_ref.extractall(EXTRACT_DS_PATH)

    return tf.data.experimental.load(path,compression= 'GZIP')

# Tensorboard

In [None]:
# load tensorboard extension
%load_ext tensorboard
# show tensorboard
%tensorboard --logdir logs/

# Main

In [None]:
# get Dataset in place

# size of training, test and validation sets
TRAIN_IMAGES = 500 * BATCH_SIZE
TEST_IMAGES = 1 * BATCH_SIZE
VAL_IMAGES = 50 * BATCH_SIZE
TRAIN_IMAGES = 1 if (TRAIN_IMAGES // BATCH_SIZE) == 0 else (TRAIN_IMAGES // BATCH_SIZE)
TEST_IMAGES = 1 if (TEST_IMAGES // BATCH_SIZE) == 0 else (TEST_IMAGES // BATCH_SIZE)
VAL_IMAGES = 1 if (VAL_IMAGES // BATCH_SIZE) == 0 else (VAL_IMAGES // BATCH_SIZE)
EPOCHS = 50

ds = unzip_and_load_ds()
test_ds = ds.take(TEST_IMAGES)
val_ds = ds.skip(TEST_IMAGES).take(VAL_IMAGES)
train_ds = ds.skip(TEST_IMAGES+VAL_IMAGES).take(TRAIN_IMAGES)

In [None]:
def load_model(model):
    print("Load model...")
    
    # # load optimizer weights
    # with open("saved_model/opt_weights.npy", "rb") as f:
    #     loaded_weights = np.load(f, allow_pickle=True)
    # grad_vars = model.trainable_weights
    # zero_grads = [tf.zeros_like(w) for w in grad_vars]
    # # Apply gradients which don't do nothing
    # model.optimizer.apply_gradients(zip(zero_grads, grad_vars))
    # # model.make_train_function()
    # model.optimizer.set_weights(loaded_weights)

    # load whole model
    model = tf.keras.models.load_model(model_save_loc, 
                                       compile=False, 
                                       custom_objects={
                                           "IizukaRecolorizationModel": IizukaRecolorizationModel, 
                                        #    "ColorizationNet" : ColorizationNet,
                                        #    "FusionLayer" : FusionLayer,
                                        #    "GlobalFeatNet" : GlobalFeatNet,
                                        #    "MidLevelFeatNet" : MidLevelFeatNet,
                                        #    "LowLevelFeatNet" : LowLevelFeatNet
                                           }
                                       )
    # model = tf.saved_model.load(model_save_loc)   # doesn't seem to work
    
    # # load only weights
    # model.load_weights(model_save_loc)

    # load epoch number
    with open("saved_model/epoch.dump", "rb") as f:
        epoch = pickle.load(f)

    return model, epoch

def save_model(model, epoch):
    # save the model for this epoch
    model.save(model_save_loc, save_format="tf", save_traces=False)
    # tf.saved_model.save(model, model_save_loc)    # doesn't seem to work

    # # save only weights
    # model.save_weights(model_save_loc, save_format='tf')

    # save epoch number
    with open("saved_model/epoch.dump", "wb") as f:
        pickle.dump(epoch, f)
    
    # # save optimizer weights
    # weight_values = model.optimizer.get_weights()
    # with open("saved_model/opt_weights.npy", "wb") as f:
    #     np.save(f, weight_values, allow_pickle=True)


print("GPU in use:")
!nvidia-smi -L
print("#############################")


# get the model in place
model = IizukaRecolorizationModel(BATCH_SIZE)
model_save_loc = "/content/drive/MyDrive/checkpoints"
ckpt = tf.train.Checkpoint(step=tf.Variable(0), optimizer=model.optimizer, net=model)
manager = tf.train.CheckpointManager(ckpt, model_save_loc, max_to_keep=3)
log_save_loc = "./logs"

ckpt.restore(manager.latest_checkpoint)
if manager.latest_checkpoint:
    print("Restored from {}".format(manager.latest_checkpoint))
else:
    print("Initializing from scratch.")
    #  clear all logs if the model is created newly and not loaded
    !rm -rf ./logs/


train_log_path = f"{log_save_loc}/train"
val_log_path = f"{log_save_loc}/val"
img_test_log_path = f"{log_save_loc}/img_test"
# log writer for training metrics
train_summary_writer = tf.summary.create_file_writer(train_log_path)
# log writer for validation metrics
val_summary_writer = tf.summary.create_file_writer(val_log_path)
# log writer for test images
test_summary_writer = tf.summary.create_file_writer(img_test_log_path)


# save first version validation images before training starts
print("Getting first example images from untrained model")
for input, target in tqdm.notebook.tqdm(test_ds.take(1),position=0, leave=True):
    prediction = model(input)
    # get l channel, target should be in shape (SIZE, SIZE, lab)
    l = tf.slice(target, begin=[0,0,0,0], size=[-1,-1,-1,1])
    prediction = tf.concat([l, prediction], axis=-1) # should be concatenating along last dimension
    prediction = tfio.experimental.color.lab_to_rgb(prediction)
    target = tfio.experimental.color.lab_to_rgb(target)
    input = (input+1)/2

    with test_summary_writer.as_default():
        tf.summary.image('Target', data=target, step=int(ckpt.step), max_outputs=16)
        tf.summary.image(name="Prediction", data=prediction, step=int(ckpt.step), max_outputs=16)
        tf.summary.image(name="Input", data=input, step=int(ckpt.step), max_outputs=16)

while int(ckpt.step) < EPOCHS:
    ckpt.step.assign_add(1)
    print(f"Epoch {int(ckpt.step)}:")
    start = time.time()

    ### Training:
    
    for input, target in tqdm.notebook.tqdm(train_ds, position=0, leave=True):
        metrics = model.train_step((input, target))

    end = time.time()
    
    # print the metrics
    print(f"Training took {end-start} seconds.")
    print([f"{key}: {value}" for (key, value) in zip(list(metrics.keys()), list(metrics.values()))])
    
    # logging the validation metrics to the log file which is used by tensorboard
    with train_summary_writer.as_default():
        for metric in model.metrics:
            tf.summary.scalar(f"{metric.name}", metric.result(), step=int(ckpt.step))
    
    # reset all metrics (requires a reset_metrics method in the model)
    model.reset_metrics()
    
    
    ### Validation:
    
    for input, target in tqdm.notebook.tqdm(val_ds,position=0, leave=True):
        metrics = model.test_step((input, target))
    
    print([f"val_{key}: {value}" for (key, value) in zip(list(metrics.keys()), list(metrics.values()))])
    
    # logging the validation metrics to the log file which is used by tensorboard
    with val_summary_writer.as_default():
        for metric in model.metrics:
            tf.summary.scalar(f"{metric.name}", metric.result(), step=int(ckpt.step))
    
    # reset all metrics
    model.reset_metrics()

    
    ### Test image:

    for input, target in tqdm.notebook.tqdm(test_ds.take(1),position=0, leave=True):
        prediction = model(input)
        
        # get l channel, target should be in shape (SIZE, SIZE, lab)
        l = tf.slice(target, begin=[0,0,0,0], size=[-1,-1,-1,1])
        prediction = tf.concat([l, prediction], axis=-1) # should be concatenating along last dimension
        prediction = tfio.experimental.color.lab_to_rgb(prediction)

        with test_summary_writer.as_default():
            tf.summary.image(name="Prediction", data=prediction, step=int(ckpt.step), max_outputs=16)

    print("\n")

    save_path = manager.save()
    print("Saved checkpoint for epoch {}: {}".format(int(ckpt.step), save_path))
