# UNet with VGG16

In [None]:
def get_backbone(width = WIDTH, height = HEIGHT):
    backbone = keras.applications.VGG16(include_top=False, input_shape=(width, height, 3), weights = "imagenet")
    output_0, output_1, output_2, output_3, output_4, fin_output = [
        backbone.get_layer(layer_name).output
        for layer_name in ["block1_conv2","block2_conv2", "block3_conv3", "block4_conv3", "block5_conv3", "block5_pool"]
    ]
        
    return keras.Model(
        inputs=[backbone.inputs],
        outputs=[ fin_output, output_4, output_3, output_2, output_1, output_0 ],
        )

In [None]:
def twice_Conv2D(input_image, num_filters, i):
    for j in range(2):
        output = keras.layers.Conv2D(num_filters, 3, 1, 
                                    padding = "same",
                                    kernel_initializer='he_normal',
                                    name=f"block{i+1}_Conv{j+1}")(input_image)
        output = keras.layers.BatchNormalization(name=f"block{i+1}_BN{j+1}")(output)
        output = keras.layers.ReLU(name = f"block{i+1}_ReLU{j+1}")(output)
    return output

def Unet(num_classes = NUM_CLASSES, width = WIDTH, height = HEIGHT):
    input_image = keras.Input(shape=(width, height, 3), name="Image")
    output_list = get_backbone()(input_image, training = False)
    output = output_list[0]
    for i in range(2):
        output = keras.layers.Conv2D(512, 3, 1, 
                                     padding = "same",
                                     kernel_initializer='he_normal',
                                     name=f"block{0}_Conv{i}")(output)
        output = keras.layers.BatchNormalization(name=f"block{0}_BN{i}")(output)
        output = keras.layers.ReLU(name=f"block{0}_ReLU{i}")(output)
    
    for i, filters in enumerate([256, 128, 64, 32, 16]):
        output = keras.layers.UpSampling2D(2, name=f"block{i+1}_UpSampling{0}")(output)
        output = keras.layers.Conv2D(output.shape[-1], 2, 1, 
                                     padding = "same",
                                     kernel_initializer='he_normal',
                                     name=f"block{i+1}_Conv{0}")(output)
        output = keras.layers.BatchNormalization(name=f"block{i+1}_BN{0}")(output)
        output = keras.layers.ReLU(name=f"block{i+1}_ReLU{0}")(output)
        output = keras.layers.concatenate([output, output_list[i+1]], name = f"block{i+1}_concat")
        output = twice_Conv2D(output, filters, i)
                
    output = keras.layers.Conv2D(num_classes, 1, 1, 
                                 padding = "same", 
                                 activation = "softmax", 
                                 kernel_initializer=tf.initializers.RandomNormal(0.0, 0.01),
                                 bias_initializer = tf.constant_initializer(-np.log((1 - 0.01) / 0.01)),
                                 name = f"block{5}_Conv_ReLU")(output)
#     print("output.shape",output.shape) #output.shape (None, 1024, 1024, 3)
    model = keras.models.Model(inputs = input_image, outputs=output)
    return model
# Unet(3,1024,1024)

In [None]:
# Check Test set test
input_image = tf.keras.Input(shape=(1024, 1024, 3), name="Image")
dataset = Dataset_Generator()
model = Unet()
test_dataset = tf.data.Dataset.from_generator(
                dataset.test_generator,
                (tf.float32, tf.int32),
                (tf.TensorShape([1, HEIGHT, WIDTH, 3]), tf.TensorShape([1,  HEIGHT, WIDTH])),
                )
predictions = model(input_image, training = True)
# print(predictions.shape)# (None, 1024, 1024, 3)

inference_model = tf.keras.Model(inputs=input_image, outputs=predictions)# cc

for i,test in enumerate(test_dataset):
    img, mask = test
    prediction = inference_model.predict(img)
                
    img = img[0].numpy()
    mask = mask[0].numpy()
    
    prediction = prediction[0]
    prediction = tf.math.argmax(prediction, -1)
    prediction = prediction.numpy()

    fig = plt.figure(10, figsize = (20,20))
    ax1 = fig.add_subplot(1, 3, 1)
    img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
    ax1.imshow(img)
    ax1.set_title('Image')
    ax1.axis("off")

    ax2 = fig.add_subplot(1, 3, 2)
    ax2.imshow(mask)
    ax2.set_title('Ground Truth Mask')
    ax2.axis("off")

    ax3 = fig.add_subplot(1, 3, 3)
    ax3.imshow(prediction)
    ax3.set_title('Prediction')
    ax3.axis("off")
    plt.show()

In [None]:
model = Unet() 