In [None]:
import numpy as np
import nibabel as nib
import os
from tqdm.auto import tqdm
from joblib import Parallel, delayed
from glob import glob
import json
import shutil
import subprocess
from multiprocessing import Process

join = os.path.join

# Preparing as nnunetv2 format
Before proceeding, ensure that your dataset is properly formatted and ready for integration with the nnU-Net framework.
It is assumed that all preprocessing steps have already been completed and that the data is conformed to the required structure.

In [None]:
with open("../splits/CNP_patients.txt", "r") as f:
    dwi_eddy_paths = [line.strip() for line in f]
patient_ids = sorted(dwi_eddy_paths)

In [None]:
root_path = "/hdd/CNP_dataset"

folder_to_predict = join(root_path, "to_pred_T+F+S+Max")
os.makedirs(folder_to_predict, exist_ok=True)
for patient_id in tqdm(
    patient_ids, desc="Copying files to predict", total=len(patient_ids)
):
    fa_path = join(
        root_path, "data", patient_id, "DATA", "images", "FA_ras_MNI_conform.nii.gz"
    )
    if not os.path.exists(fa_path):
        print(f"File {fa_path} does not exist. Skipping...")
        continue
    trace_path = join(
        root_path,
        "data",
        patient_id,
        "DATA",
        "images",
        "trace_from_eigs_ras_MNI_conform.nii.gz",
    )
    sphericity_path = join(
        root_path,
        "data",
        patient_id,
        "DATA",
        "images",
        "sphericity_ras_MNI_conform.nii.gz",
    )
    max_eigenvalue_path = join(
        root_path, "data", patient_id, "DATA", "images", "maxEig_ras_MNI_conform.nii.gz"
    )
    # Copy the files from the results directory to the new folder
    shutil.copy(fa_path, join(folder_to_predict, f"{patient_id}_0000.nii.gz"))
    shutil.copy(trace_path, join(folder_to_predict, f"{patient_id}_0001.nii.gz"))
    shutil.copy(sphericity_path, join(folder_to_predict, f"{patient_id}_0002.nii.gz"))
    shutil.copy(
        max_eigenvalue_path, join(folder_to_predict, f"{patient_id}_0003.nii.gz")
    )

# First stage

In [None]:
input_path = (
    "/hdd/CNP_dataset/to_pred_T+F+S+Max"  # "/hdd/HPC_retest_preprocess/to_predict/"
)
output_path = "/hdd/tmp/DMRI_PRED/CNP_T+F+S+Max_SwinUNETR/seven_class/"
os.makedirs(output_path, exist_ok=True)
model_path = "/hdd/nnunet_dataset/Parcel_final/SwinUNETR/Dataset630_dMRI_FA_trace_sphericity_maxEig_V3_7_classes/SwinUnetrBratsTrainer__nnUNetPlans__3d_fullres/"
seven_class_command = f"nnUNetv2_predict_from_modelfolder -i {input_path} -o {output_path} -m {model_path} -f 0 --disable_progress_bar"
subprocess.run(seven_class_command, shell=True, check=True)

# Moving the data in a place to be ready for the second stage

In [None]:
# the new prediction folder for the subsequent steps
new_prediction_path = (
    "/hdd/tmp/DMRI_PRED/CNP_T+F+S+Max_SwinUNETR/to_predict_with_seven_class"
)
os.makedirs(new_prediction_path, exist_ok=True)
for patient_id in tqdm(
    patient_ids, desc="Copying files to predict", total=len(patient_ids)
):
    fa_path = join(
        root_path, "data", patient_id, "DATA", "images", "FA_ras_MNI_conform.nii.gz"
    )
    trace_path = join(
        root_path,
        "data",
        patient_id,
        "DATA",
        "images",
        "trace_from_eigs_ras_MNI_conform.nii.gz",
    )
    sphericity_path = join(
        root_path,
        "data",
        patient_id,
        "DATA",
        "images",
        "sphericity_ras_MNI_conform.nii.gz",
    )
    max_eigenvalue_path = join(
        root_path, "data", patient_id, "DATA", "images", "maxEig_ras_MNI_conform.nii.gz"
    )
    seven_class_path = join(output_path, f"{patient_id}.nii.gz")
    # Copy the files from the results directory to the new folder
    shutil.copy(fa_path, join(new_prediction_path, f"{patient_id}_0000.nii.gz"))
    shutil.copy(trace_path, join(new_prediction_path, f"{patient_id}_0001.nii.gz"))
    shutil.copy(sphericity_path, join(new_prediction_path, f"{patient_id}_0002.nii.gz"))
    shutil.copy(
        max_eigenvalue_path, join(new_prediction_path, f"{patient_id}_0003.nii.gz")
    )
    shutil.copy(
        seven_class_path, join(new_prediction_path, f"{patient_id}_0004.nii.gz")
    )

# Now if you have enough RAM and VRAM you can run the second stage by multiprocessing if not you can use single processing

## Multiprocess

In [None]:
def run_subprocess(command):
    subprocess.run(command, shell=True, check=True)


input_path = new_prediction_path
output_paths = ["center", "left", "right", "right_small", "left_small"]
model_paths = [
    "Dataset636_dMRI_MedNeXt_FA_trace_sphericity_maxEig_V3_center_classes",
    "Dataset637_dMRI_MedNeXt_FA_trace_sphericity_maxEig_V3_left_classes",
    "Dataset639_dMRI_MedNeXt_FA_trace_sphericity_maxEig_V3_right_classes",
    "Dataset640_dMRI_MedNeXt_FA_trace_sphericity_maxEig_V3_right_small_classes",
    "Dataset638_dMRI_MedNeXt_FA_trace_sphericity_maxEig_V3_left_small_classes",
]
# output_paths=["center","right_small","left_small"]
# model_paths=["Dataset636_dMRI_MedNeXt_FA_trace_sphericity_maxEig_V3_center_classes",
#              "Dataset640_dMRI_MedNeXt_FA_trace_sphericity_maxEig_V3_right_small_classes","Dataset638_dMRI_MedNeXt_FA_trace_sphericity_maxEig_V3_left_small_classes"]
processes = []
for output_path, model_path in zip(output_paths, model_paths):
    output_path = join("/hdd/tmp/DMRI_PRED/CNP_T+F+S+Max_mednext", output_path)
    os.makedirs(output_path, exist_ok=True)
    model_path = join(
        "/hdd/nnunet_dataset/Parcel_final/MedNeXt",
        model_path,
        "MednextBratsTrainer__nnUNetResEncUNetMPlans__3d_fullres",
    )
    command = f"nnUNetv2_predict_from_modelfolder -i {input_path} -o {output_path} -m {model_path} -f 0 --c --disable_progress_bar"
    processes.append(Process(target=run_subprocess, args=(command,)))
    processes[-1].start()
for process in processes:
    process.join()

## Single Process

In [None]:
def run_subprocess(command):
    subprocess.run(command, shell=True, check=True)


input_path = new_prediction_path
output_paths = ["center", "left", "right", "right_small", "left_small"]
model_paths = [
    "Dataset641_dMRI_SwinUNETR_FA_trace_sphericity_maxEig_V3_center_classes",
    "Dataset642_dMRI_SwinUNETR_FA_trace_sphericity_maxEig_V3_left_classes",
    "Dataset644_dMRI_SwinUNETR_FA_trace_sphericity_maxEig_V3_right_classes",
    "Dataset645_dMRI_SwinUNETR_FA_trace_sphericity_maxEig_V3_right_small_classes",
    "Dataset643_dMRI_SwinUNETR_FA_trace_sphericity_maxEig_V3_left_small_classes",
]

# processes = []
for output_path, model_path in zip(output_paths, model_paths):
    output_path = join("/hdd/tmp/DMRI_PRED/CNP_T+F+S+Max_SwinUNETR", output_path)
    os.makedirs(output_path, exist_ok=True)
    model_path = join(
        "/hdd/nnunet_dataset/Parcel_final/SwinUNETR",
        model_path,
        "SwinUnetrBratsTrainer__nnUNetPlans__3d_fullres",
    )
    command = f"nnUNetv2_predict_from_modelfolder -i {input_path} -o {output_path} -m {model_path} -f 0 --c --disable_progress_bar"
    print(output_path)
    run_subprocess(command)

# Managing the predicted folders

In [None]:
# folder that should be used for the prediction
seven_class_paths = sorted(
    glob(join("/hdd/tmp/DMRI_PRED/CNP_T+F+S+Max_SwinUNETR", "seven_class", "*.nii.gz"))
)
print(f"Number of files in seven_class: {len(seven_class_paths)}")
center_paths = sorted(
    glob(join("/hdd/tmp/DMRI_PRED/CNP_T+F+S+Max_SwinUNETR", "center", "*.nii.gz"))
)
print(f"Number of files in center: {len(center_paths)}")
left_paths = sorted(
    glob(join("/hdd/tmp/DMRI_PRED/CNP_T+F+S+Max_SwinUNETR", "left", "*.nii.gz"))
)
print(f"Number of files in left: {len(left_paths)}")
right_paths = sorted(
    glob(join("/hdd/tmp/DMRI_PRED/CNP_T+F+S+Max_SwinUNETR", "right", "*.nii.gz"))
)
print(f"Number of files in right: {len(right_paths)}")
right_small_paths = sorted(
    glob(join("/hdd/tmp/DMRI_PRED/CNP_T+F+S+Max_SwinUNETR", "right_small", "*.nii.gz"))
)
print(f"Number of files in right_small: {len(right_small_paths)}")
left_small_paths = sorted(
    glob(join("/hdd/tmp/DMRI_PRED/CNP_T+F+S+Max_SwinUNETR", "left_small", "*.nii.gz"))
)
print(f"Number of files in left_small: {len(left_small_paths)}")
zipped_objects = zip(
    seven_class_paths,
    center_paths,
    left_paths,
    right_paths,
    right_small_paths,
    left_small_paths,
)

# Getting the mappings files ready

In [None]:
mappings = {}
sets = set()
mapping_root = "/home/say26747/Desktop/git/DDParcel_Yousef/mappings"
for map in sorted(glob(join(mapping_root, "*.json"))):
    with open(map, "r") as f:
        mapping = json.load(f)
    mappings[os.path.basename(map).split(".")[0]] = mapping
    for value in mapping.values():
        sets.add(value)

print(f"Number of unique values in all mappings: {len(sets)}")

with open(
    "/home/say26747/Desktop/git/DDParcel_Yousef/Preprocess_HCP/wmparc_values_to_sequence.json",
    "r",
) as f:
    wmparc_values_to_sequence = json.load(f)

sequence_to_wmparc = {int(v): int(k) for k, v in wmparc_values_to_sequence.items()}

mapping_dict = {
    "seven_mapping": mappings["seven_mapping"],
    "center_mapping": mappings["center_mapping"],
    "left_mapping": mappings["left_mapping"],
    "right_mapping": mappings["right_mapping"],
    "right_small_mapping": mappings["left_small_mapping"],
    "left_small_mapping": mappings["right_small_mapping"],
}

results_dir = "/hdd/CNP_dataset/predicted_SwinUNETR/"
os.makedirs(results_dir, exist_ok=True)

# Fusion and post processing

In [None]:
def process_and_save(
    seven_path, center_path, left_path, right_path, left_small_path, right_small_path
):
    if os.path.exists(join(results_dir, os.path.basename(center_path))):
        print(f"File {os.path.basename(center_path)} already exists. Skipping...")
        return
    seven_img = nib.load(seven_path)
    center_img = nib.load(center_path)
    left_img = nib.load(left_path)
    right_img = nib.load(right_path)
    left_small_img = nib.load(left_small_path)
    right_small_img = nib.load(right_small_path)

    seven_data = seven_img.get_fdata()
    center_data = center_img.get_fdata()
    left_data = left_img.get_fdata()
    right_data = right_img.get_fdata()
    left_small_data = left_small_img.get_fdata()
    right_small_data = right_small_img.get_fdata()

    # Create a new array to hold the combined data
    combined_data = np.zeros(center_data.shape, dtype=np.int32)

    # Function to map the labels and update the combined data
    def process_data(data, mapping_name):
        unique_labels = np.unique(data)
        # sets = set()
        for label in unique_labels:
            label = int(label)
            if str(label) in mapping_dict[mapping_name]:
                value = wmparc_values_to_sequence[
                    str(mapping_dict[mapping_name][str(label)])
                ]
                # sets.add(value)
                combined_data[data == label] = value
        # return sets

    process_data(seven_data, "seven_mapping")
    process_data(center_data, "center_mapping")
    process_data(left_data, "left_mapping")
    process_data(right_data, "right_mapping")
    process_data(left_small_data, "left_small_mapping")
    process_data(right_small_data, "right_small_mapping")

    # transform the prediction back to wmparc values
    mapped = np.empty_like(combined_data, dtype=np.int32)
    for k, v in sequence_to_wmparc.items():
        mapped[combined_data == k] = v

    # Create a new NIfTI image with the combined data
    combined_img = nib.Nifti1Image(
        mapped, center_img.affine, header=center_img.header.set_data_dtype(np.int32)
    )
    output_filename = join(results_dir, os.path.basename(center_path))

    nib.save(combined_img, output_filename)

In [None]:
results = Parallel(n_jobs=-1)(
    delayed(process_and_save)(
        seven_path,
        center_path,
        left_path,
        right_path,
        left_small_path,
        right_small_path,
    )
    for seven_path, center_path, left_path, right_path, left_small_path, right_small_path in tqdm(
        zipped_objects, desc="Processing files", total=len(seven_class_paths)
    )
)

# Moving the data to their original state (reverse transform)

In [None]:
ResampleScalarVectorDWIVolume = "/home/say26747/Downloads/Slicer-5.2.2-linux-amd64/Slicer --launch /home/say26747/Downloads/Slicer-5.2.2-linux-amd64/lib/Slicer-5.2/cli-modules/ResampleScalarVectorDWIVolume"

In [None]:
conformed_pred_path = "/hdd/CNP_dataset/predicted_SwinUNETR_T+F"
normal_pred_path = "/hdd/CNP_dataset/predicted_SwinUNETR_T+F_normal"
os.makedirs(normal_pred_path, exist_ok=True)

In [None]:
def worker_MNI_to_RAS(patient_id):
    conformed_pred = join(conformed_pred_path, f"{patient_id}.nii.gz")
    normal_pred = join(normal_pred_path, f"{patient_id}.nii.gz")
    ResampleScalarVectorDWIVolume_command = ResampleScalarVectorDWIVolume.split() + [
        "--interpolation",
        "nn",
        "--Reference",
        join(
            "/hdd/CNP_dataset/data/",
            patient_id,
            "DATA",
            "labels",
            "brain_mask_ras.nii.gz",
        ),
        "--transformationFile",
        join(
            "/hdd/CNP_dataset/data/",
            patient_id,
            "DATA",
            "transformations",
            "b0_to_MNI_Brain_Inverse.h5",
        ),
        conformed_pred,
        normal_pred,
    ]
    subprocess.run(ResampleScalarVectorDWIVolume_command)

In [None]:
results = Parallel(n_jobs=-1)(
    delayed(worker_MNI_to_RAS)(patient_id)
    for patient_id in tqdm(patient_ids, desc="Processing files", total=len(patient_ids))
)