In [24]:
import keras
from keras.applications.inception_resnet_v2 import InceptionResNetV2
from keras.preprocessing import image
from keras.engine import Layer
from keras.applications.inception_resnet_v2 import preprocess_input
from keras.layers import Conv2D, UpSampling2D, InputLayer, Conv2DTranspose, Input, Reshape, merge, concatenate
from keras.layers import Activation, Dense, Dropout, Flatten
from keras.layers.normalization import BatchNormalization
from keras.callbacks import TensorBoard 
from keras.models import Sequential, Model
from keras.layers.core import RepeatVector, Permute
from keras.preprocessing.image import ImageDataGenerator, array_to_img, img_to_array, load_img
from skimage.color import rgb2lab, lab2rgb, rgb2gray, gray2rgb
from skimage.transform import resize
from skimage.io import imsave
import numpy as np
import os
import random
import tensorflow as tf
import time

In [25]:
print(tf.__version__)
print( keras.__version__)

1.12.0
2.2.4.2


# 取得training data

In [28]:
# Get images
X = []
for filename in os.listdir('Shiitake_jyneda/shiitake_512/'):
    X.append(img_to_array(load_img('Shiitake_jyneda/shiitake_512/'+filename)))
X = np.array(X, dtype=float)
Xtrain = 1.0/255*X


#Load weights
inception = InceptionResNetV2(weights='imagenet', include_top=True)
inception.graph = tf.get_default_graph()

# 如果出現 類似'data / training / images / .ipynb_checkpoints'的error 

In [27]:
!rm -rf Shiitake_jyneda/shiitake_512/.ipynb_checkpoints

In [29]:
embed_input = Input(shape=(1000,))

#Encoder
encoder_input = Input(shape=(512, 512, 1,))
print("encoder_input:",encoder_input)
encoder_output = Conv2D(32, (3,3), activation='relu', padding='same', strides=2)(encoder_input)
print("encoder_output00:",encoder_output)
encoder_output = Conv2D(64, (3,3), activation='relu', padding='same')(encoder_output)
print("encoder_output000:",encoder_output)
encoder_output = Conv2D(64, (3,3), activation='relu', padding='same', strides=2)(encoder_output)
print("encoder_output0:",encoder_output)
encoder_output = Conv2D(128, (3,3), activation='relu', padding='same')(encoder_output)
print("encoder_output1:",encoder_output)
encoder_output = Conv2D(128, (3,3), activation='relu', padding='same', strides=2)(encoder_output)
print("encoder_output2:",encoder_output)
encoder_output = Conv2D(256, (3,3), activation='relu', padding='same')(encoder_output)
print("encoder_output3:",encoder_output)
encoder_output = Conv2D(256, (3,3), activation='relu', padding='same', strides=2)(encoder_output)
print("encoder_output4:",encoder_output)
encoder_output = Conv2D(512, (3,3), activation='relu', padding='same')(encoder_output)
print("encoder_output5:",encoder_output)
encoder_output = Conv2D(512, (3,3), activation='relu', padding='same')(encoder_output)
print("encoder_output6:",encoder_output)
encoder_output = Conv2D(256, (3,3), activation='relu', padding='same')(encoder_output)
print("encoder_output7:",encoder_output)
#Fusion
fusion_output = RepeatVector(32 * 32)(embed_input)
print("fusion_output0",fusion_output)
fusion_output = Reshape(([32, 32, 1000]))(fusion_output)
print("fusion_output1",fusion_output)
fusion_output = concatenate([encoder_output, fusion_output], axis=3) 
print("fusion_output2",fusion_output)
fusion_output = Conv2D(256, (1, 1), activation='relu', padding='same')(fusion_output) 
print("fusion_output3",fusion_output)
#Decoder
decoder_output = Conv2D(128, (3,3), activation='relu', padding='same')(fusion_output)
print("decoder_output0",decoder_output)
decoder_output = UpSampling2D((2, 2))(decoder_output)
print("decoder_output1",decoder_output)
decoder_output = Conv2D(64, (3,3), activation='relu', padding='same')(decoder_output)
print("decoder_output2",decoder_output)
decoder_output = UpSampling2D((2, 2))(decoder_output)
print("decoder_output3",decoder_output)
decoder_output = Conv2D(32, (3,3), activation='relu', padding='same')(decoder_output)
print("decoder_output4",decoder_output)
decoder_output = Conv2D(16, (3,3), activation='relu', padding='same')(decoder_output)
print("decoder_output5",decoder_output)
decoder_output = Conv2D(2, (3, 3), activation='tanh', padding='same')(decoder_output)
print("decoder_output6",decoder_output)
decoder_output = UpSampling2D((4, 4))(decoder_output)
print("decoder_output7",decoder_output)
model = Model(inputs=[encoder_input, embed_input], outputs=decoder_output)

encoder_input: Tensor("input_9:0", shape=(?, 512, 512, 1), dtype=float32)
encoder_output00: Tensor("conv2d_642/Relu:0", shape=(?, 256, 256, 32), dtype=float32)
encoder_output000: Tensor("conv2d_643/Relu:0", shape=(?, 256, 256, 64), dtype=float32)
encoder_output0: Tensor("conv2d_644/Relu:0", shape=(?, 128, 128, 64), dtype=float32)
encoder_output1: Tensor("conv2d_645/Relu:0", shape=(?, 128, 128, 128), dtype=float32)
encoder_output2: Tensor("conv2d_646/Relu:0", shape=(?, 64, 64, 128), dtype=float32)
encoder_output3: Tensor("conv2d_647/Relu:0", shape=(?, 64, 64, 256), dtype=float32)
encoder_output4: Tensor("conv2d_648/Relu:0", shape=(?, 32, 32, 256), dtype=float32)
encoder_output5: Tensor("conv2d_649/Relu:0", shape=(?, 32, 32, 512), dtype=float32)
encoder_output6: Tensor("conv2d_650/Relu:0", shape=(?, 32, 32, 512), dtype=float32)
encoder_output7: Tensor("conv2d_651/Relu:0", shape=(?, 32, 32, 256), dtype=float32)
fusion_output0 Tensor("repeat_vector_3/Tile:0", shape=(?, 1024, 1000), dtype=f

# 設定參數&訓練

In [30]:
def create_inception_embedding(grayscaled_rgb):
    grayscaled_rgb_resized = []
    for i in grayscaled_rgb:
        i = resize(i, (299, 299, 3), mode='constant')
        grayscaled_rgb_resized.append(i)
    grayscaled_rgb_resized = np.array(grayscaled_rgb_resized)
    grayscaled_rgb_resized = preprocess_input(grayscaled_rgb_resized)
    with inception.graph.as_default():
        embed = inception.predict(grayscaled_rgb_resized)
    return embed

# Image transformer
datagen = ImageDataGenerator(
        shear_range=0.2,
        zoom_range=0.2,
        rotation_range=20,
        horizontal_flip=True
)

#Generate training data
batch_size = 1

def image_a_b_gen(batch_size):
    for batch in datagen.flow(Xtrain, batch_size=batch_size):
        grayscaled_rgb = gray2rgb(rgb2gray(batch))
        embed = create_inception_embedding(grayscaled_rgb)
        lab_batch = rgb2lab(batch)
        X_batch = lab_batch[:,:,:,0]
        X_batch = X_batch.reshape(X_batch.shape+(1,))
        Y_batch = lab_batch[:,:,:,1:] / 128
        yield ([X_batch, create_inception_embedding(grayscaled_rgb)], Y_batch)


#Train model      
model.compile(optimizer='adam', loss='mse')
model.fit_generator(image_a_b_gen(batch_size), epochs=1500, steps_per_epoch=1)

Epoch 1/1500
Epoch 2/1500
Epoch 3/1500
Epoch 4/1500
Epoch 5/1500
Epoch 6/1500
Epoch 7/1500
Epoch 8/1500
Epoch 9/1500
Epoch 10/1500
Epoch 11/1500
Epoch 12/1500
Epoch 13/1500
Epoch 14/1500
Epoch 15/1500
Epoch 16/1500
Epoch 17/1500
Epoch 18/1500
Epoch 19/1500
Epoch 20/1500
Epoch 21/1500
Epoch 22/1500
Epoch 23/1500
Epoch 24/1500
Epoch 25/1500
Epoch 26/1500
Epoch 27/1500
Epoch 28/1500
Epoch 29/1500
Epoch 30/1500
Epoch 31/1500
Epoch 32/1500
Epoch 33/1500
Epoch 34/1500
Epoch 35/1500
Epoch 36/1500
Epoch 37/1500
Epoch 38/1500
Epoch 39/1500
Epoch 40/1500
Epoch 41/1500
Epoch 42/1500
Epoch 43/1500
Epoch 44/1500
Epoch 45/1500
Epoch 46/1500
Epoch 47/1500
Epoch 48/1500
Epoch 49/1500
Epoch 50/1500
Epoch 51/1500
Epoch 52/1500
Epoch 53/1500
Epoch 54/1500
Epoch 55/1500
Epoch 56/1500
Epoch 57/1500
Epoch 58/1500
Epoch 59/1500
Epoch 60/1500
Epoch 61/1500
Epoch 62/1500
Epoch 63/1500
Epoch 64/1500
Epoch 65/1500
Epoch 66/1500
Epoch 67/1500
Epoch 68/1500
Epoch 69/1500
Epoch 70/1500
Epoch 71/1500
Epoch 72/1500
E

<keras.callbacks.History at 0x7f13a966ae48>

In [31]:
model.save('Shiitake_jyneda/Shiitak_jyneda_512_1img.h5')  # creates a HDF5 file 'my_model.h5'

In [35]:
from keras.models import load_model
model = load_model('Shiitake_jyneda/Shiitak_jyneda_512_1img.h5')

# 執行預測(上色)

In [38]:
start = time.time()
color_me = []
for filename in os.listdir('Shiitake_jyneda/shiitake_512_gray/'):
    color_me.append(img_to_array(load_img('Shiitake_jyneda/shiitake_512_gray/'+filename)))
color_me = np.array(color_me, dtype=float)
gray_me = gray2rgb(rgb2gray(1.0/255*color_me))
color_me_embed = create_inception_embedding(gray_me)
color_me = rgb2lab(1.0/255*color_me)[:,:,:,0]
color_me = color_me.reshape(color_me.shape+(1,))


# Test model
output = model.predict([color_me, color_me_embed])
output = output * 128

# Output colorizations
for i in range(len(output)):
    cur = np.zeros((512, 512, 3))
    cur[:,:,0] = color_me[i][:,:,0]
    cur[:,:,1:] = output[i]
    imsave("Shiitake_jyneda/result/img_"+str(i)+".png", lab2rgb(cur))
end = time.time() 
print("執行時間：%f 秒" % (end - start))
    

執行時間：15.826204 秒


In [37]:
!rm -rf Shiitake_jyneda/shiitake_512_gray/.ipynb_checkpoints

In [None]:
! zip -r Shiitake_gray.zip Shiitake/Shiitake_stream_512_gray