<a href="https://colab.research.google.com/github/sutummala/AutismNet/blob/main/Autism_siamese.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
from google.colab import drive
drive.mount('/content/drive')

Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).


In [None]:
import os
import tensorflow as tf
from tensorflow.keras import backend as K
#import tensorflow_addons as tfa
from sklearn.model_selection import RepeatedStratifiedKFold, cross_val_score
import sklearn
import nibabel as nib
import numpy as np
import random
import matplotlib.pyplot as plt
from pytorch_grad_cam import GradCAM

In [None]:
# normalizing the input to have values between zero and one to make them suitable for further analysis
def normalize(input):
  norm_input = []
  print(f'shape of input is {np.shape(input[0])}')
  for i in range(np.shape(input)[0]):
    norm_in = (input[i]-np.min(input[i]))/(np.max(input[i])-np.min(input[i]))
    norm_input.append(norm_in)
  return norm_input

In [None]:
left_input = np.load('/content/drive/My Drive/Autism_CNN/left_input.npy')
left_input = np.squeeze(normalize(left_input))

right_input = np.load('/content/drive/My Drive/Autism_CNN/right_input.npy')
right_input = np.squeeze(normalize(right_input))

targets = np.load('/content/drive/My Drive/Autism_CNN/autism_labels.npy')

targets = 1-targets # 1 for negative pair and 0 for positive pair

print(f'total size of the data is {len(targets)}')

shape of input is (30, 45, 30)
shape of input is (30, 45, 30)
total size of the data is 4280


In [None]:
folds = RepeatedStratifiedKFold(n_splits = 5, n_repeats = 1)

for train_index, test_index in folds.split(left_input, targets):
    left_input_cv, left_input_test, targets_cv, targets_test = left_input[train_index], left_input[test_index], targets[train_index], targets[test_index]
    right_input_cv, right_input_test = right_input[train_index], right_input[test_index]

#cv_index = int(0.9 * len(targets)) 

#left_input_cv = left_input[:cv_index]
#right_input_cv = right_input[:cv_index]
#targets_cv = targets[:cv_index]

print(f'shape of left/right input for CV is {left_input_cv.shape}')
print(f'input size for cross-validation is {len(targets_cv)}')
print(f'no.of positive pairs in CV are {np.shape(np.nonzero(targets_cv))[1]}')

#left_input_test = left_input[cv_index:]
#right_input_test = right_input[cv_index:]
#targets_test = targets[cv_index:]

print(f'shape of left/right input for testing is {left_input_test.shape}')
print(f'input size for testing is {len(targets_test)}')
print(f'no.of positive pairs in test are {np.shape(np.nonzero(targets_test))[1]}')

shape of left/right input for CV is (3424, 30, 45, 30)
input size for cross-validation is 3424
no.of positive pairs in CV are 1694
shape of left/right input for testing is (856, 30, 45, 30)
input size for testing is 856
no.of positive pairs in test are 423


In [None]:
def specificity(y_true, y_pred):
    
    y_pred = K.cast(y_pred > 0.5, y_true.dtype)
    tn, fp, fn, tp = (0, 0, 0, 0)
    for i in range(len(y_true)):
      if y_true[i] == 1 and y_pred[i] == 1:
        tp += 1
      elif y_true[i] == 1 and y_pred[i] == 0:
        fp += 1
      elif y_true[i] == 0 and y_pred[i] == 1:
        fn += 1
      elif y_true[i] == 0 and y_pred[i] == 0:
        tn += 1
    return tn/(tn+fp)

In [None]:
def contrastive_loss(y_true, y_pred):
    margin = 1
    square_pred = K.square(y_pred)
    margin_square = K.square(K.maximum(margin - y_pred, 0))
    return K.mean((1-y_true) * square_pred + y_true * margin_square)

In [None]:
## newly added methods begin
def recall_m(y_true, y_pred):
  y_pred = K.cast(y_pred > 0.5, y_true.dtype)
  true_positives = (K.round(K.clip(y_true * y_pred, 0, 1)))
  possible_positives = (K.round(K.clip(y_true, 0, 1)))
  recall = K.mean(K.equal(true_positives, possible_positives))
  return recall

def precision_m(y_true, y_pred):
  y_pred = K.cast(y_pred > 0.5, y_true.dtype)
  true_positives = (K.round(K.clip(y_true * y_pred, 0, 1)))
  predicted_positives = (K.round(K.clip(y_pred, 0, 1)))
  precision = K.mean(K.equal(true_positives, predicted_positives))
  return precision

def f1_m(y_true, y_pred):
  precision = precision_m(y_true, y_pred)
  recall = recall_m(y_true, y_pred)
  return 2*((precision*recall)/(precision+recall+K.epsilon()))
##newly added end

def compute_accuracy(y_true, y_pred):
    '''Compute classification accuracy with a fixed threshold on distances.
    '''
    pred = y_pred.ravel() < 0.5
    return np.mean(pred == y_true)

def accuracy(y_true, y_pred):
    '''Compute classification accuracy with a fixed threshold on distances.
    '''
    return K.mean(K.equal(y_true, K.cast(y_pred > 0.5, y_true.dtype)))

In [None]:
def euclidean_distance(vectors):
	# unpack the vectors into separate lists
	(featsA, featsB) = vectors
	# compute the sum of squared distances between the vectors
	sumSquared = K.sum(K.square(featsA - featsB), axis=1,
		keepdims=True)
	# return the euclidean distance between the vectors
	return K.sqrt(K.maximum(sumSquared, K.epsilon()))

In [None]:
def eucl_dist_output_shape(shapes):
    shape1, shape2 = shapes
    return (shape1[0], 1)

In [None]:
def specificity1(y_true, y_pred):
    true_negatives = K.sum(K.round(K.clip((1-y_true) * (1-y_pred), 0, 1)))
    possible_negatives = K.sum(K.round(K.clip(1-y_true, 0, 1)))
    return true_negatives / (possible_negatives + K.epsilon())

In [None]:
def SiameseNetwork(input_shape):
    
    input_layer = tf.keras.Input(input_shape)

   # 1st 3D conv blocks, which involves, convolution, BN, activation and pooling 
    x_1 = tf.keras.layers.Conv3D(32, (3,3,3), strides = (1,1,1), padding = 'same', kernel_regularizer = 'L2', name = 'conv3d_1')(input_layer)
    x_1_bn = tf.keras.layers.BatchNormalization(axis = -1)(x_1)
    x_1_bn_ac = tf.keras.layers.Activation('relu')(x_1_bn)

    x_2 = tf.keras.layers.Conv3D(32, (3,3,3), strides = (1,1,1), padding = 'same', kernel_regularizer = 'L2', name = 'conv3d_2')(x_1_bn_ac)
    x_2_bn = tf.keras.layers.BatchNormalization(axis = -1)(x_2)
    x_2_bn_ac = tf.keras.layers.Activation('relu')(x_2_bn)
    x_2_bn_ac_pooling = tf.keras.layers.MaxPooling3D(strides = (2, 2, 2))(x_2_bn_ac)
              
    # 2nd 3D conv block, which involves, convolution, BN, activation and pooling 
    x_3 = tf.keras.layers.Conv3D(64, (3,3,3), strides = (1,1,1), padding = 'same', kernel_regularizer = 'L2', name = 'conv3d_3')(x_2_bn_ac_pooling)
    x_3_bn = tf.keras.layers.BatchNormalization(axis = -1)(x_3)
    x_3_bn_ac = tf.keras.layers.Activation('relu')(x_3_bn)

    x_4 = tf.keras.layers.Conv3D(64, (3,3,3), strides = (1,1,1), padding = 'same', kernel_regularizer = 'L2', name = 'conv3d_4')(x_3_bn_ac)
    x_4_bn = tf.keras.layers.BatchNormalization(axis = -1)(x_4)
    x_4_bn_ac = tf.keras.layers.Activation('relu')(x_4_bn)
    x_4_bn_ac_pooling = tf.keras.layers.MaxPooling3D(strides = (2, 2, 2))(x_4_bn_ac)
    
    # 3rd 3D conv block, which involves, convolution, BN, activation and pooling 
    x_5 = tf.keras.layers.Conv3D(256, (3,3,3), strides = (1,1,1), padding = 'same', kernel_regularizer = 'L2', name = 'conv3d_5')(x_4_bn_ac_pooling)
    x_5_bn = tf.keras.layers.BatchNormalization(axis = -1)(x_5)
    x_5_bn_ac = tf.keras.layers.Activation('relu')(x_5_bn)
    x_5_pooling = tf.keras.layers.MaxPooling3D(strides = (2, 2, 2))(x_5_bn_ac)       
    
    gap_layer = tf.keras.layers.GlobalAveragePooling3D()(x_5_pooling) # Global average pooling layer
    #model.add(tf.keras.layers.Dropout(0.3))
    embeddings = tf.keras.layers.Dense(1024, activation = 'relu', kernel_regularizer = 'L2')(gap_layer)
    
    encoding_model = tf.keras.Model(inputs = input_layer, outputs =  embeddings, name = 'base_3dcnn')

    # Siamese model starts

    moving_input = tf.keras.Input(input_shape)
    ref_input = tf.keras.Input(input_shape)

    encoded_moving  = encoding_model(moving_input)
    encoded_ref = encoding_model(ref_input)

    L1_layer = tf.keras.layers.Lambda(lambda tensors:K.abs(tensors[0] - tensors[1]))
    #L2_layer = tf.keras.layers.Lambda(lambda tensors:K.l2_normalize((tensors[0] - tensors[1]), axis = 1))

    L1_distance = L1_layer([encoded_moving, encoded_ref]) # L1-norm
    #dot_product = tf.keras.layers.dot([encoded_moving, encoded_ref], axes = 1, normalize = False)
    #L2_distance = L2_layer([encoded_moving, encoded_ref]) # L2-norm or Euclidean Norm
    #L2_distance = tf.keras.layers.Lambda(euclidean_distance, output_shape=eucl_dist_output_shape)([encoded_moving, encoded_ref])

    prediction = tf.keras.layers.Dense(1, activation='sigmoid')(L1_distance)
    siamesenet = tf.keras.Model(inputs = [moving_input, ref_input], outputs = prediction, name = 'siamese_3dmodel')
    
    return siamesenet, encoding_model

In [None]:
img_shape = (30, 45, 30, 1)

In [None]:
siamese_model, base_model = SiameseNetwork(img_shape)
base_learning_rate = 0.00005
base_model.summary()
siamese_model.summary()

Model: "base_3dcnn"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
input_1 (InputLayer)         [(None, 30, 45, 30, 1)]   0         
_________________________________________________________________
conv3d_1 (Conv3D)            (None, 30, 45, 30, 32)    896       
_________________________________________________________________
batch_normalization (BatchNo (None, 30, 45, 30, 32)    128       
_________________________________________________________________
activation (Activation)      (None, 30, 45, 30, 32)    0         
_________________________________________________________________
conv3d_2 (Conv3D)            (None, 30, 45, 30, 32)    27680     
_________________________________________________________________
batch_normalization_1 (Batch (None, 30, 45, 30, 32)    128       
_________________________________________________________________
activation_1 (Activation)    (None, 30, 45, 30, 32)    0

In [None]:
siamese_model.compile(optimizer = tf.keras.optimizers.Adam(learning_rate = base_learning_rate), loss = 'binary_crossentropy', metrics = [accuracy, recall_m, specificity, precision_m, f1_m])
    
fine_tune_epochs = 20
history_fine = siamese_model.fit([left_input_cv, right_input_cv], targets_cv, batch_size = 32,
                              epochs = fine_tune_epochs,
                              shuffle = True,
                              validation_split = 0.2)

Epoch 1/20
Epoch 2/20
Epoch 3/20
Epoch 4/20
Epoch 5/20
Epoch 6/20
Epoch 7/20
Epoch 8/20
Epoch 9/20
Epoch 10/20
15/86 [====>.........................] - ETA: 1:36 - loss: 4.8270 - accuracy: 0.5417 - recall_m: 0.5417 - specificity: 0.5417 - precision_m: 1.0000 - f1_m: 0.6996

In [None]:
acc = history_fine.history['accuracy']
val_acc = history_fine.history['val_accuracy']

recall_m = history_fine.history['recall_m']
val_recall_m = history_fine.history['val_recall_m']

specificity = history_fine.history['specificity']
val_specificity = history_fine.history['val_specificity']

precision_m = history_fine.history['precision_m']
val_precision_m = history_fine.history['val_precision_m']

f1_m = history_fine.history['f1_m']
val_f1_m = history_fine.history['val_f1_m']
    
loss = history_fine.history['loss']
val_loss = history_fine.history['val_loss']
    
plt.figure(figsize=(8, 8))

plt.figure()
plt.plot(acc, label='Training Accuracy')
plt.plot(val_acc, label='Validation Accuracy')
plt.ylim([min(plt.ylim()),1.01])
plt.legend(loc='lower right')
plt.ylabel('Accuracy')
plt.xlabel('epoch')

plt.figure()
plt.plot(recall_m, label='Training Recall')
plt.plot(val_recall_m, label='Validation Recall')
plt.ylim([min(plt.ylim()),1.01])
plt.legend(loc='lower right')
plt.ylabel('Recall(Sensitivity)')
plt.xlabel('epoch')

plt.figure()
plt.plot(specificity, label='Training Specificity')
plt.plot(val_specificity, label='Validation Specificity')
plt.ylim([min(plt.ylim()),1.01])
plt.legend(loc='lower right')
plt.ylabel('Specificity')
plt.xlabel('epoch')

plt.figure()
plt.plot(precision_m, label='Training Precision')
plt.plot(val_precision_m, label='Validation Precision')
plt.ylim([min(plt.ylim()),1.01])
plt.legend(loc='lower right')
plt.ylabel('Precision')
plt.xlabel('epoch')

plt.figure()
plt.plot(f1_m, label='Training F1-score')
plt.plot(val_f1_m, label='Validation F1-score')
plt.ylim([min(plt.ylim()),1.01])
plt.legend(loc='lower right')
plt.ylabel('F1-score')
plt.xlabel('epoch')

plt.figure()
plt.plot(loss, label='Training Loss')
plt.plot(val_loss, label='Validation Loss')
plt.ylim([0, max(plt.ylim())])
plt.legend(loc='upper right')
plt.ylabel('Contrastive Loss')
plt.xlabel('epoch')

plt.show()

In [None]:
predicton_prob = siamese_model.predict([left_input_test, right_input_test])

print(np.transpose(predicton_prob))

predicton_prob[predicton_prob > 0.5] = 1
predicton_prob[predicton_prob <= 0.5] = 0

tn, fp, fn, tp = sklearn.metrics.confusion_matrix(targets_test, predicton_prob).ravel()

print(f'test Accuracy: {sklearn.metrics.accuracy_score(targets_test, predicton_prob)}')
print(f'test ROC (AUC): {sklearn.metrics.roc_auc_score(targets_test, predicton_prob)}')
print(f'test Sensitivity (Recall): {sklearn.metrics.recall_score(targets_test, predicton_prob)}')
print(f'test Precision: {sklearn.metrics.precision_score(targets_test, predicton_prob)}')
print(f'test F1-score: {sklearn.metrics.f1_score(targets_test, predicton_prob)}')
print(f'test Mathews Correlation Coefficient: {sklearn.metrics.matthews_corrcoef(targets_test, predicton_prob)}')
print(f'test Specificity: {tn/(tn+fp)}')
print(targets_test)

In [None]:
# Save Model
siamese_model.save('/content/drive/My Drive/Autism_CNN/autism_cnn_model')

In [None]:
siamese_model = tf.keras.models.load_model('/content/drive/My Drive/Autism_CNN/autism_cnn_model', custom_objects={'contrastive_loss':contrastive_loss, 'accuracy':accuracy, 'recall_m':recall_m, 'specificity': specificity, 'precision_m':precision_m, 'f1_m':f1_m})
print('Model is loaded')

In [None]:

left_test = np.expand_dims(left_input_test[1], axis = (0, -1))
print(left_test.shape)
right_test = np.expand_dims(right_input_test[1], axis = (0, -1))
#print(right_test.astype)
p = siamese_model.predict([left_test, right_test])
print(f'predicted percentage similarity is {(1-p[0][0])*100} and actual similarity is {(1-targets_test[1])*100}')

In [None]:
siamese_model.layers[-1].activation = None

last_conv_layer_name = 'conv3d_5'
last_layer = siamese_model.get_layer('base_3dcnn').get_layer(last_conv_layer_name)
s_model = tf.keras.models.Model(
        [siamese_model.inputs], [last_layer.output, siamese_model.output]
    )

In [None]:
with tf.GradientTape(persistent=True) as tape:
    #tape.watch(s_model.trainable_variables)
    #tape.watch(tf.Variable(left_test))
    #tape.watch(tf.Variable(right_test))
    last_conv_layer_output, prediction = s_model([left_test, right_test])
    loss = prediction[:, 0]
output = last_conv_layer_output[0]
output_mean_map = tf.reduce_mean(output, axis=(0, 1, 2))
print(loss)
print(output_mean_map)

In [None]:
grads = tape.gradient(loss, last_conv_layer_output, unconnected_gradients='zero')[0]
print(grads.shape)
weights = tf.reduce_mean(grads, axis = (0, 1, 2))
print(weights)

heat_map = np.zeros(output.shape[0:3], dtype=np.float32)

for index, w in enumerate(weights):
    heat_map += w * output[:, :, :, index]

In [None]:
from skimage.transform import resize
import cv2

hm = resize(heat_map, (30, 45, 30))

#heatmap = (hm - hm.min())/(hm.max() - hm.min())
heatmap = hm

f, axarr = plt.subplots(2,3, figsize=(15,10));
f.suptitle('Grad-CAM')
slice_count=14
slice_count2=5

mri_image = np.squeeze(right_test)

sagittal_mri_img=np.rot90(np.squeeze(mri_image[slice_count, :, :]))
sagittal_grad_cmap_img=np.squeeze(heatmap[slice_count, :, :])

coronal_mri_img=np.rot90(np.squeeze(mri_image[:, slice_count2, :]))
coronal_grad_cmap_img=np.squeeze(heatmap[:, slice_count2, :]) 

axial_mri_img=np.rot90(np.squeeze(mri_image[:, :, slice_count2]))
axial_grad_cmap_img=np.squeeze(heatmap[:, :, slice_count2]) 

img_plot = axarr[0,0].imshow(axial_mri_img, cmap='gray');
axarr[0,0].axis('off')
axarr[0,0].set_title('MRI-Axial')
    
img_plot = axarr[0,1].imshow(coronal_mri_img, cmap='gray');
axarr[0,1].axis('off')
axarr[0,1].set_title('MRI-Coronal')

img_plot = axarr[0,2].imshow(sagittal_mri_img, cmap='gray');
axarr[0,2].axis('off')
axarr[0,2].set_title('MRI-Sagittal')
    
#axial_overlay=cv2.addWeighted(axial_mri_img, 0.3, axial_grad_cmap_img, 0.6, 0)
    



