In [1]:
from keras.layers import Conv2D, Conv2DTranspose, UpSampling2D, Conv3D, UpSampling3D, MaxPooling2D
from keras.layers import Activation, Dense, Dropout, Flatten, InputLayer
from keras.layers.normalization import BatchNormalization
from keras.callbacks import TensorBoard
from keras.models import Sequential
from keras.utils import plot_model
from keras.preprocessing.image import ImageDataGenerator, array_to_img, img_to_array, load_img
from skimage.color import rgb2lab, lab2rgb, rgb2gray
from skimage.io import imsave
import numpy as np
import os
import random
import tensorflow as tf
from PIL import Image
import PIL

Using TensorFlow backend.


In [2]:
class Colorize:
    def __init__(self, model_type, img_size, batch_size):
        self.train_dir = 'Dataset/' + model_type + '_train_' + str(img_size) + '/'
        self.test_dir = 'Dataset/' + model_type + '_test_' + str(img_size) + '/'
        self.gold_dir = 'Dataset/' + model_type + '_gold_' + str(img_size) + '/'
        self.model_type = model_type
        self.test_filenames = []
        self.img_size = img_size
        self.batch_size = batch_size
        self.model_dict = {'dense': self.model_dense, 'up': self.model_up, 'maxpool': self.model_maxpool}
        
    def set_train_images(self):
        self.X_train = []
        for filename in os.listdir(self.train_dir):
            if filename != '.DS_Store':
                img = Image.open(self.train_dir+filename)
                self.X_train.append(img_to_array(load_img(self.train_dir + filename)))
        self.eval = np.array(self.X_train, dtype=float)[-10:]
        self.X_train = 1.0/255*np.array(self.X_train, dtype=float)

            
    def define_model_dense(self):
        self.model_dense = Sequential()
        self.model_dense.add(InputLayer(input_shape=(256, 256, 1)))
        self.model_dense.add(Conv2D(8, (3, 3), activation='relu', padding='same'))
        self.model_dense.add(Conv2D(16, (3, 3), activation='relu', padding='same'))
        self.model_dense.add(Dense(1, activation='sigmoid'))
        self.model_dense.add(Conv2D(32, (3, 3), activation='relu', padding='same'))
        self.model_dense.add(Conv2D(32, (3, 3), activation='relu', padding='same'))
        self.model_dense.add(Dense(1, activation='sigmoid'))
        self.model_dense.add(Conv2D(16, (3, 3), activation='relu', padding='same'))
        self.model_dense.add(Conv2D(8, (3, 3), activation='relu', padding='same'))
        self.model_dense.add(Conv2D(2, (3, 3), activation='sigmoid', padding='same'))
        return 'dense'
    
    def define_model_up(self):
        self.model_up = Sequential()
        self.model_up.add(InputLayer(input_shape=(self.img_size, self.img_size, 1)))
        self.model_up.add(Conv2D(8, (3, 3), activation='relu', padding='same', strides=2))
        self.model_up.add(Conv2D(8, (3, 3), activation='relu', padding='same'))
        self.model_up.add(Conv2D(16, (3, 3), activation='relu', padding='same'))
        self.model_up.add(Conv2D(16, (3, 3), activation='relu', padding='same', strides=2))
        self.model_up.add(Conv2D(32, (3, 3), activation='relu', padding='same'))
        self.model_up.add(Conv2D(32, (3, 3), activation='relu', padding='same', strides=2))
        self.model_up.add(UpSampling2D((2, 2)))
        self.model_up.add(Conv2D(32, (3, 3), activation='relu', padding='same'))
        self.model_up.add(UpSampling2D((2, 2)))
        self.model_up.add(Conv2D(16, (3, 3), activation='relu', padding='same'))
        self.model_up.add(UpSampling2D((2, 2)))
        self.model_up.add(Conv2D(2, (3, 3), activation='tanh', padding='same'))
        return 'up'
        
    def define_model_maxpool(self):
        self.model_maxpool = Sequential()
        self.model_maxpool.add(InputLayer(input_shape=(256, 256, 1)))
        self.model_maxpool.add(Conv2D(8, (3, 3), activation='relu', padding='same'))
        self.model_maxpool.add(Conv2D(16, (3, 3), activation='relu', padding='same'))
        self.model_maxpool.add(MaxPooling2D(pool_size=(2, 2), padding='valid'))
        self.model_maxpool.add(UpSampling2D((2, 2)))
        self.model_maxpool.add(Conv2D(32, (3, 3), activation='relu', padding='same'))
        self.model_maxpool.add(Conv2D(32, (3, 3), activation='relu', padding='same'))
        self.model_maxpool.add(MaxPooling2D(pool_size=(2, 2), padding='valid'))
        self.model_maxpool.add(UpSampling2D((2, 2)))
        self.model_maxpool.add(Conv2D(16, (3, 3), activation='relu', padding='same'))
        self.model_maxpool.add(Conv2D(8, (3, 3), activation='relu', padding='same'))
        self.model_maxpool.add(Conv2D(2, (3, 3), activation='tanh', padding='same'))
        return 'maxpool'
        
    def compile_model(self, model_key, model_type):
        model = self.model_dict[model_key]
        model.compile(optimizer='rmsprop', loss='mse')
        model.summary()
        plot_model(model, to_file=model_key+'_'+self.model_type+'.png', show_shapes=True, show_layer_names=False)
    
    def image_a_b_gen(self, batch_size):
        datagen = ImageDataGenerator(
        shear_range=0.2,
        zoom_range=0.2,
        rotation_range=20,
        horizontal_flip=True)
        for batch in datagen.flow(self.X_train, batch_size=batch_size):
            lab_batch = rgb2lab(batch)
            X_batch = lab_batch[:,:,:,0]
            Y_batch = lab_batch[:,:,:,1:] / 128
            yield (X_batch.reshape(X_batch.shape+(1,)), Y_batch)

    def train_model(self, model_key, epochs, steps_per_epoch):
        tensorboard = TensorBoard(log_dir="first_run")
        self.model.fit_generator(self.image_a_b_gen(self.batch_size), callbacks=[tensorboard], epochs=epochs, steps_per_epoch=steps_per_epoch)
        model_json = self.model.to_json()
        with open(model_key+'_'+self.model_type+".json", "w") as json_file:
            json_file.write(model_json)
            self.model.save_weights(model_key+'_'+self.model_type+".h5")
            
    def evaluate(self, model_key):
        X_test = rgb2lab(1.0/255*self.eval)[:,:,:,0]
        X_test = X_test.reshape(X_test.shape+(1,))
        Y_test = rgb2lab(1.0/255*self.eval)[:,:,:,1:] / 128
        print (self.model_dict[model_key].evaluate(X_test, Y_test, batch_size=self.batch_size))
        
    def set_test_images(self):
        self.X_test = []
        for filename in os.listdir(self.test_dir):
            if filename != '.DS_Store':
                self.test_filenames.append(filename)
                self.X_test.append(img_to_array(load_img(self.test_dir+filename)))
        self.X_test = np.array(self.X_test, dtype=float)
        self.X_test = rgb2lab(1.0/255*self.X_test)[:,:,:,0]
        self.X_test = self.X_test.reshape(self.X_test.shape+(1,))
        
    def predict(self, model_key):
        self.y_test = self.model_dict[model_key].predict(self.X_test)
        self.y_test *= 128
    
    def colorized_output(self, model_key):
        for i in range(len(self.y_test)):
            result = np.zeros((self.img_size, self.img_size, 3))
            result[:,:,0] = self.X_test[i][:,:,0]
            result[:,:,1:] = self.y_test[i]
            os.system('mkdir Results_' + model_key + '_' + self.model_type)
            imsave('Results_' + model_key + '_' + self.model_type + '/' + self.test_filenames[i], lab2rgb(result))