<a href="https://colab.research.google.com/github/macLeHoang/BTL-AI-AI-Colorization/blob/main/Ai4.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
# gpu_info = !nvidia-smi
# gpu_info = '\n'.join(gpu_info)
# if gpu_info.find('failed') >= 0:
#   print('Select the Runtime → "Change runtime type" menu to enable a GPU accelerator, ')
#   print('and then re-execute this cell.')
# else:
#   print(gpu_info)

In [None]:
from google.colab import drive
drive.mount('./gdrive')

In [None]:
!pip install git+https://github.com/qubvel/classification_models.git

In [None]:
!wget http://images.cocodataset.org/zips/test2017.zip
!unzip -qq -o test2017.zip
!rm test2017.zip

In [None]:
from skimage.color import rgb2lab, lab2rgb
from skimage import transform
import numpy as np
from PIL import Image
from tqdm import tqdm
import datetime
from matplotlib import pyplot as plt
import numpy as np
import os
import random

import tensorflow as tf
from classification_models.tfkeras import Classifiers

# DATA LOADER

In [None]:
path = '/content/test2017'
c = 0
for f in tqdm(os.listdir(path)):
  fPath = os.path.join(path, f)
  img = Image.open(fPath)
  if img.mode != 'RGB':
    os.remove(fPath)
    c += 1

print()
print(f'Remove {c} gray images')
print(f'Remain {len(os.listdir(path))} images')

In [None]:
SIZE = (256, 256)

In [None]:
dataset = tf.data.Dataset.list_files('/content/test2017/*.jpg')

def process(path):
  path_ = bytes.decode(path.numpy())

  img = Image.open(path_)
  img = img.resize(SIZE, Image.BICUBIC)
  
  # Slightly augmentation
  randNumber = random.random() # create random number in range [0, 1]
  if randNumber > 0.7: # do augmentation if the number created is greater than 0.5
    anotherRandNumber = random.random() 
    if anotherRandNumber < 0.5:
      img = img.transpose(Image.FLIP_LEFT_RIGHT) # flip vertical
      img = np.array(img)

    elif 0.5 < anotherRandNumber:
      img = img.transpose(Image.FLIP_TOP_BOTTOM) # flip horizontal
      img = np.array(img)

    # elif 0.7 < anotherRandNumber:
    #   alpha = random.randint(-30, 30) 
    #   img = img.rotate(alpha, expand = False) # rotate
    #   img = np.array(img)

    # elif anotherRandNumber > 0.9:
    #   img = np.array(img)
    #   sx = random.uniform(-0.2, 0.2) #create random number in range [-0.2, 0.2]
    #   sy = random.uniform(-0.2, 0.2)
    #   matrix = np.asarray([[1, sx, 0], [sy, 1, 0], [0, 0, 1]])
    #   affine = transform.AffineTransform(matrix)
    #   img = transform.warp(img, affine.params) # shere
  else:
    img = np.array(img)

  labImg = rgb2lab(img)
  lChannel = labImg[:, :, 0:1] / 50.0 - 1 # convert L channel to range [-1, 1]
  abChannels = labImg[:, :, 1:] / 110.0 # convert ab channel to range [-1, 1]

  return tf.convert_to_tensor(lChannel, dtype = tf.float32), \
         tf.convert_to_tensor(abChannels, dtype = tf.float32)

In [None]:
dataset = dataset.map(lambda x: tf.py_function(process, [x], [tf.float32, tf.float32]))
dataset = dataset.batch(16)

In [None]:
# l, ab = next(iter(dataset))

# plt.imshow(l[0, :, :, 0].numpy(),cmap = 'gray')
# np.max(ab[0, :, :, 0].numpy()), np.min(ab[0, :, :, 0].numpy()), np.max(ab[0, :, :, 1].numpy()), np.min(ab[0, :, :, 1].numpy())

# GENERATOR

In [None]:
def process(input_, nfilters_1 = 1024, nfilters_2 = 512, ksize = (3, 3), strides = 1, last_relu = True):
  x = tf.keras.layers.ZeroPadding2D()(input_)
  x = tf.keras.layers.Conv2D(nfilters_1, ksize, strides)(x)
  x = tf.keras.layers.ReLU()(x)
  x = tf.keras.layers.ZeroPadding2D()(x)
  x = tf.keras.layers.Conv2D(nfilters_2, ksize, strides)(x)

  if last_relu:
    x = tf.keras.layers.ReLU()(x)

  return x

In [None]:
def decoder(input_, concat, nfilters = 1024, ksize = (1, 1), strides = 1):
  x = tf.keras.layers.Conv2D(nfilters, ksize, strides)(input_)
  x = tf.keras.layers.ReLU()(x)
  out = tf.nn.depth_to_space(x, 2)
  x = tf.keras.layers.BatchNormalization(momentum=0.1, epsilon=1e-5)(out)
  x = tf.keras.layers.Concatenate()([concat, x])
  x = tf.keras.layers.ReLU()(x)
  
  return x

In [None]:
class GENERATOR(tf.keras.models.Model):
  def __init__(self):
    super().__init__()
    self.ResNet18, _ = Classifiers.get('resnet18')
  
  def __call__(self):
    resnet18 = self.ResNet18(input_shape = (256, 256, 3), 
                             weights = 'imagenet',
                             include_top = False)
  
    encoder_1 = resnet18.get_layer('bn0').output # encode 128 - shape = (None, 128, 128, 64)
    encoder_2 = resnet18.get_layer('stage2_unit1_bn1').output # encode 64 - shape = (None, 64, 64, 64)
    encoder_3 = resnet18.get_layer('stage3_unit1_bn1').output # encode 32 - shape = (None, 32, 32, 128)
    encoder_4 = resnet18.get_layer('stage4_unit1_bn1').output # encode 16 - shape = (None, 16, 16, 256)

    last_layer = resnet18.layers[-1].output
    bridge = process(last_layer)

    x = decoder(bridge, encoder_4) # shape = (None, 16, 16, 512)
    x = process(x, 512, 512)
    x = decoder(x, encoder_3, 1024) # shape = (None, 32, 32, 384)
    x = process(x, 384, 384)
    x = decoder(x, encoder_2, 768) # shape = (None, 64, 64, 256)
    x = process(x, 256, 256)
    x = decoder(x, encoder_1, 512) # shape = (None, 128, 128, 192)
    x = process(x, 96, 96)
    x = tf.keras.layers.Conv2D(384, (1, 1), 1)(x)
    x = tf.keras.layers.ReLU()(x)
    x = tf.nn.depth_to_space(x, 2) # shape = (256, 256, 96)
    x = tf.keras.layers.Concatenate()([x, resnet18.input[:, :, :, 0:1]]) # shape = (None, 256, 256, 97)
    res = process(x, 97, 97, last_relu = False)
    x = tf.keras.layers.Add()([x, res])
    x = tf.keras.layers.ReLU()(x)
    x = tf.keras.layers.Conv2D(2, (1, 1), 1)(x)
    x = tf.keras.layers.Activation('tanh')(x)

    return tf.keras.models.Model(inputs = resnet18.input, outputs = x)

In [None]:
gModel = GENERATOR()
generator = gModel()

In [None]:
# generator.summary()

In [None]:
# tf.keras.utils.plot_model(generator, show_shapes=True, dpi=64)

In [None]:
def generative_loss(target, predict, discriminator_output_of_predict, LAMBDA = 100.0):
  l1_loss = tf.reduce_mean(tf.abs(predict - target))

  generative_loss = tf.keras.losses.BinaryCrossentropy(from_logits = True)(tf.ones_like(discriminator_output_of_predict),
                                                                           discriminator_output_of_predict)
  total_loss = generative_loss + LAMBDA*l1_loss
  return total_loss, generative_loss

# DISCRIMINATOR

In [None]:
initializer = tf.random_normal_initializer(0.0, 0.02)

In [None]:
def down_sample(input, nums_filters, kernel_size = (4, 4), strides = 2, use_batchnorm = True, **kwags):
  x = input
  x = tf.keras.layers.Conv2D(nums_filters, 
                             kernel_size = kernel_size,
                             strides = strides,
                             kernel_initializer = initializer,
                             use_bias = False,
                             padding = 'same',
                             **kwags)(x)
  if use_batchnorm:
    x = tf.keras.layers.BatchNormalization(momentum=0.1, epsilon=1e-5)(x)
  
  x = tf.keras.layers.LeakyReLU()(x)
  return x 

In [None]:
class DISCRIMINATOR(tf.keras.models.Model):
  def __init__(self):
    super().__init__
    self.cpart= tf.keras.layers.Input(shape = (None, None, 1)) # use as condition input
    self.predict = tf.keras.layers.Input(shape = (None, None, 2))
  
  def __call__(self):
    x = tf.keras.layers.Concatenate(axis = -1)([self.cpart, self.predict]) # shape = None, 256, 256, 3

    x = down_sample(x, 64, use_batchnorm = False) # shape = None, 128, 128, 64
    x = down_sample(x, 128) # shape = None, 64, 64, 128
    x = down_sample(x, 256) # shape = None, 32, 32, 256

    x = tf.keras.layers.ZeroPadding2D(((1, 1), (1, 1)))(x) # shape = None, 34, 34, 256
    x = down_sample(x, 512, strides = 1) # shape = None, 31, 31, 512

    x = tf.keras.layers.ZeroPadding2D(((1, 1), (1, 1)))(x) # shape = None, 33, 33, 512

    # Each pixel in the feature map looks up to 70*70 patch of the origin image
    x = tf.keras.layers.Conv2D(1, kernel_size = (4, 4), 
                               strides = 1, 
                               kernel_initializer = initializer)(x) # shape = None, 30, 30, 1
    return tf.keras.models.Model(inputs = [self.cpart, self.predict], outputs = x)

In [None]:
dModel = DISCRIMINATOR()
discriminator = dModel()

In [None]:
def discriminative_loss(target, predict):
  posLoss = tf.keras.losses.BinaryCrossentropy(from_logits = True)(tf.ones_like(target),
                                                                   target)
  
  negLoss = tf.keras.losses.BinaryCrossentropy(from_logits = True)(tf.zeros_like(predict),
                                                                   predict)
  return posLoss + negLoss

# Pre-Trained

In [None]:
# def pretrained_loss(target, predict):
#   l1 = tf.reduce_mean(tf.abs(target - predict))
#   return l1

In [None]:
# pretrain_opt = tf.keras.optimizers.Adam(1e-4)

In [None]:
# log_dir = '/content/gdrive/MyDrive/AI_color_weights/Logs/Pretrain_Logs/'
# summary_writer = tf.summary.create_file_writer(
#   log_dir + datetime.datetime.now().strftime("%Y%m%d-%H%M%S"))

In [None]:
# gen_ckpt_list_dir = '/content/gdrive/MyDrive/AI_color_weights/Generator_v2'
# gen_latest_checkpoint = tf.train.latest_checkpoint(gen_ckpt_list_dir)
# print(gen_latest_checkpoint)
# generator.load_weights(gen_latest_checkpoint)

/content/gdrive/MyDrive/AI_color_weights/Generator_v2/pre_generator-20220614-153134


<tensorflow.python.training.tracking.util.CheckpointLoadStatus at 0x7f1410061c50>

In [None]:
# @tf.function
# def pre_step(L_target, ab_target, step):
#   with tf.GradientTape() as preTape:
#     L = tf.repeat(L_target, repeats = 3, axis = 3)
#     ab_predict = generator(L, training = True)
#     l1 = pretrained_loss(ab_target, ab_predict)
  
#   grads = preTape.gradient(l1, generator.trainable_variables)
#   pretrain_opt.apply_gradients(zip(grads,
#                                    generator.trainable_variables))
  
#   with summary_writer.as_default():
#     tf.summary.scalar('L1_pretrained_Loss', l1, step = step//10)

In [None]:
# def pre_fit(dataset, epochs):
#   for epoch in range(epochs):
#     for idx, (L, ab) in tqdm(dataset.enumerate()):
#       pre_step(L, ab, idx)
    
#     gen_ckpt_dir = '/content/gdrive/MyDrive/AI_color_weights/Generator_v2'
#     gen_ckpt_name = 'pre_generator-' + datetime.datetime.now().strftime('%Y%m%d-%H%M%S')
#     generator.save_weights(os.path.join(gen_ckpt_dir, gen_ckpt_name))

In [None]:
# pre_fit(dataset, 1)

100%|██████████| 2540/2540 [36:51<00:00,  1.15it/s]


# Train Steps

In [None]:
log_dir = '/content/gdrive/MyDrive/AI_color_weights/Logs/Train_Logs/'
summary_writer = tf.summary.create_file_writer(
  log_dir + datetime.datetime.now().strftime("%Y%m%d-%H%M%S"))

In [None]:
gOpt = tf.keras.optimizers.Adam(2e-4, beta_1 = 0.5)
dOpt = tf.keras.optimizers.Adam(2e-4, beta_1 = 0.5)

In [None]:
# gen_ckpt_list_dir = '/content/gdrive/MyDrive/AI_color_weights/Generator_v2'
# gen_lastest_checkpoint = tf.train.latest_checkpoint(gen_ckpt_list_dir)
# print(gen_lastest_checkpoint)
# generator.load_weights(gen_lastest_checkpoint)

# dis_ckpt_list_dir = '/content/gdrive/MyDrive/AI_color_weights/Discriminator_v2'
# dis_lastest_checkpoint = tf.train.latest_checkpoint(dis_ckpt_list_dir)
# print(dis_lastest_checkpoint)
# discriminator.load_weights(dis_lastest_checkpoint)

In [None]:
@tf.function
def train_step(Limg, ab_target, step):
  with tf.GradientTape() as gTape, tf.GradientTape() as dTape:
    L = tf.repeat(Limg, 3, axis = 3)
    ab_predict = generator(L, training = True) # ab Generative image

    d_predict = discriminator([Limg, ab_predict], training = True) # Discriminator output of predict
    d_target = discriminator([Limg, ab_target], training = True) # Discriminator output of target

    # Discriminative Loss
    d_loss = discriminative_loss(d_target, d_predict)

    # Generative loss
    g_loss, g_gan_loss = generative_loss(ab_target, ab_predict, d_predict)

  gGradients = gTape.gradient(g_loss, generator.trainable_variables)
  dGradients = dTape.gradient(d_loss, discriminator.trainable_variables)

  gOpt.apply_gradients(zip(gGradients,
                           generator.trainable_variables))
  dOpt.apply_gradients(zip(dGradients,
                           discriminator.trainable_variables))
  with summary_writer.as_default():
    tf.summary.scalar('Total Gen loss', g_loss, step = step//10)
    tf.summary.scalar('Gan loss', g_gan_loss, step = step//10)
    tf.summary.scalar('Total Disc loss', d_loss, step = step//10)
    # tf.summary.scalar('Positive Disc loss', pos_d_loss, step = step//10)
    # tf.summary.scalar('Negative Disc loss', neg_d_loss, step = step//10)

In [None]:
def fit(dataset, epochs):
  for epoch in range(epochs):
    for idx, (L, ab) in tqdm(dataset.enumerate()):
      train_step(L, ab, idx)

    # save generator weights and discriminator weights after each epochs
    gen_ckpt_dir = '/content/gdrive/MyDrive/AI_color_weights/Generator_v2'
    gen_ckpt_name = 'generator-' + datetime.datetime.now().strftime('%Y%m%d-%H%M%S')
    generator.save_weights(os.path.join(gen_ckpt_dir, gen_ckpt_name))

    dis_ckpt_dir = '/content/gdrive/MyDrive/AI_color_weights/Discriminator_v2'
    dis_ckpt_name = 'discriminator-' + datetime.datetime.now().strftime('%Y%m%d-%H%M%S')
    discriminator.save_weights(os.path.join(dis_ckpt_dir, dis_ckpt_name))

In [None]:
# fit 
fit(dataset, 1)

In [None]:
%load_ext tensorboard
%tensorboard --logdir /content/gdrive/MyDrive/AI_color_weights/Logs/Train_Logs/20220627-014237