In [None]:
import numpy as np
import os
import tensorflow as tf
import tensorflow_probability as tfp
from tensorflow.keras.layers import Input, Conv3D, BatchNormalization, Activation, Add, GlobalAveragePooling3D, Dense, Dropout, Masking
from tensorflow.keras.models import Model
import keras_tuner as kt
from sklearn.model_selection import train_test_split
from sklearn.metrics import f1_score, precision_score, recall_score, roc_auc_score, confusion_matrix
from scipy.ndimage import zoom
import matplotlib.pyplot as plt
from tf_keras_vis.saliency import Saliency
from tf_keras_vis.scorecam import ScoreCAM
from tf_keras_vis.gradcam_plus_plus import GradcamPlusPlus
from tf_keras_vis.utils.model_modifiers import ReplaceToLinear
from tf_keras_vis.utils.scores import CategoricalScore
import ants

In [None]:
def add_rician_noise(image, mask):
    noise_level = np.random.uniform(low=0.1, high=0.5, size=(1,))
    noise_real = tf.random.normal(shape=tf.shape(image), mean=0.0, stddev=noise_level)
    noise_imag = tf.random.normal(shape=tf.shape(image), mean=0.0, stddev=noise_level)
    noisy_image = tf.sqrt(tf.square(image + noise_real) + tf.square(noise_imag))
    noisy_image = tf.multiply(noisy_image, tf.cast(mask[...,np.newaxis], tf.float32))
    return noisy_image


def augment_image_3d(image, mask):
    image = add_rician_noise(image, mask)
    return image

# Function to augment dataset
def augment_dataset(images, labels, mask, positive_augmentation_factor=2):
    augmented_images = []
    augmented_labels = []
    
    for i in range(len(labels)):
        image = images[i]
        label = labels[i]
        
        if label == 1:
            for _ in range(positive_augmentation_factor):
                augmented_image = augment_image_3d(image, mask)
                augmented_images.append(augmented_image)
                augmented_labels.append(label)
            
        augmented_image = augment_image_3d(image, mask)
        augmented_images.append(augmented_image)
        augmented_labels.append(label)
    
    return np.array(augmented_images), np.array(augmented_labels)

In [None]:
def conv_block(input_tensor, filters, kernel_size, strides=(1, 1, 1), padding='same'):
    x = Conv3D(filters, kernel_size, strides=strides, padding=padding)(input_tensor)
    x = BatchNormalization()(x)
    x = Activation('relu')(x)
    return x

def identity_block(input_tensor, filters, kernel_size):
    x = Conv3D(filters, kernel_size, padding='same')(input_tensor)
    x = BatchNormalization()(x)
    x = Activation('relu')(x)
    
    x = Conv3D(filters, kernel_size, padding='same')(x)
    x = BatchNormalization()(x)
    
    x = Add()([x, input_tensor])
    x = Activation('relu')(x)
    return x

def conv_identity_block(input_tensor, filters, kernel_size, strides=(2, 2, 2)):
    x = Conv3D(filters, kernel_size, strides=strides, padding='same')(input_tensor)
    x = BatchNormalization()(x)
    x = Activation('relu')(x)
    
    x = Conv3D(filters, kernel_size, padding='same')(x)
    x = BatchNormalization()(x)
    
    shortcut = Conv3D(filters, kernel_size=(1, 1, 1), strides=strides, padding='same')(input_tensor)
    shortcut = BatchNormalization()(shortcut)
    
    x = Add()([x, shortcut])
    x = Activation('relu')(x)
    return x

def resnet_3d(input_shape, dropout_rate=0.5):
    input_layer = Input(shape=input_shape)
    mask = Masking(mask_value=0.0)(input_layer)
    x = conv_block(mask, 64, (7, 7, 7), strides=(2, 2, 2))
    
    x = conv_identity_block(x, 64, (3, 3, 3))
    for _ in range(3):
        x = identity_block(x, 64, (3, 3, 3))
    
    x = conv_identity_block(x, 128, (3, 3, 3), strides=(2, 2, 2))
    for _ in range(4):
        x = identity_block(x, 128, (3, 3, 3))
    
    x = conv_identity_block(x, 256, (3, 3, 3), strides=(2, 2, 2))
    for _ in range(6):
        x = identity_block(x, 256, (3, 3, 3))
    
    x = conv_identity_block(x, 512, (3, 3, 3), strides=(2, 2, 2))
    for _ in range(3):
        x = identity_block(x, 512, (3, 3, 3))
    
    x = GlobalAveragePooling3D()(x)
    x = Dropout(rate=dropout_rate)(x)
    output_layer = Dense(1, activation='sigmoid')(x)
    
    model = Model(inputs=input_layer, outputs=output_layer)
    return model

In [None]:
# Define the input shape and number of classes
input_shape = (64, 82, 32, 1)  # Example input shape for 3D data
batch_size = 12

In [None]:
template = ants.image_read('CT_template_resamp.nii')
template_mask = np.where(template.numpy() > 0, 1.0, 0.0)
bad_subj = [5, 28, 29, 30, 43, 65, 66, 70, 89, 94, 99, 101, 102, 112, 116, 118, 119, 134, 135, 138, 151]
x_arr = np.load('x_arr.npy')
y_arr = np.load('y_arr.npy')
x_arr = np.array([a for i,a in enumerate(x_arr) if i not in bad_subj])
y_arr = np.array([a for i,a in enumerate(y_arr) if i not in bad_subj])
x_resamp = []
for i in range(x_arr.shape[0]):
    subj_windowed = np.clip(x_arr[i], 10, 100)[32:96, 22:104,  28:60]
    subj_norm = (subj_windowed - subj_windowed.min())/(subj_windowed.max() - subj_windowed.min())
    subj_masked = subj_norm * template_mask[32:96, 22:104,  28:60]
    x_resamp.append(subj_masked)
x_arr = np.stack(x_resamp, axis=0)
bad_subj_2 = [41,42,44,59, 75, 86, 92, 97, 98, 99, 105, 108, 109, 110, 111, 116, 121,125,128, 131, 135, 139]
x_arr = np.array([a for i,a in enumerate(x_arr) if i not in bad_subj_2])
y_arr = np.array([a for i,a in enumerate(y_arr) if i not in bad_subj_2])
x_arr = x_arr[...,np.newaxis]
x_train, x_val, y_train, y_val = train_test_split(x_arr, y_arr, random_state=40, test_size=0.2, stratify=y_arr)

In [None]:
weight_for_0 = (1 / (len(y_train)-y_train.sum())) * (y_train.shape[0] / 2.0)
weight_for_1 = (1 / y_train.sum()) * (y_train.shape[0] / 2.0)
class_weight = {0: weight_for_0, 1: weight_for_1}
class_weight

In [None]:
from sklearn.model_selection import StratifiedKFold
folds = list(StratifiedKFold(n_splits=5, shuffle=True, random_state=42).split(x_arr, y_arr))
auc_per_fold = []
loss_per_fold = []
for j, (train_idx, val_idx) in enumerate(folds):
    #Split data
    x_train_cv = x_arr[train_idx]
    y_train_cv = y_arr[train_idx]
    x_val_cv = x_arr[val_idx]
    y_val_cv = y_arr[val_idx]
    
    
    # Create the model
    model = resnet_3d(input_shape, dropout_rate=0.3)

    model.compile(optimizer=tf.keras.optimizers.Nadam(learning_rate=0.0001), loss=tf.keras.losses.BinaryFocalCrossentropy(), metrics=tf.keras.metrics.AUC())
    filepath = "best_model_weights_" + str(j) + ".h5"
    if j == 0:
        monitor_val = 'val_auc'
    else:
        monitor_val = 'val_auc_' + str(j)
    weight_for_0 = (1 / (len(y_train_cv)-y_train_cv.sum())) * (y_train_cv.shape[0] / 2.0)
    weight_for_1 = (1 / y_train_cv.sum()) * (y_train_cv.shape[0] / 2.0)
    cv_class_weight = {0: weight_for_0, 1: weight_for_1}
    history = model.fit(x_train_cv, y_train_cv, validation_data=(x_val_cv, y_val_cv), 
                    batch_size=batch_size, epochs=120, class_weight=cv_class_weight, 
                    callbacks=[tf.keras.callbacks.ModelCheckpoint(filepath = filepath, monitor=monitor_val,
                                                                 mode = "max", save_weights_only=True, save_best_only=True)])

In [None]:
from sklearn.model_selection import StratifiedKFold
from sklearn.metrics import roc_curve, auc
import seaborn as sns

# Set style for the plot
sns.set(style="whitegrid")
colors = sns.color_palette("husl", 5) 
folds = list(StratifiedKFold(n_splits=5, shuffle=True, random_state=42).split(x_arr, y_arr))
auc_per_fold = []
loss_per_fold = []
for j, (train_idx, val_idx) in enumerate(folds):
    #Split data
    x_train_cv = x_arr[train_idx]
    y_train_cv = y_arr[train_idx]
    x_val_cv = x_arr[val_idx]
    y_val_cv = y_arr[val_idx]
    
    # Create the model
    model = resnet_3d(input_shape, dropout_rate=0.3)

    model.compile(optimizer=tf.keras.optimizers.Nadam(learning_rate=0.0001), loss=tf.keras.losses.BinaryFocalCrossentropy(), metrics=tf.keras.metrics.AUC())
    batch_size=12
    filepath = "best_model_weights_" + str(j) + ".h5"
    weight_for_0 = (1 / (len(y_train_cv)-y_train_cv.sum())) * (y_train_cv.shape[0] / 2.0)
    weight_for_1 = (1 / y_train_cv.sum()) * (y_train_cv.shape[0] / 2.0)
    cv_class_weight = {0: weight_for_0, 1: weight_for_1}
    model.load_weights(filepath)
    results = model.predict(x_val_cv)
        # Calculate ROC curve and AUC
    fpr, tpr, _ = roc_curve(y_val_cv, results)
    roc_auc = auc(fpr, tpr)
    
    # Plot ROC curve for the current fold
    plt.plot(fpr, tpr, color=colors[j], lw=3, label=f'Fold {j + 1} (AUC = {roc_auc:.2f})')

# Plot diagonal line
plt.plot([0, 1], [0, 1], linestyle='--', color='gray', lw=2)

# Customize plot details
plt.xlim([0.0, 1.0])
plt.ylim([0.0, 1.05])
plt.xlabel('False Positive Rate', fontsize=14)
plt.ylabel('True Positive Rate', fontsize=14)
plt.title('ROC Curve for Each Fold', fontsize=16, weight='bold')
plt.legend(loc='lower right', fontsize=12)
plt.grid(True, linestyle='--', linewidth=0.7, alpha=0.7)
plt.tick_params(axis='both', which='major', labelsize=12)

plt.savefig('roc_curve_plot.png', dpi=300, bbox_inches='tight')

# Display the plot
plt.show()

In [None]:
results = model.evaluate(x_val, y_val, batch_size=12)
predictions = model.predict(x_val)
ht_example_correct = x_val_cv[17] #Hand selected examples
ht_example_incorrect = x_val_cv[0] #Hand selected examples

In [None]:
saliency = Saliency(model, model_modifier=ReplaceToLinear(), clone=True)
saliency_map_correct = saliency(CategoricalScore([0]), ht_example_correct)
saliency_map_incorrect = saliency(CategoricalScore([0]), ht_example_incorrect)
template_resamp = ants.image_read('CT_template_resamp.nii')
x_correct_ants = ants.copy_image_info(template_resamp, ants.from_numpy(ht_example_correct.squeeze()))
x_incorrect_ants = ants.copy_image_info(template_resamp, ants.from_numpy(ht_example_incorrect.squeeze()))
sal_correct_ants = ants.copy_image_info(template_resamp, ants.from_numpy(saliency_map_correct.squeeze()))
sal_incorrect_ants = ants.copy_image_info(template_resamp, ants.from_numpy(saliency_map_incorrect.squeeze()))
sal_correct_smooth = sal_correct_ants.smooth_image(3)
sal_correct_norm = (sal_correct_smooth - sal_correct_smooth.mean())/sal_correct_smooth.std()
sal_incorrect_smooth = sal_incorrect_ants.smooth_image(3)
sal_incorrect_norm = (sal_incorrect_smooth - sal_incorrect_smooth.mean())/sal_incorrect_smooth.std()
sal_diff = sal_correct_norm - sal_incorrect_norm
ants.image_write(x_correct_ants, 'ht_example_correct.nii')
ants.image_write(sal_correct_norm, 'sal_example_correct.nii')
ants.image_write(x_incorrect_ants, 'ht_example_incorrect.nii')
ants.image_write(sal_incorrect_norm, 'sal_example_incorrect.nii')
ants.image_write(sal_diff, 'sal_example_diff.nii')
template_mask = ants.image_read('template_reduced_mask.nii')
sal_diff_masked = sal_diff*template_mask
ants.image_write(sal_diff_masked, 'sal_example_diff_masked.nii')