# Detection of AF Recurrence Using Deep Learning Approaches of Segmented Pulmonary Vein

In [2]:
import os
import numpy as np
import pandas as pd
import nibabel as nib
import re

from scipy.ndimage import zoom
from sklearn.decomposition import PCA
from sklearn.preprocessing import StandardScaler
from sklearn.cluster import KMeans
from sklearn.cluster import DBSCAN

import tensorflow as tf
from sklearn.model_selection import train_test_split
from tensorflow.keras.models import Sequential
from tensorflow.keras.layers import Conv2D, MaxPooling2D, Flatten, Dense

import seaborn as sns
import matplotlib.pyplot as plt
from nilearn import plotting
from tqdm import tqdm
from sklearn.metrics import roc_curve, auc
import matplotlib.image
%matplotlib inline

import warnings
warnings.filterwarnings("ignore")

In [5]:
# define directory
input_folder = './vanderbilt/' # use ./small_batch_test' for testing

# load demographics data
# dem = pd.read_excel("./CCF_CT_demographic.xlsx") # CCF data
dem = pd.read_csv("./vanderbilt_ct_phenotype_2-14-23.csv")

if not os.path.exists("./projections"): os.makedirs("./projections")
if not os.path.exists("./slices"): os.makedirs("./slices")
if not os.path.exists("./plots"): os.makedirs("./plots")

# define desired output voxel size
output_spacing = (1.0, 1.0, 1.0)  # 1 mm isotropic spacing

# define number of clusters to create
n_clusters = 2

# load all NIfTI files in input folder
resampled_data = []
projected_data = []
scan_IDs = []

dem.head()

Unnamed: 0,study_id,age_ablation,gender,race,ethnicity,pt_height,weight,htn,diabetes,chf,...,add_ablation___6,hx_cti_ablation,cryo,discharged_on_aad,recurrence,date_of_recur,time_to_recur,cont_monitor,mri_ct,time_scan_ablation
0,10407,55.06,1,4,1,172.7,94.3,0,0,0,...,0,0,0,4,1,9/28/11,203.0,0,2,0
1,10411,77.85,1,4,1,178.0,91.0,1,1,1,...,0,0,0,3,0,,,1,2,18
2,10415,60.95,1,4,1,193.0,111.0,1,1,0,...,0,0,0,1,1,8/20/12,298.0,0,2,111
3,10422,56.27,1,4,1,175.0,103.0,1,1,0,...,0,0,0,1,0,,,0,2,1
4,10461,67.6,0,4,1,163.0,89.0,1,1,1,...,0,0,0,4,1,11/7/12,258.0,1,2,1


In [None]:
print("Loading and resampling NIfTI files...")
for file_name in tqdm(os.listdir(input_folder), desc='Progress', unit='image'):
    if file_name.endswith('.nii.gz'):
        file_path = os.path.join(input_folder, file_name)
        img = nib.load(file_path)
        scan_IDs.append(re.search(r"Cardiac_(\d+)_", file_name).group(1))

        # resample the loaded image
        input_spacing = img.header.get_zooms()
        resampling_factors = tuple(np.array(input_spacing) / np.array(output_spacing))
        resampled_image = zoom(img.get_fdata(), resampling_factors, order=1)

        # pad the images to equal size
        target_shape = (500, 500, 500)
        x_pad = max(target_shape[0] - resampled_image.shape[0], 0)
        y_pad = max(target_shape[1] - resampled_image.shape[1], 0)
        z_pad = max(target_shape[2] - resampled_image.shape[2], 0)
        x_pad_before = x_pad // 2
        x_pad_after = x_pad - x_pad_before
        y_pad_before = y_pad // 2
        y_pad_after = y_pad - y_pad_before
        z_pad_before = z_pad // 2
        z_pad_after = z_pad - z_pad_before
        resampled_image_padded = np.pad(resampled_image, ((x_pad_before, x_pad_after), (y_pad_before, y_pad_after), (z_pad_before, z_pad_after)), mode='constant', constant_values=0)

        resampled_data.append(resampled_image_padded)

        # create a 2D projection of the 3D images
        proj = np.sum(resampled_image_padded, axis = 2)
        proj_path = "./projections/" + file_name[:-7] + ".png"
        matplotlib.image.imsave(proj_path, proj)

        projected_data.append(proj)

Loading and resampling NIfTI files...


Progress:  13%|███▍                       | 95/749 [14:06<1:14:14,  6.81s/image]

In [None]:
# extract the image_id and af_recur columns and store in a new dataframe
af_recur_status = dem[['study_id', 'recurrence']].astype({"study_id": "string"})

# filter the dataframe to only include rows where image_id is in the scan_id vector
af_recur_status = af_recur_status[af_recur_status['image_id'].isin(scan_IDs)]

# sort the dataframe based on the order of the scan_id vector
af_recur_status = af_recur_status.set_index('image_id').loc[scan_IDs].reset_index()

af_recur = np.array(af_recur_status['recurrence'].values)

af_recur_status

## PCA

In [None]:
# Flatten each image
flattened_images = np.array([np.array(img.flatten()) for img in projected_data])

# Perform PCA
pca = PCA(n_components = min(len(flattened_images), flattened_images[0].size)) # could use just 2 PCs
pca.fit(flattened_images)
reduced_images = pca.transform(flattened_images)

# Create a scatter plot using seaborn
pca_df = pd.DataFrame({'PC1':reduced_images[:, 0], 'PC2':reduced_images[:, 1], 'Recurrence':af_recur})
pc_plot = sns.scatterplot(data=pca_df, x="PC1", y="PC2", hue="Recurrence")
pc_plot.set_title("PCA of 2D Projections")
plt.show()

## K-Means Clustering

In [None]:
# Perform k-means clustering
print("Performing K-means clustering on projections...")
kmeans = KMeans(n_clusters=2)  # Choose the number of clusters you want to form
kmeans.fit(reduced_images)
labels = kmeans.labels_

# Print the labels for each image
for i, label in enumerate(labels):
    print(f"Image {i+1} belongs to cluster {label+1}")

## DBSCAN

In [None]:
# define the DBSCAN clustering algorithm
dbscan = DBSCAN(eps=0.5, min_samples=5)

# fit the algorithm to the flattened image data
clusters = dbscan.fit_predict(flattened_images)

# print the number of clusters and their indices
n_clusters = len(set(clusters)) - (1 if -1 in clusters else 0)
print("Number of clusters:", n_clusters)
print("Cluster indices:", clusters)

## CNN-DL

In [None]:
# Split data into train and test sets
X_train, X_test, y_train, y_test = train_test_split(np.array(projected_data), af_recur, test_size=0.2, random_state=42)

X_train = X_train.reshape(X_train.shape[0], 500, 500, 1)
X_test = X_test.reshape(X_test.shape[0], 500, 500, 1)

# Define the model architecture
model = Sequential()
model.add(Conv2D(32, (5, 5), activation='relu', input_shape=X_train.shape[1:]))
model.add(MaxPooling2D((2, 2)))
model.add(Conv2D(64, (5, 5), activation='relu'))
model.add(MaxPooling2D((2, 2)))
model.add(Flatten())
model.add(Dense(64, activation='relu'))
model.add(Dense(1, activation='sigmoid'))

# Compile the model
model.compile(optimizer='adam', loss='binary_crossentropy', metrics=['accuracy'])

# Train the model
history = model.fit(X_train, y_train, epochs=2, batch_size=8, validation_data=(X_test, y_test))

In [None]:
# Plot the training and validation accuracy over time
plt.plot(history.history['accuracy'])
plt.plot(history.history['val_accuracy'])
plt.title('Model Accuracy')
plt.ylabel('Accuracy')
plt.xlabel('Epoch')
plt.legend(['Train', 'Validation'], loc='upper left')
plt.show()

In [None]:
# Generate predictions on the test set
y_pred = model.predict(X_test)

# Calculate the false positive rate and true positive rate for different thresholds
fpr, tpr, thresholds = roc_curve(y_test, y_pred)

# Calculate the area under the ROC curve
roc_auc = auc(fpr, tpr)

# Plot the ROC curve
plt.plot(fpr, tpr, color='darkorange', lw=2, label='ROC curve (area = %0.2f)' % roc_auc)
plt.plot([0, 1], [0, 1], color='navy', lw=2, linestyle='--')
plt.xlabel('False Positive Rate')
plt.ylabel('True Positive Rate')
plt.title('Receiver Operating Characteristic (ROC)')
plt.legend(loc="lower right")
plt.show()