## Model Prediction Test

In [None]:
from keras.models import load_model
from sklearn.preprocessing import LabelEncoder
from keras.utils import np_utils

import numpy as np
import os
import glob
import cv2
from collections import Counter
import random


import matplotlib.pyplot as plt
plt.rcParams['figure.figsize'] = (10.0, 5.0)

def resize_image(image, shape = (500, 500)):
    resized_image = cv2.resize(image, shape)
    return resized_image


def label_one_hot_encoding(labe_list):
    encoder = LabelEncoder()
    transfomed_label = encoder.fit_transform(labe_list)
    one_hot_encoded_labels =  np_utils.to_categorical(transfomed_label)
    one_hot_label_dict = {}
    for i  in range(len(labe_list)):
        one_hot_label_dict[labe_list[i]] = one_hot_encoded_labels[i] 
    
    print("one hot encoding: ", one_hot_label_dict)
    return one_hot_label_dict


def load_samples(DATA_PATH, one_hot_encoding_dict):
    labels = one_hot_encoding_dict.keys()    
    samples = [] # [one_hot_label, image_path]
    temp_size = 0
    for dir in DATA_PATH: 
        print("loading data from: " + dir)
        for key in labels:           
            samples += [[one_hot_encoding_dict[key],file] for file in glob.glob(os.path.join(dir, key + '*.jpg'))]
            temp_size = len(samples) - temp_size    
            print("total " + str(temp_size) + " " + key + " images loaded.")
            temp_size = len(samples)
        
        print("total " + str(temp_size) + " raw data samples loaded.")
        
    return samples  




LABEL_LIST = ["dog", "cat"]
TRAIN_DATA_PATH_LIST = ["./train"]
MODEL_PATH = "./models"

model_path_list = sorted(glob.glob(os.path.join(MODEL_PATH + '/*.h5')))
print("models: ", model_path_list)



one_hot_encoding_dict = label_one_hot_encoding(LABEL_LIST)
print("label encoding is done.", one_hot_encoding_dict)
    
raw_samples = load_samples(TRAIN_DATA_PATH_LIST, one_hot_encoding_dict)
print("raw sample loaded.")


In [None]:
#model = load_model("dog_vs_cat_model.h5")

def predict(model, image, one_hot_encoding_dict):
    inv_encoding = {tuple(map(int, v)): k for k, v in one_hot_encoding_dict.items()}
    
    image_input_shape = (model.input_shape[1], model.input_shape[2])
    
    print(image_input_shape, (image.shape[0], image.shape[1]))
    
    if (image.shape[0], image.shape[1]) != image_input_shape:
        image = resize_image(image, image_input_shape)
   
    
    image = np.expand_dims(image, axis=0)

    
    raw_prediction = model.predict(image)
    raw_prediction = np.squeeze(raw_prediction)

    
    normalized_prediction = tuple([1 if max(raw_prediction) == i else 0 for i in raw_prediction])

    
    return raw_prediction , inv_encoding[normalized_prediction]
    
    
    

'''
model = load_model(models[3])

img = cv2.imread(raw_samples[24500][1])
img = resize_image(img, (200, 200))
plt.imshow(cv2.cvtColor(img, cv2.COLOR_BGR2RGB))
plt.show()

raw_prediction, label  = predict(model, img, one_hot_encoding_dict)
print("raw: ", raw_prediction, ", label: ", label)

'''

test_image_path = sorted(glob.glob(os.path.join("./test_image" + '/*.jpg')))

test_image_path = test_image_path[:2]
model_path_list = model_path_list[:2]
model_dict = dict([(path.split("/")[-1].split("\\")[-1].split(".")[0], path)  for path in model_path_list])
print(model_dict)    


def predict_and_plot(image_path_list, model_dict, one_hot_encoding_dict):
    num_column = 1
    fig, axes = plt.subplots(int(len(image_path_list)/2) + 1, num_column)
    axes = np.asarray(axes).flatten()
    ax_index = 0
    
    for image_path in image_path_list:
        image = cv2.imread(image_path)
        prediction_text = ""
        for model_name in sorted(model_dict.keys()):
            model = load_model(model_dict[model_name])
            raw_prediction, label = predict(model, image, one_hot_encoding_dict)
            prediction_text += model_name + " predction: " + label + " raw output: " + str(raw_prediction) + "\n"
            #print(prediction_text)
        
        axes[ax_index].set_title(prediction_text)
        axes[ax_index].imshow(cv2.cvtColor(image, cv2.COLOR_BGR2RGB))
        ax_index += 1
     
    plt.show()
    
    


predict_and_plot(test_image_path, model_dict, one_hot_encoding_dict)