In [1]:
import nibabel as nib

In [15]:
import SimpleITK as sitk
import numpy as np

def resample_image(image, reference_image, is_label=False):
    resample = sitk.ResampleImageFilter()
    resample.SetReferenceImage(reference_image)
    resample.SetOutputSpacing(reference_image.GetSpacing())
    resample.SetSize(reference_image.GetSize())
    resample.SetOutputDirection(reference_image.GetDirection())
    resample.SetOutputOrigin(reference_image.GetOrigin())
    resample.SetTransform(sitk.Transform())
    resample.SetDefaultPixelValue(image.GetPixelIDValue())
    
    if is_label:
        resample.SetInterpolator(sitk.sitkNearestNeighbor)
    else:
        resample.SetInterpolator(sitk.sitkLinear)
    
    return resample.Execute(image)

def register_ct_to_mri(ct_path, mri_path, output_path):
    # Read images
    ct_image = sitk.ReadImage(ct_path, sitk.sitkFloat32)
    mri_image = sitk.ReadImage(mri_path, sitk.sitkFloat32)
    
    # Create a reference image with MRI dimensions but CT voxel size
    reference_image = sitk.Image(mri_image.GetSize(), sitk.sitkFloat32)
    reference_image.SetSpacing(ct_image.GetSpacing())
    reference_image.SetDirection(mri_image.GetDirection())
    reference_image.SetOrigin(mri_image.GetOrigin())
    
    # Resample MRI to match CT resolution
    mri_resampled = resample_image(mri_image, reference_image)
    
    # Initialize registration framework
    registration_method = sitk.ImageRegistrationMethod()
    
    # Set up similarity metric (Mutual Information for multi-modality)
    registration_method.SetMetricAsMattesMutualInformation(numberOfHistogramBins=50)
    registration_method.SetMetricSamplingStrategy(registration_method.RANDOM)
    registration_method.SetMetricSamplingPercentage(0.01)
    
    # Set optimizer (Gradient Descent)
    registration_method.SetOptimizerAsGradientDescent(learningRate=1.0, numberOfIterations=100)
    registration_method.SetOptimizerScalesFromPhysicalShift()
    
    # Set interpolator
    registration_method.SetInterpolator(sitk.sitkLinear)
    
    # Set up initial transform (center of mass alignment)
    initial_transform = sitk.CenteredTransformInitializer(mri_resampled, ct_image, 
                                                          sitk.Euler3DTransform(), 
                                                          sitk.CenteredTransformInitializerFilter.GEOMETRY)
    registration_method.SetInitialTransform(initial_transform)
    
    # Perform registration
    final_transform = registration_method.Execute(mri_resampled, ct_image)
    
    # Apply transform and resample CT to original MRI space
    resampler = sitk.ResampleImageFilter()
    resampler.SetReferenceImage(mri_image)
    resampler.SetInterpolator(sitk.sitkLinear)
    resampler.SetDefaultPixelValue(ct_image.GetPixelIDValue())
    resampler.SetTransform(final_transform)
    
    ct_registered = resampler.Execute(ct_image)
    
    # Save result
    sitk.WriteImage(ct_registered, output_path)
    
    print(f"Final metric value: {registration_method.GetMetricValue()}")
    print(f"Optimizer stop condition: {registration_method.GetOptimizerStopConditionDescription()}")


In [16]:
# Usage
ct_path = "/Users/yibeichen/Desktop/fusi/microCT/reoriented/MASK_Marmoset_brain_B_200micron.nii.gz"
mri_path = "/Users/yibeichen/Desktop/fusi/atlas/template_T1w_brain.nii.gz"
output_path = "/Users/yibeichen/Desktop/fusi/microCT/reoriented/registered_Marmoset_brain_B_200micron.nii.gz"

register_ct_to_mri(ct_path, mri_path, output_path)

Final metric value: -0.00014977091781087233
Optimizer stop condition: GradientDescentOptimizerv4Template: Convergence checker passed at iteration 9.


In [14]:
import SimpleITK as sitk

# Load the CT and MRI images
ct_image = sitk.ReadImage(ct_image_path)
mri_image = sitk.ReadImage(mri_template_path)

ct_image_float32 = sitk.Cast(ct_image, sitk.sitkFloat32)

# Step 1: Perform Registration (rigid registration)

# Initialize transform using the image centers
initial_transform = sitk.CenteredTransformInitializer(mri_image, 
                                                      ct_image_float32, 
                                                      sitk.Euler3DTransform(), 
                                                      sitk.CenteredTransformInitializerFilter.GEOMETRY)

# Set up the registration method
registration_method = sitk.ImageRegistrationMethod()

# Similarity metric: Mutual Information is good for multimodal images like CT and MRI
registration_method.SetMetricAsMattesMutualInformation(numberOfHistogramBins=50)

# Optimizer: Regular step gradient descent
registration_method.SetOptimizerAsRegularStepGradientDescent(learningRate=1.0, 
                                                             minStep=0.001, 
                                                             numberOfIterations=200)

# Interpolation method
registration_method.SetInterpolator(sitk.sitkLinear)

# Set initial transform (rigid)
registration_method.SetInitialTransform(initial_transform, inPlace=False)

# Execute the registration
final_transform = registration_method.Execute(mri_image, ct_image_float32)

print(f"Final metric value: {registration_method.GetMetricValue()}")
print(f"Optimizer stop condition: {registration_method.GetOptimizerStopConditionDescription()}")

# Step 2: Resample the CT image to align with the MRI and match the MRI's voxel size

# Get MRI's voxel size (spacing) and size
mri_spacing = mri_image.GetSpacing()
mri_size = mri_image.GetSize()

# Set up the resampler to apply the transformation
resample = sitk.ResampleImageFilter()
resample.SetReferenceImage(mri_image)  # Reference MRI image for resampling
resample.SetTransform(final_transform)  # Use the computed transform from registration
resample.SetInterpolator(sitk.sitkLinear)  # Use linear interpolation
resample.SetOutputSpacing(mri_spacing)  # Match the voxel size of the MRI
resample.SetSize(mri_size)  # Match the size of the MRI image

# Execute resampling
resampled_ct = resample.Execute(ct_image_float32)

# Step 3: Save the registered and resampled CT image
sitk.WriteImage(resampled_ct, 'registered_and_resampled_CT.nii')

# Save the final transform for future use
sitk.WriteTransform(final_transform, 'ct_to_mri_transform.tfm')

Final metric value: -0.012567775187517002
Optimizer stop condition: RegularStepGradientDescentOptimizerv4: Maximum number of iterations (200) exceeded.


In [5]:
ct_img = nib.load(ct_image_path)
mri_img = nib.load(mri_template_path)

In [8]:
def get_img_info(img_path):
    # Load the image
    img = nib.load(img_path)
    # Get image data
    data = img.get_fdata()
    # Get the affine matrix
    affine = img.affine
    # Get image shape
    shape = data.shape
    # Get header information
    header = img.header
    # Extract voxel size (zoom) from the header
    voxel_size = header.get_zooms()

    # Extract any other relevant header information
    data_type = header.get_data_dtype()  # Data type of the image
    dim_info = header['dim']  # Dimension information (includes shape, data type, etc.)
    pixdim = header['pixdim']  # Pixel dimensions
    qform_code = header['qform_code']  # Qform information
    sform_code = header['sform_code']  # Sform information

    # Print all information
    print("Image Shape:", shape)
    print("Affine Matrix:\n", affine)
    print("Voxel Size (Zoom):", voxel_size)
    print("Data Type:", data_type)
    print("Dimension Info:", dim_info)
    print("Pixel Dimensions (pixdim):", pixdim)
    print("Qform Code:", qform_code)
    print("Sform Code:", sform_code)


In [9]:
get_img_info(ct_image_path)

Image Shape: (133, 113, 189)
Affine Matrix:
 [[ 9.97604311e-01  4.87278774e-02  4.91040573e-02 -1.40135498e+02]
 [-3.87489572e-02 -1.94414392e-01  9.80153859e-01 -1.81165329e+02]
 [ 5.73073514e-02 -9.79708433e-01 -1.92060485e-01  9.28858414e+01]
 [ 0.00000000e+00  0.00000000e+00  0.00000000e+00  1.00000000e+00]]
Voxel Size (Zoom): (1.0, 1.0, 1.0)
Data Type: uint8
Dimension Info: [  3 133 113 189   1   1   1   1]
Pixel Dimensions (pixdim): [1. 1. 1. 1. 0. 0. 0. 0.]
Qform Code: 1
Sform Code: 1


In [10]:
get_img_info(mri_template_path)

Image Shape: (147, 200, 135)
Affine Matrix:
 [[  0.2          0.           0.         -14.60000038]
 [  0.           0.2          0.         -14.40000057]
 [  0.           0.           0.2         -3.60000014]
 [  0.           0.           0.           1.        ]]
Voxel Size (Zoom): (0.2, 0.2, 0.2)
Data Type: float32
Dimension Info: [  3 147 200 135   1   1   1   1]
Pixel Dimensions (pixdim): [1.  0.2 0.2 0.2 0.  0.  0.  0. ]
Qform Code: 1
Sform Code: 0
