In [1]:
import keras
from keras.applications.inception_resnet_v2 import InceptionResNetV2
from keras.preprocessing import image
from keras.engine import Layer
from keras.applications.inception_resnet_v2 import preprocess_input
from keras.layers import Conv2D, UpSampling2D, InputLayer, Conv2DTranspose, Input, Reshape, merge, concatenate
from keras.layers import Activation, Dense, Dropout, Flatten
from keras.layers.normalization import BatchNormalization
from keras.callbacks import TensorBoard 
from keras.models import Sequential, Model
from keras.layers.core import RepeatVector, Permute
from keras.preprocessing.image import ImageDataGenerator, array_to_img, img_to_array, load_img
from skimage.color import rgb2lab, lab2rgb, rgb2gray, gray2rgb
from skimage.transform import resize
from skimage.io import imsave
import numpy as np
import os
import random
import tensorflow as tf
import time

Using TensorFlow backend.


# 取得training data 

In [2]:
# Get images(training的input(IR照片))
X = []
for filename in os.listdir('container1012/IR_resize1024/'):
    X.append(img_to_array(load_img('container1012/IR_resize1024/'+filename)))
X = np.array(X, dtype=float)
Xtrain = 1.0/255*X

# Get images(training的label(白天照片))
Y = []
for filename in os.listdir('container1012/daylight_resize1024/'):
    Y.append(img_to_array(load_img('container1012/daylight_resize1024/'+filename)))
Y = np.array(Y, dtype=float)
Ytrain = 1.0/255*Y

#Load weights
inception = InceptionResNetV2(weights='imagenet', include_top=True)
inception.graph = tf.get_default_graph()

# 如果出現 類似'data / training / images / .ipynb_checkpoints'的error 

In [5]:
!rm -rf container1012/daylight_resize1024/.ipynb_checkpoints

In [3]:
embed_input = Input(shape=(1000,))

#Encoder
encoder_input = Input(shape=(1024, 1024, 1,))
print("encoder_input:",encoder_input)

encoder_output = Conv2D(16, (3,3), activation='relu', padding='same', strides=2)(encoder_input)
print("encoder_output00:",encoder_output)
encoder_output = Conv2D(32, (3,3), activation='relu', padding='same')(encoder_output)
print("encoder_output000:",encoder_output)

encoder_output = Conv2D(32, (3,3), activation='relu', padding='same', strides=2)(encoder_output)
print("encoder_output00:",encoder_output)
encoder_output = Conv2D(64, (3,3), activation='relu', padding='same')(encoder_output)
print("encoder_output000:",encoder_output)

encoder_output = Conv2D(64, (3,3), activation='relu', padding='same', strides=2)(encoder_output)
print("encoder_output0:",encoder_output)
encoder_output = Conv2D(128, (3,3), activation='relu', padding='same')(encoder_output)
print("encoder_output1:",encoder_output)
encoder_output = Conv2D(128, (3,3), activation='relu', padding='same', strides=2)(encoder_output)
print("encoder_output2:",encoder_output)
encoder_output = Conv2D(256, (3,3), activation='relu', padding='same')(encoder_output)
print("encoder_output3:",encoder_output)
encoder_output = Conv2D(256, (3,3), activation='relu', padding='same', strides=2)(encoder_output)
print("encoder_output4:",encoder_output)
encoder_output = Conv2D(512, (3,3), activation='relu', padding='same')(encoder_output)
print("encoder_output5:",encoder_output)
encoder_output = Conv2D(512, (3,3), activation='relu', padding='same')(encoder_output)
print("encoder_output6:",encoder_output)
encoder_output = Conv2D(256, (3,3), activation='relu', padding='same')(encoder_output)
print("encoder_output7:",encoder_output)
#Fusion
fusion_output = RepeatVector(32 * 32)(embed_input)
print("fusion_output0",fusion_output)
fusion_output = Reshape(([32, 32, 1000]))(fusion_output)
print("fusion_output1",fusion_output)
fusion_output = concatenate([encoder_output, fusion_output], axis=3) 
print("fusion_output2",fusion_output)
fusion_output = Conv2D(256, (1, 1), activation='relu', padding='same')(fusion_output) 
print("fusion_output3",fusion_output)
#Decoder
decoder_output = Conv2D(128, (3,3), activation='relu', padding='same')(fusion_output)
print("decoder_output0",decoder_output)
decoder_output = UpSampling2D((2, 2))(decoder_output)
print("decoder_output1",decoder_output)
decoder_output = Conv2D(64, (3,3), activation='relu', padding='same')(decoder_output)
print("decoder_output2",decoder_output)
decoder_output = UpSampling2D((2, 2))(decoder_output)
print("decoder_output3",decoder_output)
decoder_output = Conv2D(32, (3,3), activation='relu', padding='same')(decoder_output)
print("decoder_output4",decoder_output)
decoder_output = Conv2D(16, (3,3), activation='relu', padding='same')(decoder_output)
print("decoder_output5",decoder_output)
decoder_output = Conv2D(2, (3, 3), activation='tanh', padding='same')(decoder_output)
print("decoder_output6",decoder_output)
decoder_output = UpSampling2D((8, 8))(decoder_output)
print("decoder_output7",decoder_output)
model = Model(inputs=[encoder_input, embed_input], outputs=decoder_output)

encoder_input: Tensor("input_3:0", shape=(?, 1024, 1024, 1), dtype=float32)
encoder_output00: Tensor("conv2d_204/Relu:0", shape=(?, 512, 512, 16), dtype=float32)
encoder_output000: Tensor("conv2d_205/Relu:0", shape=(?, 512, 512, 32), dtype=float32)
encoder_output00: Tensor("conv2d_206/Relu:0", shape=(?, 256, 256, 32), dtype=float32)
encoder_output000: Tensor("conv2d_207/Relu:0", shape=(?, 256, 256, 64), dtype=float32)
encoder_output0: Tensor("conv2d_208/Relu:0", shape=(?, 128, 128, 64), dtype=float32)
encoder_output1: Tensor("conv2d_209/Relu:0", shape=(?, 128, 128, 128), dtype=float32)
encoder_output2: Tensor("conv2d_210/Relu:0", shape=(?, 64, 64, 128), dtype=float32)
encoder_output3: Tensor("conv2d_211/Relu:0", shape=(?, 64, 64, 256), dtype=float32)
encoder_output4: Tensor("conv2d_212/Relu:0", shape=(?, 32, 32, 256), dtype=float32)
encoder_output5: Tensor("conv2d_213/Relu:0", shape=(?, 32, 32, 512), dtype=float32)
encoder_output6: Tensor("conv2d_214/Relu:0", shape=(?, 32, 32, 512), dt

# 設定參數

In [4]:
def create_inception_embedding(grayscaled_rgb):
    grayscaled_rgb_resized = []
    for i in grayscaled_rgb:
        i = resize(i, (299, 299, 3), mode='constant')
        grayscaled_rgb_resized.append(i)
    grayscaled_rgb_resized = np.array(grayscaled_rgb_resized)
    grayscaled_rgb_resized = preprocess_input(grayscaled_rgb_resized)
    with inception.graph.as_default():
        embed = inception.predict(grayscaled_rgb_resized)
    return embed

# Image transformer
datagen = ImageDataGenerator(
#         shear_range=0.2,
#         zoom_range=0.2,
#         rotation_range=20,
#         horizontal_flip=True
)

#Generate training data
batch_size = 1

def image_a_b_gen(batch_size): 
    for batch,batch_label in zip(datagen.flow(Xtrain, batch_size=batch_size),datagen.flow(Ytrain, batch_size=batch_size)):
        grayscaled_rgb = gray2rgb(rgb2gray(batch))
        embed = create_inception_embedding(grayscaled_rgb)
        lab_batch = rgb2lab(batch)
        X_batch = lab_batch[:,:,:,0]
        X_batch = X_batch.reshape(X_batch.shape+(1,))
        ###########################################################以下是讀label的兩個通道
        lab_batch_label = rgb2lab(batch_label)
        Y_batch = lab_batch_label[:,:,:,1:] / 128
        yield ([X_batch, create_inception_embedding(grayscaled_rgb)], Y_batch)

# 訓練

In [7]:
#Train model      
model.compile(optimizer='adam', loss='mse')
model.fit_generator(image_a_b_gen(batch_size), epochs=3000, steps_per_epoch=1)

Epoch 1/3000
Epoch 2/3000
Epoch 3/3000
Epoch 4/3000
Epoch 5/3000
Epoch 6/3000
Epoch 7/3000
Epoch 8/3000
Epoch 9/3000
Epoch 10/3000
Epoch 11/3000
Epoch 12/3000
Epoch 13/3000
Epoch 14/3000
Epoch 15/3000
Epoch 16/3000
Epoch 17/3000
Epoch 18/3000
Epoch 19/3000
Epoch 20/3000
Epoch 21/3000
Epoch 22/3000
Epoch 23/3000
Epoch 24/3000
Epoch 25/3000
Epoch 26/3000
Epoch 27/3000
Epoch 28/3000
Epoch 29/3000
Epoch 30/3000
Epoch 31/3000
Epoch 32/3000
Epoch 33/3000
Epoch 34/3000
Epoch 35/3000
Epoch 36/3000
Epoch 37/3000
Epoch 38/3000
Epoch 39/3000
Epoch 40/3000
Epoch 41/3000
Epoch 42/3000
Epoch 43/3000
Epoch 44/3000
Epoch 45/3000
Epoch 46/3000
Epoch 47/3000
Epoch 48/3000
Epoch 49/3000
Epoch 50/3000
Epoch 51/3000
Epoch 52/3000
Epoch 53/3000
Epoch 54/3000
Epoch 55/3000
Epoch 56/3000
Epoch 57/3000
Epoch 58/3000
Epoch 59/3000
Epoch 60/3000
Epoch 61/3000
Epoch 62/3000
Epoch 63/3000
Epoch 64/3000
Epoch 65/3000
Epoch 66/3000
Epoch 67/3000
Epoch 68/3000
Epoch 69/3000
Epoch 70/3000
Epoch 71/3000
Epoch 72/3000
E

<keras.callbacks.History at 0x7f79e04699b0>

In [8]:
model.save('container1012/resize1024_2_model.h5')  # creates a HDF5 file 'my_model.h5'

In [9]:
from keras.models import load_model
model = load_model('container1012/resize1024_2_model.h5')

# 執行預測(上色)

In [10]:
start = time.time()
color_me = []
#以下輸入灰階照片(待上色)
for filename in os.listdir('test1012/IRStream_resize1024_gray/'):
    color_me.append(img_to_array(load_img('test1012/IRStream_resize1024_gray/'+filename)))
color_me = np.array(color_me, dtype=float)
gray_me = gray2rgb(rgb2gray(1.0/255*color_me))
color_me_embed = create_inception_embedding(gray_me)
color_me = rgb2lab(1.0/255*color_me)[:,:,:,0]
color_me = color_me.reshape(color_me.shape+(1,))


# Test model
output = model.predict([color_me, color_me_embed])
output = output * 128

# Output colorizations
for i in range(len(output)):
    cur = np.zeros((1024, 1024, 3))
    cur[:,:,0] = color_me[i][:,:,0]
    cur[:,:,1:] = output[i]
    imsave("test1012/IRColor_resize1024_2/img_"+str(i)+".jpg", lab2rgb(cur))
end = time.time() 
print("執行時間：%f 秒" % (end - start))
    

  .format(dtypeobj_in, dtypeobj_out))


執行時間：202.279401 秒


In [12]:
!rm -rf test1012/IRStream_resize1024_gray/.ipynb_checkpoints

In [11]:
!zip -r colorstream1024_2.zip test1012/IRColor_resize1024_2

  adding: test1012/IRColor_resize1024_2/ (stored 0%)
  adding: test1012/IRColor_resize1024_2/img_0.jpg (deflated 1%)
  adding: test1012/IRColor_resize1024_2/img_1.jpg (deflated 1%)
  adding: test1012/IRColor_resize1024_2/img_2.jpg (deflated 1%)
  adding: test1012/IRColor_resize1024_2/img_3.jpg (deflated 1%)
  adding: test1012/IRColor_resize1024_2/img_4.jpg (deflated 1%)
  adding: test1012/IRColor_resize1024_2/img_5.jpg (deflated 1%)
  adding: test1012/IRColor_resize1024_2/img_6.jpg (deflated 1%)
  adding: test1012/IRColor_resize1024_2/img_7.jpg (deflated 1%)
  adding: test1012/IRColor_resize1024_2/img_8.jpg (deflated 1%)
  adding: test1012/IRColor_resize1024_2/img_9.jpg (deflated 1%)
  adding: test1012/IRColor_resize1024_2/img_10.jpg (deflated 1%)
  adding: test1012/IRColor_resize1024_2/img_11.jpg (deflated 1%)
  adding: test1012/IRColor_resize1024_2/img_12.jpg (deflated 1%)
  adding: test1012/IRColor_resize1024_2/img_13.jpg (deflated 1%)
  adding: test1012/IRColor_resize1024_2/img_14.