# Demo

In this section, we perform inference on a subject from the CNP cohort.
All preprocessing steps have already been completed, and the corresponding transformations have been saved along with the diffusion-derived parameters (FA, Trace, Sphericity, and Maximum Eigenvalue).

In [None]:
!unzip sub-10217.zip -d ./

## Libraries and packages

In [2]:
import numpy as np
import nibabel as nib
import os
from tqdm.auto import tqdm
from joblib import Parallel, delayed
from glob import glob
import json
import shutil
import subprocess
from multiprocessing import Process
from copy import deepcopy

join = os.path.join

# Paths to model weights

In [3]:
# depending on where you put the data and weights, change the paths here
weight_path = "/hdd/nnunet_dataset/Parcel_final/nnunet"
trainer_name = "nnUNetTrainer__nnUNetPlans__3d_fullres"

# depending on the path of Slicer in your system, change the path here
ResampleScalarVectorDWIVolume = "/home/say26747/Downloads/Slicer-5.2.2-linux-amd64/Slicer --launch /home/say26747/Downloads/Slicer-5.2.2-linux-amd64/lib/Slicer-5.2/cli-modules/ResampleScalarVectorDWIVolume"

# Making the data as the nnUNet needs it as input

In [None]:
root_path = os.getcwd()
folder_to_predict = join(root_path, "tmp_to_pred_T+F+S+Max")
os.makedirs(folder_to_predict, exist_ok=True)
patient_id = "sub-10217"
fa_path = join(root_path, patient_id, "images", "FA_ras_MNI_conform.nii.gz")
trace_path = join(
    root_path, patient_id, "images", "trace_from_eigs_ras_MNI_conform.nii.gz"
)
sphericity_path = join(
    root_path, patient_id, "images", "sphericity_ras_MNI_conform.nii.gz"
)
max_eigenvalue_path = join(
    root_path, patient_id, "images", "maxEig_ras_MNI_conform.nii.gz"
)
# Copy the files from the results directory to the new folder
shutil.copy(fa_path, join(folder_to_predict, f"{patient_id}_0000.nii.gz"))
shutil.copy(trace_path, join(folder_to_predict, f"{patient_id}_0001.nii.gz"))
shutil.copy(sphericity_path, join(folder_to_predict, f"{patient_id}_0002.nii.gz"))
shutil.copy(max_eigenvalue_path, join(folder_to_predict, f"{patient_id}_0003.nii.gz"))

In [None]:
input_path = folder_to_predict
output_path = join(root_path, "tmp_pred_7_classes")
os.makedirs(output_path, exist_ok=True)
# the model path should the weights for the seven class prediction
model_path = join(
    weight_path,
    "Dataset630_dMRI_FA_trace_sphericity_maxEig_V3_7_classes",
    trainer_name,
)
seven_class_command = f"nnUNetv2_predict_from_modelfolder -i {input_path} -o {output_path} -m {model_path} -f 0 --disable_progress_bar"
subprocess.run(seven_class_command, shell=True, check=True)

# moving the predicted 7 class with other diffusion derived parameters

In [None]:
seven_class_path = join(output_path, f"{patient_id}.nii.gz")
shutil.copy(seven_class_path, join(folder_to_predict, f"{patient_id}_0004.nii.gz"))

# Predict the second stage

In [None]:
def run_subprocess(command):
    subprocess.run(command, shell=True, check=True)


# check that teh output path and model path are correctly ordered for each of the 5 models
output_paths = ["center", "left", "right", "right_small", "left_small"]
model_paths = [
    "Dataset631_dMRI_nnUNet_FA_trace_sphericity_maxEig_V3_center_classes",
    "Dataset632_dMRI_nnUNet_FA_trace_sphericity_maxEig_V3_left_classes",
    "Dataset634_dMRI_nnUNet_FA_trace_sphericity_maxEig_V3_right_classes",
    "Dataset635_dMRI_nnUNet_FA_trace_sphericity_maxEig_V3_right_small_classes",
    "Dataset633_dMRI_nnUNet_FA_trace_sphericity_maxEig_V3_left_small_classes",
]

for output_path, model_path in zip(output_paths, model_paths):
    output_path = join(root_path, f"tmp_{output_path}")
    os.makedirs(output_path, exist_ok=True)
    model_path = join(
        weight_path,
        model_path,
        trainer_name,
    )
    command = f"nnUNetv2_predict_from_modelfolder -i {input_path} -o {output_path} -m {model_path} -f 0 --c --disable_progress_bar"
    print(output_path)
    run_subprocess(command)

# Fusion and post processing

In [None]:
# folder that should be used for the prediction
seven_class_paths = sorted(glob(join(root_path, "tmp_pred_7_classes", "*.nii.gz")))
print(f"Number of files in tmp_pred_7_classes: {len(seven_class_paths)}")
center_paths = sorted(glob(join(root_path, "tmp_center", "*.nii.gz")))
print(f"Number of files in center: {len(center_paths)}")
left_paths = sorted(glob(join(root_path, "tmp_left", "*.nii.gz")))
print(f"Number of files in left: {len(left_paths)}")
right_paths = sorted(glob(join(root_path, "tmp_right", "*.nii.gz")))
print(f"Number of files in right: {len(right_paths)}")
right_small_paths = sorted(glob(join(root_path, "tmp_right_small", "*.nii.gz")))
print(f"Number of files in right_small: {len(right_small_paths)}")
left_small_paths = sorted(glob(join(root_path, "tmp_left_small", "*.nii.gz")))
print(f"Number of files in left_small: {len(left_small_paths)}")
zipped_objects = zip(
    seven_class_paths,
    center_paths,
    left_paths,
    right_paths,
    right_small_paths,
    left_small_paths,
)

In [9]:
mappings = {}
sets = set()
mapping_root = join(root_path, "..", "utils", "mappings")
for map in sorted(glob(join(mapping_root, "*.json"))):
    with open(map, "r") as f:
        mapping = json.load(f)
    mappings[os.path.basename(map).split(".")[0]] = mapping
    for value in mapping.values():
        sets.add(value)

print(f"Number of unique values in all mappings: {len(sets)}")

with open(
    join(root_path, "..", "utils", "wmparc_values_to_sequence.json"),
    "r",
) as f:
    wmparc_values_to_sequence = json.load(f)

sequence_to_wmparc = {int(v): int(k) for k, v in wmparc_values_to_sequence.items()}

mapping_dict = {
    "seven_mapping": mappings["seven_mapping"],
    "center_mapping": mappings["center_mapping"],
    "left_mapping": mappings["left_mapping"],
    "right_mapping": mappings["right_mapping"],
    "right_small_mapping": mappings["left_small_mapping"],
    "left_small_mapping": mappings["right_small_mapping"],
}

results_dir = join(root_path, "tmp_prediction")
os.makedirs(results_dir, exist_ok=True)

Number of unique values in all mappings: 101


In [10]:
def process_and_save(
    seven_path, center_path, left_path, right_path, left_small_path, right_small_path
):
    if os.path.exists(join(results_dir, os.path.basename(center_path))):
        print(f"File {os.path.basename(center_path)} already exists. Skipping...")
        return
    seven_img = nib.load(seven_path)
    center_img = nib.load(center_path)
    left_img = nib.load(left_path)
    right_img = nib.load(right_path)
    left_small_img = nib.load(left_small_path)
    right_small_img = nib.load(right_small_path)

    seven_data = seven_img.get_fdata()
    center_data = center_img.get_fdata()
    left_data = left_img.get_fdata()
    right_data = right_img.get_fdata()
    left_small_data = left_small_img.get_fdata()
    right_small_data = right_small_img.get_fdata()

    # Create a new array to hold the combined data
    combined_data = np.zeros(center_data.shape, dtype=np.int32)

    # Function to map the labels and update the combined data
    def process_data(data, mapping_name):
        unique_labels = np.unique(data)
        # sets = set()
        for label in unique_labels:
            label = int(label)
            if str(label) in mapping_dict[mapping_name]:
                value = wmparc_values_to_sequence[
                    str(mapping_dict[mapping_name][str(label)])
                ]
                # sets.add(value)
                combined_data[data == label] = value
        # return sets

    process_data(seven_data, "seven_mapping")
    process_data(center_data, "center_mapping")
    process_data(left_data, "left_mapping")
    process_data(right_data, "right_mapping")
    process_data(left_small_data, "left_small_mapping")
    process_data(right_small_data, "right_small_mapping")

    # transform the prediction back to wmparc values
    mapped = np.empty_like(combined_data, dtype=np.int32)
    for k, v in sequence_to_wmparc.items():
        mapped[combined_data == k] = v

    # Create a new NIfTI image with the combined data
    combined_img = nib.Nifti1Image(
        mapped, center_img.affine, header=center_img.header.set_data_dtype(np.int32)
    )
    output_filename = join(results_dir, os.path.basename(center_path))

    nib.save(combined_img, output_filename)

In [11]:
process_and_save(
    seven_class_paths[0],
    center_paths[0],
    left_paths[0],
    right_paths[0],
    right_small_paths[0],
    left_small_paths[0],
)

# Moving the data to their original state (reverse transform)

In [12]:
conformed_pred_path = join(root_path, "tmp_prediction")
normal_pred_path = root_path

In [13]:
def worker_MNI_to_RAS(patient_id):
    conformed_pred = join(conformed_pred_path, f"{patient_id}.nii.gz")
    normal_pred = join(normal_pred_path, f"{patient_id}.nii.gz")
    ResampleScalarVectorDWIVolume_command = ResampleScalarVectorDWIVolume.split() + [
        "--interpolation",
        "nn",
        "--Reference",
        join(
            root_path,
            patient_id,
            "labels",
            "brain_mask_ras.nii.gz",
        ),
        "--transformationFile",
        join(
            root_path,
            patient_id,
            "transformations",
            "b0_to_MNI_Brain_Inverse.h5",
        ),
        conformed_pred,
        normal_pred,
    ]
    subprocess.run(ResampleScalarVectorDWIVolume_command)

In [None]:
worker_MNI_to_RAS("sub-10217")
shutil.move(
    join(root_path, "sub-10217.nii.gz"), join(root_path, "sub-10217-wmparc.nii.gz")
)

In [15]:
# delete all the temporary folders
folders_to_delete = glob(join(root_path, "tmp*"))
for folder in folders_to_delete:
    shutil.rmtree(folder)

# Use freeview to visualize the output

In [None]:
!freeview -v sub-10217/images/FA_ras.nii.gz -v sub-10217-wmparc.nii.gz:colormap=lut