<a href="https://colab.research.google.com/github/ysikalie/colorization/blob/master/full_version_pre_extractedFeatures.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [0]:
from keras.applications.vgg16 import  VGG16,preprocess_input
from keras.layers import Conv2D, UpSampling2D, Input, Reshape,  concatenate
from keras.layers.normalization import BatchNormalization
from keras.callbacks import TensorBoard
from keras.models import  Model
from keras.layers.core import RepeatVector
from keras.preprocessing.image import ImageDataGenerator, array_to_img, img_to_array, load_img
from skimage.color import rgb2lab, lab2rgb, rgb2gray, gray2rgb
from skimage.io import imsave
import numpy as np
import os
import tensorflow as tf
from pickle import dump,load


# Get images
X = []
for filename in os.listdir('Test/test140'):
    X.append(img_to_array(load_img('Test/test140/'+filename, target_size=(256, 256))))
X = np.array(X, dtype=float)
Xtest = 1.0/255*X
Xtest_length = len(Xtest)
feature_file = 'test_features.pkl'
batch_size = 2

#Load weights
inception = VGG16(weights='vgg16_weights_tf_dim_ordering_tf_kernels.h5', include_top=True)
inception.graph = tf.get_default_graph()

embed_input = Input(shape=(1000,))

def extract_features(directory):
    features = dict()
    model = VGG16(weights='/home/zhaoyangze/PycharmProjects/ful_version/vgg16_weights_tf_dim_ordering_tf_kernels.h5', include_top=True)
    model = Model(inputs=model.inputs, outputs=model.output)
    for name in os.listdir(directory):
        filename = directory + '/' + name
        image = load_img(filename, target_size=(224, 224))
        image = img_to_array(image)
        image = image.reshape(1,image.shape[0], image.shape[1],image.shape[2])
        image = preprocess_input(image)
        feature = model.predict(image, verbose=0)
        image_id = name.split('.')[0]
        features[image_id] = feature
    return features

#Encoder
encoder_input = Input(shape=(256, 256, 1,))
encoder_output = Conv2D(64, (3,3), activation='relu', padding='same', strides=2)(encoder_input)
encoder_output = BatchNormalization()(encoder_output)
encoder_output = Conv2D(128, (3,3), activation='relu', padding='same')(encoder_output)
encoder_output = BatchNormalization()(encoder_output)
encoder_output = Conv2D(128, (3,3), activation='relu', padding='same', strides=2)(encoder_output)
encoder_output = BatchNormalization()(encoder_output)
encoder_output = Conv2D(256, (3,3), activation='relu', padding='same')(encoder_output)
encoder_output = BatchNormalization()(encoder_output)
encoder_output = Conv2D(256, (3,3), activation='relu', padding='same', strides=2)(encoder_output)
encoder_output = BatchNormalization()(encoder_output)
encoder_output = Conv2D(512, (3,3), activation='relu', padding='same')(encoder_output)
encoder_output = BatchNormalization()(encoder_output)
encoder_output = Conv2D(512, (3,3), activation='relu', padding='same')(encoder_output)
encoder_output = BatchNormalization()(encoder_output)
encoder_output = Conv2D(256, (3,3), activation='relu', padding='same')(encoder_output)
encoder_output = BatchNormalization()(encoder_output)

#Fusion
fusion_output = RepeatVector(32 * 32)(embed_input)
fusion_output = Reshape(([32, 32, 1000]))(fusion_output)
fusion_output = concatenate([encoder_output, fusion_output], axis=3)
fusion_output = Conv2D(256, (1, 1), activation='relu', padding='same')(fusion_output)
fusion_output = BatchNormalization()(fusion_output)

#Decoder
decoder_output = Conv2D(128, (3,3), activation='relu', padding='same')(fusion_output)
decoder_output = BatchNormalization()(decoder_output)
decoder_output = UpSampling2D((2, 2))(decoder_output)
decoder_output = Conv2D(64, (3,3), activation='relu', padding='same')(decoder_output)
decoder_output = BatchNormalization()(decoder_output)
decoder_output = UpSampling2D((2, 2))(decoder_output)
decoder_output = Conv2D(32, (3,3), activation='relu', padding='same')(decoder_output)
decoder_output = BatchNormalization()(decoder_output)
decoder_output = Conv2D(16, (3,3), activation='relu', padding='same')(decoder_output)
decoder_output = BatchNormalization()(decoder_output)
decoder_output = Conv2D(2, (3, 3), activation='tanh', padding='same')(decoder_output)
decoder_output = BatchNormalization()(decoder_output)
decoder_output = UpSampling2D((2, 2))(decoder_output)

model = Model(inputs=[encoder_input, embed_input], outputs=decoder_output)
model.summary()


def load_features(feature_file):
    features = load(open('test_features.pkl', 'rb'))
    features_value = features.values()
    features_list = []
    for feature in features_value:
        features_list.append(feature)
    feature_array = np.array(features_list)
    print(feature_array.shape)
    return feature_array


def create_inception_embedding(batch_count,feature_file):
    '''

    :param batch_number:
    :param feature_file:
    :return: return the corresponding features of the current batch images
    '''
    feature_arrray = load_features(feature_file)
    batch_number = batch_count*batch_size
    batch_feature = feature_arrray[batch_number:batch_number+batch_size]
    batch_feature = batch_feature.reshape(batch_size, 1000)
    # print(batch_feature.shape)
    return batch_feature
# Image transformer
datagen = ImageDataGenerator()

def train_image_a_b_gen(batch_size):
    batch_count = 0
    for batch in datagen.flow(Xtest, batch_size=batch_size):
        batch_feature = create_inception_embedding(batch_count, feature_file)
        lab_batch = rgb2lab(batch)
        X_batch = lab_batch[:, :, :, 0]
        X_batch = X_batch.reshape(X_batch.shape+(1,))
        Y_batch = 2*lab_batch[:, :, :, 1:] / 128 - 1
        if batch_count==(Xtest_length/batch_size-1):
            batch_count=0
        else:
            batch_count += 1
        yield ([X_batch, batch_feature], Y_batch)

# Image transformer
datagen = ImageDataGenerator()



from keras import callbacks
from keras.callbacks import ModelCheckpoint, ReduceLROnPlateau, EarlyStopping
if os.path.exists('full_version_weights.hdf5'):
    model.load_weights('full_version_weights.hdf5')
    print("checkpoint_loaded")

es = EarlyStopping(monitor='val_loss', mode='min', verbose=1, patience=20)
rlrp = ReduceLROnPlateau(monitor='val_loss', factor=0.1, patience=50, min_delta=1E-7, verbose=20)
filepath = "full_version_weights.hdf5"
checkpoint = ModelCheckpoint(filepath, monitor='val_loss', verbose=1, save_best_only=True, mode='auto',period=5)
#Train model
model.compile(optimizer='rmsprop', loss='mse')
model.fit_generator(train_image_a_b_gen(batch_size), epochs=100, steps_per_epoch=70)
model.save_weights("full_version_weights.hdf5")



color_me = []
for filename in os.listdir('predict'):
    color_me.append(img_to_array(load_img('predict/'+filename)))
color_me = np.array(color_me, dtype=float)
gray_me = gray2rgb(rgb2gray(1.0/255*color_me))
color_me_embed = create_inception_embedding(gray_me)
color_me = rgb2lab(color_me)[:,:,:,0]
color_me = color_me.reshape(color_me.shape+(1,))


# Test model
output = model.predict([color_me, color_me_embed])
output = (output +1)*64

# Output colorizations
for i in range(len(output)):
    cur = np.zeros((256, 256, 3))
    cur[:,:,0] = color_me[i][:,:,0]
    cur[:,:,1:] = output[i]
    imsave("result/img_"+str(i)+".png", lab2rgb(cur))