In [1]:
import numpy as np
import nibabel as nib
import os
import glob
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import layers, models
from sklearn.model_selection import train_test_split
from tensorflow.keras.utils import to_categorical
import matplotlib.pyplot as plt
from patchify import patchify

In [2]:
from pathlib import Path

# Get the path to the current script
current_path = Path.cwd()
root_path = current_path.parent
data_path = root_path / 'data' / 'prostate_cancer_data'

Legions_path = data_path / 'Legions'
MRIs_path = data_path / 'MRIs'
Prostates_path = data_path / 'Prostates'

# put legion mask files in a list
legion_files = sorted(list(Legions_path.glob('*.nii.gz')))

# put image files (MRI) in a list
image_files = sorted(list(MRIs_path.glob('*.nii.gz')))

print(f"Number of image files: {len(legion_files)}")
print(f"Number of lesion mask files: {len(image_files)}")

Number of image files: 6
Number of lesion mask files: 6


In [3]:
patch_size = (64, 64, 64)
step_size = 64

image_patches = []
mask_patches = []

for img_path, mask_path in zip(image_files, legion_files):
    # Load MRI volume and lesion mask
    mri_img = nib.load(img_path).get_fdata()
    lesion_mask = nib.load(mask_path).get_fdata()

    # Normalize the MRI volume to [0, 1]
    mri_img = (mri_img - np.min(mri_img)) / (np.max(mri_img) - np.min(mri_img))

    # Ensure volumes are of the same shape
    assert mri_img.shape == lesion_mask.shape, f"Shape mismatch between {img_path} and {mask_path}"

    # Pad volumes to ensure uniform shape of (64, 64, 64)
    desired_shape = (64, 64, 64)
    padding = [(0, max(0, desired - current)) for current, desired in zip(mri_img.shape, desired_shape)]
    mri_img = np.pad(mri_img, padding, mode='constant', constant_values=0)
    lesion_mask = np.pad(lesion_mask, padding, mode='constant', constant_values=0)

    # Extract patches
    img_patches = patchify(mri_img, patch_size, step=step_size)
    mask_patches_ = patchify(lesion_mask, patch_size, step=step_size)

    # Reshape patches to a list
    num_patches = img_patches.shape[0] * img_patches.shape[1] * img_patches.shape[2]
    img_patches = img_patches.reshape(num_patches, *patch_size)
    mask_patches_ = mask_patches_.reshape(num_patches, *patch_size)

    # Filter out patches without lesions
    for img_patch, mask_patch in zip(img_patches, mask_patches_):
        if np.max(mask_patch) > 0:  # Keep only patches with lesions
            image_patches.append(img_patch)
            mask_patches.append(mask_patch)

print(f"Total number of patches: {len(image_patches)}")

Total number of patches: 10


In [4]:
# Convert to NumPy arrays
image_patches = np.array(image_patches)
mask_patches = np.array(mask_patches)

print(f"Image patches shape: {image_patches.shape}")
print(f"Mask patches shape: {mask_patches.shape}")


Image patches shape: (10, 64, 64, 64)
Mask patches shape: (10, 64, 64, 64)


In [5]:
# Expand dimensions for the channel
# For grayscale images, expand dimensions to have shape (batch_size, x, y, z, channels)
image_patches = np.expand_dims(image_patches, axis=-1)  # Shape: (n_patches, x, y, z, 1)
mask_patches = np.expand_dims(mask_patches, axis=-1)    # Shape: (n_patches, x, y, z, 1)

# For 3 channels, you can stack the grayscale image
image_patches = np.concatenate((image_patches, image_patches, image_patches), axis=-1)  # Now channels=3

print(f"Image patches shape after channel expansion: {image_patches.shape}")

Image patches shape after channel expansion: (10, 64, 64, 64, 3)


In [6]:
# Ensure masks are binary (0 and 1)
# Convert all non-zero values to 1
mask_patches[mask_patches > 0] = 1

In [7]:
# train, test split
X_train, X_test, y_train, y_test = train_test_split(image_patches, mask_patches, test_size=0.1)

print(f"Training data shape: {X_train.shape}")
print(f"Validation data shape: {X_test.shape}")


Training data shape: (9, 64, 64, 64, 3)
Validation data shape: (1, 64, 64, 64, 3)


In [8]:
# Just for the testing purpose, let's use a pretrained 3D U-Net model from the library
import segmentation_models_3D as sm

Segmentation Models: using `tf.keras` framework.


In [9]:
# Parameters
n_classes = 1  # For binary segmentation
channels = 3

# Define loss and metrics
LR = 0.0001
optim = keras.optimizers.Adam(LR)

# Dice loss function
dice_loss = sm.losses.DiceLoss(class_weights=np.array([0.5, 0.5]))

# Binary Focal Loss
focal_loss = sm.losses.BinaryFocalLoss()

# Total loss
total_loss = dice_loss + (1 * focal_loss)

# Metrics
metrics = [sm.metrics.IOUScore(threshold=0.5), sm.metrics.FScore(threshold=0.5)]

In [10]:
# Define the model
model = sm.Unet(
    'resnet34',
    encoder_weights=None,
    input_shape=(64, 64, 64, channels),
    classes=n_classes,
    activation='sigmoid'
)

# Compile the model
model.compile(optimizer=optim, loss=total_loss, metrics=metrics)

# Print the model summary
model.summary()

In [11]:
# path to save the model as a check-point
model_save_path = current_path / 'models'

# Define callbacks
callbacks = [
    keras.callbacks.ModelCheckpoint(model_save_path / 'unet3d_model.keras', save_best_only=True),
    keras.callbacks.EarlyStopping(patience=10, restore_best_weights=True)
]

# Train the model
history = model.fit(
    X_train, y_train,
    batch_size=2,
    epochs=100,
    verbose=1,
    validation_data=(X_test, y_test),
    callbacks=callbacks
)

Epoch 1/100
[1m5/5[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m32s[0m 4s/step - f1-score: 0.0058 - iou_score: 0.0029 - loss: 1.1316 - val_f1-score: 0.0018 - val_iou_score: 8.8653e-04 - val_loss: 1.0082
Epoch 2/100
[1m5/5[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m20s[0m 4s/step - f1-score: 0.0074 - iou_score: 0.0037 - loss: 1.0890 - val_f1-score: 0.0182 - val_iou_score: 0.0092 - val_loss: 1.0007
Epoch 3/100
[1m5/5[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m19s[0m 4s/step - f1-score: 0.0069 - iou_score: 0.0035 - loss: 1.0652 - val_f1-score: 0.0237 - val_iou_score: 0.0120 - val_loss: 0.9938
Epoch 4/100
[1m5/5[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m20s[0m 4s/step - f1-score: 0.0080 - iou_score: 0.0040 - loss: 1.0495 - val_f1-score: 0.0708 - val_iou_score: 0.0367 - val_loss: 0.9882
Epoch 5/100
[1m5/5[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m20s[0m 4s/step - f1-score: 0.0179 - iou_score: 0.0092 - loss: 1.0409 - val_f1-score: 0.0506 - val_iou_score: 0.0259 