In [1]:
import numpy as np
import datetime
import os
import pickle
from keras.preprocessing.image import ImageDataGenerator
from keras.preprocessing import image
from keras.layers import Dropout, Flatten, Dense
from keras.applications import ResNet50
from keras.models import Model, Sequential
from keras.layers import Dense, GlobalAveragePooling2D
from keras import backend as K
from keras.applications.resnet50 import preprocess_input
from keras import optimizers

import matplotlib as mpl

#  ######################################################
#  #### Matplotlib X display error - removing for server#
#  ######################################################
mpl.use('Agg')  # This has to run before pyplot import

import matplotlib.pyplot as plt

import sys
import pandas as pd
import numpy as np
import itertools


from sklearn.metrics import confusion_matrix

Using TensorFlow backend.


In [10]:
###################################################################
#  Getting main data directory
###############################################################

main_data_dir = sys.argv[1]  # Main data directory to be handled
model_name = sys.argv[2] # model name to be saves

my_file_name = model_name #"8_1_pytorch_resnet18_v1"  # model name to be saved

###########################################

#  Set parameters here
data_dir = "../../../data/data_generated_medicotask_v1" #main_data_dir
model_dir = data_dir + '/keras_models'
plot_dir  = data_dir + '/keras_plots'



model_name = "backups/8_1_keras_resnet50_v2" # take my file name as the model name

cm_plot_name = 'cm_'+model_name


batch_size = 16

trgt_sz = 224

In [11]:
#########################################
#  Managing Directory
#########################################
if not os.path.exists(plot_dir):
    os.mkdir(plot_dir)


In [12]:
test_data_dir = f'{data_dir}/validation'

In [13]:
test_datagen = ImageDataGenerator(preprocessing_function=preprocess_input)

In [14]:
test_generator = test_datagen.flow_from_directory(test_data_dir,
    shuffle=False,
    target_size=(trgt_sz, trgt_sz),
    batch_size=batch_size, class_mode='categorical')

Found 1067 images belonging to 16 classes.


In [15]:
base_model = ResNet50(weights='imagenet', include_top=False)
x = base_model.output
x = GlobalAveragePooling2D()(x)
x = Dense(1024, activation='relu')(x)
predictions = Dense(16, activation='softmax')(x)

In [16]:
model = Model(inputs=base_model.input, outputs=predictions)

In [17]:
model.load_weights(os.path.join(model_dir, model_name))

In [18]:
probabilities = model.predict_generator(test_generator)

In [19]:
predicted = np.argmax(probabilities.data,1)

In [24]:
test_generator.class_indices

{'blurry-nothing': 0,
 'colon-clear': 1,
 'dyed-lifted-polyps': 2,
 'dyed-resection-margins': 3,
 'esophagitis': 4,
 'instruments': 5,
 'normal-cecum': 6,
 'normal-pylorus': 7,
 'normal-z-line': 8,
 'out-of-patient': 9,
 'polyps': 10,
 'retroflex-rectum': 11,
 'retroflex-stomach': 12,
 'stool-inclusions': 13,
 'stool-plenty': 14,
 'ulcerative-colitis': 15}

In [23]:
predicted

array([ 0,  0,  0, ...,  6,  6, 10])

In [20]:
cm = confusion_matrix(test_generator.classes, predicted)

In [21]:
cm

array([[ 35,   0,   0,   0,   0,   0,   0,   1,   0,   0,   0,   0,   0,
          0,   0,   0],
       [  0,  25,   0,   0,   0,   0,   7,   0,   0,   0,  22,   0,   0,
          0,   0,   0],
       [  0,   0,  48,   1,   0,   0,   8,   0,   0,   0,  33,   2,   0,
          0,   0,   0],
       [  0,   0,  57,  15,   0,   0,   4,   0,   0,   0,   5,   3,   0,
          0,   0,   0],
       [  0,   0,   0,   0,  45,   0,   0,  29,  15,   0,   0,   0,   0,
          0,   0,   0],
       [  0,   0,   0,   0,   0,   0,   0,   1,   0,   0,   7,   0,   0,
          0,   0,   0],
       [  0,   0,   0,   0,   0,   0,  72,   0,   0,   0,  12,   0,   0,
          0,   0,   0],
       [  0,   0,   0,   0,   0,   0,   0,  88,   0,   0,   0,   0,   0,
          0,   0,   0],
       [  0,   0,   0,   0,   7,   0,   0,  50,  31,   0,   0,   0,   0,
          0,   0,   0],
       [  0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   1,
          0,   0,   0],
       [  0,   0,   0,   0,   

In [14]:
class_names = test_generator.class_indices

In [15]:
class_names = list(class_names.keys())

In [16]:
class_names = np.asarray(class_names)

In [17]:
def plot_confusion_matrix(cm, classes,
                          normalize=False,
                          title='Confusion matrix',
                          cmap=plt.cm.Blues,
                          plt_size=[10,10]):
    """
    This function prints and plots the confusion matrix.
    Normalization can be applied by setting `normalize=True`.
    """
    plt.rcParams['figure.figsize'] = plt_size
    if normalize:
        cm = cm.astype('float') / cm.sum(axis=1)[:, np.newaxis]
        print("Normalized confusion matrix")
    else:
        print('Confusion matrix, without normalization')

    print(cm)

    plt.imshow(cm, interpolation='nearest', cmap=cmap)
    plt.title(title)
    plt.colorbar()
    tick_marks = np.arange(len(classes))
    plt.xticks(tick_marks, classes, rotation=90)
    plt.yticks(tick_marks, classes)

    fmt = '.2f' if normalize else 'd'
    thresh = cm.max() / 2.
    for i, j in itertools.product(range(cm.shape[0]), range(cm.shape[1])):
        plt.text(j, i, format(cm[i, j], fmt),
                 horizontalalignment="center",
                 color="white" if cm[i, j] > thresh else "black")

    plt.tight_layout()
    plt.ylabel('True label')
    plt.xlabel('Predicted label')
    plt.savefig(os.path.join(plot_dir, cm_plot_name))

In [19]:
plot_confusion_matrix(cm, classes=class_names, title='my confusion matrix')

Confusion matrix, without normalization
[[ 35   0   0   0   0   0   0   1   0   0   0   0   0   0   0   0]
 [  0  25   0   0   0   0   7   0   0   0  22   0   0   0   0   0]
 [  0   0  48   1   0   0   8   0   0   0  33   2   0   0   0   0]
 [  0   0  57  15   0   0   4   0   0   0   5   3   0   0   0   0]
 [  0   0   0   0  45   0   0  29  15   0   0   0   0   0   0   0]
 [  0   0   0   0   0   0   0   1   0   0   7   0   0   0   0   0]
 [  0   0   0   0   0   0  72   0   0   0  12   0   0   0   0   0]
 [  0   0   0   0   0   0   0  88   0   0   0   0   0   0   0   0]
 [  0   0   0   0   7   0   0  50  31   0   0   0   0   0   0   0]
 [  0   0   0   0   0   0   0   0   0   0   0   0   1   0   0   0]
 [  0   0   0   0   0   0   8   2   0   0 113   0   0   0   0   0]
 [  0   0   0   0   0   0   0   0   0   0  19  29   0   0   0   0]
 [  0   0   0   0   0   0   0  26   0   0   4   1  49   0   0   0]
 [  0   1   0   0   0   0   8   0   0   0  16   0   0   1   0   0]
 [  0   0   0   0   0 