# Image Segmentation using VGGSegNet
# VGG-16 Network based Encoder-Decoder FCN

In [None]:
# import the necessary packages
import glob
import cv2
import numpy as np
import random
from keras.utils import plot_model
from VGGSegnet import VGGSegnet
import LoadBatches

## Set Training Parameters

In [None]:
# training data path
train_images_path = "data/images_prepped_train/"
train_segs_path = "data/annotations_prepped_train/"

# parameters of dataset 
n_classes = 10
input_height = 224
input_width = 224

# training hyper parameters
train_batch_size = 2
epochs = 1

## Build Model

In [None]:
# initialize model and optimizer
model = VGGSegnet(n_classes, input_height=input_height, input_width=input_width)

optimizer_name = 'adadelta'
model.compile(loss='categorical_crossentropy',
          optimizer=optimizer_name, metrics=['accuracy'])

print ("Model output shape",  model.output_shape)

## Load Training Data

In [None]:
# output dimentions
output_height = model.outputHeight
output_width = model.outputWidth

# load data into pre-batches
G = LoadBatches.imageSegmentationGenerator(
    train_images_path, train_segs_path,
    train_batch_size,  n_classes,
    input_height, input_width,
    output_height, output_width)

## Train Model

In [None]:
# train model
model.fit_generator(G, 512, epochs=epochs)

# save model & weights
model.save_weights('vggsegnet_weights_test.h5')
model.save('vggsegnet_model_test.h5')

## Evaluate Pre-trained Model

In [None]:
# test data path
test_images="data/images_prepped_test/"

In [None]:
# initialize model and optimizer
modelFN = VGGSegnet
optimizer_name = 'adadelta'

model = modelFN(n_classes, input_height=input_height, input_width=input_width)
model.compile(loss='categorical_crossentropy',
          optimizer=optimizer_name, metrics=['accuracy'])

# load weights from file
model.load_weights('weights/vggsegnet_weights.19.h5')

print ("Model output shape",  model.output_shape)

In [None]:
# load test data
images = glob.glob(test_images + "*.png")
images.sort()
print ("Test Set Size: ", len(images))

In [None]:
# output dimentions
output_height = model.outputHeight
output_width = model.outputWidth

# set random colors for output
colors = [(250, 206, 135), (0, 255, 255), (0, 255, 0), (64, 64, 64), (255, 255, 255),
          (34, 139, 34), (0, 0, 0), (255, 0, 255), (0, 0, 255), (0, 0, 128)]

# process image one by one
for imgName in images:
    X = LoadBatches.getImageArr(imgName, input_width, input_height)
    
    pr = model.predict(np.array([X]))[0]
    pr = pr.reshape((output_height,  output_width, n_classes)).argmax(axis=2)
    
    seg_img = np.zeros((output_height, output_width, 3))
    for c in range(n_classes):
        seg_img[:, :, 0] += ((pr[:, :] == c)*(colors[c][0])).astype('uint8')
        seg_img[:, :, 1] += ((pr[:, :] == c)*(colors[c][1])).astype('uint8')
        seg_img[:, :, 2] += ((pr[:, :] == c)*(colors[c][2])).astype('uint8')
    
    seg_img = cv2.resize(seg_img, (input_width, input_height))
    
    cv2.imwrite('data/prediction.png', seg_img)
    seg_img = cv2.imread('data/prediction.png')
    cv2.imshow("predictions", seg_img)
    cv2.waitKey(0)

# close the display window
cv2.destroyAllWindows()    