In [2]:
import SimpleITK as sitk
from radiomics import featureextractor
import json
import glob
import pandas as pd
import re
import os
import logging
from tqdm.notebook import tqdm
from pathlib import Path
import pypickle

In [None]:
!python3 preprocessing.py --username="daniil.tikhonov" --data="/home/daniil.tikhonov/mri/data/Lumiere" --saveto="/home/daniil.tikhonov/mri/new_atlas_registered"

--- Starting Preprocessing: Registering all scans to MNI atlas ---
Registering cases to atlas:  18%|█▊        | 66/361 [25:46<1:55:34, 23.51s/case]

In [2]:
USERNAME = "daniil.tikhonov"
MNI_TEMPLATE_PATH = Path("/home") / USERNAME / "fsl" / "data" / "standard" / "MNI152_T1_1mm_brain.nii.gz"

if not os.path.exists(MNI_TEMPLATE_PATH):
    raise FileNotFoundError(f"MNI Template not found at: {MNI_TEMPLATE_PATH}. Please update the path.")

data_dir = Path("/home") / USERNAME / "mri" / "data" / "Lumiere"
patients_file = os.path.join(data_dir, "patients.json")

atlas_dir = Path(os.getcwd()) / "atlas_mapping"
atlas_scans_dir = atlas_dir / "scans"
atlas_segs_dir = atlas_dir / "segmentations"
os.makedirs(atlas_scans_dir, exist_ok=True)
os.makedirs(atlas_segs_dir, exist_ok=True)

In [None]:
def get_patient_id_from_path(path):
    match = re.search(r'Patient-(\d+)', path)
    if match: return f'patient_{match.group(1).zfill(3)}'
    else: raise ValueError(f"No valid patient ID found in path: {path}")

def load_patients(patients_file):
    with open(patients_file) as f: return json.load(f)

In [None]:
def register_to_atlas(fixed_image, moving_t1_image, moving_other_image, is_segmentation=False):
    """
    Registers an image to the atlas space using a transform derived from its corresponding T1 scan.
    
    Args:
        fixed_image (sitk.Image): The MNI atlas template image.
        moving_t1_image (sitk.Image): The patient's T1 scan (used to calculate the transform).
        moving_other_image (sitk.Image): The image to be transformed (can be T1c, T2, FLAIR, or seg).
        is_segmentation (bool): If True, use Nearest Neighbor interpolation.
        
    Returns:
        sitk.Image: The resampled/registered image.
    """
    # Use a robust registration method
    registration_method = sitk.ImageRegistrationMethod()
    
    # Similarity metric
    registration_method.SetMetricAsMattesMutualInformation(numberOfHistogramBins=50)
    registration_method.SetMetricSamplingStrategy(registration_method.RANDOM)
    registration_method.SetMetricSamplingPercentage(0.01)
    
    # Interpolator
    registration_method.SetInterpolator(sitk.sitkLinear)
    
    # Optimizer
    registration_method.SetOptimizerAsGradientDescent(
        learningRate=0.33,
        numberOfIterations=300,
        convergenceMinimumValue=1e-6,
        convergenceWindowSize=10
    )
    registration_method.SetOptimizerScalesFromPhysicalShift()
    
    # Setup for the transform
    initial_transform = sitk.CenteredTransformInitializer(fixed_image, 
                                                          moving_t1_image, 
                                                          sitk.Euler3DTransform(), 
                                                          sitk.CenteredTransformInitializerFilter.GEOMETRY)
    
    registration_method.SetInitialTransform(initial_transform, inPlace=False)
    
    # Execute the registration to find the transform
    final_transform = registration_method.Execute(sitk.Cast(fixed_image, sitk.sitkFloat32), 
                                                  sitk.Cast(moving_t1_image, sitk.sitkFloat32))

    # --- Apply the found transform to the 'moving_other_image' ---
    resampler = sitk.ResampleImageFilter()
    resampler.SetReferenceImage(fixed_image) # Resample to the same grid as the atlas
    resampler.SetTransform(final_transform)
    
    # Use appropriate interpolation
    if is_segmentation:
        resampler.SetInterpolator(sitk.sitkNearestNeighbor)
    else:
        resampler.SetInterpolator(sitk.sitkLinear)
        
    resampler.SetDefaultPixelValue(0) # Pad with black
    
    return resampler.Execute(moving_other_image)


# === Preprocessing Step to Create Atlas-Registered Files ===

def preprocess_and_register_all_scans(patients, data_dir):
    """
    Iterates through all scans, registers them to the MNI atlas, and saves them.
    This version correctly handles inconsistencies in patients.json.
    """
    print("--- Starting Preprocessing: Registering all scans to MNI atlas ---")
    fixed_image = sitk.ReadImage(MNI_TEMPLATE_PATH, sitk.sitkFloat32)
    total_cases = sum(len(cases) for cases in patients.values())
    pbar = tqdm(total=total_cases, desc="Registering cases to atlas", unit="case")

    # The outer `pid` from `patients.items()` is unreliable.
    for unreliable_pid, cases in patients.items():
        for cid, meta in cases.items():
            try:
                true_pid = get_patient_id_from_path(meta['baseline_registered'])

                # --- Process BASELINE scans ---
                baseline_t1_atlas_path = os.path.join(atlas_scans_dir, f"{true_pid}_{cid}_baseline_T1.nii.gz")
                if not os.path.exists(baseline_t1_atlas_path):
                    baseline_t1_path = f"{data_dir}/{meta['baseline_registered'].replace('./', '')}/{meta['baseline_registered'].replace('./images_registered/', '')}_0000.nii.gz"
                    moving_t1_baseline = sitk.ReadImage(baseline_t1_path, sitk.sitkFloat32)
                    
                    for mri_type_idx, mri_type_name in enumerate(['T1', 'T1CE', 'T2', 'FLAIR']):
                        image_path = f"{data_dir}/{meta['baseline_registered'].replace('./', '')}/{meta['baseline_registered'].replace('./images_registered/', '')}_{mri_type_idx:04d}.nii.gz"
                        moving_image = sitk.ReadImage(image_path, sitk.sitkFloat32)
                        registered_image = register_to_atlas(fixed_image, moving_t1_baseline, moving_image)
                        # Use 'true_pid' for saving the file.
                        output_path = os.path.join(atlas_scans_dir, f"{true_pid}_{cid}_baseline_{mri_type_name}.nii.gz")
                        sitk.WriteImage(registered_image, output_path)

                    seg_path = f"{data_dir}/{meta['baseline_seg_registered'].replace('./', '')}"
                    moving_seg = sitk.ReadImage(seg_path)
                    registered_seg = register_to_atlas(fixed_image, moving_t1_baseline, moving_seg, is_segmentation=True)
                    # Use 'true_pid' for saving the file.
                    output_seg_path = os.path.join(atlas_segs_dir, f"{true_pid}_{cid}_baseline_seg.nii.gz")
                    sitk.WriteImage(registered_seg, output_seg_path)

                # --- Process FOLLOWUP scans ---
                followup_t1_atlas_path = os.path.join(atlas_scans_dir, f"{true_pid}_{cid}_followup_T1.nii.gz")
                if not os.path.exists(followup_t1_atlas_path):
                    followup_t1_path = f"{data_dir}/{meta['followup_registered'].replace('./', '')}/{meta['followup_registered'].replace('./images_registered/', '')}_0000.nii.gz"
                    moving_t1_followup = sitk.ReadImage(followup_t1_path, sitk.sitkFloat32)

                    for mri_type_idx, mri_type_name in enumerate(['T1', 'T1CE', 'T2', 'FLAIR']):
                        image_path = f"{data_dir}/{meta['followup_registered'].replace('./', '')}/{meta['followup_registered'].replace('./images_registered/', '')}_{mri_type_idx:04d}.nii.gz"
                        moving_image = sitk.ReadImage(image_path, sitk.sitkFloat32)
                        registered_image = register_to_atlas(fixed_image, moving_t1_followup, moving_image)
                        output_path = os.path.join(atlas_scans_dir, f"{true_pid}_{cid}_followup_{mri_type_name}.nii.gz")
                        sitk.WriteImage(registered_image, output_path)

                    seg_path = f"{data_dir}/{meta['followup_seg_registered'].replace('./', '')}"
                    moving_seg = sitk.ReadImage(seg_path)
                    registered_seg = register_to_atlas(fixed_image, moving_t1_followup, moving_seg, is_segmentation=True)
                    output_seg_path = os.path.join(atlas_segs_dir, f"{true_pid}_{cid}_followup_seg.nii.gz")
                    sitk.WriteImage(registered_seg, output_seg_path)

            except Exception as e:
                print(f"ERROR processing case {unreliable_pid}/{cid}. True Patient: {get_patient_id_from_path(meta.get('baseline_registered', '')) if meta.get('baseline_registered') else 'Unknown'}. Error: {e}")
            finally:
                pbar.update(1)

    pbar.close()
    print("--- Preprocessing complete. All files saved to atlas_mapping directory. ---")

In [None]:
# Load patient data
patients = load_patients(patients_file)

# STEP 1: Run the one-time preprocessing to register all images to the atlddas
# You can comment this out after it has been run once.
preprocess_and_register_all_scans(patients, data_dir)