In [15]:
from tensorflow.keras import Sequential, Model
from tensorflow.keras.applications import InceptionResNetV2, VGG19
from keras.layers import *
import keras.backend as K

In [16]:
vgg_cl = VGG19()
vgg_cl.trainable = False
vgg_cl.summary()

Model: "vgg19"
_________________________________________________________________
 Layer (type)                Output Shape              Param #   
 input_8 (InputLayer)        [(None, 224, 224, 3)]     0         
                                                                 
 block1_conv1 (Conv2D)       (None, 224, 224, 64)      1792      
                                                                 
 block1_conv2 (Conv2D)       (None, 224, 224, 64)      36928     
                                                                 
 block1_pool (MaxPooling2D)  (None, 112, 112, 64)      0         
                                                                 
 block2_conv1 (Conv2D)       (None, 112, 112, 128)     73856     
                                                                 
 block2_conv2 (Conv2D)       (None, 112, 112, 128)     147584    
                                                                 
 block2_pool (MaxPooling2D)  (None, 56, 56, 128)       0     

In [17]:
model_ph = Sequential([
    Conv2D(64, (3, 3), padding="same", activation="relu"),
    Conv2D(64, (3, 3), padding="same", activation="relu", strides=(2, 2)),  # 112 112 64
    Conv2D(256, (3, 3), padding="same", activation="relu"),
    Conv2D(256, (3, 3), padding="same", activation="relu", strides=(2, 2)),  # 56 56 256
    Conv2D(512, (3, 3), padding="same", activation="relu"),
    Conv2D(512, (3, 3), padding="same", activation="relu", strides=(2, 2)),  # 28 28 512
    Conv2D(1024, (3, 3), padding="same", activation="relu"), # 28,28 , 1024
    Conv2D(1024, (3, 3), padding="same", activation="relu", strides=(2, 2)) # 14 14 1024


])
model_out = Sequential([
    Input((14, 14, 1030), 1),
    UpSampling2D((2, 2)), # 28,28,1045
    Conv2D(512, (3, 3), padding="same", activation="relu"),
    UpSampling2D((2, 2)),  # 56,56,514
    Conv2D(128, (3, 3), padding="same", activation="relu"),
    UpSampling2D((2, 2)),  # 112 112 128
    Conv2D(64, (3, 3), padding="same", activation="relu"),
    Conv2D(32, (3, 3), padding="same", activation="relu"),
    Conv2D(2, (3, 3), padding="same", activation="tanh"),
    UpSampling2D((2, 2)),  # 224 224
])


In [18]:
inp = Input((224, 224, 1), 1)
inp_for_vgg = concatenate([inp, inp, inp])
out_vgg_cl = vgg_cl(inp_for_vgg)  # 1 1000
out_ph_model = model_ph(inp)  # 28 28 512

concatenate_vgg_zero = concatenate([
    out_vgg_cl,
    K.constant([[0 for _ in range(176)] for _ in range(out_vgg_cl.shape[0])], dtype=float)
])

concatenate_vgg_re = Reshape((14, 14, 6))(concatenate_vgg_zero)

concatenate_vgg_ph = concatenate([concatenate_vgg_re, out_ph_model])  # 28 28 514

output = model_out(concatenate_vgg_ph)

model = Model(inp, output)
model.summary()

Model: "model_2"
__________________________________________________________________________________________________
 Layer (type)                   Output Shape         Param #     Connected to                     
 input_10 (InputLayer)          [(1, 224, 224, 1)]   0           []                               
                                                                                                  
 concatenate_6 (Concatenate)    (1, 224, 224, 3)     0           ['input_10[0][0]',               
                                                                  'input_10[0][0]',               
                                                                  'input_10[0][0]']               
                                                                                                  
 vgg19 (Functional)             (None, 1000)         143667240   ['concatenate_6[0][0]']          
                                                                                            

In [19]:
from data_preparation import *

x, ab, rgb = create_data_imagenet()
x.shape, ab.shape, rgb.shape

((1, 224, 224, 1), (1, 224, 224, 2), (224, 224, 3))

In [20]:
model.compile(optimizer="adam",
              loss='mse',
              metrics=['accuracy'])
epochs = 300
for i in range(epochs):
    print(f"epochs = {i}",end=" ")
    itr = IteratorImages()

    for j in range(14500):
        print(j)
        x,y = next(itr)
        model.fit(x, y, epochs=1)
        if j % 1000 == 0:
            model.save(f"model_{i}_{j}")

epochs = 0 

NameError: name 'IteratorImages' is not defined

In [None]:
a = model(x)
res = lab_ab_in_rgb(x, a)
rgb = np.array(rgb, dtype=int)
plt.subplot(1, 2, 1)
print(np.min(rgb))
plt.imshow(rgb)

plt.subplot(1, 2, 2)
plt.imshow(res)
plt.show()

In [None]:
avg_img = avg_photo(rgb)
x2, ab2 = grey_in_lab(avg_img.astype(int))