# 0 Load dataset

In [None]:
from google.colab import drive
drive.mount('/content/drive')
!nvidia-smi

In [None]:
PROJECT_path = '/content/drive/MyDrive/IDB_diamond_damage'

In [None]:
import os
import cv2
import numpy as np
import tensorflow as tf
from PIL import Image
from tqdm import tqdm

In [None]:
def load_dataset(dataset_path):
  image_list = []
  label_list = []
  tag = 0
  category_names = os.listdir(dataset_path)
  category_nums = len(category_names)
  category_names.sort()
  print(category_names)
  for category in category_names:
    category_path = os.path.join(dataset_path, category)
    file_names = os.listdir(category_path)
    file_nums = len(file_names)
    file_names.sort()  
    for file in tqdm(file_names):
      file_path = os.path.join(category_path, file)
      image = Image.open(file_path)
      img = np.asarray(image,dtype="float32")

      #chose wheather to crop the images, e.g., 1024*1024
      img = img[0:1024, 0:1024]
      
      img = img[:, :, np.newaxis] 
      image_list.append(img)
      label_list.append(tag)
    tag += 1
  return image_list, label_list

In [None]:
X_set, Y_set = load_dataset(os.path.join(PROJECT_path, 'SEM'))

In [None]:
CLASS_num = 65
CLASS = ["#01", "#02", "#03", "#04", "#05", "#06", "#07", "#08", "#09", "#10", "#11", "#12", "#13", "#14", "#15", "#16", "#17", "#18", "#19", "#20", "#21", "#22", "#23", "#24", "#25", "#26", "#27", "#28", "#29", "#30", "#31", "#32", "#33", "#34", "#35", "#36", "#37", "#38", "#39", "#40", "#41", "#42", "#43", "#44", "#45", "#46", "#47", "#48", "#49", "#50", "#51", "#52", "#53", "#54", "#55", "#56", "#57", "#58", "#59", "#60", "#61", "#62", "#63", "#64", "#65"]

# 1 Dataset processing

In [None]:
def classification_dataset_process(X_set, Y_set):
  
  # choose the size to convert, e.g., 224*224
  X_set = [cv2.cvtColor(cv2.resize(i, (224, 224)), cv2.COLOR_GRAY2RGB) for i in X_set]

  X_set = np.asarray(X_set)
  X_set = X_set.astype('float32')
  X_set /= 255.0
  Y_set = tf.keras.utils.to_categorical(Y_set, 65)
  return X_set, Y_set

In [None]:
X_set, Y_set = classification_dataset_process(X_set, Y_set)

# 2 Grad-CAM algorithms

In [None]:
import os
import cv2
import heapq
import keras
import tensorflow as tf
import numpy as np
import tensorflow.keras.backend as K
from google.colab.patches import cv2_imshow
from tensorflow.keras.applications.vgg16 import (VGG16, preprocess_input, decode_predictions)
from tensorflow.keras.models import load_model, Sequential
from tensorflow.keras.preprocessing import image
from tensorflow.python.framework import ops
tf.compat.v1.disable_eager_execution()

In [None]:
model_path = "/content/drive/MyDrive/IDB_diamond_damage/saved_models/classification_model.h5"
model = load_model(model_path)

In [None]:
def register_gradient():
    if "GuidedBackProp" not in ops._gradient_registry._registry:
        @ops.RegisterGradient("GuidedBackProp")
        def _GuidedBackProp(op, grad):
            dtype = op.inputs[0].dtype
            return grad * tf.cast(grad > 0., dtype) * tf.cast(op.inputs[0] > 0., dtype)

def compile_saliency_function(model, activation_layer='block5_conv3'):
    input_img = model.input
    layer_dict = dict([(layer.name, layer) for layer in model.layers[1:]])
    layer_output = layer_dict[activation_layer].output
    max_output = K.max(layer_output, axis=3)
    saliency = K.gradients(K.sum(max_output), input_img)[0]
    return K.function([input_img, K.learning_phase()], [saliency])

def modify_backprop(model, name):
    g = tf.compat.v1.get_default_graph()
    with g.gradient_override_map({'Relu': name}):
        layer_dict = [layer for layer in model.layers[1:]
                      if hasattr(layer, 'activation')]
        for layer in layer_dict:
            if layer.activation == keras.activations.relu:
                layer.activation = tf.nn.relu
        new_model = load_model(model_path)
    return new_model

def deprocess_image(x):
    '''
    Same normalization as in:
    https://github.com/fchollet/keras/blob/master/examples/conv_filter_visualization.py
    '''
    if np.ndim(x) > 3:
        x = np.squeeze(x)
    x -= x.mean()
    x /= (x.std() + 1e-5)
    x *= 0.1
    x += 0.5
    x = np.clip(x, 0, 1)
    x *= 255
    if K.image_data_format() == 'channels_first':
        x = x.transpose((1, 2, 0))
    x = np.clip(x, 0, 255).astype('uint8')
    return x

def _compute_gradients(tensor, var_list):
    with tf.GradientTape() as gtape:
        grads = gtape.gradient(tensor, var_list)
        return [grad if grad is not None else tf.zeros_like(var) for var, grad in zip(var_list, grads)]

def load_image(path):
    img_path = path
    img = cv2.imread(img_path)
    img = img[0:1024][0:1024]
    img= cv2.resize(img,(224,224),interpolation=cv2.INTER_NEAREST)
    x = img
    x = np.expand_dims(x, axis=0)
    x = preprocess_input(x)
    return x

def doCAM(image_path, image_code, dirs1):
  preprocessed_input = load_image(image_path)
  register_gradient()
  guided_model = modify_backprop(model, 'GuidedBackProp')
  saliency_fn = compile_saliency_function(guided_model)
  saliency = saliency_fn([preprocessed_input, 0])
  gradcam = saliency[0].transpose(1, 2, 3, 0)
  a = np.squeeze(gradcam)
  cv2.imwrite(dirs1+"/Guided_BP_"+image_code+".jpg", deprocess_image(a))
  pred = model.predict(preprocessed_input)
  print(np.argmax(pred))
  top1_idx, top2_idx, top3_idx= heapq.nlargest(3, range(len(pred[0])), pred[0].take)
  class_output = model.output[:, top1_idx]
  last_conv_layer = model.get_layer("block5_pool")
  grads = K.gradients(class_output, last_conv_layer.output)[0]
  pooled_grads = K.mean(grads, axis=(0, 1, 2))
  iterate = K.function([model.input], [pooled_grads, last_conv_layer.output[0]])
  pooled_grads_value, conv_layer_output_value = iterate([preprocessed_input])
  for i in range(512):
      conv_layer_output_value[:, :, i] *= pooled_grads_value[i]
  heatmap = np.mean(conv_layer_output_value, axis=-1)
  heatmap = np.maximum(heatmap, 0)
  heatmap /= np.max(heatmap)
  img = cv2.imread(image_path)
  img = img[0:1024][0:1024]
  img= cv2.resize(img,(224,224),interpolation=cv2.INTER_NEAREST)
  heatmap = cv2.resize(heatmap, (img.shape[1], img.shape[0]))
  heatmap = np.uint8(255 * heatmap)
  cv2.imwrite(dirs1+"/Heatmap_"+image_code+".jpg", heatmap)
  heatmap2color = cv2.applyColorMap(heatmap, cv2.COLORMAP_JET)
  grd_CAM = cv2.addWeighted(img, 0.6, heatmap2color, 0.4, 0)
  cv2.imwrite(dirs1+"/Grd-CAM_"+image_code+".jpg", grd_CAM)
  heatmap =cv2.imread(dirs1+"/Heatmap_"+image_code+".jpg")
  guided_CAM = saliency[0].transpose(1, 2, 3, 0) * heatmap[..., np.newaxis]
  guided_CAM = deprocess_image(guided_CAM)
  cv2.imwrite(dirs1+"/Guided-CAM_"+image_code+".jpg", guided_CAM)

# 2 Feature visulization

In [None]:
from keras.models import Model

In [None]:
Y_set_pred=model.predict(X_set)
print(Y_set_pred)

In [None]:
Y_set_true_label = []
for i in range(len(Y_set)):
  n = np.argmax(Y_set[i])
  Y_set_true_label.append(n)

In [None]:
Y_set_pred_label = []
for i in range(len(Y_set_pred)):
  n = np.argmax(Y_set_pred[i])
  Y_set_pred_label.append(n)

In [None]:
print(Y_set_true_label)
print(Y_set)
print(Y_set_pred_label)
print(Y_set_pred)

In [None]:
Y_true_set = []
for i in range(len(Y_set_pred)):
  if Y_set_true_label[i] == Y_set_pred_label[i]:
    Y_true_set.append(i)
print(Y_true_set)
print(len(Y_true_set))

In [None]:
def load_image_list(dataset_path):
  img_path_list = []
  category_names = os.listdir(dataset_path)
  category_names.sort()
  # print(category_names)
  category_nums = len(category_names)
  images = []
  for category in category_names:
    category_path = os.path.join(dataset_path, category)
    file_names = os.listdir(category_path)
    file_names.sort(key=lambda x:int(x[:-5]))
    # print(file_names)
    file_nums = len(file_names)  
    for file in (file_names):
      img_path = os.path.join(category_path, file)
      img_path_list.append(img_path)
  return img_path_list

In [None]:
image_path_list = load_image_list("/content/drive/MyDrive/IDB_diamond_damage/SEM/")

In [None]:
Y_true_list = []
for i in Y_true_set:
  Y_true_list.append(image_path_list[i])

In [None]:
print(Y_true_list)

In [None]:
dirs1 = "/content/drive/MyDrive/IDB_diamond_damage/visual"
if not os.path.exists(dirs1):
  os.makedirs(dirs1)

In [None]:
for n in range(len(Y_true_list)):
  print("> "+str(n))
  image_path = Y_true_list[n]
  image_code = str(n)
  doCAM(image_path, image_code, dirs1)