In [7]:
import os
import numpy as np
import nibabel as nib
!pip install SimpleITK nilearn
import SimpleITK as sitk
from nilearn.masking import compute_brain_mask
import cv2 as cv

Collecting nilearn
  Downloading nilearn-0.11.1-py3-none-any.whl.metadata (9.3 kB)
Downloading nilearn-0.11.1-py3-none-any.whl (10.5 MB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m10.5/10.5 MB[0m [31m36.4 MB/s[0m eta [36m0:00:00[0m
[?25hInstalling collected packages: nilearn
Successfully installed nilearn-0.11.1


In [8]:
def load_nifti(path):
    """Load a NIfTI file and return data and affine."""
    img = nib.load(path)
    return img.get_fdata(), img.affine


def save_nifti(data, affine, out_path):
    """Save NumPy data as a NIfTI file."""
    img = nib.Nifti1Image(data, affine)
    nib.save(img, out_path)


In [9]:
def n4_bias_field_correction(img_sitk, mask_sitk=None):
    """Apply N4 bias field correction using SimpleITK."""
    corrector = sitk.N4BiasFieldCorrectionImageFilter()
    return corrector.Execute(img_sitk, mask_sitk) if mask_sitk else corrector.Execute(img_sitk)

In [10]:
def skull_strip(data, affine):
    """Compute and apply a brain mask using Nilearn."""
    mask = compute_brain_mask(nib.Nifti1Image(data, affine))
    return data * mask.get_fdata(), mask.get_fdata().astype(bool)

In [11]:
def register_to_template(moving_sitk, template_path):
    """Register to template using rigid + affine transform."""
    fixed_sitk = sitk.ReadImage(template_path)
    registration = sitk.ImageRegistrationMethod()
    registration.SetMetricAsMattesMutualInformation(50)
    registration.SetOptimizerAsRegularStepGradientDescent(1.0, 1e-6, 200)
    transform = sitk.CenteredTransformInitializer(
        fixed_sitk, moving_sitk, sitk.Euler3DTransform(),
        sitk.CenteredTransformInitializerFilter.GEOMETRY)
    registration.SetInitialTransform(transform)
    registration.SetInterpolator(sitk.sitkLinear)
    final_transform = registration.Execute(fixed_sitk, moving_sitk)
    resampler = sitk.ResampleImageFilter()
    resampler.SetReferenceImage(fixed_sitk)
    resampler.SetTransform(final_transform)
    resampler.SetInterpolator(sitk.sitkLinear)
    return resampler.Execute(moving_sitk)

In [12]:
def resample_image(img_sitk, new_spacing=[1.0, 1.0, 1.0]):
    """Resample image to isotropic 1mm spacing."""
    orig_spacing = img_sitk.GetSpacing()
    orig_size = img_sitk.GetSize()
    new_size = [int(np.round(osz * ospc / nspc)) for osz, ospc, nspc in zip(orig_size, orig_spacing, new_spacing)]
    resampler = sitk.ResampleImageFilter()
    resampler.SetOutputSpacing(new_spacing)
    resampler.SetSize(new_size)
    resampler.SetOutputDirection(img_sitk.GetDirection())
    resampler.SetOutputOrigin(img_sitk.GetOrigin())
    resampler.SetInterpolator(sitk.sitkLinear)
    return resampler.Execute(img_sitk)

In [13]:
def intensity_normalization(data, mask=None):
    """Z-score within mask or min-max globally."""
    if mask is not None:
        brain_pixels = data[mask]
        return (data - brain_pixels.mean()) / brain_pixels.std()
    else:
        return (data - data.min()) / (data.max() - data.min())

In [None]:
def CLAHE(meta_data, clip_limit, grid_size):
    
    clahe = cv.createCLAHE(clipLimit=clip_limit, tileGridSize = grid_size)

    ## Slicing axially
    slices = [ meta_data[:, :, i] for i in range(meta_data.shape[2]) ]
    slices = [ cv.normalize(slc, None, 0, 255, cv.NORM_MINMAX).astype(np.uint8) for slc in slices ]
    enhanced_slices = [ clahe.apply(slc) for slc in slices ]

    enhanced_meta_data = np.zeros(meta_data.shape)
    for i in range(enhanced_meta_data.shape[2]):
        enhanced_meta_data[:, :, i] = enhanced_slices[i]
    #print(enhanced_meta_data.shape)
    return enhanced_meta_data

In [14]:
def preprocess_nifti(input_path, template_path=None, output_dir=None):

    # Load the MRI volume
    data, affine = load_nifti(input_path)

    # Convert to SimpleITK image for processing
    sitk_img = sitk.GetImageFromArray(data.astype(np.float32))
    sitk_img.SetSpacing(tuple(np.diag(affine)[:3]))

    # Bias field correction
    corrected_img = sitk.GetArrayFromImage(n4_bias_field_correction(sitk_img))

    # Skull-stripping
    stripped_img, mask = skull_strip(corrected_img, affine)

    stripped_img = CLAHE(stripped_img, 1.5, (16, 16) ) ## CLAHE

    # Optional: Register to standard template
    if template_path:
        reg_img_sitk = sitk.GetImageFromArray(stripped_img.astype(np.float32))
        registered = register_to_template(reg_img_sitk, template_path)
        processed = sitk.GetArrayFromImage(registered)
        proc_affine = np.eye(4)
    else:
        processed = stripped_img
        proc_affine = affine

    # Resample to standard voxel spacing
    sitk_proc = sitk.GetImageFromArray(processed.astype(np.float32))
    sitk_proc.SetSpacing(tuple(np.diag(proc_affine)[:3]))
    resampled = sitk.GetArrayFromImage(resample_image(sitk_proc))

    # Resample brain mask and normalize
    mask_resampled = resample_image(sitk.GetImageFromArray(mask.astype(np.uint8)))
    norm_img = intensity_normalization(resampled, mask=sitk.GetArrayFromImage(mask_resampled) > 0)

    # Save if path is given
    if output_dir:
        os.makedirs(output_dir, exist_ok=True)
        fname = os.path.basename(input_path).replace('.nii', '_preproc.nii')
        save_nifti(norm_img, proc_affine, os.path.join(output_dir, fname))
        print(f"Saved preprocessed scan to {output_dir}/{fname}")

    return norm_img, proc_affine

In [None]:
preprocessed_data, new_affine = preprocess_nifti(
    input_path='',
    template_path='',
    output_dir=''
)