<a href="https://colab.research.google.com/github/zi-bou/zi-bou/blob/main/Al4l.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# **AI4Imaging-Hackathon-2024**
Classification of Heart disease based on cine MRI scan



## **Classification label :**
*   NOR - Normal subjects
*   MINF - Myocardial infarction
*   DCM - Dilated cardiomyopathy
*   HCM - Hypertrophic cardiomyopathy
*   RV - Abnormal right ventricle


## **Data structure**


/data/
- **drive-download-20250107T191042Z-001/**
  - **train/**
    - `p0096/`
    - `p0097/`
    - `p0088/`
    - `p0100/`
    - `p0098/`
    - `p0083/`
    - `p0090/`
    - `p0092/`
    - `p0094/`
    - `p0099/`
  - **test/**
    - Similar subdirectories as in `train/`
  - `test_sample_submission.csv`



NIfTI : Neuroimaging Informatics Technology Initiative
gt : ground truth

ROI (Region of interests)
The segmentation mask, e.g., p0001_frame01_gt.nii.gz, utilizes non-zero integer (i.e., 1-3) to highlight different anatomical structures of the heart. The numbers 1, 2, and 3 represent the right ventricle, left ventricle, and myocardium, respectively.

In [15]:
from google.colab import drive
import os
import nibabel as nib
import tensorflow as tf
import zipfile

In [16]:
# Upload files from your local system
drive.mount('/content/drive')

base_dir='/content/drive/My Drive/'
data_dir='/content/drive/My Drive/KAGGLE/Al4l/data'

Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).


## **Process Nested ZIP and GZ Files in All Directories**

In [17]:
#Code to Process Nested ZIP and GZ Files in All Directories

import os
import gzip
import zipfile
import shutil

# Initial setup
data_dir = data_dir

def process_compressed_files(directory):
    """
    Recursively checks for ZIP and GZ files in the given directory and its subdirectories.
    Extracts and removes compressed files until none are left.
    """
    while True:
        # List all files and subdirectories
        compressed_files = []
        subdirectories = []

        for root, dirs, files in os.walk(directory):
            for file in files:
                if file.endswith('.gz') or file.endswith('.zip'):
                    compressed_files.append(os.path.join(root, file))
            for dir in dirs:
                subdirectories.append(os.path.join(root, dir))

        # Ensure we process all compressed files before breaking
        if not compressed_files:
            print("No more compressed files to process.")
            break

        # Process each compressed file
        for compressed_file in compressed_files:
            if compressed_file.endswith('.gz'):
                extraction_path = os.path.splitext(compressed_file)[0]  # Remove .gz extension for output file

                # Extract the GZ file
                with gzip.open(compressed_file, 'rb') as f_in:
                    with open(extraction_path, 'wb') as f_out:
                        shutil.copyfileobj(f_in, f_out)
                        print(f"Extracted {compressed_file} to {extraction_path}")

                # Delete the processed GZ file to avoid reprocessing
                os.remove(compressed_file)
                print(f"Deleted {compressed_file}")

            elif compressed_file.endswith('.zip'):
                extraction_dir = os.path.splitext(compressed_file)[0]  # Remove .zip extension for output folder

                # Create the extraction directory if it doesn't exist
                os.makedirs(extraction_dir, exist_ok=True)

                # Extract the ZIP file
                with zipfile.ZipFile(compressed_file, 'r') as zip_ref:
                    zip_ref.extractall(extraction_dir)
                    print(f"Extracted {compressed_file} to {extraction_dir}")

                # Delete the processed ZIP file to avoid reprocessing
                os.remove(compressed_file)
                print(f"Deleted {compressed_file}")

        # Recursively process subdirectories
        for sub_dir in subdirectories:
            print(f"Checking subdirectory: {sub_dir}")
            process_compressed_files(sub_dir)  # Recursive call to ensure all compressed files are processed

# Start the recursive processing
process_compressed_files(data_dir)


No more compressed files to process.


In [18]:
data_dir=os.path.join(data_dir,"drive-download-20250107T191042Z-001")
train_dir =os.path.join(data_dir, 'train')
test_dir=os.path.join(data_dir, 'test')
patient_ids = [f"p{i:04d}" for i in range(1, 101)]  # Generate patient IDs from p0001 to p0100


## **Preprocessing_V2**

In [19]:
import os
import numpy as np
import nibabel as nib
from scipy.ndimage import zoom

##############################
# 1) locate_patient_data
##############################

def locate_patient_data(patient_id):
    """
    Finds the systolic frame, diastolic frame, systolic mask, and diastolic mask for a patient.
    Returns (file_names, paths, frames) or (None, None, None) on error.

    'frames' will be a list of nib.Nifti1Image objects:
      [systolic_img, diastolic_img, systolic_mask_img, diastolic_mask_img]
    """
    patient_dir = os.path.join(train_dir, patient_id)
    if not os.path.exists(patient_dir):
        print(f"Patient directory not found: {patient_dir}")
        return None, None, None

    files = os.listdir(patient_dir)
    systolic_frame_file = None
    diastolic_frame_file = None
    systolic_mask_file = None
    diastolic_mask_file = None

    for file in files:
        if "frame01" in file and "gt" not in file and file.endswith(".nii"):
            systolic_frame_file = file
        elif "frame01" in file and "gt" in file and file.endswith(".nii"):
            systolic_mask_file = file
        elif "frame" in file and "gt" not in file and "frame01" not in file and file.endswith(".nii"):
            diastolic_frame_file = file
        elif "frame" in file and "gt" in file and "frame01_gt" not in file and file.endswith(".nii"):
            diastolic_mask_file = file

    file_names = [systolic_frame_file, diastolic_frame_file, systolic_mask_file, diastolic_mask_file]
    if None in file_names:
        print(f"Missing files for patient {patient_id}. Check the directory.")
        return None, None, None

    # Construct paths
    systolic_path = os.path.join(patient_dir, systolic_frame_file)
    diastolic_path = os.path.join(patient_dir, diastolic_frame_file)
    systolic_mask_path = os.path.join(patient_dir, systolic_mask_file)
    diastolic_mask_path = os.path.join(patient_dir, diastolic_mask_file)
    label_path = os.path.join(patient_dir, "gt.txt")  # optional label

    paths = [systolic_path, diastolic_path, systolic_mask_path, diastolic_mask_path, label_path]

    # Load nibabel images (not .get_fdata() yet)
    try:
        systolic_img = nib.load(systolic_path)
        diastolic_img = nib.load(diastolic_path)
        systolic_mask_img = nib.load(systolic_mask_path)
        diastolic_mask_img = nib.load(diastolic_mask_path)
        frames = [systolic_img, diastolic_img, systolic_mask_img, diastolic_mask_img]
    except Exception as e:
        print(f"Error loading files for patient {patient_id}: {e}")
        return None, None, None

    return file_names, paths, frames

##############################
# 2) Gather all spacings
##############################

def check_all_patient_spacings(patient_ids):
    """
    Loops over all patient_ids, calls locate_patient_data,
    and collects voxel spacings for systolic/diastolic frames + masks.

    Returns: list of (sx, sy, sz) for all volumes.
    """
    spacings_list = []

    for pid in patient_ids:
        file_names, paths, frames = locate_patient_data(pid)
        if file_names and paths and frames:
            try:
                systolic_img, diastolic_img, systolic_mask_img, diastolic_mask_img = frames
                # get_zooms() gives spacing
                systolic_spacing = systolic_img.header.get_zooms()[:3]
                diastolic_spacing = diastolic_img.header.get_zooms()[:3]
                systolic_mask_spacing = systolic_mask_img.header.get_zooms()[:3]
                diastolic_mask_spacing = diastolic_mask_img.header.get_zooms()[:3]

                # Collect them
                spacings_list.append(systolic_spacing)
                spacings_list.append(diastolic_spacing)
                spacings_list.append(systolic_mask_spacing)
                spacings_list.append(diastolic_mask_spacing)

            except Exception as e:
                print(f"Error reading spacing for patient {pid}: {e}")

    return spacings_list

##############################
# 3) Decide on 90th percentile
##############################

def decide_target_spacing(spacings_list, percentile=90):
    """
    Compute e.g. the 90th percentile across each dimension (sx, sy, sz).
    """
    arr = np.array(spacings_list)  # shape (N, 3)
    target_sx = np.percentile(arr[:,0], percentile)
    target_sy = np.percentile(arr[:,1], percentile)
    target_sz = np.percentile(arr[:,2], percentile)
    return (target_sx, target_sy, target_sz)

##############################
# 4) Resample volume
##############################

def resample_volume(nifti_img, target_spacing, is_label=False):
    """
    Resample a nibabel image to target_spacing using scipy zoom.
    is_label=True => nearest-neighbor (order=0).
    is_label=False => linear interpolation (order=1).
    Returns a NumPy array of resampled data.
    """
    data = nifti_img.get_fdata()
    orig_spacing = nifti_img.header.get_zooms()[:3]
    scale_factors = (
        orig_spacing[0]/target_spacing[0],
        orig_spacing[1]/target_spacing[1],
        orig_spacing[2]/target_spacing[2],
    )
    order = 0 if is_label else 1
    resampled_data = zoom(data, scale_factors, order=order)
    return resampled_data

##############################
# 5) Crop or Pad Volume
##############################

def crop_or_pad_mri(volume, target_shape=(256,256,16)):
    """
    Crop or pad a volume to match (target_shape).
    volume.shape => (Dx, Dy, Dz) or (H, W, D), etc.
    Adjust to your convention.
    We'll assume volume.shape => (H, W, Z)
    and target_shape => (256,256,16).
    """

    # If your dimension ordering is different (D, H, W),
    # be consistent or rename variables as needed.
    adjusted_volume = volume
    for i, (dim, target_dim) in enumerate(zip(adjusted_volume.shape, target_shape)):
        diff = target_dim - dim
        if diff > 0:
            # pad
            pad_before = diff // 2
            pad_after = diff - pad_before
            pad_widths = [(0,0)]*3
            pad_widths[i] = (pad_before, pad_after)
            adjusted_volume = np.pad(adjusted_volume, pad_widths,
                                     mode='constant', constant_values=0)
        elif diff < 0:
            # crop
            crop_start = abs(diff)//2
            crop_end = crop_start + target_dim
            slices = [slice(None)]*3
            slices[i] = slice(crop_start, crop_end)
            adjusted_volume = adjusted_volume[tuple(slices)]
    return adjusted_volume

##############################
# 6) Full Preprocessing
##############################

def preprocess_and_resample_all(patient_ids, target_spacing=(1.0,1.0,1.0), final_shape=(256,256,16)):
    """
    For each patient:
      - locate data
      - resample (systolic/diastolic + masks) to target_spacing
      - crop or pad each to final_shape
      - stack frames, do normalization, etc.
    Returns a list/dict of results.
    """
    all_volumes = []

    for pid in patient_ids:
        file_names, paths, frames = locate_patient_data(pid)
        if file_names and paths and frames:
            systolic_img, diastolic_img, systolic_mask_img, diastolic_mask_img = frames

            # 1) Resample
            systolic_resampled = resample_volume(systolic_img, target_spacing, is_label=False)
            diastolic_resampled = resample_volume(diastolic_img, target_spacing, is_label=False)
            systolic_mask_resampled = resample_volume(systolic_mask_img, target_spacing, is_label=True)
            diastolic_mask_resampled = resample_volume(diastolic_mask_img, target_spacing, is_label=True)

            # 2) Crop/Pad to final_shape=(256,256,16)
            systolic_resampled = crop_or_pad_mri(systolic_resampled, final_shape)
            diastolic_resampled = crop_or_pad_mri(diastolic_resampled, final_shape)
            systolic_mask_resampled = crop_or_pad_mri(systolic_mask_resampled, final_shape)
            diastolic_mask_resampled = crop_or_pad_mri(diastolic_mask_resampled, final_shape)

            # 3) Stack frames => shape (H, W, D, 2)
            # If your shape is (256,256,16), then after stacking => (256,256,16,2)
            combined_frames = np.stack((systolic_resampled, diastolic_resampled), axis=-1)

            # 4) Normalize intensities (avoid division by zero if empty)
            max_val = combined_frames.max()
            if max_val > 0:
                combined_frames = combined_frames / max_val

            # Save final results
            all_volumes.append({
                'patient_id': pid,
                'frames': combined_frames,   # shape (256,256,16,2)
                'systolic_mask': systolic_mask_resampled,   # shape (256,256,16)
                'diastolic_mask': diastolic_mask_resampled, # shape (256,256,16)
                # Optionally store more: label, original spacing, etc.
            })

    return all_volumes

################################
# 7) Putting It All Together
################################

# 2) Gather all spacings
spacings_list = check_all_patient_spacings(patient_ids)

# 3) Compute 90th percentile
target_spacing_90th = decide_target_spacing(spacings_list, percentile=90)
print("Chosen 90th-percentile spacing:", target_spacing_90th)

# 4) Preprocess: resample to that spacing, then crop/pad to (256,256,16)
all_preprocessed = preprocess_and_resample_all(
    patient_ids=patient_ids,
    target_spacing=target_spacing_90th,
    final_shape=(256,256,16)
)

print(f"\nPreprocessed {len(all_preprocessed)} patients.")
print("Sample keys in first entry:", all_preprocessed[0].keys() if all_preprocessed else "No data")


Chosen 90th-percentile spacing: (1.7382799744606023, 1.7382799744606023, 10.0)

Preprocessed 100 patients.
Sample keys in first entry: dict_keys(['patient_id', 'frames', 'systolic_mask', 'diastolic_mask'])


## **Visualization**

In [20]:
import matplotlib.pyplot as plt

In [23]:
def visualize_mri_and_mask(patient_id):
    """
    Visualize a single slice from an MRI patient image and its corresponding segmentation mask.

    Args:
        patient_id (str): Patient ID (e.g., 'p0001').
    """
    # Locate patient data
    file_names, paths, frames = locate_patient_data(patient_id)

    if not frames:
        print(f"No data found for patient {patient_id}.")
        return

    # Check dimensions
    print(f"Systolic MRI Shape: {frames[0].shape}")
    print(f"Diastolic MRI Shape: {frames[1].shape}")
    print(f"Systolic Mask Shape: {frames[2].shape}")
    print(f"Diastolic Mask Shape: {frames[3].shape}")

    # Load and print patient classification
    label_path = paths[4]  # Label file is the fifth element in paths
    try:
        with open(label_path, 'r') as label_file:
            label = label_file.readline().strip()
            print(f"Patient Classification: {label}")
    except FileNotFoundError:
        print(f"Label file not found for patient {patient_id}.")
        return

    # Choose slice index dynamically
    slice_index = frames[0].shape[2] // 2

    # Extract the slice for visualization
    mri_systolic_slice = frames[0][:, :, slice_index]
    systolic_mask_slice = frames[2][:, :, slice_index]

    mri_diastolic_slice = frames[1][:, :, slice_index]
    diastolic_mask_slice = frames[3][:, :, slice_index]

    # Plot the MRI image and the segmentation mask
    plt.figure(figsize=(12, 12))

    # MRI systolic slice
    plt.subplot(2, 2, 1)
    plt.imshow(mri_systolic_slice, cmap='gray')
    plt.title("MRI Systolic Slice")
    plt.axis("off")

    # Systolic segmentation mask
    plt.subplot(2, 2, 2)
    plt.imshow(mri_systolic_slice, cmap='gray')  # Display the MRI slice as the background
    plt.imshow(systolic_mask_slice, alpha=0.5, cmap='jet')  # Overlay the mask with transparency
    plt.title("MRI Systolic Slice with Segmentation Mask")
    plt.axis("off")

    # MRI diastolic slice
    plt.subplot(2, 2, 3)
    plt.imshow(mri_diastolic_slice, cmap='gray')
    plt.title("MRI Diastolic Slice")
    plt.axis("off")

    # Diastolic segmentation mask
    plt.subplot(2, 2, 4)
    plt.imshow(mri_diastolic_slice, cmap='gray')  # Display the MRI slice as the background
    plt.imshow(diastolic_mask_slice, alpha=0.5, cmap='jet')  # Overlay the mask with transparency
    plt.title("MRI Diastolic Slice with Segmentation Mask")
    plt.axis("off")

    plt.tight_layout(pad=1)
    plt.show()


random_patient_id = np.random.choice(patient_ids)
print(f"Randomly selected patient ID: {random_patient_id}")

visualize_mri_and_mask(random_patient_id)

Randomly selected patient ID: p0053
Systolic MRI Shape: (216, 256, 10)
Diastolic MRI Shape: (216, 256, 10)
Systolic Mask Shape: (216, 256, 10)
Diastolic Mask Shape: (216, 256, 10)
Patient Classification: HCM


TypeError: Cannot slice image objects; consider using `img.slicer[slice]` to generate a sliced image (see documentation for caveats) or slicing image array data with `img.dataobj[slice]` or `img.get_fdata()[slice]`

In [None]:

# Let's check image shapes

import numpy as np

shapes = []
for patient_id in patient_ids:
    patient_dir = os.path.join(train_dir, patient_id)
    files = os.listdir(patient_dir)
    for file in files:
        if file.endswith(".nii"):
            file_path = os.path.join(patient_dir, file)
            data = nib.load(file_path).get_fdata()
            print(f"File: {file}, Shape: {data.shape}")
            shapes.append(data.shape)

# Compute the median shape for each dimension
median_shape = tuple(np.median(np.array(shapes), axis=0).astype(int))
print(f"Median shape: {median_shape}")

# Compute the 90 percentile
percentile_90 = tuple(np.percentile(np.array(shapes), 90, axis=0).astype(int))
print(f"90th percentile shape: {percentile_90}")


## **Build the model**

simple Conv3D model with :
* adam optimizer
* he initializer
* same padding
* 80/20 train/validation split

In [None]:
import tensorflow as tf
from tensorflow.keras import layers, models
from tensorflow.keras.models import Model

def build_3d_cnn(input_shape=(256,256,16,2), num_classes=5):
    """
    Build a 3D CNN model with Keras
    """
    model = models.Sequential([
        layers.Conv3D(32, (3,3,3), activation='relu', padding='same', input_shape=input_shape, kernel_initializer='he_normal'),
        layers.MaxPooling3D((2,2,2)),

        layers.Conv3D(64, (3,3,3), activation='relu', padding='same', kernel_initializer='he_normal'),
        layers.MaxPooling3D((2,2,2)),

        layers.Flatten(),
        layers.Dense(128, activation='relu', kernel_initializer='he_normal'),
        layers.Dense(num_classes, activation='softmax', kernel_initializer='he_normal')
    ])
    return model

#Go through the entire dataset
patient_ids = [f"p{i:04d}" for i in range(1, 101)]  # e.g., 100 patients

# Preprocess the data
inputs, labels = preprocess_and_resample_all(patient_ids, target_spacing=(1.0,1.0,1.0), final_shape=(256,256,16))

# Split into train/val (80/20 here, but adjust as you like)
train_data, val_data, train_labels, val_labels = train_test_split(
    inputs, labels, test_size=0.2, random_state=42
)

# Build model
model = build_3d_cnn(input_shape=(256,256,16,2), num_classes=5)
model.compile(optimizer='adam', loss='categorical_crossentropy', metrics=['accuracy'])

# Print summary
model.summary()

# TRAIN ON THE ENTIRE DATASET (80% train, 20% val)
history = model.fit(
    train_data, train_labels,
    validation_data=(val_data, val_labels),
    epochs=20,  # recommended values between 50 to 100
    batch_size=4 # recommended values between 2 to 8
)


In [None]:
###########################
#VISUALIZE INTERMEDIATE LAYERS
##########################

# Create an activation model that outputs the feature maps of all layers
layer_outputs = [layer.output for layer in model.layers]
activation_model = Model(inputs=model.input, outputs=layer_outputs)

# Pick a single sample from your dataset
sample_volume = train_data[0:1]  # shape => (1, 256, 256, 16, 2)

# Get the activations
activations = activation_model.predict(sample_volume)

# Grab layer names (for plot titles)
layer_names = [layer.name for layer in model.layers]

# Visualize each layer's activations
for layer_idx, layer_activation in enumerate(activations):
    # If shape is (1, D, H, W, C) or something else, remove batch dimension
    if layer_activation.ndim == 5:
        # shape => (1, newD, newH, newW, newC)
        layer_activation = layer_activation[0]  # (newD, newH, newW, newC)

        D, H, W, C = layer_activation.shape
        # We'll show the middle slice in depth
        mid_slice = D // 2

        # We'll try to display up to 8 channels in a row
        cols = 8
        rows = C // cols + 1

        fig = plt.figure(figsize=(cols * 2, rows * 2))
        fig.suptitle(f"Layer {layer_idx+1}: {layer_names[layer_idx]}", fontsize=16)

        for c in range(C):
            ax = plt.subplot(rows, cols, c + 1)
            # Show the middle slice for channel c
            channel_slice = layer_activation[mid_slice, :, :, c]
            plt.imshow(channel_slice, cmap='viridis', aspect='auto')
            plt.axis('off')

        plt.tight_layout()
        plt.show()

    elif layer_activation.ndim == 2:
        # It's probably a Dense layer output with shape (1, features)
        # or after Flatten with shape (1, features).
        # We can skip or just print shape info
        print(f"Layer {layer_idx+1}: {layer_names[layer_idx]} output shape => {layer_activation.shape}. Skipping visualization.")
    else:
        print(f"Layer {layer_idx+1}: {layer_names[layer_idx]} has shape {layer_activation.shape}. Skipping.")