<a href="https://colab.research.google.com/github/sutummala/periCellNet/blob/main/periCellNet_Contrastive.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
from google.colab import drive
drive.mount('/content/drive')

Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).


In [None]:
import os
import tensorflow as tf
#import tensorflow_addons as tfa
from tensorflow import keras
from tensorflow.keras import backend as K
from tensorflow.keras import layers
#import tensorflow_addons as tfa
from sklearn.model_selection import StratifiedKFold, cross_val_score
import sklearn
import nibabel as nib
import numpy as np
import random
import matplotlib.pyplot as plt

In [None]:
# Data paths for loading individual cells 

basophil = np.load('/content/drive/MyDrive/Datasets/PBC_dataset_normal_DIB/basophil.npy')
print(f'basophil tensor shape is {basophil.shape}')
basophil_labels = np.zeros(len(basophil))

eosinophil = np.load('/content/drive/MyDrive/Datasets/PBC_dataset_normal_DIB/eosinophil.npy')
print(f'eosinophil tensor shape is {eosinophil.shape}')
eosinophil_labels = np.ones(len(eosinophil))

ig = np.load('/content/drive/MyDrive/Datasets/PBC_dataset_normal_DIB/ig.npy')
print(f'immature grannulocytes tensor shape is {ig.shape}')
ig_labels = 2 * np.ones(len(ig))

erythroblast = np.load('/content/drive/MyDrive/Datasets/PBC_dataset_normal_DIB/erythroblast.npy')
print(f'erythroblast tensor shape is {erythroblast.shape}')
erythroblast_labels = 3 * np.ones(len(erythroblast))

lymphocyte = np.load('/content/drive/MyDrive/Datasets/PBC_dataset_normal_DIB/lymphocyte.npy')
print(f'lymphocyte tensor shape is {lymphocyte.shape}')
lymphocyte_labels = 4 * np.ones(len(lymphocyte))

monocyte = np.load('/content/drive/MyDrive/Datasets/PBC_dataset_normal_DIB/monocyte.npy')
print(f'monocyte tensor shape is {monocyte.shape}')
monocyte_labels = 5 * np.ones(len(monocyte))

platelet = np.load('/content/drive/MyDrive/Datasets/PBC_dataset_normal_DIB/platelet.npy')
print(f'platelet tensor shape is {platelet.shape}')
platelet_labels = 6 * np.ones(len(platelet))

neutrophil = np.load('/content/drive/MyDrive/Datasets/PBC_dataset_normal_DIB/neutrophil.npy')
print(f'neutrophil tensor shape is {neutrophil.shape}')
neutrophil_labels =  7 * np.ones(len(neutrophil))

basophil tensor shape is (1218, 358, 358, 3)
eosinophil tensor shape is (3117, 358, 358, 3)
immature grannulocytes tensor shape is (2895, 358, 358, 3)
erythroblast tensor shape is (1551, 358, 358, 3)
lymphocyte tensor shape is (1214, 358, 358, 3)
monocyte tensor shape is (1420, 358, 358, 3)
platelet tensor shape is (2348, 358, 358, 3)
neutrophil tensor shape is (3329, 358, 358, 3)


In [None]:
def find_compare_to(l, i, j, index, no_of_pairs):
  
  # finding indices
  if True:
    compare_to = i
    while compare_to == i: # Making sure it's not comparing to itself
      if l == 0: # Label 0
        if (j % 2 != 0):
          compare_to = random.randint(0, index-1)
        else:
          compare_to = random.randint(index, 8*index-1)
      elif l == 1:
        if (j % 2 != 0):
          compare_to = random.randint(index, 2*index-1)
        elif j <= int(0.2*no_of_pairs) and (j % 2 == 0):
          compare_to = random.randint(0, index-1)
        else:
          compare_to = random.randint(2*index, 8*index-1)
      elif l == 2:
        if (j % 2 != 0):
          compare_to = random.randint(2*index, 3*index-1)
        elif j <= int(0.3*no_of_pairs) and (j % 2 == 0):
          compare_to = random.randint(0, 2*index-1)
        else:
          compare_to = random.randint(3*index, 8*index-1)
      elif l == 3:
        if (j % 2 != 0):
          compare_to = random.randint(3*index, 4*index-1)
        elif j <= int(0.4*no_of_pairs) and (j % 2 == 0):
          compare_to = random.randint(0, 3*index-1)
        else:
          compare_to = random.randint(4*index, 8*index-1)
      elif l == 4:
        if (j % 2 != 0):
          compare_to = random.randint(4*index, 5*index-1)
        elif j <= int(0.5*no_of_pairs) and (j % 2 == 0):
          compare_to = random.randint(0, 4*index-1)
        else:
          compare_to = random.randint(5*index, 8*index-1)
      elif l == 5:
        if (j % 2 != 0):
          compare_to = random.randint(5*index, 6*index-1)
        elif j <= int(0.6*no_of_pairs) and (j % 2 == 0):
          compare_to = random.randint(0, 5*index-1)
        else:
          compare_to = random.randint(6*index, 8*index-1)
      elif l == 6:
        if (j % 2 != 0):
          compare_to = random.randint(6*index, 7*index-1)
        elif j <= int(0.7*no_of_pairs) and (j % 2 == 0):
          compare_to = random.randint(0, 6*index-1)
        else:
          compare_to = random.randint(7*index, 8*index-1)
      elif l == 7:
        if (j % 2 != 0):
          compare_to = random.randint(7*index, 8*index-1)
        else:
          compare_to = random.randint(0, 7*index-1)
  
  return compare_to

In [None]:
def generate_pairs(X, y, index, no_of_pairs):
    
    image_list = np.split(X, X.shape[0])
    label_list = np.split(y, len(y))
    
    left_input = []
    right_input = []
    targets = []
    
    no_of_labels = np.unique(y) # Number of classes

    for l in range(len(no_of_labels)):
      if l == 0:
        print('doing for label 0')
        for i in range(0, index):
          for j in range(no_of_pairs):
            compare_to = find_compare_to(l, i, j, index, no_of_pairs)
            left_input.append(image_list[i])
            right_input.append(image_list[compare_to])
            if label_list[i] == label_list[compare_to]: # They are same
                targets.append(1.)
            else:# Not same
                targets.append(0.)
        print(f'targets {len(targets)}')
      if l == 1:
        print('doing for label 1')
        for i in range(index, 2*index):
          for j in range(no_of_pairs):
            compare_to = find_compare_to(l, i, j, index, no_of_pairs)
            left_input.append(image_list[i])
            right_input.append(image_list[compare_to])
            if label_list[i] == label_list[compare_to]: # They are same
                targets.append(1.)
            else:# Not same
                targets.append(0.)
        print(f'targets {len(targets)}')
      if l == 2:
        print('doing for label 2')
        for i in range(2*index, 3*index):
          for j in range(no_of_pairs):
            compare_to = find_compare_to(l, i, j, index, no_of_pairs)
            left_input.append(image_list[i])
            right_input.append(image_list[compare_to])
            if label_list[i] == label_list[compare_to]: # They are same
                targets.append(1.)
            else:# Not same
                targets.append(0.)
        print(f'targets {len(targets)}')
      if l == 3:
        print('doing for label 3')
        for i in range(3*index, 4*index):
          for j in range(no_of_pairs):
            compare_to = find_compare_to(l, i, j, index, no_of_pairs)
            left_input.append(image_list[i])
            right_input.append(image_list[compare_to])
            if label_list[i] == label_list[compare_to]: # They are same
                targets.append(1.)
            else:# Not same
                targets.append(0.)
        print(f'targets {len(targets)}')
      if l == 4:
        print('doing for label 4')
        for i in range(4*index, 5*index):
          for j in range(no_of_pairs):
            compare_to = find_compare_to(l, i, j, index, no_of_pairs)
            left_input.append(image_list[i])
            right_input.append(image_list[compare_to])
            if label_list[i] == label_list[compare_to]: # They are same
                targets.append(1.)
            else:# Not same
                targets.append(0.)
        print(f'targets {len(targets)}')
      if l == 5:
        print('doing for label 5')
        for i in range(5*index, 6*index):
          for j in range(no_of_pairs):
            compare_to = find_compare_to(l, i, j, index, no_of_pairs)
            left_input.append(image_list[i])
            right_input.append(image_list[compare_to])
            if label_list[i] == label_list[compare_to]: # They are same
                targets.append(1.)
            else:# Not same
                targets.append(0.)
        print(f'targets {len(targets)}')
      if l == 6:
        print('doing for label 6')
        for i in range(6*index, 7*index):
          for j in range(no_of_pairs):
            compare_to = find_compare_to(l, i, j, index, no_of_pairs)
            left_input.append(image_list[i])
            right_input.append(image_list[compare_to])
            if label_list[i] == label_list[compare_to]: # They are same
                targets.append(1.)
            else:# Not same
                targets.append(0.)
        print(f'targets {len(targets)}')
      if l == 7:
        print('doing for label 7')
        for i in range(7*index, 8*index):
          for j in range(no_of_pairs):
            compare_to = find_compare_to(l, i, j, index, no_of_pairs)
            left_input.append(image_list[i])
            right_input.append(image_list[compare_to])
            if label_list[i] == label_list[compare_to]: # They are same
                targets.append(1.)
            else:# Not same
                targets.append(0.)
        print(f'targets {len(targets)}')

    left_input = np.squeeze(np.array(left_input))
    right_input = np.squeeze(np.array(right_input))
    targets = np.squeeze(np.array(targets))
    
    return left_input, right_input, targets

In [None]:
#y_cor = np.ones(np.shape(lymphocyte)[0])

#y_incor = np.zeros(np.shape(monocyte)[0])

#y_true = np.concatenate((y_cor, y_incor))

#X = np.concatenate((lymphocyte, monocyte))

index = 125

X = np.concatenate((basophil[:index, :, :, :], eosinophil[:index, :, :, :], ig[:index, :, :, :], erythroblast[:index, :, :, :], lymphocyte[:index, :, :, :], monocyte[:index, :, :, :], platelet[:index, :, :, :], neutrophil[:index, :, :, :]))
y = np.concatenate((basophil_labels[:index], eosinophil_labels[:index], ig_labels[:index], erythroblast_labels[:index], lymphocyte_labels[:index], monocyte_labels[:index], platelet_labels[:index], neutrophil_labels[:index]))

print(X.shape)
print(y.shape)
print(len(np.unique(y)))

left_input, right_input, targets = generate_pairs(X, y, index, no_of_pairs = 20)

#targets = np.load('/content/drive/My Drive/Autism_CNN/autism_labels.npy')

targets = 1-targets # 1 for different pair and 0 for same pair 

print(f'total size of the data is {len(targets)}')

(1000, 358, 358, 3)
(1000,)
8
doing for label 0
targets 2500
doing for label 1
targets 5000
doing for label 2
targets 7500
doing for label 3
targets 10000
doing for label 4
targets 12500
doing for label 5
targets 15000
doing for label 6
targets 17500
doing for label 7
targets 20000
total size of the data is 20000


In [None]:
left_input = left_input[:, 30:330, 30:330, :]
right_input = right_input[:, 30:330, 30:330, :]

left_input.shape

(20000, 300, 300, 3)

In [None]:
np.save('/content/drive/MyDrive/Datasets/PBC_dataset_normal_DIB/siamese_left_input', left_input)
np.save('/content/drive/MyDrive/Datasets/PBC_dataset_normal_DIB/siamese_right_input', right_input)
np.save('/content/drive/MyDrive/Datasets/PBC_dataset_normal_DIB/siamese_labels', targets)

In [None]:
left_input = np.load('/content/drive/MyDrive/Datasets/PBC_dataset_normal_DIB/siamese_right_input.npy')
right_input = np.load('/content/drive/MyDrive/Datasets/PBC_dataset_normal_DIB/siamese_left_input.npy')
targets = np.load('/content/drive/MyDrive/Datasets/PBC_dataset_normal_DIB/siamese_labels.npy')

print(left_input.shape)
print(right_input.shape)
print(np.unique(targets))

(20000, 300, 300, 3)
(20000, 300, 300, 3)
[0. 1.]


In [None]:

cv_index = int(0.8 * len(targets)) 

left_input_cv = left_input[:cv_index]
right_input_cv = right_input[:cv_index]
targets_cv = targets[:cv_index]

print(f'shape of left/right input for CV is {left_input_cv.shape}')
print(f'input size for cross-validation is {len(targets_cv)}')
print(f'no.of positive pairs in CV are {np.shape(np.nonzero(targets_cv))[1]}')

left_input_test = left_input[cv_index:]
right_input_test = right_input[cv_index:]
targets_test = targets[cv_index:]

print(f'shape of left/right input for testing is {left_input_test.shape}')
print(f'input size for testing is {len(targets_test)}')
print(f'no.of positive pairs in test are {np.shape(np.nonzero(targets_test))[1]}')

shape of left/right input for CV is (12800, 300, 300, 3)
input size for cross-validation is 12800
no.of positive pairs in CV are 6400
shape of left/right input for testing is (3200, 300, 300, 3)
input size for testing is 3200
no.of positive pairs in test are 1600


In [None]:
def contrastive_loss(y_true, y_pred):
    margin = 1
    square_pred = K.square(y_pred)
    margin_square = K.square(K.maximum(margin - y_pred, 0))
    return K.mean((1-y_true) * square_pred + y_true * margin_square)

In [None]:
def SiameseNetwork(input_shape):
    
    moving_input = tf.keras.Input(input_shape)
    ref_input = tf.keras.Input(input_shape)
    
    # base_model using EfficientNetB3

    inp = tf.keras.layers.Input(shape=input_shape)
    base_model = tf.keras.applications.EfficientNetB3(include_top=False, weights='imagenet', input_tensor=inp,
                                          input_shape=input_shape)

    #index = int(0.1 * len(base_model.layers))
    
    #for layer in base_model.layers[:index]:
      #layer.trainable = False

    last_conv_layer = base_model.get_layer('top_activation')

    x =  tf.keras.layers.GlobalAveragePooling2D()(last_conv_layer.output)
    #x = tf.keras.layers.BatchNormalization()(x)
    #x = tf.keras.layers.Dense(256, activation = 'relu', name = 'dense_layer')(x)

    model = tf.keras.models.Model(inputs = inp, outputs = x)
    
    # Siamese model starts here

    encoded_moving  = model(moving_input)
    encoded_ref = model(ref_input)

    L1_layer = tf.keras.layers.Lambda(lambda tensors:K.abs(tensors[0] - tensors[1]), name = 'lambda_layer')
    #L2_layer = tf.keras.layers.Lambda(lambda tensors:K.l2_normalize((tensors[0] - tensors[1]), axis = 1))

    L1_distance = L1_layer([encoded_moving, encoded_ref]) # L1-norm
    #L2_distance = L2_layer([encoded_moving, encoded_ref]) # L2-norm or Euclidean Norm
    #L2_distance = tf.keras.layers.Lambda(euclidean_distance, output_shape=eucl_dist_output_shape)([encoded_moving, encoded_ref])

    prediction = tf.keras.layers.Dense(1, activation='sigmoid')(L1_distance)
    siamesenet = tf.keras.Model(inputs = [moving_input, ref_input], outputs = prediction, name = 'siamese_3dmodel')
    
    return siamesenet, model

In [None]:
img_shape = (left_input.shape[1], left_input.shape[2], left_input.shape[3])

In [None]:
siamese_model, base_model = SiameseNetwork(img_shape)
base_learning_rate = 0.0005
base_model.summary()
siamese_model.summary()

Model: "model"
__________________________________________________________________________________________________
 Layer (type)                   Output Shape         Param #     Connected to                     
 input_3 (InputLayer)           [(None, 300, 300, 3  0           []                               
                                )]                                                                
                                                                                                  
 rescaling (Rescaling)          (None, 300, 300, 3)  0           ['input_3[0][0]']                
                                                                                                  
 normalization (Normalization)  (None, 300, 300, 3)  7           ['rescaling[0][0]']              
                                                                                                  
 stem_conv_pad (ZeroPadding2D)  (None, 301, 301, 3)  0           ['normalization[0][0]']      

In [None]:
siamese_model.compile(optimizer = tf.keras.optimizers.Adam(learning_rate = base_learning_rate), loss = contrastive_loss, metrics = 'accuracy')
    
fine_tune_epochs = 10
history_fine = siamese_model.fit([left_input, right_input], targets, batch_size = 4,
                              epochs = fine_tune_epochs,
                              shuffle = True,
                              validation_split = 0.0)

Epoch 1/10
Epoch 2/10
Epoch 3/10
Epoch 4/10
Epoch 5/10
Epoch 6/10
Epoch 7/10
Epoch 8/10
Epoch 9/10
Epoch 10/10


In [None]:
index = 125

# Selecting random labels for creating Support Set
baso_index = random.randint(index, np.shape(basophil)[0]-1)
eosi_index = random.randint(index, np.shape(eosinophil)[0]-1)
ig_index = random.randint(index, np.shape(ig)[0]-1)
ery_index = random.randint(index, np.shape(erythroblast)[0]-1)
lym_index = random.randint(index, np.shape(lymphocyte)[0]-1)
mono_index = random.randint(index, np.shape(monocyte)[0]-1)
pla_index = random.randint(index, np.shape(platelet)[0]-1)
neu_index = random.randint(index, np.shape(neutrophil)[0]-1)

support_set = np.concatenate((np.expand_dims(basophil[baso_index, :, :, :], axis=0),
                              np.expand_dims(eosinophil[eosi_index, :, :, :], axis=0), np.expand_dims(ig[ig_index, :, :, :], axis=0), 
                              np.expand_dims(erythroblast[ery_index, :, :, :], axis=0), np.expand_dims(lymphocyte[lym_index, :, :, :], axis=0), 
                              np.expand_dims(monocyte[mono_index, :, :, :], axis=0), np.expand_dims(platelet[pla_index, :, :, :], axis=0),
                              np.expand_dims(neutrophil[neu_index, :, :, :], axis=0)))
# Support Set
support_set = support_set[:, 30:330, 30:330, :]

cell_types = ['neutrophil', 'eosinophil', 'basophil', 'ig', 'lymphocyte', 'monocyte', 'erythroblast', 'platelet']

overall_accuracy = 0

for cells in cell_types:
  if cells == 'neutrophil':
    print(f'calculating accuracy for {cells}')
    cell = neutrophil
    cell_labels = neutrophil_labels
  elif cells == 'eosinophil':
    print(f'calculating accuracy for {cells}')
    cell = eosinophil
    cell_labels = eosinophil_labels
  elif cells == 'basophil':
    print(f'calculating accuracy for {cells}')
    cell = basophil
    cell_labels = basophil_labels
  elif cells == 'ig':
    print(f'calculating accuracy for {cells}')
    cell = ig
    cell_labels = ig_labels
  elif cells == 'lymphocyte':
    print(f'calculating accuracy for {cells}')
    cell = lymphocyte
    cell_labels = lymphocyte_labels
  elif cells == 'monocyte':
    print(f'calculating accuracy for {cells}')
    cell = monocyte
    cell_labels = monocyte_labels
  elif cells == 'erythroblast':
    print(f'calculating accuracy for {cells}')
    cell = erythroblast
    cell_labels = erythroblast_labels
  elif cells == 'platelet':
    print(f'calculating accuracy for {cells}')
    cell = platelet
    cell_labels = platelet_labels

  correct_predictions = 0 # Initializing for calculating accuracy
  for c in range(index, np.shape(cell)[0]):

      query_image = np.expand_dims(cell[c, :, :, :], axis=0)
      query_images = np.repeat(query_image, repeats = 8, axis = 0)
      query_images = query_images[:, 30:330, 30:330, :]

      # predicting on a query image
      query_predictions = siamese_model.predict([query_images, support_set])
      predicted_label = np.argmin(query_predictions)
      #print(np.transpose(query_predictions))
      #print(f'actual label is {erythroblast_labels[c]}, predicted label is {predicted_label}')
      if cell_labels[c] == predicted_label:
        correct_predictions += 1
  accuracy = correct_predictions/(np.shape(cell)[0]-index)
  print(f'accuracy for predicting {cells} is {accuracy}')
  overall_accuracy = accuracy + overall_accuracy

overall_accuracy = overall_accuracy/8
print(f'Overall Accuracy is {overall_accuracy}')

calculating accuracy for neutrophil
accuracy for predicting neutrophil is 0.9797128589263421
calculating accuracy for eosinophil
accuracy for predicting eosinophil is 0.9973262032085561
calculating accuracy for basophil
accuracy for predicting basophil is 0.9817017383348582
calculating accuracy for ig
accuracy for predicting ig is 0.8873646209386281
calculating accuracy for lymphocyte
accuracy for predicting lymphocyte is 0.9715335169880625
calculating accuracy for monocyte
accuracy for predicting monocyte is 0.4617760617760618
calculating accuracy for erythroblast
accuracy for predicting erythroblast is 0.4831697054698457
calculating accuracy for platelet
accuracy for predicting platelet is 0.9932523616734144
Overall Accuracy is 0.8444796334144711


In [None]:
predicton_prob = siamese_model.predict([left_input_test, right_input_test])

print(np.transpose(predicton_prob))

predicton_prob[predicton_prob > 0.5] = 1
predicton_prob[predicton_prob <= 0.5] = 0

tn, fp, fn, tp = sklearn.metrics.confusion_matrix(targets_test, predicton_prob).ravel()

print(f'test Accuracy: {sklearn.metrics.accuracy_score(targets_test, predicton_prob)}')
print(f'test ROC (AUC): {sklearn.metrics.roc_auc_score(targets_test, predicton_prob)}')
print(f'test Sensitivity (Recall): {sklearn.metrics.recall_score(targets_test, predicton_prob)}')
print(f'test Precision: {sklearn.metrics.precision_score(targets_test, predicton_prob)}')
print(f'test F1-score: {sklearn.metrics.f1_score(targets_test, predicton_prob)}')
print(f'test Mathews Correlation Coefficient: {sklearn.metrics.matthews_corrcoef(targets_test, predicton_prob)}')
print(f'test Specificity: {tn/(tn+fp)}')
print(targets_test)

[[0.00759898 0.00699868 0.05420027 ... 0.13596778 0.9997669  0.13029419]]
test Accuracy: 0.6828125
test ROC (AUC): 0.6828124999999998
test Sensitivity (Recall): 0.6275
test Precision: 0.7055516514406184
test F1-score: 0.6642408203771087
test Mathews Correlation Coefficient: 0.3678829853793725
test Specificity: 0.738125
[1. 0. 1. ... 0. 1. 0.]


In [None]:
_, siamese_embeddings_model = SiameseNetwork(img_shape)
#siamese_embeddings_model.summary()

In [None]:
embeddings_weights = siamese_model.layers[-3].get_weights()
siamese_embeddings_model.set_weights(embeddings_weights)

In [None]:
test_L = np.expand_dims(left_input_test[20], axis = 0)
test_R = np.expand_dims(right_input_test[20], axis = 0)

In [None]:
last_conv_layer = siamese_embeddings_model.get_layer('top_conv')
new_embeddings_model = tf.keras.models.Model(siamese_embeddings_model.inputs, [last_conv_layer.output, siamese_embeddings_model.output])

last_conv_output_left, vector_left = new_embeddings_model.predict(test_L)
last_conv_output_right, vector_right = new_embeddings_model.predict(test_R)

In [None]:
last_conv_output_left = np.squeeze(last_conv_output_left)
last_conv_output_left.shape

(10, 10, 1536)

In [None]:
last_conv_output_right = np.squeeze(last_conv_output_right)
last_conv_output_right.shape

(10, 10, 1536)

In [None]:
lambda_layer = siamese_model.get_layer('lambda_layer')
s_model = tf.keras.models.Model(inputs = [siamese_model.inputs], outputs = [lambda_layer.output, siamese_model.output])
L1, prediction = s_model.predict([test_L, test_R]) 

In [None]:
L1 = np.squeeze(L1)
L1.shape, np.min(L1), np.max(L1), np.mean(L1)

In [None]:
heat_map = np.zeros(last_conv_output_left.shape[0:2], dtype = np.float32)

In [None]:
for index, w in enumerate(L1):
  heat_map += w * last_conv_output_left[:, :, index]

ValueError: ignored

In [None]:
heat_map[heat_map < 0] = 0

In [None]:
np.min(heat_map), np.max(heat_map), np.mean(heat_map)

In [None]:
figure = plt.figure()
axes = figure.add_subplot(111)

caxes = axes.matshow(heat_map, interpolation = 'nearest')
figure.colorbar(caxes, boundaries=np.linspace(0, np.max(heat_map), 8, endpoint = True))

plt.show()

In [None]:
def save_and_display_gradcam(img, heatmap, cam_path="cam.jpg", alpha=0.5):
    
    # Rescale heatmap to a range 0-255
    heatmap = np.uint8(255 * heatmap)
    #heatmap = np.uint8(heatmap/50)

    # Use jet colormap to colorize heatmap
    jet = cm.get_cmap("jet")

    # Use RGB values of the colormap
    jet_colors = jet(np.arange(256))[:, :3]
    jet_heatmap = jet_colors[heatmap]

    # Create an image with RGB colorized heatmap
    jet_heatmap = keras.preprocessing.image.array_to_img(jet_heatmap)
    jet_heatmap = jet_heatmap.resize((img.shape[1], img.shape[0]))
    jet_heatmap = keras.preprocessing.image.img_to_array(jet_heatmap)

    # Superimpose the heatmap on original image
    superimposed_img = jet_heatmap * alpha + img
    superimposed_img = keras.preprocessing.image.array_to_img(superimposed_img)

    jet_heatmap_path = 'heatmap.jpg'

    # Save the superimposed image
    superimposed_img.save(cam_path)

    jet_heatmap1 = keras.preprocessing.image.array_to_img(jet_heatmap)
    jet_heatmap1.save(jet_heatmap_path)

    # Display Grad CAM
    display(Image(cam_path))
    display(Image(jet_heatmap_path))

In [None]:
from skimage.transform import resize
from IPython.display import Image, display
import matplotlib.cm as cm

hm = resize(heat_map, (300, 300), anti_aliasing = True)

heatmap = (hm - hm.min())/(hm.max() - hm.min())

save_and_display_gradcam(np.squeeze(test_L), heatmap)

In [None]:
# Save Model
siamese_model.save('/content/drive/My Drive/periCellNet/siamese_pericell_model')

INFO:tensorflow:Assets written to: /content/drive/My Drive/periCellNet/siamese_pericell_model/assets


  layer_config = serialize_layer_fn(layer)
  return generic_utils.serialize_keras_object(obj)


In [None]:
new_model = tf.keras.models.load_model('/content/drive/My Drive/periCellNet/siamese_pericell_model', custom_objects={'contrastive_loss':contrastive_loss})