**Importing libraries and dependencies**

In [None]:
import cv2
from glob import glob
import tensorflow as tf
import tfimm
import timm
from builtins import range, input
import seaborn as sns
import numpy as np
import matplotlib.pyplot as plt
from sklearn.preprocessing import LabelBinarizer
from tensorflow.keras.utils import to_categorical
from tensorflow.keras.optimizers import Adam
from tensorflow.keras.layers import Input, Dense, Flatten, Concatenate
from tensorflow.keras.models import Model, load_model
from sklearn.metrics import confusion_matrix, roc_curve, auc, ConfusionMatrixDisplay

In [12]:
## WM - Coronal Slices
dement_path_y = "path to Alzheimer's disease samples" 
nondement_path_y = "path to cognitively normal samples"

inp_y=Input(shape=(224, 224, 3))

## Hyperparameters
epochs = 20
batch_size = 32

NonDemfiles_y = glob( nondement_path_y +'/*' )
Demfiles_y = glob( dement_path_y + '/*' )

In [None]:
print("First 5 CN Files: ",NonDemfiles_y[0:5])
print("Total Count: ",len(NonDemfiles_y))
print("First 5 AD Files: ",Demfiles_y[0:5])
print("Total Count: ",len(Demfiles_y))

**Image data Preprocessing**

In [15]:
Dem_labels_y = []
NonDem_labels_y = []

Dem_images_y=[]
NonDem_images_y=[]

for i in range(len(Demfiles_y)):
  image = cv2.imread(Demfiles_y[i]) 
  image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) 
  image = cv2.resize(image,(224,224)) 
  Dem_images_y.append(image) 
  Dem_labels_y.append('AD') 
for i in range(len(NonDemfiles_y)):
  image = cv2.imread(NonDemfiles_y[i])
  image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
  image = cv2.resize(image,(224,224))
  NonDem_images_y.append(image)
  NonDem_labels_y.append('CN')  

In [None]:
def plot_images(images, title):
    nrows, ncols = 1,2
    figsize = [4, 4]

    fig, ax = plt.subplots(nrows=nrows, ncols=ncols, figsize=figsize, facecolor=(1, 1, 1))

    for i, axi in enumerate(ax.flat):
        axi.imshow(images[i])
        axi.set_axis_off()

    plt.suptitle(title, fontsize=24)
    plt.tight_layout(pad=0.2, rect=[0, 0, 1, 0.9])
    plt.show()

plot_images(Dem_images_y, 'AD Scans y-axis')
plot_images(NonDem_images_y, 'CN Scans y-axis')

In [639]:
Dem_images_y = np.array(Dem_images_y) / 255
NonDem_images_y = np.array(NonDem_images_y) / 255

**Splitting Dataset into training set, validation set, test set at the subject level**

In [640]:
#seed = 42  #Fold i ; i = 1,2,3,4,5 ; seed i = 42, 40, 38, 44, 46

#Fold i
Dem_x_train_y_axis = Dem_images_y [0:1499]
Dem_y_train_y_axis = Dem_labels_y [0:1499]

Dem_x_val_y_axis = Dem_images_y [1499:1650]
Dem_y_val_y_axis = Dem_labels_y [1499:1650]

Dem_x_test_y_axis = Dem_images_y [1650:2100]
Dem_y_test_y_axis = Dem_labels_y [1650:2100]

NonDem_x_train_y_axis = NonDem_images_y [0:1769]
NonDem_y_train_y_axis = NonDem_labels_y [0:1769]

NonDem_x_val_y_axis = NonDem_images_y [1769:2070]
NonDem_y_val_y_axis = NonDem_labels_y [1769:2070]

NonDem_x_test_y_axis = NonDem_images_y [2070:2550]
NonDem_y_test_y_axis = NonDem_labels_y [2070:2550]

X_train_y_axis = np.concatenate((NonDem_x_train_y_axis, Dem_x_train_y_axis), axis=0)
X_val_y_axis = np.concatenate((NonDem_x_val_y_axis, Dem_x_val_y_axis), axis=0)
X_test_y_axis = np.concatenate((NonDem_x_test_y_axis, Dem_x_test_y_axis), axis=0)
y_train_y_axis = np.concatenate((NonDem_y_train_y_axis, Dem_y_train_y_axis), axis=0)
y_val_y_axis = np.concatenate((NonDem_y_val_y_axis, Dem_y_val_y_axis), axis=0)
y_test_y_axis = np.concatenate((NonDem_y_test_y_axis, Dem_y_test_y_axis), axis=0)

y_train_y_axis = LabelBinarizer().fit_transform(y_train_y_axis)
y_train_y_axis = to_categorical(y_train_y_axis)

y_val_y_axis = LabelBinarizer().fit_transform(y_val_y_axis)
y_val_y_axis = to_categorical(y_val_y_axis)

y_test_y_axis = LabelBinarizer().fit_transform(y_test_y_axis)
y_test_y_axis = to_categorical(y_test_y_axis)

**CNN (ResNet)**

In [None]:
base_model_cnn = tfimm.create_model("resnet18", pretrained="timm", nb_classes=0)

for layer_cnn in base_model_cnn.layers:
    layer_cnn.trainable = False
out_cnn = base_model_cnn(inp_y)
out_cnn = Flatten()(out_cnn)

output = Dense(2, activation='softmax')(out_cnn)
model_cnn = Model(inputs=[inp_y], outputs=output)
model_cnn.summary()

**ViT**

In [None]:
base_model_vit = tfimm.create_model("vit_tiny_patch16_224", pretrained="timm", nb_classes=0)

for layer_vit in base_model_vit.layers:
    layer_vit.trainable = False
out_vit = base_model_vit(inp_y)
out_vit = Flatten()(out_vit)

output = Dense(2, activation='softmax')(out_vit)
model_vit = Model(inputs=[inp_y], outputs=output)
model_vit.summary()

**Pyramid ViT**

In [None]:
base_model_pvt = tfimm.create_model("pvt_tiny", pretrained="timm", nb_classes=0)

for layer_pvt in base_model_pvt.layers:
    layer_pvt.trainable = False
out_pvt = base_model_pvt(inp_y)
out_pvt = Flatten()(out_pvt)

output = Dense(2, activation='softmax')(out_pvt)
model_pvt = Model(inputs=[inp_y], outputs=output)
model_pvt.summary()

**Parallel CNN (ResNet) & ViT**

In [None]:
## CNN (ResNet-18)
base_model_cnn = tfimm.create_model("resnet18", pretrained="timm", nb_classes=0)

for layer_cnn in base_model_cnn.layers:
    layer_cnn.trainable = False
out_cnn = base_model_cnn(inp_y)
out_cnn = Flatten()(out_cnn)

## ViT-Tiny
base_model_vit = tfimm.create_model("vit_tiny_patch16_224", pretrained="timm", nb_classes=0)

for layer_vit in base_model_vit.layers:
    layer_vit.trainable = False
out_vit = base_model_vit(inp_y)
out_vit = Flatten()(out_vit)

merged = Concatenate()([out_cnn, out_vit])
output = Dense(2, activation='softmax')(merged)
model_concat_cnn_vit = Model(inputs=[inp_y], outputs=output)
model_concat_cnn_vit.summary()

**Training ResNet, ViT, PVT, & ResNet+ViT**

In [None]:
opt_model_cnn = Adam(learning_rate=0.001)
opt_model_vit = Adam(learning_rate=0.001)
opt_model_pvt = Adam(learning_rate=0.001)
opt_model_concat_cnn_vit = Adam(learning_rate=0.001)

model_cnn.compile(optimizer=opt_model_cnn, loss='categorical_crossentropy', metrics=['accuracy'])
model_vit.compile(optimizer=opt_model_vit, loss='categorical_crossentropy', metrics=['accuracy'])
model_pvt.compile(optimizer=opt_model_pvt, loss='categorical_crossentropy', metrics=['accuracy'])
model_concat_cnn_vit.compile(optimizer=opt_model_concat_cnn_vit, loss='categorical_crossentropy', metrics=['accuracy'])

print("CNN Training:")
history_model_cnn = model_cnn.fit(x=X_train_y_axis, y=y_train_y_axis, epochs=epochs, batch_size=batch_size, shuffle=True, 
                       validation_data=(X_val_y_axis, y_val_y_axis))

print("Standard ViT Training:")
history_model_vit = model_vit.fit(x=X_train_y_axis, y=y_train_y_axis, epochs=epochs, batch_size=batch_size, shuffle=True, 
                       validation_data=(X_val_y_axis, y_val_y_axis))

print("Pyramid ViT Training:")
history_model_pvt = model_pvt.fit(x=X_train_y_axis, y=y_train_y_axis, epochs=epochs, batch_size=batch_size, shuffle=True, 
                       validation_data=(X_val_y_axis, y_val_y_axis))

print("Concatenation of CNN & Standard ViT Training:")
history_model_concat_cnn_vit = model_concat_cnn_vit.fit(x=X_train_y_axis, y=y_train_y_axis, epochs=epochs, batch_size=batch_size, shuffle=True, 
                       validation_data=(X_val_y_axis, y_val_y_axis))

In [611]:
#model_cnn.save('resnet18.h5') 
#model_vit.save('vit_tiny.h5') 
#model_pvt.save('pvt_tiny.h5') 
#model_concat_cnn_vit.save('concat_resnet18_vit_tiny.h5')

#model = load_model('resnet18.h5')
#model = load_model('vit_tiny.h5')
#model = load_model('pvt_tiny.h5')
#model = load_model('concat_resnet18_vit_tiny.h5')

**Results for CNN (ResNet)**

In [None]:
y_pred_cnn = model_cnn.predict(X_test_y_axis, batch_size=batch_size)
y_pred_bin,y_test_bin_y_axis=None,None
y_pred_bin_cnn = np.argmax(y_pred_cnn, axis=1)
y_test_bin_y_axis = np.argmax(y_test_y_axis, axis=1)
fpr, tpr, threshold = roc_curve(y_test_bin_y_axis, y_pred_bin_cnn)
roc_auc = auc(fpr, tpr)

tp,fn,fp,tn = confusion_matrix(y_test_bin_y_axis,y_pred_bin_cnn).flatten()
sensitivity = tp/(tp+fn)
specificity = tn/(tn+fp)
precision = tp/(tp+fp)
print('Results for Across Y - (ResNet-18):')
print('tp:',tp)
print('tn:',tn)
print('fp:',fp)
print('fn:',fn)
print('Accuracy: ', (tp+tn)/len(y_test_bin_y_axis))
print('Sensitivity: ', sensitivity)
print('Specificity: ', specificity)
print('Precision: ', precision)
print('F1 Score:', 2*(precision*sensitivity)/(precision+sensitivity))
print('AUC:', roc_auc)


**Plots for CNN (ResNet)**

In [None]:
### Training, validation, and testing plot
plt.plot(history_model_cnn.history["accuracy"], label="training accuracy")
plt.plot(history_model_cnn.history["val_accuracy"], label="validation accuracy")
plt.legend()
plt.show()

plt.plot(history_model_cnn.history["loss"], label="training loss")
plt.plot(history_model_cnn.history["val_loss"], label="validation loss")
plt.legend()
plt.show()

### ROC Curve
plt.plot(fpr, tpr, lw=2, color='b', label = 'ResNet-18 (AUC= %0.2f)' % roc_auc, alpha=.8)
plt.plot([0, 1], [0, 1], linestyle='--', lw=2, color='r', label='Chance', alpha=.8)
plt.xlim([-0.01, 1.01])
plt.ylim([-0.01, 1.01])
plt.xlabel('False Positive Rate', fontsize=16)
plt.ylabel('True Positive Rate', fontsize=16)
plt.title('ROC Curve', fontsize=16)
plt.legend(loc="lower right", prop={'size': 12})
plt.show()

### Confusion Matix Plot
def plot_confusion_matrix():
  classes = ['AD','CN']
#  tick_marks = [0.5,1.5]
  cm = confusion_matrix(y_test_bin_y_axis, y_pred_bin_cnn)
  disp = ConfusionMatrixDisplay(confusion_matrix=cm,display_labels=classes)
  disp.plot()
  plt.show()

plot_confusion_matrix()

**Results for ViT**

In [None]:
y_pred_vit = model_vit.predict(X_test_y_axis, batch_size=batch_size)
y_pred_bin,y_test_bin_y_axis=None,None
y_pred_bin_vit = np.argmax(y_pred_vit, axis=1)
y_test_bin_y_axis = np.argmax(y_test_y_axis, axis=1)
fpr, tpr, threshold = roc_curve(y_test_bin_y_axis, y_pred_bin_vit)
roc_auc = auc(fpr, tpr)

tp,fn,fp,tn = confusion_matrix(y_test_bin_y_axis,y_pred_bin_vit).flatten()
sensitivity = tp/(tp+fn)
specificity = tn/(tn+fp)
precision = tp/(tp+fp)
print('Results for Across Y - (Standard ViT):')
print('tp:',tp)
print('tn:',tn)
print('fp:',fp)
print('fn:',fn)
print('Accuracy: ', (tp+tn)/len(y_test_bin_y_axis))
print('Sensitivity: ', sensitivity)
print('Specificity: ', specificity)
print('Precision: ', precision)
print('F1 Score:', 2*(precision*sensitivity)/(precision+sensitivity))
print('AUC:', roc_auc)


**Plots for ViT**

In [None]:
### Training, validation, and testing plot
plt.plot(history_model_vit.history["accuracy"], label="training accuracy")
plt.plot(history_model_vit.history["val_accuracy"], label="validation accuracy")
plt.legend()
plt.show()

plt.plot(history_model_vit.history["loss"], label="training loss")
plt.plot(history_model_vit.history["val_loss"], label="validation loss")
plt.legend()
plt.show()

### ROC Curve
plt.plot(fpr, tpr, lw=2, color='b', label = 'ViT-Tiny (AUC= %0.2f)' % roc_auc, alpha=.8)
plt.plot([0, 1], [0, 1], linestyle='--', lw=2, color='r', label='Chance', alpha=.8)
plt.xlim([-0.01, 1.01])
plt.ylim([-0.01, 1.01])
plt.xlabel('False Positive Rate', fontsize=16)
plt.ylabel('True Positive Rate', fontsize=16)
plt.title('ROC Curve', fontsize=16)
plt.legend(loc="lower right", prop={'size': 12})
plt.show()

### Confusion Matix Plot
def plot_confusion_matrix():
  classes = ['AD','CN']
#  tick_marks = [0.5,1.5]
  cm = confusion_matrix(y_test_bin_y_axis, y_pred_bin_vit)
  disp = ConfusionMatrixDisplay(confusion_matrix=cm,display_labels=classes)
  disp.plot()
  plt.show()

plot_confusion_matrix()

**Results for Pyramid ViT (PVT)**

In [None]:
y_pred_pvt = model_pvt.predict(X_test_y_axis, batch_size=batch_size)
y_pred_bin,y_test_bin_y_axis=None,None
y_pred_bin_pvt = np.argmax(y_pred_pvt, axis=1)
y_test_bin_y_axis = np.argmax(y_test_y_axis, axis=1)
fpr, tpr, threshold = roc_curve(y_test_bin_y_axis, y_pred_bin_pvt)
roc_auc = auc(fpr, tpr)

tp,fn,fp,tn = confusion_matrix(y_test_bin_y_axis,y_pred_bin_pvt).flatten()
sensitivity = tp/(tp+fn)
specificity = tn/(tn+fp)
precision = tp/(tp+fp)
print('Results for Across Y - (Pyramid ViT-Tiny):')
print('tp:',tp)
print('tn:',tn)
print('fp:',fp)
print('fn:',fn)
print('Accuracy: ', (tp+tn)/len(y_test_bin_y_axis))
print('Sensitivity: ', sensitivity)
print('Specificity: ', specificity)
print('Precision: ', precision)
print('F1 Score:', 2*(precision*sensitivity)/(precision+sensitivity))
print('AUC:', roc_auc)


**Plots for Pyramid ViT (PVT)**

In [None]:
### Training, validation, and testing plot
plt.plot(history_model_pvt.history["accuracy"], label="training accuracy")
plt.plot(history_model_pvt.history["val_accuracy"], label="validation accuracy")
plt.legend()
plt.show()

plt.plot(history_model_pvt.history["loss"], label="training loss")
plt.plot(history_model_pvt.history["val_loss"], label="validation loss")
plt.legend()
plt.show()

### ROC Curve
plt.plot(fpr, tpr, lw=2, color='b', label = 'PVT-Tiny (AUC= %0.2f)' % roc_auc, alpha=.8)
plt.plot([0, 1], [0, 1], linestyle='--', lw=2, color='r', label='Chance', alpha=.8)
plt.xlim([-0.01, 1.01])
plt.ylim([-0.01, 1.01])
plt.xlabel('False Positive Rate', fontsize=16)
plt.ylabel('True Positive Rate', fontsize=16)
plt.title('ROC Curve', fontsize=16)
plt.legend(loc="lower right", prop={'size': 12})
plt.show()

### Confusion Matix Plot
def plot_confusion_matrix():
  classes = ['AD','CN']
#  tick_marks = [0.5,1.5]
  cm = confusion_matrix(y_test_bin_y_axis, y_pred_bin_pvt)
  disp = ConfusionMatrixDisplay(confusion_matrix=cm,display_labels=classes)
  disp.plot()
  plt.show()

plot_confusion_matrix()

**Results for Parallel CNN (ResNet) and ViT**

In [None]:
y_pred_concat_cnn_vit = model_concat_cnn_vit.predict(X_test_y_axis, batch_size=batch_size)
y_pred_bin,y_test_bin_y_axis=None,None
y_pred_bin_concat_cnn_vit = np.argmax(y_pred_concat_cnn_vit, axis=1)
y_test_bin_y_axis = np.argmax(y_test_y_axis, axis=1)
fpr, tpr, threshold = roc_curve(y_test_bin_y_axis, y_pred_bin_concat_cnn_vit)
roc_auc = auc(fpr, tpr)

tp,fn,fp,tn = confusion_matrix(y_test_bin_y_axis,y_pred_bin_concat_cnn_vit).flatten()
sensitivity = tp/(tp+fn)
specificity = tn/(tn+fp)
precision = tp/(tp+fp)
print('Results for Across Y - (Concatanation ResNet18 & Standard ViT-Tiny):')
print('tp:',tp)
print('tn:',tn)
print('fp:',fp)
print('fn:',fn)
print('Accuracy: ', (tp+tn)/len(y_test_bin_y_axis))
print('Sensitivity: ', sensitivity)
print('Specificity: ', specificity)
print('Precision: ', precision)
print('F1 Score:', 2*(precision*sensitivity)/(precision+sensitivity))
print('AUC:', roc_auc)


**Plots for Parallel CNN (ResNet) and ViT**

In [None]:
### Training, validation, and testing plot
plt.plot(history_model_concat_cnn_vit.history["accuracy"], label="training accuracy")
plt.plot(history_model_concat_cnn_vit.history["val_accuracy"], label="validation accuracy")
plt.legend()
plt.show()

plt.plot(history_model_concat_cnn_vit.history["loss"], label="training loss")
plt.plot(history_model_concat_cnn_vit.history["val_loss"], label="validation loss")
plt.legend()
plt.show()

### ROC Curve
plt.plot(fpr, tpr, lw=2, color='b', label = 'ResNet-18 & ViT-Tiny (AUC= %0.2f)' % roc_auc, alpha=.8)
plt.plot([0, 1], [0, 1], linestyle='--', lw=2, color='r', label='Chance', alpha=.8)
plt.xlim([-0.01, 1.01])
plt.ylim([-0.01, 1.01])
plt.xlabel('False Positive Rate', fontsize=16)
plt.ylabel('True Positive Rate', fontsize=16)
plt.title('ROC Curve', fontsize=16)
plt.legend(loc="lower right", prop={'size': 12})
plt.show()

### Confusion Matix Plot
def plot_confusion_matrix():
  classes = ['AD','CN']
#  tick_marks = [0.5,1.5]
  cm = confusion_matrix(y_test_bin_y_axis, y_pred_bin_concat_cnn_vit)
  disp = ConfusionMatrixDisplay(confusion_matrix=cm,display_labels=classes)
  disp.plot()
  plt.show()

plot_confusion_matrix()

**ROC CURVE - 4 Architectures (ResNet & ViT & ResNet+ViT & PVT)**

In [None]:
### CNN-ResNet(18)
y_pred_cnn = model_cnn.predict(X_test_y_axis, batch_size=batch_size)
y_pred_bin,y_test_bin_y_axis=None,None
y_pred_bin_cnn = np.argmax(y_pred_cnn, axis=1)
y_test_bin_y_axis = np.argmax(y_test_y_axis, axis=1)
fpr, tpr, threshold = roc_curve(y_test_bin_y_axis, y_pred_bin_cnn)
roc_auc = auc(fpr, tpr)
plt.plot(fpr, tpr, lw=2, color='c', label = 'ResNet-18 (AUC= %0.2f)' % roc_auc, alpha=.8)
plt.xlim([-0.01, 1.01])
plt.ylim([-0.01, 1.01])
plt.xlabel('False Positive Rate', fontsize=16)
plt.ylabel('True Positive Rate', fontsize=16)
plt.title('ROC Curve', fontsize=16)
plt.legend(loc="lower right", prop={'size': 10})

### Standard ViT-Tiny
y_pred_vit = model_vit.predict(X_test_y_axis, batch_size=batch_size)
y_pred_bin,y_test_bin_y_axis=None,None
y_pred_bin_vit = np.argmax(y_pred_vit, axis=1)
y_test_bin_y_axis = np.argmax(y_test_y_axis, axis=1)
fpr, tpr, threshold = roc_curve(y_test_bin_y_axis, y_pred_bin_vit)
roc_auc = auc(fpr, tpr)
plt.plot(fpr, tpr, lw=2, color='k', label = 'ViT-Tiny (AUC= %0.2f)' % roc_auc, alpha=.8)
plt.xlim([-0.01, 1.01])
plt.ylim([-0.01, 1.01])
plt.xlabel('False Positive Rate', fontsize=16)
plt.ylabel('True Positive Rate', fontsize=16)
plt.title('ROC Curve', fontsize=16)
plt.legend(loc="lower right", prop={'size': 10})

### Pyramid ViT-Tiny
y_pred_pvt = model_pvt.predict(X_test_y_axis, batch_size=batch_size)
y_pred_bin,y_test_bin_y_axis=None,None
y_pred_bin_pvt = np.argmax(y_pred_pvt, axis=1)
y_test_bin_y_axis = np.argmax(y_test_y_axis, axis=1)
fpr, tpr, threshold = roc_curve(y_test_bin_y_axis, y_pred_bin_pvt)
roc_auc = auc(fpr, tpr)
plt.plot(fpr, tpr, lw=2, color='y', label = 'PVT-Tiny (AUC= %0.2f)' % roc_auc, alpha=.8)
plt.xlim([-0.01, 1.01])
plt.ylim([-0.01, 1.01])
plt.xlabel('False Positive Rate', fontsize=16)
plt.ylabel('True Positive Rate', fontsize=16)
plt.title('ROC Curve', fontsize=16)
plt.legend(loc="lower right", prop={'size': 10})

### ResNet18 & Standard ViT-Tiny
y_pred_concat_cnn_vit = model_concat_cnn_vit.predict(X_test_y_axis, batch_size=batch_size)
y_pred_bin,y_test_bin_y_axis=None,None
y_pred_bin_concat_cnn_vit = np.argmax(y_pred_concat_cnn_vit, axis=1)
y_test_bin_y_axis = np.argmax(y_test_y_axis, axis=1)
fpr, tpr, threshold = roc_curve(y_test_bin_y_axis, y_pred_bin_concat_cnn_vit)
roc_auc = auc(fpr, tpr)
plt.plot(fpr, tpr, lw=2, color='g', label = 'ResNet-18 & ViT-Tiny (AUC= %0.2f)' % roc_auc, alpha=.8)
plt.plot([0, 1], [0, 1], linestyle='--', lw=2, color='r', label='Chance', alpha=.8)
plt.xlim([-0.01, 1.01])
plt.ylim([-0.01, 1.01])
plt.xlabel('False Positive Rate', fontsize=16)
plt.ylabel('True Positive Rate', fontsize=16)
plt.title('ROC Curve', fontsize=16)
plt.legend(loc="lower right", prop={'size': 10})
plt.show()