In [6]:
import os
import random
import shutil
import SimpleITK as sitk
import numpy as np

# ---------------------------
# 1. Preprocess the Test Data
# ---------------------------

# Define the path to the directory containing the test patient folders.
test_dir = './dataset/ACDC_single_frame/single_frame'  # Adjust this path if necessary
output_dir = './dataset/secret_test_set_extracted'  # Directory to save processed test data

# Create output directory
os.makedirs(output_dir, exist_ok=True)

# List all folders that start with "patient" and are directories.
patient_dirs = [d for d in os.listdir(test_dir) 
                if d.startswith("patient") and os.path.isdir(os.path.join(test_dir, d))]
patient_dirs.sort()  # Ensure a consistent order before shuffling

print("Total number of patient folders in test set:", len(patient_dirs))

# Optional: Copy the patient folders to the output directory.
# This step physically separates the test data.
for patient in patient_dirs:
    src = os.path.join(test_dir, patient)
    dst = os.path.join(output_dir, patient)
    shutil.copytree(src, dst, dirs_exist_ok=True)

print(f"Test data copied to: {output_dir}")

# ---------------------------
# 2. Resample the Images to 1.25 x 1.25 mm In-Plane Resolution
# ---------------------------
# Set the target in-plane resolution (width and height) in mm.
target_resolution = [1.25, 1.25]

def resample_image(image, target_resolution, interpolator=sitk.sitkLinear):
    """
    Resample a 2D or 3D image so that its first two dimensions (x and y) have the specified resolution.
    The third dimension (if present) is left unchanged.
    """
    original_spacing = image.GetSpacing()   # e.g., (sx, sy, [sz])
    original_size = image.GetSize()           # e.g., (nx, ny, [nz])
    
    # Compute the new size for x and y
    new_size = list(original_size)
    new_size[0] = int(round(original_size[0] * original_spacing[0] / target_resolution[0]))
    new_size[1] = int(round(original_size[1] * original_spacing[1] / target_resolution[1]))
    
    # Create new spacing: update x and y; leave others unchanged.
    new_spacing = list(original_spacing)
    new_spacing[0] = target_resolution[0]
    new_spacing[1] = target_resolution[1]
    
    resampler = sitk.ResampleImageFilter()
    resampler.SetOutputSpacing(new_spacing)
    resampler.SetSize(new_size)
    resampler.SetOutputOrigin(image.GetOrigin())
    resampler.SetOutputDirection(image.GetDirection())
    resampler.SetInterpolator(interpolator)
    
    return resampler.Execute(image)

# Process the test directory.
print("Processing test directory:", output_dir)
for patient in os.listdir(output_dir):
    patient_path = os.path.join(output_dir, patient)
    if os.path.isdir(patient_path):
        for file in os.listdir(patient_path):
            # Process only files with "frame" in their name, excluding those with "4d".
            if 'frame' in file and '4d' not in file and file.endswith(('.nii', '.nii.gz')):
                file_path = os.path.join(patient_path, file)
                
                # Read the image using SimpleITK.
                image = sitk.ReadImage(file_path)
                spacing = image.GetSpacing()
                
                # Check if x and y spacing are already as desired (within a small tolerance).
                if abs(spacing[0] - target_resolution[0]) < 1e-6 and abs(spacing[1] - target_resolution[1]) < 1e-6:
                    print("Skipping already resampled:", file_path)
                    continue
                
                print("Resampling:", file_path)
                
                # Use nearest neighbor interpolation for ground truth images to preserve labels.
                if '_gt' in file:
                    interp = sitk.sitkNearestNeighbor
                else:
                    interp = sitk.sitkLinear
                
                # Resample the image.
                resampled_image = resample_image(image, target_resolution, interpolator=interp)
                
                # Overwrite the original file with the resampled image.
                sitk.WriteImage(resampled_image, file_path)

print("Resampling for test set complete.")

# ---------------------------
# 3. Intensity Normalization (Z-score)
# ---------------------------
def zscore_normalization(image):
    """
    Normalize the image using Z-score normalization:
      (image - mean) / std
    """
    # Convert the image to a NumPy array.
    arr = sitk.GetArrayFromImage(image)
    mean_val = np.mean(arr)
    std_val = np.std(arr)
    
    # Prevent division by zero.
    if std_val == 0:
        std_val = 1
    
    normalized_arr = (arr - mean_val) / std_val
    
    # Convert the normalized array back to a SimpleITK image and preserve metadata.
    normalized_image = sitk.GetImageFromArray(normalized_arr)
    normalized_image.CopyInformation(image)
    return normalized_image

# Process the test directory for normalization.
print("Processing directory for normalization:", output_dir)
for patient in os.listdir(output_dir):
    patient_path = os.path.join(output_dir, patient)
    if os.path.isdir(patient_path):
        for file in os.listdir(patient_path):
            # Process only files with "frame" in their name, excluding "4d" files.
            if 'frame' in file and '4d' not in file and file.endswith('.nii'):
                # Skip ground truth images.
                if '_gt' not in file:
                    file_path = os.path.join(patient_path, file)
                    print("Normalizing:", file_path)
                    
                    # Read the image.
                    image = sitk.ReadImage(file_path)
                    
                    # Normalize the image.
                    normalized_image = zscore_normalization(image)
                    
                    # Overwrite the original file with the normalized image.
                    sitk.WriteImage(normalized_image, file_path)

print("Intensity normalization complete for test set.")

# ---------------------------
# 4. Resize Images to 352 x 352 Pixels for 2D U-Net Input
# ---------------------------
# Desired output size for the network input.
target_size = (352, 352)

def resize_image_to_fixed_size(image, output_size, interpolator=sitk.sitkLinear):
    """
    Resize the in-plane (x, y) size of a 2D or 3D image to match output_size,
    while preserving spacing consistency and spatial orientation.
    """
    original_size = np.array(image.GetSize(), dtype=int)
    original_spacing = np.array(image.GetSpacing())
    output_size = np.array(output_size, dtype=int)

    # Compute new spacing for the resized image
    new_spacing = original_spacing[:2] * (original_size[:2] / output_size)

    if image.GetDimension() == 2:
        size = [output_size[0], output_size[1]]
        spacing = [float(new_spacing[0]), float(new_spacing[1])]
    elif image.GetDimension() == 3:
        size = [output_size[0], output_size[1], original_size[2]]
        spacing = [float(new_spacing[0]), float(new_spacing[1]), original_spacing[2]]
    else:
        raise ValueError("Only 2D or 3D images are supported.")

    resampler = sitk.ResampleImageFilter()
    resampler.SetOutputSpacing(spacing)
    resampler.SetSize([int(s) for s in size])  # <-- Fix here
    resampler.SetOutputOrigin(image.GetOrigin())
    resampler.SetOutputDirection(image.GetDirection())
    resampler.SetInterpolator(interpolator)

    return resampler.Execute(image)



# Process the test directory for resizing.
print("Processing directory for resizing:", output_dir)
for patient in os.listdir(output_dir):
    patient_path = os.path.join(output_dir, patient)
    if os.path.isdir(patient_path):
        for file in os.listdir(patient_path):
            # Process only files with "frame" in the name and ignore any with "4d".
            if 'frame' in file and '4d' not in file and file.endswith(('.nii', '.nii.gz')):
                file_path = os.path.join(patient_path, file)
                print("Resizing:", file_path)
                
                # Read the image.
                image = sitk.ReadImage(file_path)
                
                # Use nearest neighbor for ground truth images, linear for others.
                if '_gt' in file:
                    interp = sitk.sitkNearestNeighbor
                else:
                    interp = sitk.sitkLinear
                
                # Resize the image to 352 x 352 pixels.
                resized_image = resize_image_to_fixed_size(image, target_size, interpolator=interp)
                
                # Overwrite the original file with the resized image.
                sitk.WriteImage(resized_image, file_path)

print("Resizing complete for test set.")

# ---------------------------
# 5. Extract 2D Slices from Volumes for 2D U-Net Input
# ---------------------------
# Here we extract individual 2D slices from each preprocessed volume (both frame and ground truth)
# and save them in a new directory (test_2d). These slices will serve as individual test samples.

def extract_slices(volume_image):
    """
    Extract 2D slices from a 3D image along the z-dimension using SimpleITK's ExtractImageFilter.
    Returns a list of 2D SimpleITK images.
    """
    # Get the full 3D image size (order: [width, height, depth])
    size = list(volume_image.GetSize())
    slices = []
    extractor = sitk.ExtractImageFilter()
    
    # Iterate over the z-dimension
    for z in range(size[2]):
        # Set extraction index : start at x=0, y=0, and at the current z slice.
        extraction_index = [0, 0, z]
        # Set extraction size: full width and height, but only one slice in z.
        extraction_size = [size[0], size[1], 0]  # 0 means "extract one slice" along that dimension.
        
        extractor.SetSize(extraction_size)
        extractor.SetIndex(extraction_index)
        slice_img = extractor.Execute(volume_image)
        slices.append(slice_img)
    return slices

# Create output directory for extracted 2D slices.
test_2d_dir = './dataset/secret_test_set_2d'
os.makedirs(test_2d_dir, exist_ok=True)

# Process the test directory to extract 2D slices.
print(f"Extracting 2D slices from {output_dir} ...")
for patient in os.listdir(output_dir):
    patient_path = os.path.join(output_dir, patient)
    if os.path.isdir(patient_path):
        # Create corresponding patient folder in the output directory.
        out_patient_dir = os.path.join(test_2d_dir, patient)
        os.makedirs(out_patient_dir, exist_ok=True)
        for file in os.listdir(patient_path):
            if 'frame' in file and file.endswith(('.nii', '.nii.gz')):
                file_path = os.path.join(patient_path, file)
                print("Extracting slices from:", file_path)
                image = sitk.ReadImage(file_path)
                slices = extract_slices(image)
                # Save each 2D slice as a separate file.
                for i, slice_img in enumerate(slices):
                    base_name = os.path.splitext(file)[0]
                    new_filename = f"{base_name}_slice{i:02d}.nii"
                    new_file_path = os.path.join(out_patient_dir, new_filename)
                    sitk.WriteImage(slice_img, new_file_path)
                    print("Saved slice:", new_file_path)

print("Slice extraction complete for test set.")

Total number of patient folders in test set: 25
Test data copied to: ./dataset/secret_test_set_extracted
Processing test directory: ./dataset/secret_test_set_extracted
Resampling: ./dataset/secret_test_set_extracted/patient151/patient151_frame01.nii.gz
Resampling: ./dataset/secret_test_set_extracted/patient152/patient152_frame01.nii.gz
Skipping already resampled: ./dataset/secret_test_set_extracted/patient153/patient153_frame01.nii.gz
Resampling: ./dataset/secret_test_set_extracted/patient154/patient154_frame01.nii.gz
Resampling: ./dataset/secret_test_set_extracted/patient155/patient155_frame01.nii.gz
Skipping already resampled: ./dataset/secret_test_set_extracted/patient156/patient156_frame01.nii.gz
Resampling: ./dataset/secret_test_set_extracted/patient157/patient157_frame01.nii.gz
Resampling: ./dataset/secret_test_set_extracted/patient158/patient158_frame01.nii.gz
Resampling: ./dataset/secret_test_set_extracted/patient159/patient159_frame01.nii.gz
Resampling: ./dataset/secret_test_s