In [None]:
import os
import SimpleITK as sitk
import numpy as np
import pandas as pd
import math

# Initialize result storage
results = []

# Function to calculate SNR
def calculate_SNR(signal_region):
    mean_signal = np.mean(signal_region)
    noise = np.std(signal_region)
    return mean_signal / noise if noise != 0 else 0

# Function to calculate contrast
def calculate_contrast(SI_CBD, SI_periductal):
    return (SI_CBD - SI_periductal) / (SI_CBD + SI_periductal)

# Function to calculate CNR
def calculate_CNR(SI_CBD, SI_periductal, noise_CBD, noise_periductal):
    if noise_CBD == 0 or noise_periductal == 0:
        return 0
    noise_rms = math.sqrt((noise_CBD**2 + noise_periductal**2) / 2)
    return (SI_CBD - SI_periductal) / noise_rms

# Function to calculate FWHM after projection onto the coronal plane
def calculate_FWHM_projection(region):
    """
    Function to calculate FWHM after projecting multi-dimensional data onto the coronal plane.
    Handles both 3D and 1D arrays.
    """
    # Ensure the region is 3D before projection
    if region.ndim == 3:
        # Step 1: Project the 3D data onto the coronal plane (e.g., by summing along the coronal axis)
        projection = np.mean(region, axis=1)  # Project along the coronal plane

        # Step 2: Reduce to a 1D signal
        profile_1d = np.sum(projection, axis=1)  # Summing along rows to get a 1D profile
    else:
        # If the region is already 1D (e.g., all values are non-zero), treat it as a 1D profile
        profile_1d = region

    # Step 3: Calculate FWHM on the 1D profile
    max_val = np.max(profile_1d)
    half_max_val = max_val / 2

    indices_above_half_max = np.where(profile_1d >= half_max_val)[0]

    # Calculate FWHM if points above half max are found
    if len(indices_above_half_max) > 0:
        lower_index = indices_above_half_max[0]
        upper_index = indices_above_half_max[-1]
        return upper_index - lower_index  # FWHM calculation

    return 0  # If no values above half max, FWHM is 0

# Function to process image and ground truth for multiple dilation radii
def process_file(file, img_dir, gt_dir, output_dir, dilation_radii):
    img_path = os.path.join(img_dir, file)
    gt_path = os.path.join(gt_dir, file)
    
    # Read image and ground truth
    img = sitk.ReadImage(img_path)
    GT = sitk.ReadImage(gt_path)

    for radius in dilation_radii:
        # Convert GT to binary and dilate with the given radius
        GT_cast = sitk.Cast(GT, sitk.sitkUInt8)
        GT_dilated = sitk.BinaryDilate(GT_cast, radius)

        # Convert to numpy arrays
        img_arr = sitk.GetArrayFromImage(img)
        GT_arr = sitk.GetArrayFromImage(GT_cast)
        GT_dilated_arr = sitk.GetArrayFromImage(GT_dilated)

        # Define CBD and periductal regions
        GT_region = img_arr * GT_arr
        GT_region_dilated = img_arr * GT_dilated_arr
        GT_periduct = GT_region_dilated - GT_region

        # Calculate metrics
        SI_CBD = np.mean(GT_region[GT_region > 0])
        SI_periductal = np.mean(GT_periduct[GT_periduct > 0])
        noise_CBD = np.std(GT_region[GT_region > 0])
        noise_periductal = np.std(GT_periduct[GT_periduct > 0])

        SNR_CBD = calculate_SNR(GT_region[GT_region > 0])
        contrast = calculate_contrast(SI_CBD, SI_periductal)
        CNR = calculate_CNR(SI_CBD, SI_periductal, noise_CBD, noise_periductal)
        FWHM_CBD = calculate_FWHM_projection(GT_region)  # Pass the 3D region directly

        # Store the results
        results.append({
            "File": file,
            "Radius": radius,  # Add the radius as a parameter
            "SNR": SNR_CBD,
            "Contrast": contrast,
            "CNR": CNR,
            "FWHM": FWHM_CBD
        })

        # Save .nii.gz files for GT_region and GT_periduct
        save_nifti(img, GT_region, output_dir, file, f"GT_region_radius_{radius}")
        save_nifti(img, GT_periduct, output_dir, file, f"GT_periduct_radius_{radius}")

# Function to save NIFTI files
def save_nifti(img, data, output_dir, file_name, suffix):
    img_out = sitk.GetImageFromArray(data)
    img_out.CopyInformation(img)
    sitk.WriteImage(img_out, os.path.join(output_dir, f"{file_name}_{suffix}.nii.gz"))

# Process all files in a directory with different dilation radii
def process_files(img_dir, gt_dir, output_dir, dilation_radii):
    os.makedirs(output_dir, exist_ok=True)
    
    for file in os.listdir(img_dir):
        if not file.startswith('.'):
            process_file(file, img_dir, gt_dir, output_dir, dilation_radii)
    
    # Save results to Excel
    df_results = pd.DataFrame(results)
    df_results.to_excel(os.path.join(output_dir, "MRCP_Results.xlsx"), index=False)

# Define paths for datasets and results
def run_batch(img_dir, gt_dir, output_dir, dilation_radii):
    process_files(img_dir, gt_dir, output_dir, dilation_radii)
    print("Results saved to Excel and .nii.gz files.")

# Example: Running batch processing for Radiopaedia and CPTAC datasets with two dilation radii
run_batch("/Users/ziling/Desktop/MRCP/data/NII_img_CPTAC", 
          "/Users/ziling/Desktop/MRCP/data/NII_GT_CPTAC", 
          "/Users/ziling/Desktop/MRCP/results/results_CPTAC", [[6,6,6], [12,12,12]])
run_batch("/Users/ziling/Desktop/MRCP/data/NII_img", 
          "/Users/ziling/Desktop/MRCP/data/NII_GT", 
          "/Users/ziling/Desktop/MRCP/results/results_LocData", [[6,6,6], [12,12,12]])

Results saved to Excel and .nii.gz files.


In [None]:
import SimpleITK as sitk
import numpy as np
import os
from scipy.ndimage import distance_transform_edt, binary_fill_holes
import skimage.morphology as morph
from scipy.spatial import cKDTree
from skimage.morphology import remove_small_objects
import pandas as pd

def resample_label(label_image, new_spacing=(1, 1, 1)):
    # Get the original spacing of the image
    original_spacing = label_image.GetSpacing()

    # Calculate the resampling factor
    resample_factor = (original_spacing[0] / new_spacing[0],
                       original_spacing[1] / new_spacing[1],
                       original_spacing[2] / new_spacing[2])

    # Create a resampling filter
    resample_filter = sitk.ResampleImageFilter()
    resample_filter.SetReferenceImage(label_image)
    resample_filter.SetSize((int(label_image.GetSize()[0] * resample_factor[0]),
                             int(label_image.GetSize()[1] * resample_factor[1]),
                             int(label_image.GetSize()[2] * resample_factor[2])))
    resample_filter.SetTransform(sitk.Transform(3, sitk.sitkIdentity))
    resample_filter.SetInterpolator(sitk.sitkNearestNeighbor)
    resample_filter.SetOutputSpacing(new_spacing)

    # Resample the label image
    resampled_label_image = resample_filter.Execute(label_image)
    resampled_label_image = sitk.Cast(resampled_label_image, sitk.sitkUInt8)

    # Return the resampled label image
    return resampled_label_image

def calculate_center(nii_img):
    binary_image = sitk.GetArrayFromImage(nii_img).astype(bool)

    # Step 1: Remove small objects (noise removal)
    binary_image = remove_small_objects(binary_image, min_size=100)

    # Step 2: Fill holes in the binary image to smooth the surface and eliminate internal holes
    binary_image = binary_fill_holes(binary_image)

    # Step 3: Skeletonize the binary image to extract the centerline
    skeleton = morph.skeletonize_3d(binary_image)
    centerline_points = np.argwhere(skeleton)

    # Step 4: Optionally prune the skeleton to remove small branches
    skeleton_pruned = remove_small_objects(skeleton, min_size=50)
    centerline_points = np.argwhere(skeleton_pruned)
    
    # Step 5: Generate a centerline NIfTI image
    centerline = np.zeros_like(binary_image)
    centerline[tuple(centerline_points.T)] = 1
    centerline_nii = sitk.GetImageFromArray(centerline.astype(np.uint8))
    centerline_nii.CopyInformation(nii_img)
    centerline = sitk.Cast(centerline_nii, sitk.sitkUInt8)

    return centerline_nii


def calculate_vessel_surface(label_image):
    # Create a binary erosion filter
    erosion_filter = sitk.BinaryErodeImageFilter()
    erosion_filter.SetKernelRadius(1)
    erosion_filter.SetKernelType(sitk.sitkBall)

    # Apply the erosion filter to the label image
    eroded_label_image = erosion_filter.Execute(label_image)

    # Calculate the vessel surface as the difference between the original and eroded images
    vessel_surface_image = sitk.SubtractImageFilter().Execute(label_image, eroded_label_image)
    vessel_surface_image = sitk.Cast(vessel_surface_image, sitk.sitkUInt8)
    return vessel_surface_image

def calculate_vessel_radius(vessel, centerline):
    """
    Calculate the radius of vessels using SimpleITK.
    
    Parameters:
    vessel: NIfTI mask of the vessel
    centerline: NIfTI mask of the vessel centerline
    """
    
    # Create a result image initialized to zero
    radius_map = sitk.Image(vessel.GetSize(), sitk.sitkUInt8)
    radius_map.CopyInformation(vessel)
    
    # Define dilation parameters
    kernel_radius = [1, 1, 1]  # Radius of the 3D spherical kernel
    
    # Initialize the region to the centerline
    current_region = centerline
    radius = 1

    while True:
        # Dilate the current region
        dilated = sitk.BinaryDilate(current_region, 
                                  kernel_radius, 
                                  sitk.sitkBall, 
                                  0,  # backgroundValue
                                  1)  # foregroundValue
        
        # Intersect the dilated region with the vessel mask
        intersection = dilated * vessel
        
        # Break if no new region is added
        if sitk.GetArrayFromImage(intersection).sum() == sitk.GetArrayFromImage(current_region).sum():
            break
            
        # Calculate the newly added region
        new_region = intersection - current_region
        # Assign the radius value to the newly added region
        radius_map = radius_map + (new_region * radius)
        
        # Update the region
        current_region = intersection
        radius += 1

    return radius_map


def calculate2table(nii, surface, visualization_path, table_path):
    nii_array = sitk.GetArrayFromImage(nii)
    surface_array = sitk.GetArrayFromImage(surface)
    surface_num = nii_array * surface_array

    # Classify into "<3mm%", "3-5mm%", "5-7mm%", "7-10mm%", ">10mm%" categories
    surface_num[(surface_num > 0) & (surface_num < 3)] = 1
    surface_num[(surface_num >= 3) & (surface_num < 5)] = 2
    surface_num[(surface_num >= 5) & (surface_num < 7)] = 3
    surface_num[(surface_num >= 7) & (surface_num < 10)] = 4
    surface_num[surface_num >= 10] = 5
    surface_visualization = sitk.GetImageFromArray(surface_num)
    surface_visualization.CopyInformation(nii)
    sitk.WriteImage(surface_visualization, visualization_path)

    # Calculate the percentages and save as a CSV
    df = pd.Series(index = ["volume", "<3mm%", "3-5mm%", "5-7mm%", "7-10mm%", ">10mm%"], dtype = float)
    df["volume"] = np.sum(nii_array)
    df["<3mm%"] = np.sum(surface_num == 1) / np.sum(surface_num > 0) * 100
    df["3-5mm%"] = np.sum(surface_num == 2) / np.sum(surface_num > 0) * 100
    df["5-7mm%"] = np.sum(surface_num == 3) / np.sum(surface_num > 0) * 100
    df["7-10mm%"] = np.sum(surface_num == 4) / np.sum(surface_num > 0) * 100
    df[">10mm%"] = np.sum(surface_num == 5) / np.sum(surface_num > 0) * 100

    df.to_frame(name='Value').to_csv(table_path)


# Image preprocessing
path = './label'
for name in os.listdir(path):
    if name.endswith('.nii.gz'):
        label_path = os.path.join(path, name)
        label_image = sitk.ReadImage(label_path)
        # Resample
        new_spacing = (1, 1, 1)
        resampled_label_image = resample_label(label_image, new_spacing)
        # Save the resampled label image
        sitk.WriteImage(resampled_label_image, './preprocess/resample/' + name)

        # Calculate the vessel surface
        surface = calculate_vessel_surface(resampled_label_image)
        # Save the vessel surface image
        sitk.WriteImage(surface, './preprocess/surface/' + name)

        # Calculate the centerline
        centerline_image = calculate_center(resampled_label_image)
        # Save the centerline image
        sitk.WriteImage(centerline_image, './preprocess/centerline/' + name)

        # Calculate the distance of vessel points to the centerline
        new_nii = calculate_vessel_radius(resampled_label_image, centerline_image)
        sitk.WriteImage(new_nii, './result/nifti/' + name)

        # Generate and export the table, and save the visualization of the surface
        calculate2table(new_nii, surface, visualization_path = './result/visualization/' + name, table_path = './result/table/' + name.replace('.nii.gz', '.csv'))