In [None]:
%matplotlib inline

import numpy as np
import pandas as pd 
from skimage import io, color
import skimage
import matplotlib.pyplot as plt
import cv2
import numpy as np

import os

for dirname, _, filenames in os.walk('/kaggle/input'):
    for filename in filenames:
        print(os.path.join(dirname, filename))

**UTILS**

In [None]:
def display(img):
    plt.figure()
    plt.set_cmap('gray')
    plt.imshow(img)
    plt.show()


def combineLAB(l, a, b):
    shape = (l.shape[0], l.shape[1], 3)
    zeros = np.zeros(shape)
    zeros[:, :, 0] = l
    zeros[:, :, 1] = a
    zeros[:, :, 2] = b
    return zeros


def combineAB(a, b):
    shape = (a.shape[0], b.shape[1], 2)
    zeros = np.zeros(shape)
    zeros[:, :, 0] = a
    zeros[:, :, 1] = b
    return zeros


def combineL_AB(l, ab):
    shape = (l.shape[0], l.shape[1], 3)
    zeros = np.zeros(shape)
    zeros[:, :, 0] = l
    zeros[:, :, 1] = ab[:, :, 0]
    zeros[:, :, 2] = ab[:, :, 1]
    return zeros


def make3channels(gray):
    shape = (gray.shape[0], gray.shape[1], 3)
    zeros = np.zeros(shape)
    zeros[:, :, 0] = gray
    zeros[:, :, 1] = gray
    zeros[:, :, 2] = gray
    return zeros


def get_l_from_gray(img_path):
    img = io.imread(img_path)
    img = skimage.transform.resize(img,(64,64))
    gray = color.rgb2gray(img)
    gray = make3channels(gray)
    lgray = color.rgb2lab(gray, illuminant='D50')[:, :, 0]
    return lgray


def get_ab_from_file(file):
    img = io.imread(file)
    ab = np.zeros((64, 64, 2))
    ab[:, :, 0] = img[:, :, 1]
    ab[:, :, 1] = img[:, :, 2]
    return ab


def lab_normal_image(path):
    l, ab = load_img_for_training(path)
    l, ab = (l-127.5)/127.5, (ab-127.5)/127.5
    return l, ab


def rgb_image(l, ab):
    shape = (l.shape[0],l.shape[1],3)
    img = np.zeros(shape)
    img[:,:,0] = l[:,:,0]
    img[:,:,1:]= ab
    img = img.astype('uint8')
    img = cv2.cvtColor(img, cv2.COLOR_LAB2RGB)
    return img


def load_img_for_training(img_path):
    img = io.imread(img_path)
    img = skimage.transform.resize(img,(64,64))
    lab = color.rgb2lab(img, illuminant='D50')
    l, a, b = lab[:, :, 0], lab[:, :, 1], lab[:, :, 2]
    ab = combineAB(a, b)
    lgray = get_l_from_gray(img_path)
    return lgray, ab


def save_ab_file(image, filepath):
    # add in 0zeros to its first component
    shape = (image.shape[0], image.shape[1], 3)
    new_ab_image = np.zeros(shape)
    new_ab_image[:, :, 1] = image[:, :, 0]
    new_ab_image[:, :, 2] = image[:, :, 1]
    save_file(new_ab_image, filepath)


def save_file(image, filepath):
    io.imsave(filepath, image)


def load_ab_image(path):
    img = io.imread(path)
    shape = (img.shape[0], img.shape[1], 2)
    ab = np.zeros(shape)
    ab[:, :, 0] = img[:, :, 1]
    ab[:, :, 1] = img[:, :, 2]
    return ab


**Data Loading**

In [None]:
# for dirname, _, filenames in os.walk('/kaggle/input'):
#     for filename in filenames:
#         print(os.path.join(dirname, filename))

In [None]:
def normalize(image):
    # convert image from range 0-256 to 
    #image = cv2.resize(image, dsize=(64, 64), interpolation=cv2.INTER_CUBIC)
    image = image/255
    return image

def unnormalize(image):
    image = (image*255)
    return image.astype('uint8')

In [None]:
gray_scale = np.load('/kaggle/input/image-colorization/l/gray_scale.npy')[:6000]
ab_scale = np.load('/kaggle/input/image-colorization/ab/ab/ab1.npy')[:6000]
print(gray_scale.shape)
print(ab_scale.shape)

In [None]:
index = 4579
l_sample,ab_sample = gray_scale[index].reshape((224,224,1)),ab_scale[index]
rgb_sample = rgb_image(l_sample,ab_sample)
display(rgb_sample)
display(l_sample[:,:,0])

In [None]:
x = np.zeros((6000,224,224,3), dtype='uint8')

for i in range(6000):
    l_sample = (gray_scale[i]).reshape((224,224,1))
    ab_sample = (ab_scale[i])
    x[i] = rgb_image(l_sample, ab_sample)
    
display(x[0])

In [None]:
print(x[0])

In [None]:
x = x/256.0

**Architecture**

In [None]:
from keras import *
from keras.layers import *
from keras.activations import *
from keras.optimizers import *
from matplotlib import pyplot as plt
from utils import *
from keras.initializers import RandomNormal, Zeros

In [None]:
from keras.layers import Input, Dense, Conv2D, MaxPooling2D, UpSampling2D, Conv2DTranspose, concatenate
from keras.models import Model
from keras.callbacks import EarlyStopping, ModelCheckpoint
from keras.optimizers import Adam
from keras.applications.vgg16 import VGG16

In [None]:
base_model = VGG16(input_shape = (224, 224, 3), # Shape of our images
include_top = False, # Leave out the last fully connected layer
weights = 'imagenet')

for layer in base_model.layers:
    layer.trainable = False

base_model.summary()
encoder = base_model

In [None]:
from keras.layers import *

dec_input = Input((7,7,512))

dec_x = ReLU()(dec_input)

layer_dim = [512, 256, 128, 64, 32]

for i in layer_dim:
    dec_x = Conv2DTranspose(i, (2, 2), strides=(2, 2), padding='same', activation='relu')(dec_x)
    dec_x = Conv2D(i, 2, padding='same', activation='relu')(dec_x)

    
dec_x = Conv2D(3,2,padding='same',activation='sigmoid')(dec_x)

decoder = Model(inputs=[dec_input], outputs=[dec_x])
decoder.compile(optimizer='adam', loss='mse', metrics=['acc'])
decoder.summary()

In [None]:
model_out = decoder(encoder.output)
encoder.trainable=False
model = Model(encoder.input, model_out)
model.compile(optimizer='adam', loss='mae', metrics=['acc'])
model.summary()

In [None]:
samples = x.shape[0]
epochs = 100
history = model.fit(x,x,validation_split=0.1,epochs=epochs,batch_size=64,)
model.save('model.h5')

In [None]:
h = history
plt.plot(h.history['acc'])
plt.plot(h.history['val_acc'])
plt.title('Model accuracy')
plt.show()

plt.plot(h.history['loss'])
plt.plot(h.history['val_loss'])
plt.title('Model Loss')
plt.show()

plt.plot(h.history['acc'])
plt.title('Model accuracy')
plt.show()

plt.plot(h.history['loss'])
plt.title('Model Loss')
plt.show()

In [None]:
def save_images(generator, samples):
    ab_values = generator.predict(samples)
    plt.figure()
    plt.set_cmap('gray')
    for i in range(ab_values.shape[0]):
        rgb = unnormalize(ab_values[i])
        display(rgb)
        display(samples[i])
        ax = plt.subplot(64, 64, i+1)
        im = ax.imshow(rgb)
        plt.tight_layout()
        plt.title(i)
    plt.show()
    plt.savefig('gan_generated_image.png')

samples = x[0:10]
save_images(model,samples)

In [None]:
enc_x = encoder.predict(samples)
print(enc_x.shape)

In [None]:
dec_x = decoder.predict(enc_x)
print(dec_x.shape)

In [None]:
display(samples[2])

In [None]:
display(dec_x[2])

In [None]:
encoder.save('enc.h5')

In [None]:
decoder.save('dec.h5')

In [None]:
encoder.layers[-1].output.shape.as_list()

In [None]:
decoder.layers[0].input.shape.as_list()

In [None]:
decoder.layers[0].input.shape.as_list() == encoder.layers[-1].output.shape.as_list()