<a href="https://colab.research.google.com/github/toshNaik/Image-Colorizer/blob/master/ImageColorizer.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
import os
import math
import glob
import numpy as np
import tensorflow as tf
import matplotlib.pyplot as plt
from skimage import color, io, transform
from tensorflow.keras.models import Model

In [None]:
from google.colab import drive
drive.mount('/content/drive')

In [4]:
!tar -xf /content/drive/My\ Drive/datasets/color-net.tar

In [106]:
image_path_list = glob.glob('Train/*')
data = tf.data.Dataset.from_tensor_slices(image_path_list)

In [None]:
len(image_path_list)

In [6]:
AUTOTUNE = tf.data.experimental.AUTOTUNE

def tf_rgb2lab(image):
  im_shape = image.shape
  [image,] = tf.py_function(color.rgb2lab, [image], [tf.float32])
  image.set_shape(im_shape)
  return image

def preprocess(path):
  image = tf.io.read_file(path)
  image = tf.image.decode_jpeg(image, channels=3)
  image = tf.image.convert_image_dtype(image, tf.float32)
  image = tf.image.resize(image, [224, 224])
  image = tf_rgb2lab(image)
  L = image[:,:,0]/100.
  ab = image[:,:,1:]/128.
  input = tf.stack([L,L,L], axis=2)
  return input, ab

train_ds = data.shuffle(10000, reshuffle_each_iteration=True).map(preprocess, AUTOTUNE).repeat().batch(32).prefetch(AUTOTUNE)

In [10]:
vggmodel = tf.keras.applications.VGG16(include_top=False, weights='imagenet', input_shape=(224,224,3))
model = tf.keras.Sequential()
for i,layer in enumerate(vggmodel.layers):
  model.add(layer)
for layer in model.layers:
  layer.trainable=False

In [11]:
model.add(tf.keras.layers.Conv2D(256, (3,3), padding='same', activation='relu'))
model.add(tf.keras.layers.UpSampling2D((2,2)))
model.add(tf.keras.layers.Conv2D(128, (3,3), padding='same', activation='relu'))
model.add(tf.keras.layers.UpSampling2D((2,2)))
model.add(tf.keras.layers.Conv2D(64, (3,3), padding='same', activation='relu'))
model.add(tf.keras.layers.UpSampling2D((2,2)))
model.add(tf.keras.layers.Conv2D(16, (3,3), padding='same', activation='relu'))
model.add(tf.keras.layers.UpSampling2D((2,2)))
model.add(tf.keras.layers.Conv2D(8, (3,3), padding='same', activation='relu'))
model.add(tf.keras.layers.Conv2D(2, (3,3), padding='same', activation='tanh'))
model.add(tf.keras.layers.UpSampling2D((2,2)))

In [12]:
model.compile(optimizer='adam', loss='mse', metrics=['accuracy'])

In [None]:
from tensorflow.keras.callbacks import ModelCheckpoint, EarlyStopping
model_checkpoint_callback = ModelCheckpoint('/content/drive/My Drive/models/colorizer/color-net_weights.{epoch:02d}.hdf5',
                                            monitor='loss',
                                            mode='min')
early_stopping_callback = EarlyStopping(monitor='loss', mode='min', verbose=1, patience=5)
steps = len(image_path_list) // 32
history = model.fit(train_ds,
                    epochs=500,
                    steps_per_epoch=steps,
                    callbacks=[model_checkpoint_callback])

Epoch 1/500
Epoch 2/500
Epoch 3/500
Epoch 4/500
Epoch 5/500
Epoch 6/500
Epoch 7/500
Epoch 8/500
Epoch 9/500
Epoch 10/500
Epoch 11/500
Epoch 12/500
Epoch 13/500
Epoch 14/500
Epoch 15/500
Epoch 16/500
Epoch 17/500
Epoch 18/500
Epoch 19/500
Epoch 20/500
Epoch 21/500
Epoch 22/500
Epoch 23/500
Epoch 24/500
Epoch 25/500
Epoch 26/500
Epoch 27/500
Epoch 28/500
Epoch 29/500
Epoch 30/500
Epoch 31/500
Epoch 32/500
Epoch 33/500
Epoch 34/500
Epoch 35/500
Epoch 36/500
Epoch 37/500
Epoch 38/500
Epoch 39/500
Epoch 40/500
Epoch 41/500
Epoch 42/500
Epoch 43/500
Epoch 44/500
Epoch 45/500
Epoch 46/500
Epoch 47/500
Epoch 48/500
Epoch 49/500
Epoch 50/500
Epoch 51/500
Epoch 52/500
Epoch 53/500
Epoch 54/500
Epoch 55/500
Epoch 56/500
Epoch 57/500
Epoch 58/500
Epoch 59/500
Epoch 60/500
Epoch 61/500
Epoch 62/500
Epoch 63/500
Epoch 64/500
Epoch 65/500
Epoch 66/500
Epoch 67/500
Epoch 68/500
Epoch 69/500
Epoch 70/500
Epoch 71/500
Epoch 72/500
Epoch 73/500
Epoch 74/500
Epoch 75/500
Epoch 76/500
Epoch 77/500
Epoch 78

In [12]:
from tensorflow.keras.models import load_model
model = load_model('/content/drive/My Drive/models/colorizer/color-net_weights.328.hdf5')

In [None]:
test_images = glob.glob('Test/*')
len(test_images)

In [None]:
image = io.imread(test_images[256])
image = transform.resize(image, [224, 224])
plt.imshow(image)

In [None]:
image = color.rgb2gray(image)
print(np.amax(image))
plt.imshow(image, cmap='gray')

In [None]:
L = np.stack((image,image,image), axis=2)
input = np.expand_dims(L, axis=0)
input.shape

In [126]:
ab = model.predict(input)

In [None]:
ab = np.squeeze(ab, axis=0)
print(ab.shape)

In [None]:
image = np.expand_dims(image, axis=-1)
image.shape

In [None]:
ori = np.concatenate((image*100, ab*128), axis=-1)
print(ori.shape)
rgb = color.lab2rgb(ori)

In [None]:
plt.imshow(rgb)