In [None]:
!pip install torchio

In [None]:
import os
import pandas as pd
import zipfile
import numpy as np
import tensorflow as tf
import h5py
import random
from scipy import ndimage
import matplotlib.pyplot as plt
from sklearn.metrics import confusion_matrix
#import torchio as tio

from tensorflow import keras
from tensorflow.keras import layers

In [None]:
tf.__version__

### Import data from HD5File

Checking the data

In [None]:
PATH_3D_H5 = 'Images/tof_data.hdf5'

In [None]:
with h5py.File(PATH_3D_H5, 'r') as dd:
    print(list(dd.keys()))
    print(dd['X'].shape)
    print(dd['stroke'].shape)

In [None]:
def decode_data(string):
    decoded_string = [n.decode("UTF-8", "ignore") for n in string]
    return(decoded_string)

with h5py.File(PATH_3D_H5, 'r') as h5:
    print('H5-file: ', list(h5.keys()))
    
    # Image matrices
    X = h5["X"][:]
    # Patient ID's
    pat = h5["pat"][:]
    # Path to images
    path = decode_data(h5["path"][:])
    # Patient labels (1=stroke, 0=TIA)
    Y_pat = h5["stroke"][:]
    
print(len(X), len(Y_pat), len(pat), len(path))

### Train, validation, test split (vgl. NB from Lisa)

There are 508 patients with TOF-MRA images. For training, validation and testing, we need 3 sets.
- Training set: 304 images (~60%) -> 211 stroke, 93 non-stroke
- Validation set: 102 images (~20%) -> 70 stroke, 32 non-stroke
- Test set: 102 images (~20%) -> 70 stroke, 32 non-stroke

In every set there is the same percentage of stroke patients vs. non stroke patients (approx. 69% vs. 31%)

In [None]:
# consider stroke and no-stroke patients separately:
idx = np.where(Y_pat == 1)[0]
stroke_patients = np.unique(pat[idx])
idx = np.where(Y_pat == 0)[0]
non_stroke_patients = np.unique(pat[idx])
print(len(stroke_patients), len(non_stroke_patients))

In [None]:
# randomly shuffle the stroke and non-stroke patients
np.random.seed(1)
stroke_patients_test = np.random.choice(stroke_patients, size=len(stroke_patients), replace=False)
non_stroke_patients_test = np.random.choice(non_stroke_patients, size=len(non_stroke_patients), replace=False)

In [None]:
#test set
np.random.seed(1)
test_tmp = np.concatenate([stroke_patients_test[:70], non_stroke_patients_test[:32]], axis=0)
test = np.random.choice(test_tmp, size=len(test_tmp), replace=False)

In [None]:
stroke_patients_run = [i for i in stroke_patients if i not in test]
non_stroke_patients_run = [i for i in non_stroke_patients if i not in test]

# randomply shuffle the data
np.random.seed(100)
stroke_patients_tmp = np.random.choice(stroke_patients_run, size=len(stroke_patients_run), replace=False)
non_stroke_patients_tmp = np.random.choice(non_stroke_patients_run, size=len(non_stroke_patients_run), replace=False)
print(len(stroke_patients_tmp), len(non_stroke_patients_tmp))

In [None]:
train_tmp = np.concatenate([stroke_patients_tmp[0:211],non_stroke_patients_tmp[:93]], axis=0)
valid_tmp = np.concatenate([stroke_patients_tmp[211:len(stroke_patients_tmp)], non_stroke_patients_tmp[93:len(non_stroke_patients)]], axis=0)

In [None]:
# randomly shuffle the datasets such that stroke and no-stroke patients are mixed
np.random.seed(100)
train = np.random.choice(train_tmp, size=len(train_tmp), replace=False)
valid = np.random.choice(valid_tmp, size=len(valid_tmp), replace=False)
test = np.random.choice(test, size=len(test), replace=False)
print(len(train), len(valid), len(test))

In [None]:
def get_datasets(set_i, X, Y_pat, pat, path):
    Y_pat_set = []
    pat_set = []
    path_set = []
    # Find the indices corresponding to the patient_i in set_i
    idx = [i for i, pat_i in enumerate(pat) if pat_i in set_i]
    X_set = X[idx,:,:,:]
    for i in idx:
        Y_pat_set.append(Y_pat[i])
        pat_set.append(pat[i])
        path_set.append(path[i])     
    return(X_set, np.array(Y_pat_set), np.array(pat_set), np.array(path_set))

In [None]:
X_train, Y_train, pat_train, path_train = get_datasets(train, X, Y_pat, pat, path)

In [None]:
X_valid, Y_valid, pat_valid, path_valid = get_datasets(valid, X, Y_pat, pat, path)

In [None]:
X_test, Y_test, pat_test, path_test = get_datasets(test, X, Y_pat, pat, path)

### Preprocess data

Add dimension to arrays for 3D tensor

In [None]:
#X_train = X_train[:,:,:,:,np.newaxis] 
#X_valid = X_valid[:,:,:,:,np.newaxis] 

In [None]:
def resize_volume(img):
    """Resize across z-axis"""
    # Set the desired depth
    desired_depth = 50
    desired_width = 160
    desired_height = 140
    # Get current depth
    current_depth = img.shape[-1]
    current_width = img.shape[0]
    current_height = img.shape[1]
    # Compute depth factor
    depth = current_depth / desired_depth
    width = current_width / desired_width
    height = current_height / desired_height
    depth_factor = 1 / depth
    width_factor = 1 / width
    height_factor = 1 / height
    # Resize across z-axis
    img = ndimage.zoom(img, (width_factor, height_factor, depth_factor), order=1)
    return img

In [None]:
X_train = np.array([resize_volume(img) for img in X_train])

In [None]:
X_valid = np.array([resize_volume(img) for img in X_train])

### Data Augmentation

- Rotate volumes by random angles

In [None]:
X = X_train[:10,:,:,:,np.newaxis] 
#X = X_train[0,:,:]
X.shape

In [None]:
plt.imshow(np.squeeze(X[:,:,10]), cmap = 'gray')

In [None]:
transform = RandomElasticDeformation(num_control_points=4,  locked_borders=1 , max_displacement =1)

In [None]:
transformed = transform(X)

In [None]:
plt.imshow(np.squeeze(transformed[:,:,20]), cmap = 'gray')

In [None]:
transforms_dict = {
    tio.RandomAffine(): 0.75,
    tio.RandomElasticDeformation(): 0.25}

In [None]:
transform = tio.RandomElasticDeformation()

In [None]:
transform = tio.OneOf(transforms_dict)
transformed = transform(X[0])

In [None]:
n_images = X.shape[0]
fig = plt.figure(figsize = (10, 5)) # total figure size (including all subplots)
columns = 5
rows = 2
fig_all = []
for i in range(1, X.shape[0]):
    img = transform(X[i])
    fig_all.append(fig.add_subplot(rows, columns, i))
    plt.imshow(np.squeeze(img[:,:,0]), cmap = 'gray')
plt.show()

In [None]:
plt.imshow(np.squeeze(transformed[:,:,10]), cmap = 'gray')

In [None]:
@tf.function
def rotate(volume):
    """Rotate the volume by a few degrees"""

    def scipy_rotate(volume):
        # define some rotation angles
        angles = [-20, -10, -5, 5, 10, 20]
        # pick angles at random
        angle = random.choice(angles)
        # rotate volume
        volume = ndimage.rotate(volume, angle, reshape=False)
        volume[volume < 0] = 0
        volume[volume > 1] = 1
        return volume
    
    volume_shape = volume.shape
    augmented_volume = tf.numpy_function(scipy_rotate, [volume], np.float64)
    augmented_volume = tf.reshape(augmented_volume, volume_shape)
    return augmented_volume

In [None]:
def train_preprocessing(volume, label):
    """Process training data by rotating and adding a channel."""
    # Rotate volume
    volume = rotate(volume)
    volume = tf.expand_dims(volume, axis=3)
    return volume, label


def validation_preprocessing(volume, label):
    """Process validation data by only adding a channel."""
    volume = tf.expand_dims(volume, axis=3)
    return volume, label

In [None]:
# Define data loaders.
train_loader = tf.data.Dataset.from_tensor_slices((X_train, Y_train))
validation_loader = tf.data.Dataset.from_tensor_slices((X_valid, Y_valid))

In [None]:
train_loader

In [None]:
batch_size = 4
# Augment the on the fly during training.
train_dataset = (
    train_loader.map(train_preprocessing)
    .batch(batch_size)
    .prefetch(2)
)
# Only rescale.
validation_dataset = (
    validation_loader.map(validation_preprocessing)
    .batch(batch_size)
    .prefetch(2)
)

In [None]:
train_dataset

In [None]:
data = train_dataset.take(10)
images, labels = list(data)[0]
images = images.numpy()
image = images[0]
print("Dimension of the CT scan is:", image.shape)
plt.imshow(np.squeeze(image[:, :, 10]), cmap="gray")

In [None]:
data = train_dataset.take(10)
images, labels = list(data)[0]
images = images.numpy()
image = images[0]
fig, ax = plt.subplots()
pos = ax.imshow(np.squeeze(image[:, :, 10]), cmap="gray")
cbar = fig.colorbar(pos, ax=ax)
cbar.minorticks_on()
fig = plt.gcf()
plt.show()
fig.savefig('Rotatet volume', dpi = 100,bbox_inches='tight')

In [None]:
def plot_slices(num_rows, num_columns, width, height, data):
    """Plot a montage of 20 CT slices"""
    data = np.rot90(np.array(data))
    data = np.transpose(data)
    data = np.reshape(data, (num_rows, num_columns, width, height))
    rows_data, columns_data = data.shape[0], data.shape[1]
    heights = [slc[0].shape[0] for slc in data]
    widths = [slc.shape[1] for slc in data[0]]
    fig_width = 12.0
    fig_height = fig_width * sum(heights) / sum(widths)
    f, axarr = plt.subplots(
        rows_data,
        columns_data,
        figsize=(fig_width, fig_height),
        gridspec_kw={"height_ratios": heights},
    )
    for i in range(rows_data):
        for j in range(columns_data):
            axarr[i, j].imshow(data[i][j], cmap="gray")
            axarr[i, j].axis("off")
    plt.subplots_adjust(wspace=0, hspace=0, left=0, right=1, bottom=0, top=1)
    fig = plt.gcf()
    plt.show()
    fig.savefig('Rotatet slices', dpi = 100,bbox_inches='tight')

In [None]:
# Visualize montage of slices.
# 4 rows and 10 columns for 100 slices of the CT scan.
plot_slices(4, 10, 128, 112, image[:, :, :40])

In [None]:
image.shape

### Train CNN

In [None]:
w = 128
h = 112
d = 40

In [None]:
def get_model(width=w, height=h, depth=d):
    """Build a 3D convolutional neural network model."""

    inputs = keras.Input((width, height, depth, 1))

    x = layers.Conv3D(filters=64, kernel_size=3, activation="relu")(inputs)
    x = layers.MaxPool3D(pool_size=2, padding = 'same')(x)
    x = layers.BatchNormalization()(x)

    x = layers.Conv3D(filters=64, kernel_size=3, activation="relu")(x)
    x = layers.MaxPool3D(pool_size=2, padding = 'same')(x)
    x = layers.BatchNormalization()(x)

    x = layers.Conv3D(filters=128, kernel_size=3, activation="relu")(x)
    x = layers.MaxPool3D(pool_size=2, padding = 'same')(x)
    x = layers.BatchNormalization()(x)

    x = layers.Conv3D(filters=256, kernel_size=3, activation="relu")(x)
    x = layers.MaxPool3D(pool_size=2, padding = 'same')(x)
    x = layers.BatchNormalization()(x)

    x = layers.GlobalAveragePooling3D()(x)
    x = layers.Dense(units=512, activation="relu")(x)
    x = layers.Dropout(0.3)(x) #oder 0.6

    outputs = layers.Dense(units=1, activation="sigmoid")(x)

    # Define the model.
    model = keras.Model(inputs, outputs, name="3dcnn")
    return model

In [None]:
# Build model.
model = get_model(width=w, height=h, depth=d)
model.summary()

In [None]:
initial_learning_rate = 0.0001
lr_schedule = keras.optimizers.schedules.ExponentialDecay(
    initial_learning_rate, decay_steps=1000, decay_rate=0.96, staircase=True)

In [None]:
model.compile(
    loss="binary_crossentropy",
    optimizer = tf.keras.optimizers.SGD(learning_rate = lr_schedule),
    #optimizer=keras.optimizers.Adam(learning_rate=lr_schedule),
    metrics=["acc"]
)

#model.compile(loss="binary_crossentropy", optimizer = keras.optimizers.RMSprop(lr=1e-4), metrics = ["acc"])

In [None]:
# Define callbacks.
callback_list = [
    keras.callbacks.EarlyStopping(monitor="val_acc", patience=15),
    keras.callbacks.ModelCheckpoint("3d_image_classification.h5", save_best_only=True),
    #keras.callbacks.ReduceLROnPlateau(monitor = 'val_loss', factor = 0.1, patience = 10)   
]

In [None]:
epochs = 20
model.fit(
    train_dataset,
    validation_data=validation_dataset,
    epochs=epochs,
    verbose=1, callbacks=callback_list)

In [None]:
fig, ax = plt.subplots(1, 2, figsize=(20, 4))
ax = ax.ravel()

for i, metric in enumerate(["acc", "loss"]):
    ax[i].plot(model.history.history[metric])
    ax[i].plot(model.history.history["val_" + metric])
    ax[i].set_title("Model {}".format(metric))
    ax[i].set_xlabel("epochs")
    ax[i].set_ylabel(metric)
    ax[i].legend(["train", "val"])
    
ax[0].set_ylim(0,1)
ax[1].set_ylim(0)

p = 'Modell_SGD_BS4'
fig.savefig(p)

In [None]:
X_val = np.expand_dims(X_valid, axis = 4)
X_val.shape

In [None]:
model.load_weights("3d_image_classification.h5")
y_pred = model.predict(X_val)
y_pred = (y_pred > 0.5).astype(np.int)

In [None]:
tf.math.confusion_matrix(labels = Y_valid, predictions = y_pred)

In [None]:
y_pred = y_pred.flatten()

In [None]:
confusion_matrix(Y_valid, y_pred)

### Visualizing learning rate

In [None]:
def decayed_learning_rate(step, initial_learning_rate, decay_rate, decay_steps):
  return initial_learning_rate * decay_rate ** (step / decay_steps)

In [None]:
epochs = 20
steps = 152
list = []
for i in range(0,epochs):
    rate = decayed_learning_rate(steps*i, 0.0001, 0.96, 1000)
    list.append(rate)

In [None]:
plt.plot(list)