In [2]:
import os
import numpy as np
import nibabel as nib
import tensorflow as tf
from tensorflow.keras.utils import to_categorical
from tensorflow import keras
from tensorflow.keras import layers, models # Import the layers and models modules from keras
from scipy.ndimage import zoom
import matplotlib.pyplot as plt

In [3]:
# Step 1: Organize Data
# Directory containing the files
base_dir = "/content"

# List all files in the directory
all_files = sorted(os.listdir(base_dir))

# Create dictionaries for images and labels
images = {}
labels = {}

for file_name in all_files:
    if file_name.endswith(".nii.gz"):  # Only consider .nii.gz files
        if "_seg" in file_name:  # Label files
            key = file_name.replace("IBSR_", "").replace("_seg.nii.gz", "")
            labels[key] = os.path.join(base_dir, file_name)
        else:  # Image files
            key = file_name.replace("IBSR_", "").replace(".nii.gz", "")
            images[key] = os.path.join(base_dir, file_name)

# Debugging: Display the dictionaries
print("Images Dictionary:", images)
print("Labels Dictionary:", labels)

# Training and Validation Data
training_keys = ['01', '03', '04', '05', '06', '07', '08', '09', '18']
validation_keys = ['11', '12', '13', '14', '17']

training_images = {k: images[k] for k in training_keys}
training_labels = {k: labels[k] for k in training_keys}
validation_images = {k: images[k] for k in validation_keys}
validation_labels = {k: labels[k] for k in validation_keys}


Images Dictionary: {'01': '/content/IBSR_01.nii.gz', '03': '/content/IBSR_03.nii.gz', '04': '/content/IBSR_04.nii.gz', '05': '/content/IBSR_05.nii.gz', '06': '/content/IBSR_06.nii.gz', '07': '/content/IBSR_07.nii.gz', '08': '/content/IBSR_08.nii.gz', '09': '/content/IBSR_09.nii.gz', '11': '/content/IBSR_11.nii.gz', '12': '/content/IBSR_12.nii.gz', '13': '/content/IBSR_13.nii.gz', '14': '/content/IBSR_14.nii.gz', '16': '/content/IBSR_16.nii.gz', '17': '/content/IBSR_17.nii.gz', '18': '/content/IBSR_18.nii.gz'}
Labels Dictionary: {'01': '/content/IBSR_01_seg.nii.gz', '03': '/content/IBSR_03_seg.nii.gz', '04': '/content/IBSR_04_seg.nii.gz', '05': '/content/IBSR_05_seg.nii.gz', '06': '/content/IBSR_06_seg.nii.gz', '07': '/content/IBSR_07_seg.nii.gz', '08': '/content/IBSR_08_seg.nii.gz', '09': '/content/IBSR_09_seg.nii.gz', '11': '/content/IBSR_11_seg.nii.gz', '12': '/content/IBSR_12_seg.nii.gz', '13': '/content/IBSR_13_seg.nii.gz', '14': '/content/IBSR_14_seg.nii.gz', '16': '/content/IBSR_

In [4]:
# Step 2: Preprocess Data
def load_nifti(file_path):
    """Load a NIfTI file."""
    img = nib.load(file_path)
    return img.get_fdata()

def normalize_image(image):
    """
    Normalize the image intensity values to the range [0, 1].
    Handles images with uniform intensity values gracefully.
    """
    min_val = np.min(image)
    max_val = np.max(image)
    if max_val - min_val == 0:  # Avoid division by zero
        return np.zeros_like(image)
    return (image - min_val) / (max_val - min_val)


def resize_image(image, target_shape):
    """
    Resize an image to the target shape using interpolation.
    Handles images with an additional channel dimension.
    """
    if len(image.shape) == len(target_shape) + 1:  # If image has extra channel dimension
        # Resize only the spatial dimensions
        spatial_shape = image.shape[:-1]
        zoom_factors = [t / s for t, s in zip(target_shape, spatial_shape)]
        resized_image = zoom(image[..., 0], zoom_factors, order=1)  # Remove channel, resize, and re-add it
        return resized_image[..., np.newaxis]  # Add channel dimension back
    elif len(image.shape) == len(target_shape):  # If image matches target rank
        zoom_factors = [t / s for t, s in zip(target_shape, image.shape)]
        resized_image = zoom(image, zoom_factors, order=1)
        return resized_image
    else:
        raise ValueError(f"Target shape {target_shape} and input shape {image.shape} must match in rank.")

def preprocess_data(images_dict, labels_dict, target_shape, n_classes):
    """
    Preprocess the images and labels: normalize, resize, and encode.
    """
    images_list = []
    labels_list = []

    for key in sorted(images_dict.keys()):
        # Load and preprocess images
        image = load_nifti(images_dict[key])
        print(f"Original image shape {key}: {image.shape}")
        image = normalize_image(image)
        image_resized = resize_image(image, target_shape)
        print(f"Resized image shape {key}: {image_resized.shape}")
        images_list.append(image_resized)

        # Load and preprocess labels (if available)
        if key in labels_dict:
            label = load_nifti(labels_dict[key])
            print(f"Original label shape {key}: {label.shape}")
            label_resized = resize_image(label, target_shape)
            print(f"Resized label shape {key}: {label_resized.shape}")
            labels_list.append(label_resized)

    # Convert lists to NumPy arrays
    images_array = np.array(images_list)
    images_array = images_array[..., np.newaxis]  # Ensure channel dimension (5D)
    print(f"Images array shape: {images_array.shape}")

    if labels_list:  # Check if labels_list is not empty
        labels_array = np.array(labels_list, dtype=np.int32)
        labels_array = to_categorical(labels_array, num_classes=n_classes)  # One-hot encode labels
        print(f"Labels array shape: {labels_array.shape}")
    else:
        labels_array = None
        print("Warning: No labels found during preprocessing.")

    return images_array, labels_array

# Target shape for images
target_shape = (256, 128, 256)
n_classes = 4

# Preprocess training and validation data
training_images_array, training_labels_array = preprocess_data(
    training_images, training_labels, target_shape, n_classes
)
validation_images_array, validation_labels_array = preprocess_data(
    validation_images, validation_labels, target_shape, n_classes
)

print(f"Training Images Shape: {training_images_array.shape}")
print(f"Training Labels Shape: {training_labels_array.shape}")
print(f"Validation Images Shape: {validation_images_array.shape}")
print(f"Validation Labels Shape: {validation_labels_array.shape}")




Original image shape 01: (256, 128, 256, 1)
Resized image shape 01: (256, 128, 256, 1)
Original label shape 01: (256, 128, 256, 1)
Resized label shape 01: (256, 128, 256, 1)
Original image shape 03: (256, 128, 256, 1)
Resized image shape 03: (256, 128, 256, 1)
Original label shape 03: (256, 128, 256, 1)
Resized label shape 03: (256, 128, 256, 1)
Original image shape 04: (256, 128, 256, 1)
Resized image shape 04: (256, 128, 256, 1)
Original label shape 04: (256, 128, 256, 1)
Resized label shape 04: (256, 128, 256, 1)
Original image shape 05: (256, 128, 256, 1)
Resized image shape 05: (256, 128, 256, 1)
Original label shape 05: (256, 128, 256, 1)
Resized label shape 05: (256, 128, 256, 1)
Original image shape 06: (256, 128, 256, 1)
Resized image shape 06: (256, 128, 256, 1)
Original label shape 06: (256, 128, 256, 1)
Resized label shape 06: (256, 128, 256, 1)
Original image shape 07: (256, 128, 256, 1)
Resized image shape 07: (256, 128, 256, 1)
Original label shape 07: (256, 128, 256, 1)

In [5]:
def extract_patches(images, labels, patch_size, stride, threshold, num_classes):
    """
    Extract useful patches from 4D or 5D images and labels.
    """
    # Reshape images and labels if they have extra dimensions
    if images.ndim == 6:
        batch_size, depth, height, width, channels, _ = images.shape
        images = images.reshape(batch_size * depth, height, width, channels)
    elif images.ndim == 5:
        batch_size, height, width, depth, channels = images.shape
        images = images.reshape(batch_size * depth, height, width, channels)

    if labels.ndim == 6:
        batch_size, depth, height, width, channels, _ = labels.shape
        labels = labels.reshape(batch_size * depth, height, width, channels)
    elif labels.ndim == 5:
        batch_size, height, width, depth, channels = labels.shape
        labels = labels.reshape(batch_size * depth, height, width, channels)

    # Transpose labels if needed
    if images.shape[1:3] != labels.shape[1:3]:  # Compare height and width
        labels = np.transpose(labels, axes=(0, 2, 1, 3))  # Swap height and width

    # Check input dimensions
    assert images.shape[:3] == labels.shape[:3], (
        f"Images and labels must have the same spatial dimensions. "
        f"Images shape: {images.shape}, Labels shape: {labels.shape}"
    )

    # Extract patches
    vol_patches = tf.image.extract_patches(
        images=images,
        sizes=[1, *patch_size, 1],
        strides=[1, *stride, 1],
        rates=[1, 1, 1, 1],
        padding="SAME"
    ).numpy()

    seg_patches = tf.image.extract_patches(
        images=labels,
        sizes=[1, *patch_size, 1],
        strides=[1, *stride, 1],
        rates=[1, 1, 1, 1],
        padding="SAME"
    ).numpy()

    # Reshape patches
    vol_patches = vol_patches.reshape([-1, *patch_size, 1])
    seg_patches = seg_patches.reshape([-1, *patch_size])

    # Adjust the foreground mask to match vol_patches
    foreground_mask = seg_patches.sum(axis=(1, 2)) > 0  # Only keep patches with non-zero labels
    useful_patches = np.arange(len(vol_patches))[foreground_mask[:len(vol_patches)]]  # Adjust length

    # Debug shapes
    print(f"vol_patches shape: {vol_patches.shape}")
    print(f"seg_patches shape: {seg_patches.shape}")
    print(f"useful_patches count: {len(useful_patches)}")

    # Apply the filtering mask
    vol_patches = vol_patches[useful_patches]
    seg_patches = seg_patches[useful_patches]

    # Convert labels to one-hot encoding
    seg_patches = tf.keras.utils.to_categorical(seg_patches, num_classes=num_classes)

    return vol_patches, seg_patches


# Patch parameters
patch_size = (16, 16)
stride = (16, 16)
threshold = 0.5

# Extract patches
x_train, y_train = extract_patches(
    training_images_array,
    training_labels_array,
    patch_size,
    stride,
    threshold,
    n_classes
)
x_val, y_val = extract_patches(
    validation_images_array,
    validation_labels_array,
    patch_size,
    stride,
    threshold,
    n_classes
)

print(f"Training Patches Shape: {x_train.shape}")
print(f"Validation Patches Shape: {x_val.shape}")


vol_patches shape: (73728, 32, 32, 1)
seg_patches shape: (294912, 32, 32)
useful_patches count: 73728
vol_patches shape: (40960, 32, 32, 1)
seg_patches shape: (163840, 32, 32)
useful_patches count: 40960
Training Patches Shape: (73728, 32, 32, 1)
Validation Patches Shape: (40960, 32, 32, 1)


In [None]:
# Step 4: Build the SegNet Model
def get_segnet(img_size, n_classes, n_input_channels):
    inputs = tf.keras.Input(shape=img_size + (n_input_channels,))

    # Encoding Path
    conv1 = layers.Conv2D(32, (3, 3), activation='relu', padding='same')(inputs)
    pool1 = layers.MaxPooling2D((2, 2))(conv1)

    conv2 = layers.Conv2D(64, (3, 3), activation='relu', padding='same')(pool1)
    pool2 = layers.MaxPooling2D((2, 2))(conv2)

    conv3 = layers.Conv2D(128, (3, 3), activation='relu', padding='same')(pool2)
    pool3 = layers.MaxPooling2D((2, 2))(conv3)

    # Decoding Path
    up1 = layers.UpSampling2D((2, 2))(pool3)
    conv4 = layers.Conv2D(128, (3, 3), activation='relu', padding='same')(up1)

    up2 = layers.UpSampling2D((2, 2))(conv4)
    conv5 = layers.Conv2D(64, (3, 3), activation='relu', padding='same')(up2)

    up3 = layers.UpSampling2D((2, 2))(conv5)
    conv6 = layers.Conv2D(32, (3, 3), activation='relu', padding='same')(up3)

    outputs = layers.Conv2D(n_classes, (1, 1), activation='softmax')(conv6)

    model = models.Model(inputs, outputs)
    return model

segnet = get_segnet(img_size=(32, 32), n_classes=n_classes, n_input_channels=1)
segnet.compile(optimizer='adam', loss='categorical_crossentropy', metrics=['accuracy'])

# Step 5: Train the Model
batch_size = 32
epochs = 20

history = segnet.fit(
    x_train, y_train,
    validation_data=(x_val, y_val),
    batch_size=batch_size,
    epochs=epochs,
    verbose=1
)

# Step 6: Visualize Training Results
plt.figure()
plt.plot(history.history['loss'], label='Loss')
plt.plot(history.history['val_loss'], label='Validation Loss')
plt.legend()
plt.show()