In [None]:
!nvidia-smi

In [None]:
import numpy as np
import os 
os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
os.environ["CUDA_VISIBLE_DEVICES"] = "3"


# Data Loading

In [None]:
sp = np.load('/home/souryas2/new_idea/sksbks_2/data/sp_val.npy')
sa = np.load('/home/souryas2/new_idea/sksbks_2/data/sa_val.npy')
data = np.concatenate([sp, sa])
sp_lab = np.ones(len(sp))
sa_lab = np.zeros(len(sa))
labels = np.concatenate([sp_lab,sa_lab])
arr = np.arange(len(data))
np.random.seed(101)
np.random.shuffle(arr)
data_valid = data[arr]
labels_valid = labels[arr]
data_valid = data_valid.astype('float32')
im_size = sp.shape[1]
data_valid = np.reshape(data_valid, [-1, im_size, im_size, 1])

# Parameter Loading (total1 = total Conv layer in black-box model)

In [None]:
total1 = 2
padding1 = 'same'
best_blackbox_ckpt = '/home/souryas2/new_idea/sksbks_2/github'
best_interpretable_ckpt = '/home/souryas2/new_idea/sksbks_2/github'

# Loading both Classification and Estimation Model

In [None]:
import tensorflow as tf
from keras import backend as K
from skimage.util.shape import view_as_windows
import glob
from os import walk
import sys
import numpy as np
import cv2
import matplotlib.pyplot as plt
from sklearn.model_selection import train_test_split
from keras.utils import to_categorical
from keras.models import Model
from keras.layers import Dense, Flatten, Input, Conv2D, MaxPooling2D, concatenate, Lambda, Deconvolution2D
from keras.optimizers import Adam
from keras.callbacks import ModelCheckpoint, EarlyStopping, CSVLogger

# Define Mish Activation Function
def mish(x):
    return Lambda(lambda x: x * tf.tanh(tf.log(1 + tf.exp(x))))(x)

# Custom ReLU function with maximum value constraint
def create_relu_advanced(max_value=1.):
    def relu_advanced(x):
        return K.relu(x, max_value=K.cast_to_floatx(max_value))
    return relu_advanced


activation1 = mish

input_size = (im_size, im_size, 1)
n = total1

# Building the initial CNN model
inputs = Input(input_size)
conv1 = Conv2D(64, 5, padding='same', activation='relu')(inputs)

for _ in range(n - 1):
    conv1 = Conv2D(128, 5, padding='same', activation='relu')(conv1)

pool2 = MaxPooling2D(pool_size=(2, 2))(conv1)
out = Flatten()(pool2)
out = Dense(64)(out)
predictions = Dense(1, activation='sigmoid')(out)
model1 = Model(input=inputs, output=predictions)

ckpt_path = best_blackbox_ckpt 
best_ckpt = os.path.join(ckpt_path, 'best_model.hdf5')
model1.load_weights(best_ckpt)

max3 = model1.layers[-4].output
conv4 = Deconvolution2D(64, (2, 2), strides=2)(max3)
conv5 = Conv2D(64, 5, activation=activation1, padding='same', use_bias=False)(conv4)
conv5 = concatenate([conv5, model1.layers[-5].output])

for t1 in range(1, total1):
    conv5 = Conv2D(64, 5, activation=activation1, padding=padding1, use_bias=False)(conv5)

conv7 = Conv2D(1, 5, activation=activation1, padding='same', use_bias=False)(conv5)
flat2 = Flatten()(conv7)
dense2 = Dense(1, activation=create_relu_advanced(max_value=1.), use_bias=False)(flat2)

# Final model definition
model_test3 = Model(input=model1.input, output=dense2)

# Freezing layers based on 'trainable1' argument
for layer in model_test3.layers[:6]:
    layer.trainable = False
for layer in model_test3.layers[6:]:
    layer.trainable = True

# Compiling the model
model_test3.compile(optimizer=Adam(lr=1e-4, clipvalue=0.5), loss='mean_squared_error', metrics=['mean_squared_error'])


# Load the model weights and compare accuracy and AUC

In [None]:
model1.compile(optimizer = Adam(lr = 1e-4),  loss='binary_crossentropy', metrics = ['accuracy'])
ckpt_path = best_blackbox_ckpt 
best_ckpt = os.path.join(ckpt_path, 'best_model.hdf5')
model1.load_weights(best_ckpt)
model1.evaluate(data_valid, labels_valid)


In [None]:
test_prediction_r = []
test_labels_r = []
ckpt_path =  best_interpretable_ckpt
best_ckpt = os.path.join(ckpt_path, 'best_selfinterpretable_model.hdf5')
model_test3.load_weights(best_ckpt)

for t1 in data_valid:
    
    prob = model_test3.predict(np.reshape(t1,[-1, im_size, im_size, 1]))[0]
    test_prediction_r.append(prob)
    
count = 0
for t1 in range(len(test_prediction_r)):
    if test_prediction_r[t1] > 0.5 and labels_valid[t1] == 1:
        count = count + 1
    if test_prediction_r[t1] < 0.5 and labels_valid[t1] == 0:
        count = count + 1
        
print(count/len(labels_valid))

In [None]:
from sklearn.metrics import roc_curve,auc
best_ckpt = os.path.join(ckpt_path, 'best_model.hdf5')
model1.load_weights(best_ckpt)
pred_bb = model1.predict(data_valid)

best_ckpt = os.path.join(ckpt_path, 'best_selfinterpretable_model.hdf5')
model_test3.load_weights(best_ckpt)
pred_si = model_test3.predict(data_valid)

# Compute the ROC curve
fpr_bb, tpr_bb, thresholds_bb = roc_curve(labels_valid, pred_bb)
fpr_si, tpr_si, thresholds_si = roc_curve(labels_valid, pred_si)


plt.figure()  # Adjust the figure size and DPI for clarity
plt.plot(fpr_bb, tpr_bb, label=f'Black-Box Classifier (AUC = {auc(fpr_bb, tpr_bb):.2f})', color='blue')
plt.plot(fpr_si, tpr_si, label=f'Self-Interpretable (AUC = {auc(fpr_si, tpr_si):.2f})', color='red')

# Add labels and title
plt.xlabel('False Positive Rate', fontsize=14)
plt.ylabel('True Positive Rate', fontsize=14)
plt.title('ROC Curve Comparison', fontsize=14)

# Add a legend with a larger font size
plt.legend(loc='lower right', fontsize=12)

plt.show()

# Visualizing E-maps for abnormal class

In [None]:
t = 331 
data_valid_1 = sp[t]
best_ckpt = os.path.join(ckpt_path, 'best_selfinterpretable_model.hdf5')
model_test3.load_weights(best_ckpt)

model_map_relu = Model(inputs=model_test3.inputs, outputs=model_test3.layers[-3].output)

feature_maprelu = model_map_relu.predict(np.reshape(data_valid_1.astype('float32'),[-1,im_size,im_size,1]))
weights = model_test3.layers[-1].get_weights()[0]
weights_reshape = np.reshape(weights,[im_size,im_size])
feat = np.reshape(feature_maprelu,[im_size, im_size])
feat2 = np.multiply(feat,weights_reshape)


feat3_1 =  cv2.GaussianBlur(feat2,(19,19), cv2.BORDER_DEFAULT)


plt.imshow((data_valid_1),cmap='gray')
plt.xticks([])
plt.yticks([])

plt.figure()
plt.imshow((data_valid_1),cmap='gray')
plt.imshow(feat3_1, cmap='jet', alpha = 0.6)
plt.xticks([])
plt.yticks([])

print(model_test3.predict(np.reshape(data_valid_1,[-1,im_size,im_size,1])))