<a href="https://colab.research.google.com/github/stmeinert/Recolorization_IANN/blob/main/train_models.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Imports:

In [None]:
!git clone https://github.com/stmeinert/Recolorization_IANN.git

import tensorflow as tf
import tensorflow_datasets as tfds
import tqdm
!pip install tensorflow-io
import tensorflow_io as tfio
import time
import os 

import sys
if not "/content/Recolorization_IANN" in sys.path:
    sys.path.append("/content/Recolorization_IANN")
from src.iizuka.iizuka_recolorization_model import IizukaRecolorizationModel

from src.data_util.data_pipeline_util import unzip_and_load_ds


tf.keras.backend.clear_session()

# Parameter:

In [None]:
BATCH_SIZE = 32

model = IizukaRecolorizationModel(BATCH_SIZE)

# DS_NAME = "celeb_data_set_preprocessed_part_0_3"
DS_NAME = "celeb_data_set_unbatch_30000"

ZIP_DS_PATH = '/content/drive/MyDrive/' + DS_NAME + '.zip'
EXTRACT_DS_PATH = '/content/current/Dataset'


# size of training, test and validation sets
TRAIN_IMAGES = 1000
TEST_IMAGES = 100
VAL_IMAGES = 50

EPOCHS = 50


MODEL_SAVE_LOCATION = "/content/drive/MyDrive/checkpoints"
LOG_SAVE_LOCATION = "./logs"

# Tensorboard:

In [None]:
# load tensorboard extension
%load_ext tensorboard
# show tensorboard
%tensorboard --logdir $LOG_SAVE_LOCATION

# Preprocessing:

In [None]:
@tf.function
def prepare_train_dataset(image_ds):
    return image_ds.shuffle(1000).batch(BATCH_SIZE).prefetch(buffer_size=tf.data.AUTOTUNE)

@tf.function
def prepare_test_dataset(image_ds):
    return image_ds.batch(BATCH_SIZE).prefetch(buffer_size=tf.data.AUTOTUNE)

### get Dataset in place

ds = unzip_and_load_ds(DS_NAME, EXTRACT_DS_PATH, ZIP_DS_PATH)
train_ds = prepare_train_dataset(ds.take(TRAIN_IMAGES))
test_ds = prepare_test_dataset(ds.skip(TRAIN_IMAGES).take(TEST_IMAGES))
val_ds = prepare_train_dataset(ds.skip(TRAIN_IMAGES+TEST_IMAGES).take(VAL_IMAGES))

# Main:

In [None]:
print("################ GPU in use: ################")
!nvidia-smi -L
print("#############################################")


ckpt = tf.train.Checkpoint(step=tf.Variable(0), optimizer=model.optimizer, net=model)
manager = tf.train.CheckpointManager(ckpt, MODEL_SAVE_LOCATION, max_to_keep=3)

ckpt.restore(manager.latest_checkpoint)
if manager.latest_checkpoint:
    print("Restored from {}".format(manager.latest_checkpoint))
else:
    print("Initializing from scratch.")
    !mkdir $MODEL_SAVE_LOCATION
    #  clear all logs if the model is created newly and not loaded
    !rm -rf $LOG_SAVE_LOCATION


train_log_path = f"{LOG_SAVE_LOCATION}/train"
val_log_path = f"{LOG_SAVE_LOCATION}/val"
img_test_log_path = f"{LOG_SAVE_LOCATION}/img_test"
# log writer for training metrics
train_summary_writer = tf.summary.create_file_writer(train_log_path)
# log writer for validation metrics
val_summary_writer = tf.summary.create_file_writer(val_log_path)
# log writer for test images
test_summary_writer = tf.summary.create_file_writer(img_test_log_path)

# save first version validation images before training starts
print("Getting first example images from untrained model")
for input, target in tqdm.notebook.tqdm(test_ds.take(1),position=0, leave=True):
    prediction = model(input)
    # get l channel, target should be in shape (SIZE, SIZE, lab)
    l = tf.slice(target, begin=[0,0,0,0], size=[-1,-1,-1,1])
    prediction = tf.concat([l, prediction], axis=-1) # should be concatenating along last dimension
    prediction = tfio.experimental.color.lab_to_rgb(prediction)
    target = tfio.experimental.color.lab_to_rgb(target)
    input = (input+1)/2

    with test_summary_writer.as_default():
        tf.summary.image('Target', data=target, step=int(ckpt.step), max_outputs=16)
        tf.summary.image(name="Prediction", data=prediction, step=int(ckpt.step), max_outputs=16)
        tf.summary.image(name="Input", data=input, step=int(ckpt.step), max_outputs=16)

while int(ckpt.step) < EPOCHS:
    ckpt.step.assign_add(1)
    print(f"Epoch {int(ckpt.step)}:")
    start = time.time()

    ### Training:
    
    for input, target in tqdm.notebook.tqdm(train_ds, position=0, leave=True):
        metrics = model.train_step((input, target))

    end = time.time()
    
    # print the metrics
    print(f"Training took {end-start} seconds.")
    print([f"{key}: {value}" for (key, value) in zip(list(metrics.keys()), list(metrics.values()))])
    
    # logging the validation metrics to the log file which is used by tensorboard
    with train_summary_writer.as_default():
        for metric in model.metrics:
            tf.summary.scalar(f"{metric.name}", metric.result(), step=int(ckpt.step))
    
    # reset all metrics (requires a reset_metrics method in the model)
    model.reset_metrics()
    
    
    ### Validation:
    
    for input, target in tqdm.notebook.tqdm(val_ds,position=0, leave=True):
        metrics = model.test_step((input, target))
    
    print([f"val_{key}: {value}" for (key, value) in zip(list(metrics.keys()), list(metrics.values()))])
    
    # logging the validation metrics to the log file which is used by tensorboard
    with val_summary_writer.as_default():
        for metric in model.metrics:
            tf.summary.scalar(f"{metric.name}", metric.result(), step=int(ckpt.step))
    
    # reset all metrics
    model.reset_metrics()

    
    ### Test image:

    for input, target in tqdm.notebook.tqdm(test_ds.take(1),position=0, leave=True):
        prediction = model(input)
        
        # get l channel, target should be in shape (SIZE, SIZE, lab)
        l = tf.slice(target, begin=[0,0,0,0], size=[-1,-1,-1,1])
        prediction = tf.concat([l, prediction], axis=-1) # should be concatenating along last dimension
        prediction = tfio.experimental.color.lab_to_rgb(prediction)

        with test_summary_writer.as_default():
            tf.summary.image(name="Prediction", data=prediction, step=int(ckpt.step), max_outputs=16)

    print("\n")

    save_path = manager.save()
    print("Saved checkpoint for epoch {}: {}".format(int(ckpt.step), save_path))