In [1]:
import numpy as np
import h5py

In [3]:
def load_OASIS():
    with h5py.File('..\..\Datasets\OASIS_balanced.h5', 'r') as hdf:

        G1 = hdf.get('Train Data')
        trainX = np.array(G1.get('trainX'))
        trainY = np.array(G1.get('trainY'))
        G2 = hdf.get('Test Data')
        testX = np.array(G2.get('testX'))
        testY = np.array(G2.get('testY'))

        return trainX, trainY, testX, testY

In [4]:
def load_ADNI():
    with h5py.File('..\..\Datasets\ADNI_enhanced.h5', 'r') as hdf:

        G1 = hdf.get('Train Data')
        trainX = np.array(G1.get('x_train'))
        trainY = np.array(G1.get('y_train'))
        G2 = hdf.get('Test Data')
        testX = np.array(G2.get('x_test'))
        testY = np.array(G2.get('y_test'))

        return trainX, trainY, testX, testY

In [5]:
dataset = 'ADNdI'
#dataset = 'ADsNI'
# read the data which is also normalized.
if dataset == 'ADNI':
    x_train, y_train, x_test, y_test = load_ADNI()
else:
    x_train, y_train, x_test, y_test = load_OASIS()

print(x_train.shape, y_train.shape, x_test.shape,y_test.shape)

(8192, 176, 176) (8192,) (2560, 176, 176) (2560,)


In [6]:
#Dataset ready for deep learning
from tensorflow.keras.utils import to_categorical
#x_train = np.expand_dims(x_train, axis=-1)
x_train = np.repeat(x_train, 3, axis=3)

#x_test = np.expand_dims(x_test, axis=-1)
x_test = np.repeat(x_test, 3, axis=3)

y_train_cat = to_categorical(y_train, num_classes=4)
y_test_cat = to_categorical(y_test, num_classes=4)

print("Test",x_test.shape)
print ("Train",x_train.shape)
print(y_train.shape)
print (y_train_cat.shape)
print (y_test_cat.shape)
#print(np.max(x_train[0])) #to ensure that the values are within the range [0,1]


AxisError: axis 3 is out of bounds for array of dimension 3

In [None]:
print(x_test.shape)
print()
epochs = 10 #hyperparameter
num_classes=4


In [None]:
from tensorflow.keras.models import Sequential
from tensorflow.keras.layers import Flatten, Dense, Conv2D, Dropout, MaxPooling2D, Activation, BatchNormalization
import tensorflow as tf
resnet_model = Sequential()


pretrained_model= tf.keras.applications.InceptionResNetV2(include_top=False,
                   #input_shape=(176,176,3),
                   input_shape=(218,182,3),
                   pooling='avg',classes=4,
                   weights='imagenet')

for layer in pretrained_model.layers:
        layer.trainable=False

resnet_model.add(pretrained_model)
resnet_model.add(Dropout(0.5))
resnet_model.add(Flatten())
resnet_model.add(BatchNormalization())
resnet_model.add(Dense(2048,kernel_initializer='he_uniform'))
resnet_model.add(BatchNormalization())
resnet_model.add(Activation('relu'))
resnet_model.add(Dropout(0.5))
resnet_model.add(Dense(1024,kernel_initializer='he_uniform'))
resnet_model.add(BatchNormalization())
resnet_model.add(Activation('relu'))
resnet_model.add(Dropout(0.5))
resnet_model.add(Dense(4,activation='softmax'))
resnet_model.summary()


In [None]:
METRIC = tf.keras.metrics.AUC(name = 'acc')
from tensorflow.keras.optimizers import Adam # optimizer hyperparameter
resnet_model.compile(optimizer=Adam(learning_rate=0.001),loss='categorical_crossentropy',metrics=METRIC) #learning rate hyperparameter.


In [None]:
filepath = 'best_weights.hdf5'
from tensorflow.keras.callbacks import EarlyStopping, ModelCheckpoint

earlystopping = EarlyStopping(monitor = 'acc', 
                              mode = 'max' , 
                              patience = 15,
                              verbose = 1)

checkpoint    = ModelCheckpoint(filepath, 
                                monitor = 'acc', 
                                mode='max', 
                                save_best_only=True, 
                                verbose = 1)
callback_list = [earlystopping, checkpoint]



In [None]:
#training
#batch_size hyperparamter.
history=resnet_model.fit(x_train, y_train_cat, batch_size = 128, epochs = 200, verbose = 1, validation_data = (x_test, y_test_cat), callbacks=callback_list)


In [None]:
import pandas as pd
df = pd.DataFrame(history.history)
df.head()
# losses[['acc','val_acc']].plot()
import matplotlib.pyplot as plt
fig, ax = plt.subplots(2,1,sharex=True, figsize=(10,8))
ax[0].plot(df['loss'],label = 'Training Loss')
ax[0].plot(df['val_loss'],label = 'Validation Loss')
ax[1].plot(df['acc'],label = 'Training Accuracy')
ax[1].plot(df['val_acc'],label = 'Validation Accuracy')
ax[0].legend(loc='best',prop={'size':12})
ax[1].legend(loc='best',prop={'size':12})
ax[0].minorticks_on()
ax[1].minorticks_on()
ax[1].set_xlabel('Epochs', fontsize=12)
ax[1].set_ylabel('Accuracy', fontsize=12)
ax[0].set_ylabel('Loss', fontsize=12)
ax[0].grid()
ax[1].grid()
plt.show()

#saving the results
df.to_csv('results.csv')




# print(history.history['val_loss'])
# print(history.history['val_acc'])
# print(history.history['acc'])
# print(history.history['loss'])

In [None]:
from sklearn.metrics import multilabel_confusion_matrix
#%%
prediction = resnet_model.predict_classes(x_test)

mcm = multilabel_confusion_matrix(y_true=y_test, y_pred=prediction, labels=[0,1,2,3], samplewise=False)
tn = mcm[:,0,0]
tp = mcm[:,1,1]
fp = mcm[:,0,1]
fn = mcm[:,1,0]
specificity = tn/(tn+fp)
print("Specificity or TNR", np.mean(specificity))

sensitivity = tp/(tp+fn)
print("Sensitivity or TPR or Recall", np.mean(sensitivity))

print("FNR ", 1-np.mean(sensitivity))

print("FPR", 1-np.mean(specificity))

from sklearn import metrics
print("Accuracy = ", metrics.accuracy_score(y_test, prediction))
#%%
# use model to predict probability that given y value is 1
y_pred_proba = resnet_model.predict_proba(x_test)

#calculate AUC of model
auc = metrics.roc_auc_score(y_test, y_pred_proba, multi_class='ovr')

#print AUC score
print("AUC =",auc)

from sklearn.metrics import confusion_matrix
import seaborn as sns
cf_matrix = confusion_matrix(y_test, prediction)
import matplotlib.pyplot as plt
fig = plt.figure(figsize=(7, 6))
ax= plt.subplot()
sns.heatmap(cf_matrix, annot=True, cmap='Blues', cbar=False, linewidth=0.5,linecolor="black",fmt='')
ax.set_xlabel('Predicted Label', fontsize=12)
ax.set_ylabel('True Label', fontsize=12)