In [56]:
import os
import numpy as np
import nibabel as nib
import tensorflow as tf
from tensorflow.keras.applications import ResNet50
from tensorflow.keras.layers import Dense, Flatten, GlobalAveragePooling2D
from tensorflow.keras.models import Model
from sklearn.model_selection import train_test_split
from sklearn.utils import class_weight
from keras.layers import Input
from keras.applications import ResNet50
from keras.optimizers import Adam

In [37]:
# Helper function to load and preprocess NIfTI images
def load_and_preprocess_nifti1(nifti_path, target_size=(224, 224)):
    """
    Loads a NIfTI image, resizes it, and returns it as a NumPy array.

    Parameters:
    - nifti_path: Path to the NIfTI image file.
    - target_size: Tuple specifying the target size (height, width) for resizing.

    Returns:
    - Resized image as a NumPy array.
    """
    nifti_image = nib.load(nifti_path)
    nifti_data = nifti_image.get_fdata()  # Get the image data as a NumPy array
    # Assuming we are working with 2D slices, pick the middle slice along the z-axis
    middle_slice = nifti_data[:, :, nifti_data.shape[2] // 2]
    resized_img = tf.image.resize(middle_slice, target_size)
    return np.array(resized_img)

In [38]:
# Helper function to load and preprocess NIfTI images
def load_and_preprocess_nifti(nifti_path, target_size=(224, 224)):
    """
    Loads a NIfTI image, resizes it, and returns it as a NumPy array.

    Parameters:
    - nifti_path: Path to the NIfTI image file.
    - target_size: Tuple specifying the target size (height, width) for resizing.

    Returns:
    - Resized image as a NumPy array.
    """
    nifti_image = nib.load(nifti_path)
    nifti_data = nifti_image.get_fdata()  # Get the image data as a NumPy array
    # Assuming we are working with 2D slices, pick the middle slice along the z-axis
    middle_slice = nifti_data[:, :, nifti_data.shape[2] // 2]

    # Add an extra dimension to make it a 3D tensor (H, W, 1) before resizing
    middle_slice = np.expand_dims(middle_slice, axis=-1)

    # Resize the image
    resized_img = tf.image.resize(middle_slice, target_size)

    # Squeeze to remove the added dimension
    return np.squeeze(resized_img)

In [39]:
def is_empty_image(nifti_path):
    """
    Check if the NIfTI image is empty.

    Parameters:
    - nifti_path: Path to the NIfTI image file.

    Returns:
    - True if the image is empty, False otherwise.
    """
    try:
        nifti_image = nib.load(nifti_path)
        return True
    except ImageFileError:
        return False

In [40]:
# Usage in loading function
def load_and_stack_nifti(patient_folder, patient_folder_name, target_size=(224, 224)):
    stacked_imgs = []
    
    # Define MRI file name suffixes
    mri_suffixes = {
        'T1': '-t1n.nii.gz',
        'T1 CE': '-t1c.nii.gz',
        'T2': '-t2w.nii.gz',
        'T2 FLAIR': '-t2f.nii.gz'
    }
    
    for mri_type, suffix in mri_suffixes.items():
        nifti_image_path = os.path.join(patient_folder, f'{patient_folder_name}{suffix}')
        if os.path.exists(nifti_image_path):
            try:
                img = load_and_preprocess_nifti(nifti_image_path, target_size)
                stacked_imgs.append(img)
            except Exception:
                continue
    
    if len(stacked_imgs) < 4:
        print(f"Warning: No valid MRI images found for patient {patient_folder_name}.")
        return None  # Return None or handle as needed
    
    return np.stack(stacked_imgs, axis=-1)

In [41]:
# Helper function to load and stack multiple NIfTI MRI images
def load_and_stack_nifti1(patient_folder, patient_folder_name, target_size=(224, 224)):
    """
    Loads and stacks NIfTI images for different MRI types as separate channels.

    Parameters:
    - patient_folder: Path to the folder containing patient's MRI scans.
    - patient_folder_name: Name of the patient folder for constructing file names.
    - target_size: Target size to resize the images (ignoring depth for 2D CNN).

    Returns:
    - Stacked image array where each MRI type is a separate channel.
    """
    stacked_imgs = []
    
    # Define MRI file name suffixes
    mri_suffixes = {
        'T1': '-t1n.nii.gz',
        'T1 CE': '-t1c.nii.gz',
        'T2': '-t2w.nii.gz',
        'T2 FLAIR': '-t2f.nii.gz'
    }
    
    for mri_type, suffix in mri_suffixes.items():
        nifti_image_path = os.path.join(patient_folder, f'{patient_folder_name}{suffix}')
        if os.path.exists(nifti_image_path):
            img = load_and_preprocess_nifti(nifti_image_path, target_size)
            stacked_imgs.append(img)  # Append each image to the list of channels
    
    return np.stack(stacked_imgs, axis=-1)  # Stack along the channel axis

In [47]:
# Function to load dataset with MRI stacking
def load_dataset(dataset_path):
    """
    Loads the dataset and stacks multiple MRI scans as channels.

    Parameters:
    - dataset_path: Path to the dataset folder.

    Returns:
    - X: Array of stacked images.
    - y: Array of labels (cancer levels).
    """
    d = {"51_OtherNeoplasms":1, "95_Glioma":2}
    X = []
    y = []
    cancer_levels = os.listdir(dataset_path)  # Cancer level directories (labels)
    skipped_images = 0
    
    for cancer_level in cancer_levels:
        patient_folders = os.listdir(os.path.join(dataset_path, cancer_level))
        
        for patient in patient_folders:
            patient_folder = os.path.join(dataset_path, cancer_level, patient)
            
            # Load and stack the four MRI images: T1, T1 CE, T2, T2 FLAIR
            stacked_img = load_and_stack_nifti(patient_folder, patient, target_size=(224, 224))
            
            if stacked_img is not None:
                X.append(stacked_img)  # Add stacked image to the dataset
                y.append(d[cancer_level])  # Store cancer level as label (convert to integer)
            else:
                skipped_images += 1  # Count skipped images

    print("Skipped Images =", skipped_images)
    # Convert to numpy arrays only if X is not empty and all images are valid
    if X:
        return np.array(X), np.array(y)
    else:
        print("Warning: No valid MRI images found. Returning empty arrays.")
        return np.array([]), np.array([])

In [54]:
def build_model(input_shape=(224, 224, 4), num_classes=4):
    """
    Builds the CNN model using a modified ResNet50 as the base model.
    
    Parameters:
    - input_shape: Tuple representing the shape of the input image.
    - num_classes: Number of output classes.
    
    Returns:
    - Compiled Keras model.
    """
    inputs = Input(shape=input_shape)

    inp = layers.Input((dshape[0], dshape[1], dshape[2], 1))
x = Conv3D(32, kernel_size=3,padding="same", use_bias=False)(inp)
x = Activation("relu")(x)
x = Conv3D(32, kernel_size=3,padding="same", use_bias=False)(x)
x = Activation("relu")(x)
x = MaxPool3D(pool_size=3, padding='same')(x)
x = Dropout(0)(x)

x = Conv3D(64, kernel_size=3,padding="same")(x)
x = Activation("relu")(x)
x = Conv3D(64, kernel_size=3,padding="same")(x)
x = Activation("relu")(x)
x = MaxPool3D(pool_size=2, padding='same')(x)
x = Dropout(0.2)(x)

x = Conv3D(64, kernel_size=3,padding="same")(x)
x = Activation("relu")(x)
x = Conv3D(64, kernel_size=3,padding="same")(x)
x = Activation("relu")(x)
x = MaxPool3D(pool_size=2, padding='same')(x)
x = Dropout(0.25)(x)

x = Flatten()(x)
x = Dense(512, activation='relu')(x)
x = BatchNormalization()(x)
x = Dropout(0.2)(x)
x = layers.Dense(3,activation='softmax')(x) 
model = models.Model(inputs=inp, outputs=x)
model.summary()

    model = Model(inputs=inputs, outputs=x)
    #model.compile(optimizer='adam', loss='sparse_categorical_crossentropy', metrics=['accuracy'])
    model.compile(optimizer=Adam(learning_rate=0.001), loss='sparse_categorical_crossentropy', metrics=['accuracy'])
    
    return model

In [44]:
# Function to build the CNN model using ResNet50
def build_model1(input_shape=(224, 224, 4), num_classes=4):
    """
    Builds the CNN model using ResNet50 as the base model.

    Parameters:
    - input_shape: Shape of the input image (height, width, channels).
    - num_classes: Number of output classes for classification.

    Returns:
    - Compiled Keras model.
    """
    base_model = ResNet50(weights='imagenet', include_top=False, input_shape=input_shape)
    x = base_model.output
    x = GlobalAveragePooling2D()(x)
    x = Flatten()(x)
    x = Dense(1024, activation='relu')(x)
    predictions = Dense(num_classes, activation='softmax')(x)
    
    model = Model(inputs=base_model.input, outputs=predictions)
    
    # Freeze the base model layers
    for layer in base_model.layers:
        layer.trainable = False
    
    # Compile the model
    model.compile(optimizer='adam', loss='sparse_categorical_crossentropy', metrics=['accuracy'])
    return model

In [59]:
# Main function to train and evaluate the model
def main():
    dataset_path = 'PKG - BraTS-Africa/BraTS-Africa'  # Replace with the path to your dataset
    
    # Load dataset
    X, y = load_dataset(dataset_path)
    
    # Split dataset into training and testing sets
    X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)
    
    # Build model
    model = build_model(input_shape=(224, 224, 4), num_classes=len(np.unique(y)))
    
    # Train the model
    #model.fit(X_train, y_train, epochs=10, batch_size=16, validation_split=0.1)

    #history = model.fit(X_train, y_train, epochs=10, batch_size=16, validation_split=0.1)
    print(y_train)
    class_weights = class_weight.compute_class_weight('balanced', classes=np.unique(y_train), y=y_train)
    class_weight_dict = {i: weight for i, weight in enumerate(class_weights)}
    history = model.fit(X_train, y_train, epochs=10, batch_size=16, validation_split=0.1, class_weight=class_weights)

    # Plot training & validation accuracy values
    import matplotlib.pyplot as plt
    
    plt.plot(history.history['accuracy'])
    plt.plot(history.history['val_accuracy'])
    plt.title('Model accuracy')
    plt.ylabel('Accuracy')
    plt.xlabel('Epoch')
    plt.legend(['Train', 'Test'], loc='upper left')
    plt.show()

    predictions = model.predict(X_test)
    predicted_classes = np.argmax(predictions, axis=1)
    
    # Compare predicted classes with actual classes
    print("Predicted classes:", predicted_classes)
    print("Actual classes:", y_test)
        
    # Evaluate the model on the test set
    test_loss, test_acc = model.evaluate(X_test, y_test)
    print(f'Test accuracy: {test_acc:.4f}')

In [60]:
main()

Skipped Images = 1


ValueError: The truth value of an array with more than one element is ambiguous. Use a.any() or a.all()