In [2]:
import numpy as np
import SimpleITK as sitk
import matplotlib.pyplot as plt
import os
from skimage import exposure

In [None]:
# for filename in os.listdir(dataset_path):
#     if filename.endswith('.nii'):
#         file_path = os.path.join(dataset_path, filename)

#         # Read the image
#         image = sitk.ReadImage(file_path)

#         # Get the size of the image
#         size = image.GetSize()
#         origin = image.GetOrigin()
#         spacing = image.GetSpacing()
#         direction = image.GetDirection()
#         pixeltype = image.GetPixelID()

#         print(f"Image size, origin, spacing, direction, pixel type of {filename}: {size}, {origin}, {spacing}, {direction}, {pixeltype}")

# bias-field correction

In [5]:
import SimpleITK as sitk

def preprocess_image(raw_img_path, out_path):
  """
  Apply bias-field correction on MRI images with shrink factor of 4
  args: path to the raw image
  path to the output folder
  """
    # Read the image
  raw_img_sitk = sitk.ReadImage(raw_img_path, sitk.sitkFloat32)

    # Intensity rescaling and thresholding
  transformed = sitk.RescaleIntensity(raw_img_sitk, 0, 255)
  transformed = sitk.LiThreshold(transformed, 0, 1)
  head_mask = transformed

    # Shrink the images
  shrinkFactor = 4
  inputImage = sitk.Shrink(raw_img_sitk, [shrinkFactor] * raw_img_sitk.GetDimension())
  maskImage = sitk.Shrink(head_mask, [shrinkFactor] * raw_img_sitk.GetDimension())

    # Bias field correction
  bias_corrector = sitk.N4BiasFieldCorrectionImageFilter()
  corrected = bias_corrector.Execute(inputImage, maskImage)

    # Get the log bias field and correct the full resolution image
  log_bias_field = bias_corrector.GetLogBiasFieldAsImage(raw_img_sitk)
  corrected_image_full_resolution = raw_img_sitk / sitk.Exp(log_bias_field)

    # Write the corrected image to the output path
  sitk.WriteImage(corrected_image_full_resolution, out_path)
preprocess_image("data/Validation_Set/IBSR_17/IBSR_17.nii", "output.nii")

In [7]:
import os  # Import the os module

def apply_clahe_and_histogram_equalization(input_path, output_folder):
    """
    Apply CLAHE and Histogram Equalization to a single brain MRI image and save the results.

    Args:
        input_path (str): Path to the input brain MRI image in .nii format after bias-field correction
        output_folder (str): Path to the output folder to save the enhanced images.
    """
    # Load the MRI image
    input_image = sitk.ReadImage(input_path, sitk.sitkFloat32)
    array_image = sitk.GetArrayFromImage(input_image)
    image_array_01 = exposure.rescale_intensity(array_image, out_range=(0, 1))
    # CLAHE
    #eq_img = exposure.equalize_adapthist(image_array_01)
    # Histogram Equalization
    img_eq = exposure.equalize_hist(image_array_01)

    min_val, max_val = array_image.min(), array_image.max()
    #image_clahe_original_range = exposure.rescale_intensity(eq_img, out_range=(min_val, max_val))
    image_hist_eq_original_range = exposure.rescale_intensity(img_eq, out_range=(min_val, max_val))

    #final_img_clahe = sitk.GetImageFromArray(image_clahe_original_range.astype(array_image.dtype))
    final_img_hist_eq = sitk.GetImageFromArray(image_hist_eq_original_range.astype(array_image.dtype))
    #final_img_clahe.CopyInformation(input_image)
    final_img_hist_eq.CopyInformation(input_image)

    #output_path_clahe = os.path.join(output_folder, "IBRS_01_clahe.nii")
    output_path_hist_eq = os.path.join(output_folder, "IBRS_17_pp.nii")


    #sitk.WriteImage(final_img_clahe, output_path_clahe)
    sitk.WriteImage(final_img_hist_eq, output_path_hist_eq)

input_image_path = "output.nii"
output_folder = "pp/"
apply_clahe_and_histogram_equalization(input_image_path, output_folder)

# Registration

In [12]:
import os
import SimpleITK as sitk

def register_images(fixed_folder, moving_folder, output_folder):
    try:
        # Create the output folder if it doesn't exist
        if not os.path.exists(output_folder):
            os.makedirs(output_folder)

        # Get lists of all files in the fixed and moving folders
        fixed_images = [file for file in os.listdir(fixed_folder) if file.endswith(".nii.gz")]
        moving_images = [file for file in os.listdir(moving_folder) if file.endswith(".nii.gz")]

        if not fixed_images:
            print("No '.nii' files found in the fixed image folder.")
            return

        if not moving_images:
            print("No '.nii' files found in the moving image folder.")
            return

        for i, fixed_image_name in enumerate(fixed_images):
            fixed_image_path = os.path.join(fixed_folder, fixed_image_name)

            for j, moving_image_name in enumerate(moving_images):
                moving_image_path = os.path.join(moving_folder, moving_image_name)

                sitk.simpleel
                # Perform registration for the current pair of fixed and moving images
                elastixImageFilter = sitk.ElastixImageFilter()
                elastixImageFilter.SetFixedImage(sitk.ReadImage(fixed_image_path))
                elastixImageFilter.SetMovingImage(sitk.ReadImage(moving_image_path))

                parameterMapVector = sitk.VectorOfParameterMap()
                affine_parameter_map = sitk.GetDefaultParameterMap("affine")
                bspline_parameter_map = sitk.GetDefaultParameterMap("bspline")

                affine_parameter_map["Metric"] = ["AdvancedMattesMutualInformation"]
                bspline_parameter_map["Metric"] = ["AdvancedMattesMutualInformation"]

                affine_parameter_map["NumberOfResolutions"] = ["4"]
                bspline_parameter_map["NumberOfResolutions"] = ["4"]

                parameterMapVector.append(affine_parameter_map)
                parameterMapVector.append(bspline_parameter_map)

                elastixImageFilter.SetParameterMap(parameterMapVector)
                registered_image = elastixImageFilter.Execute()

                # Save the registered image in the output folder
                output_image_name = f"registered_{fixed_image_name}_to_{moving_image_name}"
                output_image_path = os.path.join(output_folder, output_image_name)
                sitk.WriteImage(registered_image, output_image_path)

                print(f"{i+1}st image is registered to {j+1}.")

    except Exception as e:
        print(f"An error occurred: {str(e)}")
fixed_folder = "data/Validation_Set/IBSR_11/"
moving_folder = "data/Training_Set/IBSR_01/"
output_folder = "pp"

register_images(fixed_folder, moving_folder, output_folder)

An error occurred: module 'SimpleITK' has no attribute 'ElastixImageFilter'


# Multi resolution

In [None]:
from scipy.optimize import minimize
from scipy.ndimage import map_coordinates
from skimage.color import rgb2gray
from skimage.io import imread
import matplotlib.pyplot as plt
import numpy as np
import source.fullAffine
import cv2 as cv
import datetime

def create_pyramid_levels(image, num_level = 3):
    # Create an image pyramid (downscale)
    pyramid = [image] # Level 0

    for _ in range(num_level):
        image = cv.pyrDown(image)
        pyramid.append(image)

    return np.array(pyramid, dtype=object)


def show_pyramid(pyramid):
    # Plot the pyramid levels
    plt.figure(figsize=(12, 6))

    num_levels = len(pyramid)

    for i, level in enumerate(pyramid):
        plt.subplot(1, num_levels+1, i+1)
        plt.imshow(np.array(level, dtype=np.float64))
        plt.title(f'Level {i}')

    plt.tight_layout()
    plt.show()

def one_level_registration(Imoving, Ifixed, scale, prev_x = None, mtype='sd', ttype='r', epsilon=0.001, order=3):
    if ttype == 'r':
        # Start rotation scaling factors from 1.
        x = np.array([0, 0, 0, 1, 1, 0, 0])
        if scale.size != 7:
            raise ValueError('Invalid scale')
    elif ttype == 'mrr':
        # Start parameters from prev x values.
        x = prev_x
        if scale.size != 7:
            raise ValueError('Invalid scale')
    else:
        raise ValueError('Unknown registration type')

    # Divide by scale because par will be also scaled.
    x = x / scale
    # print("scale:  ", scale)
    # print("x_init: ", x)

    result = minimize(
        lambda par: affine_registration_function(par, scale.astype(np.double), Imoving, Ifixed, mtype, ttype = 'r', order = order),
        x.astype(np.double), options={'eps':epsilon, 'maxls':200},
        method='L-BFGS-B'
    )

    # Get scaled result.
    x = result.x
    # Get the real value.
    x = x * scale

    print(f'Parameters: {x}')

    if ttype == 'r' or ttype == 'mrr': # rigid transformation

        M = affine_transform_matrix(x)

    else:
        raise NotImplementedError

    return affine_transform_2d_double(Imoving, M, order), x



def multiresolution (Imoving_path, Ifixed_path,
                     mtype='sd', ttype='mrr',
                     tscale=[0.001, 0.001, 1, 1, 1, 0.001, 0.001], epsilon = 0.001, order= 3, num_level = 4):

    scale = np.array(tscale)

    if ttype == 'r':
        if scale.size != 7:
            raise ValueError('Invalid scale')
    elif ttype == 'mrr':
        # Start parameters from prev x values.
        if scale.size != 7:
            raise ValueError('Invalid scale')
        if num_level > 7:
            raise ValueError('No more than 7 levels.')
    else:
        raise ValueError('Unknown registration type')


    Imoving = rgb2gray(imread(Imoving_path)[:,:,0:3]).astype(np.double)
    Ifixed = rgb2gray(imread(Ifixed_path)[:,:,0:3]).astype(np.double)

    scale = np.array(tscale)

    # Create an image pyramid (downscale)
    pyrImoving = create_pyramid_levels(Imoving, num_level=num_level)
    pyrIfixed = create_pyramid_levels(Ifixed, num_level=num_level)

    show_pyramid(pyrImoving)
    show_pyramid(pyrIfixed)

    pyrIcor = np.empty_like(pyrImoving, dtype=object)

    # Initialize prev_x to start rotation scaling factors from 1.
    prev_x = np.array([0, 0, 0, 1, 1, 0, 0])

    pyrIcor = []
    for level in range(num_level, -1, -1):
        print("I am at Level ", level)
        # TODO: Access the pyramid.

        currentImoving = pyrImoving[level]
        currentIfixed = pyrIfixed[level]

        # Return level registration.
        Icor, x = one_level_registration(currentImoving, currentIfixed, scale, prev_x = prev_x,
                                      mtype=mtype, ttype=ttype, epsilon=epsilon, order=order)

        print(Icor.shape)
        pyrIcor.append(Icor)
        prev_x = x

    pyrIcor = np.array(pyrIcor, dtype=object)
    show_pyramid(pyrIcor)

    Icor = np.array(Icor, dtype=np.float64)
    # Show the registration results
    plt.figure(figsize=(10, 8))

    plt.subplot(2, 2, 1)
    plt.imshow(Ifixed, cmap='gray')
    plt.title('Fixed Image')

    plt.subplot(2, 2, 2)
    plt.imshow(Imoving, cmap='gray')
    plt.title('Moving Image')

    plt.subplot(2, 2, 3)
    plt.imshow(Icor, cmap='gray')
    plt.title('Transformed Moving Image')

    plt.subplot(2, 2, 4)
    plt.imshow(np.abs(Ifixed - Icor), cmap='viridis')
    plt.title(f'Registration Error (mtype={mtype}, ttype={ttype}, scale={scale})')

    plt.tight_layout()
    plt.show()

    return Icor, Ifixed


start = datetime.datetime.now()
Iregistered, Ifixed = multiresolution("brain3.png", "brain1.png", tscale = [1, 1, 1, 1, 1, 0.0001, 0.0001],
                                                mtype = 'ngc', ttype = 'mrr', num_level=4)
end = datetime.datetime.now()

print(f'\nRegistration took {(end-start).seconds} sec.')

ModuleNotFoundError: No module named 'source'