In [None]:
from keras.models import model_from_json
from keras.preprocessing.image import ImageDataGenerator
from keras.optimizers import SGD
import numpy as np
from sklearn.metrics import confusion_matrix
import matplotlib.pyplot as plt
import itertools

In [2]:
test_dir = "./dataset/val_image/"
inception_json = "./model/inception_model.json"
checkpoint_path_2 = './model/checkpoint_2.h5'

In [3]:
def main():
    test_data_gen = ImageDataGenerator(rescale=1. / 255)
    test_generator = test_data_gen.flow_from_directory(
        test_dir,
        target_size=(299, 299),
        batch_size=1,
        class_mode="categorical"
    )
    file_names = test_generator.filenames
    nb_samples = len(file_names)

    json_file = open(inception_json, 'r')
    loaded_model_json = json_file.read()
    json_file.close()
    loaded_model = model_from_json(loaded_model_json)
    loaded_model.load_weights(checkpoint_path_2)

    # evaluate loaded model on test data
    sgd = SGD(lr=0.0001, momentum=0.9, nesterov=True)
    loaded_model.compile(optimizer=sgd, loss='categorical_crossentropy', metrics=['accuracy'])
    score = loaded_model.predict_generator(test_generator, steps=nb_samples)

    for i in range(0, nb_samples):
        idx = np.argmax(score[i])
        print(idx, score[i][idx])
        
    y_pred = np.argmax(score, axis=1)
    cm = confusion_matrix(test_generator.classes, y_pred)
    plot_confusion_matrix(cm, idx, normalize=True, title="Normalize confusion matrix")
    plt.show()

In [None]:


def plot_confusion_matrix(cm, classes,
                          normalize=False,
                          title='Confusion matrix',
                          cmap=plt.cm.Blues):
    if normalize:
        cm = cm.astype('float') / cm.sum(axis=1, keepdims=True)

    plt.imshow(cm, interpolation='nearest', cmap=cmap)
    plt.title(title)
    plt.colorbar()
    tick_marks = np.arange(len(classes))
    plt.xticks(tick_marks, classes, rotation=45)
    plt.yticks(tick_marks, classes)
    fmt = ".2f" if normalize else 'd'
    thresh = cm.max() / 2.
    for i, j in itertools.product(range(cm.shape[0]), range(cm.shape[1])):
        plt.text(j, i, format(cm[i, j], fmt),
                 horizontalaligment="center",
                 color="white" if cm[i, j] > thresh else "black")
        plt.tight_layout()
        plt.ylabel("True label")
        plt.xlabel("Predicted label")

In [4]:
if __name__ == '__main__':
    main()

Found 300 images belonging to 100 classes.
46 0.51191705
66 0.3933231
51 0.9956701
75 0.43075076
97 0.76174384
29 0.67993844
65 0.57227725
77 0.70035625
18 0.44198772
78 0.41699308
36 0.75286645
0 0.98290145
89 0.21628495
71 0.99998975
77 0.6000119
4 0.37337434
19 0.5687937
19 0.96824306
36 0.5066966
44 0.595791
82 0.6916528
80 0.5429663
75 0.64374906
44 0.5821129
19 0.27172455
52 0.47531354
53 0.9955852
70 0.9317397
70 0.9990716
19 0.40822777
18 0.9514206
67 0.9707619
60 0.6843123
41 0.33773026
63 0.5767306
3 0.3447093
12 0.51234686
81 0.32455313
52 0.8661684
19 0.77063835
59 0.20143333
57 0.85878354
77 0.8079099
85 0.57485896
71 0.9997392
72 0.9972173
72 0.5448835
68 0.9997651
37 0.52963936
36 0.71400917
78 0.94628304
66 0.44700345
18 0.92396396
48 0.8497892
73 0.80848134
3 0.49802136
44 0.9707965
36 0.68528354
85 0.47348684
27 0.7626124
70 0.5629862
72 0.9999484
19 0.63772595
6 0.84269726
82 0.8743816
58 0.91439694
70 0.9984219
93 0.44089243
76 0.5372803
67 0.41863278
74 0.4801778
6