In [None]:
from google.colab import drive
drive.mount('/content/drive')
%cd "/content/drive/My Drive/Xray/"

Mounted at /content/drive
/content/drive/.shortcut-targets-by-id/1jTYiDXauUNccfsOBLyb8M2KSfb1N8T8A/Xray


In [None]:
import numpy as np
import time

import PIL.Image as Image
import matplotlib.pylab as plt

import tensorflow as tf
import tensorflow_hub as hub

from sklearn.metrics import confusion_matrix



def plot_confusion_matrix(cm,
                          target_names,
                          title='Confusion matrix',
                          cmap=None,
                          normalize=True):
    """
    given a sklearn confusion matrix (cm), make a nice plot

    Arguments
    ---------
    cm:           confusion matrix from sklearn.metrics.confusion_matrix

    target_names: given classification classes such as [0, 1, 2]
                  the class names, for example: ['high', 'medium', 'low']

    title:        the text to display at the top of the matrix

    cmap:         the gradient of the values displayed from matplotlib.pyplot.cm
                  see http://matplotlib.org/examples/color/colormaps_reference.html
                  plt.get_cmap('jet') or plt.cm.Blues

    normalize:    If False, plot the raw numbers
                  If True, plot the proportions

    Usage
    -----
    plot_confusion_matrix(cm           = cm,                  # confusion matrix created by
                                                              # sklearn.metrics.confusion_matrix
                          normalize    = True,                # show proportions
                          target_names = y_labels_vals,       # list of names of the classes
                          title        = best_estimator_name) # title of graph

    Citiation
    ---------
    http://scikit-learn.org/stable/auto_examples/model_selection/plot_confusion_matrix.html

    """
    import itertools

    accuracy = np.trace(cm) / float(np.sum(cm))
    misclass = 1 - accuracy

    if cmap is None:
        cmap = plt.get_cmap('Blues')

    plt.figure(figsize=(8, 6))
    plt.imshow(cm, interpolation='nearest', cmap=cmap)
    plt.title(title)
    plt.colorbar()

    if target_names is not None:
        tick_marks = np.arange(len(target_names))
        plt.xticks(tick_marks, target_names, rotation=45)
        plt.yticks(tick_marks, target_names)

    if normalize:
        cm = cm.astype('float') / cm.sum(axis=1)[:, np.newaxis]


    thresh = cm.max() / 1.5 if normalize else cm.max() / 2
    for i, j in itertools.product(range(cm.shape[0]), range(cm.shape[1])):
        if normalize:
            plt.text(j, i, "{:0.4f}".format(cm[i, j]),
                     horizontalalignment="center",
                     color="white" if cm[i, j] > thresh else "black")
        else:
            plt.text(j, i, "{:,}".format(cm[i, j]),
                     horizontalalignment="center",
                     color="white" if cm[i, j] > thresh else "black")


    plt.tight_layout()
    plt.ylabel('True label')
    plt.xlabel('Predicted label\naccuracy={:0.4f}; misclass={:0.4f}'.format(accuracy, misclass))
    plt.show()

def fix_dims(tens):
  return np.repeat(tens, 3, -1)

labels_path = tf.keras.utils.get_file('ImageNetLabels.txt','https://storage.googleapis.com/download.tensorflow.org/data/ImageNetLabels.txt')
imagenet_labels = np.array(open(labels_path).read().splitlines())

IMAGE_SHAPE = (224, 224)

NUM_CLASSES = 3

"""
image_generator = tf.keras.preprocessing.image.ImageDataGenerator(rescale=1/255)#, preprocessing_function=fix_dims)
image_data = image_generator.flow_from_directory("CT-Data", target_size=IMAGE_SHAPE, color_mode="rgb", batch_size=128, shuffle=True)#grayscale
image_generator2 = tf.keras.preprocessing.image.ImageDataGenerator(rescale=1/255)#, preprocessing_function=fix_dims)
val_data = image_generator2.flow_from_directory("CT-Data-Youngho", target_size=IMAGE_SHAPE, color_mode="rgb", batch_size=128, shuffle=False)

for image_batch, label_batch in image_data:
  print("Image batch shape: ", image_batch.shape)
  print("Label batch shape: ", label_batch.shape)
  break

"""

first_train = True

ckpt_dir1 = "ckpt1/cp.ckpt"

ckpt_dir2 = "ckpt2/cp.ckpt"

ckpt_dir3 = "ckpt3/cp.ckpt"

#feature_extractor_layer = hub.KerasLayer("https://tfhub.dev/google/imagenet/pnasnet_large/classification/4")

img_inputs = tf.keras.Input(shape=(224, 224, 3))

feature_extractor_layer1 = hub.KerasLayer("https://tfhub.dev/google/imagenet/resnet_v2_50/feature_vector/4",
              trainable=False, input_shape=(224, 224, 3))(img_inputs) #, arguments=dict(batch_norm_momentum=0.997)

feature_extractor_layer2 = hub.KerasLayer("https://tfhub.dev/google/imagenet/resnet_v2_50/feature_vector/4",
              trainable=False, input_shape=(224, 224, 3))(img_inputs) #, arguments=dict(batch_norm_momentum=0.997)

feature_extractor_layer3 = hub.KerasLayer("https://tfhub.dev/google/imagenet/resnet_v2_50/feature_vector/4",
              trainable=False, input_shape=(224, 224, 3))(img_inputs) #, arguments=dict(batch_norm_momentum=0.997)

#Make 3 ResNet50 models as KerasLayers

dense1 = tf.keras.layers.Dense(NUM_CLASSES, activation="softmax")(feature_extractor_layer1)

dense2 = tf.keras.layers.Dense(NUM_CLASSES, activation="softmax")(feature_extractor_layer2)

dense3 = tf.keras.layers.Dense(NUM_CLASSES, activation="softmax")(feature_extractor_layer3)

#Wrap each in a dense layer for training (because the models are headless)

#COMPLETELY NEW BITS
#CONCATENATE DENSE LAYERS; THEN ADD RELU CHAIN (which will be trained as the last training step)

concat = tf.keras.layers.Concatenate()([dense1, dense2, dense3])

relu1 = tf.keras.layers.Dense(units=512, activation='relu')(concat)

relu2 = tf.keras.layers.Dense(units=256, activation='relu')(relu1)

relu3 = tf.keras.layers.Dense(units=192, activation='relu')(relu2)

relu4 = tf.keras.layers.Dense(units=128, activation='relu')(relu3)

output = relu2 = tf.keras.layers.Dense(units=3)(relu4)

model = tf.keras.Model(inputs=img_inputs, outputs=output, name="triple-resnet")

tf.keras.utils.plot_model(model, "triple-resnet-structure.png", show_shapes=True)

"""
model1 = tf.keras.Sequential([
  feature_extractor_layer1,
  tf.keras.layers.Dense(image_data.num_classes, activation="softmax")
])

model2 = tf.keras.Sequential([
  feature_extractor_layer2,
  tf.keras.layers.Dense(image_data.num_classes, activation="softmax")
])

model3 = tf.keras.Sequential([
  feature_extractor_layer3,
  tf.keras.layers.Dense(image_data.num_classes, activation="softmax")
])

print("Model 1 (Pneumonia):")
model1.summary()

print("Model 2 (COVID):")
model2.summary()

print("Model 3 (Regular):")
model3.summary()
"""

"""
model1.compile(
  optimizer=tf.keras.optimizers.Adam(),
  loss=tf.keras.losses.CategoricalCrossentropy(from_logits=True),
  metrics=['acc'])

model2.compile(
  optimizer=tf.keras.optimizers.Adam(),
  loss=tf.keras.losses.CategoricalCrossentropy(from_logits=True),
  metrics=['acc'])

model3.compile(
  optimizer=tf.keras.optimizers.Adam(),
  loss=tf.keras.losses.CategoricalCrossentropy(from_logits=True),
  metrics=['acc'])

if first_train:

  class CollectBatchStats(tf.keras.callbacks.Callback):
    def __init__(self):
      self.batch_losses = []
      self.batch_acc = []

    def on_train_batch_end(self, batch, logs=None):
      self.batch_losses.append(logs['loss'])
      self.batch_acc.append(logs['acc'])
      self.model.reset_metrics()

  cp_callback1 = tf.keras.callbacks.ModelCheckpoint(filepath=ckpt_dir1,
                                                  save_weights_only=True,
                                                  verbose=1)

  cp_callback1 = tf.keras.callbacks.ModelCheckpoint(filepath=ckpt_dir2,
                                                  save_weights_only=True,
                                                  verbose=1)

  cp_callback1 = tf.keras.callbacks.ModelCheckpoint(filepath=ckpt_dir3,
                                                  save_weights_only=True,
                                                  verbose=1)

  #Train model 1 (Pneumonia)    

  steps_per_epoch = np.ceil(image_data.samples/image_data.batch_size)

  batch_stats_callback = CollectBatchStats()

  model1.save_weights(ckpt_dir1.format(epoch=0))

  history = model1.fit(image_data, epochs=10,
                      steps_per_epoch=steps_per_epoch,
                      callbacks=[batch_stats_callback, cp_callback1])

  plt.figure()
  plt.ylabel("Loss")
  plt.xlabel("Training Steps")
  plt.ylim([0,2])
  plt.plot(batch_stats_callback.batch_losses)


  plt.figure()
  plt.ylabel("Accuracy")
  plt.xlabel("Training Steps")
  plt.ylim([0,1])
  plt.plot(batch_stats_callback.batch_acc)

  #Train model 2 (COVID)    

  steps_per_epoch = np.ceil(image_data.samples/image_data.batch_size)

  batch_stats_callback = CollectBatchStats()

  model2.save_weights(ckpt_dir2.format(epoch=0))

  history = model2.fit(image_data, epochs=10,
                      steps_per_epoch=steps_per_epoch,
                      callbacks=[batch_stats_callback, cp_callback2])

  plt.figure()
  plt.ylabel("Loss")
  plt.xlabel("Training Steps")
  plt.ylim([0,2])
  plt.plot(batch_stats_callback.batch_losses)


  plt.figure()
  plt.ylabel("Accuracy")
  plt.xlabel("Training Steps")
  plt.ylim([0,1])
  plt.plot(batch_stats_callback.batch_acc)


  #Train model 3 (Regular)    

  steps_per_epoch = np.ceil(image_data.samples/image_data.batch_size)

  batch_stats_callback = CollectBatchStats()

  model3.save_weights(ckpt_dir3.format(epoch=0))

  history = model3.fit(image_data, epochs=10,
                      steps_per_epoch=steps_per_epoch,
                      callbacks=[batch_stats_callback, cp_callback3])

  plt.figure()
  plt.ylabel("Loss")
  plt.xlabel("Training Steps")
  plt.ylim([0,2])
  plt.plot(batch_stats_callback.batch_losses)


  plt.figure()
  plt.ylabel("Accuracy")
  plt.xlabel("Training Steps")
  plt.ylim([0,1])
  plt.plot(batch_stats_callback.batch_acc)

else:
  model1.load_weights(tf.train.latest_checkpoint(ckpt_dir1))
  model2.load_weights(tf.train.latest_checkpoint(ckpt_dir2))
  model3.load_weights(tf.train.latest_checkpoint(ckpt_dir3))
"""





"""
label_id = np.argmax(label_batch, axis=-1)

class_names = sorted(image_data.class_indices.items(), key=lambda pair:pair[1])
class_names = np.array([key.title() for key, value in class_names])
predicted_batch = model.predict(val_data)
predicted_id = np.argmax(predicted_batch, axis=-1)
predicted_label_batch = class_names[predicted_id]
"""




"""
plt.figure(figsize=(10,9))
plt.subplots_adjust(hspace=0.5)
for n in range(30):
  plt.subplot(6,5,n+1)
  plt.imshow(image_batch[n])
  color = "green" if predicted_id[n] == label_id[n] else "red"
  plt.title(predicted_label_batch[n].title(), color=color)
  plt.axis('off')
_ = plt.suptitle("Model predictions (green: correct, red: incorrect)")
"""


#label_id = val_data.classes
#print("Predictions:", predicted_id)
#print("Ground truth:", label_id)

#cm = confusion_matrix(predicted_id, label_id)

#plot_confusion_matrix(cm, class_names)

'\nplt.figure(figsize=(10,9))\nplt.subplots_adjust(hspace=0.5)\nfor n in range(30):\n  plt.subplot(6,5,n+1)\n  plt.imshow(image_batch[n])\n  color = "green" if predicted_id[n] == label_id[n] else "red"\n  plt.title(predicted_label_batch[n].title(), color=color)\n  plt.axis(\'off\')\n_ = plt.suptitle("Model predictions (green: correct, red: incorrect)")\n'