In [11]:
from keras.layers import Conv2D, UpSampling2D, Input
from keras.models import Sequential
from keras.preprocessing.image import ImageDataGenerator, img_to_array, load_img
from skimage.color import rgb2lab, lab2rgb
from skimage.transform import resize
from skimage.io import imsave
import numpy as np
import tensorflow as tf

In [12]:
path = 'downloads/'

In [13]:
train_datagen = ImageDataGenerator(rescale = 1./255)

In [14]:
#resizing the images
train = train_datagen.flow_from_directory(path, 
                                         target_size = (256,256),
                                         batch_size = 128,
                                         class_mode = None)


Found 614 images belonging to 1 classes.


In [15]:
#converting the images from RGB format to Lab format
"""
iterate on each rgb image to convert it to lab format.
assign L channel to X vector.
assign A and B channels to Y vector.
essentially, what we are doing is training model on 'L' values i.e black and white (or grayscale) values
and associated A and B values. 
Thus, during testing, model should predict 'Y' vector based on input 'X' vector image (i.e a grayscale image) 
"""

X = []
Y = []

for img in train[0]:
    try:
        lab = rgb2lab(img)
        X.append(lab[:,:,0])
        Y.append(lab[:,:,1:] / 128) #A and B values range between -127 to 128, we divide by 128 for normalization.
    except:
        print("error")
        
X = np.array(X)
Y = np.array(Y)
#print(X.shape)
#print(X)
X = X.reshape(X.shape+(1,))
print(X.shape)
#print(X)
print(Y.shape)

(128, 256, 256, 1)
(128, 256, 256, 2)


In [22]:
#encoder part
model = Sequential()
model.add(Conv2D(64, (3,3), activation = 'relu', padding = 'same',strides=2, input_shape = (256,256,1)))
model.add(Conv2D(128, (3,3), activation = 'relu', padding = 'same'))
model.add(Conv2D(128, (3,3), activation = 'relu', padding = 'same', strides=2))
model.add(Conv2D(256, (3,3), activation = 'relu', padding = 'same'))
model.add(Conv2D(256, (3,3), activation = 'relu', padding = 'same', strides=2))
model.add(Conv2D(512, (3,3), activation = 'relu', padding = 'same'))
model.add(Conv2D(512, (3,3), activation = 'relu', padding = 'same'))
model.add(Conv2D(256, (3,3), activation = 'relu', padding = 'same'))

#decoder part
#NOTE:for last part used tanh because we  need to colourize image in this layer with 2 filters, A and B(i.e vector Y).
#A and B have values ranging between -1 and 1. So tanh is used as its range is also between -1 and 1.

model.add(Conv2D(128, (3,3), activation = 'relu', padding = 'same'))
model.add(UpSampling2D((2,2)))
model.add(Conv2D(64, (3,3), activation = 'relu', padding = 'same'))
model.add(UpSampling2D((2,2)))
model.add(Conv2D(32, (3,3), activation = 'relu', padding = 'same'))
model.add(Conv2D(16, (3,3), activation = 'relu', padding = 'same'))
model.add(Conv2D(2, (3,3), activation = 'tanh', padding = 'same'))
model.add(UpSampling2D((2,2)))
model.compile(optimizer = 'adam', loss = 'mse', metrics = ['accuracy'])
model.summary()

Model: "sequential_3"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
conv2d_39 (Conv2D)           (None, 128, 128, 64)      640       
_________________________________________________________________
conv2d_40 (Conv2D)           (None, 128, 128, 128)     73856     
_________________________________________________________________
conv2d_41 (Conv2D)           (None, 64, 64, 128)       147584    
_________________________________________________________________
conv2d_42 (Conv2D)           (None, 64, 64, 256)       295168    
_________________________________________________________________
conv2d_43 (Conv2D)           (None, 32, 32, 256)       590080    
_________________________________________________________________
conv2d_44 (Conv2D)           (None, 32, 32, 512)       1180160   
_________________________________________________________________
conv2d_45 (Conv2D)           (None, 32, 32, 512)      

In [26]:
model.fit(X, Y, validation_split = 0.1, epochs = 20, batch_size = 16)
model.save('image_recolourization.model')

Epoch 1/20
Epoch 2/20
Epoch 3/20
Epoch 4/20
Epoch 5/20
Epoch 6/20
Epoch 7/20
Epoch 8/20
Epoch 9/20
Epoch 10/20
Epoch 11/20
Epoch 12/20
Epoch 13/20
Epoch 14/20
Epoch 15/20
Epoch 16/20
Epoch 17/20
Epoch 18/20
Epoch 19/20
Epoch 20/20
INFO:tensorflow:Assets written to: image_recolourization.model/assets


In [27]:
model = tf.keras.models.load_model('image_recolourization.model', custom_objects = None, compile = True)
img1_color = []

img1 = img_to_array(load_img('test/test_image.jpeg'))
img1 = resize(img1, (256,256))
img1_color.append(img1)

img1_color = np.array(img1_color, dtype = float)
img1_color = rgb2lab(1.0/255*img1_color)[:,:,:,0]
img1_color = img1_color.reshape(img1_color.shape+(1,))

In [28]:
output1 = model.predict(img1_color)
output1 = output1*128

result = np.zeros((256,256,3))
result[:,:,0] = img1_color[0][:,:,0]
result[:,:,1:] = output1[0]


imsave("result.png", lab2rgb(result))

