# nnU-Net Pipeline Walkthrough

This notebook replicates the end-to-end nnU-Net pipeline (data preparation, training, inference) in three separate stages.
Update the configuration cell, then execute the subsequent sections sequentially.

## Imports

In [None]:

import json
import os
from dataclasses import dataclass, field
from pathlib import Path
from typing import Dict, Iterable, List, Optional, Sequence, Tuple

import numpy as np
import SimpleITK as sitk
from skimage import io

## Configuration
Edit the fields of `PipelineConfig` or override `cfg` attributes after instantiation to match your setup.

In [None]:

@dataclass
class PipelineConfig:
    data_root: Path = field(default_factory=lambda: Path("public_leaderboard_data"))
    dataset_id: int = 500
    dataset_name: str = "AbdominalCTMultiOrgan"
    nnunet_raw: Path = field(default_factory=lambda: Path("./nnUNet_raw"))
    nnunet_preprocessed: Path = field(default_factory=lambda: Path("./nnUNet_preprocessed"))
    nnunet_results: Path = field(default_factory=lambda: Path("./nnUNet_results"))
    configurations: Sequence[str] = ("3d_fullres",)
    trainer_class: str = "nnUNetTrainer"
    plans_identifier: str = "nnUNetPlans"
    fold: str = "0"
    device: str = "cuda"
    num_gpus: int = 1
    num_processes_fingerprint: int = 8
    num_processes_preprocess: int = 8
    prediction_output: Optional[Path] = None
    checkpoint_name: str = "checkpoint_final.pth"
    planner_class: str = "nnUNetPlannerResEncM"
    gpu_memory_target: Optional[float] = None
    preprocessor_class: str = "DefaultPreprocessor"
    verify_dataset: bool = False
    skip_conversion: bool = False
    skip_preprocessing: bool = False
    skip_training: bool = False
    skip_validation_inference: bool = False
    skip_test_inference: bool = False
    export_test_pngs: bool = True
    png_output_root: Optional[Path] = None
    overwrite: bool = False
    save_probabilities: bool = False
    export_validation_probabilities: bool = False
    bounding_box_prompts: Optional[Path] = None
    only_configuration: Optional[str] = None
    log_to_stdout: bool = True

    def clone(self) -> "PipelineConfig":
        return PipelineConfig(**self.__dict__)


cfg = PipelineConfig()

if cfg.prediction_output is None:
    cfg.prediction_output = Path("./notebook_predictions")
else:
    cfg.prediction_output = Path(cfg.prediction_output)
cfg.prediction_output.mkdir(parents=True, exist_ok=True)

## Utility Functions & Dataset Helpers

In [None]:

from nnunetv2.dataset_conversion.generate_dataset_json import generate_dataset_json

def configure_environment(cfg: PipelineConfig) -> None:
    os.environ.setdefault("nnUNet_raw", str(cfg.nnunet_raw.resolve()))
    os.environ.setdefault("nnUNet_preprocessed", str(cfg.nnunet_preprocessed.resolve()))
    os.environ.setdefault("nnUNet_results", str(cfg.nnunet_results.resolve()))
    for root in (cfg.nnunet_raw, cfg.nnunet_preprocessed, cfg.nnunet_results):
        Path(root).mkdir(parents=True, exist_ok=True)


def ensure_dependencies() -> None:
    try:
        import torch  # noqa: F401
    except ImportError as exc:
        raise RuntimeError("PyTorch is required to run the pipeline. Install project dependencies first.") from exc


def parse_spacing_map(spacing_file: Path) -> Dict[str, Tuple[float, float, float]]:
    if not spacing_file.exists():
        raise FileNotFoundError(f"Spacing file not found: {spacing_file}")
    mapping: Dict[str, Tuple[float, float, float]] = {}
    with spacing_file.open("r") as f:
        for raw_line in f:
            line = raw_line.strip()
            if not line or line.startswith("#"):
                continue
            if ":" not in line:
                continue
            key, value = line.split(":", 1)
            case_id = key.strip().zfill(2)
            spacing = eval(value.strip(), {"__builtins__": {}})
            if not isinstance(spacing, (list, tuple)) or len(spacing) != 3:
                raise ValueError(f"Unexpected spacing entry for case {case_id}: {value}")
            mapping[case_id] = tuple(float(v) for v in spacing)
    return mapping


def sorted_slice_paths(case_folder: Path) -> List[Path]:
    slices = sorted(case_folder.glob("*.png"), key=lambda p: int(p.stem))
    if not slices:
        raise FileNotFoundError(f"No PNG slices found in {case_folder}")
    return slices


def load_stack(slice_paths: Sequence[Path]) -> np.ndarray:
    stack = [io.imread(str(p)) for p in slice_paths]
    return np.stack(stack, axis=0)


def write_nifti(volume: np.ndarray, spacing: Tuple[float, float, float], output_path: Path, dtype: np.dtype) -> None:
    img = sitk.GetImageFromArray(volume.astype(dtype, copy=False))
    img.SetSpacing(tuple(float(v) for v in spacing))
    img.SetOrigin((0.0, 0.0, 0.0))
    img.SetDirection((1.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 1.0))
    output_path.parent.mkdir(parents=True, exist_ok=True)
    sitk.WriteImage(img, str(output_path))


def convert_split_to_nnunet(
    case_ids: Iterable[str],
    image_root: Path,
    label_root: Optional[Path],
    output_images: Path,
    output_labels: Optional[Path],
    spacing_map: Dict[str, Tuple[float, float, float]],
    prefix: str,
    overwrite: bool,
) -> Tuple[List[str], Dict[str, str]]:
    case_identifiers: List[str] = []
    case_mapping: Dict[str, str] = {}
    for case_id in sorted(case_ids, key=lambda x: int(x)):
        image_case_dir = image_root / case_id
        label_case_dir = label_root / case_id if label_root is not None else None
        if not image_case_dir.is_dir():
            raise FileNotFoundError(f"Missing image folder for case {case_id}: {image_case_dir}")
        spacing = spacing_map.get(case_id)
        if spacing is None:
            raise KeyError(f"No spacing metadata for case {case_id} in spacing file")
        case_name = f"{prefix}_{case_id.zfill(3)}"
        image_output_path = output_images / f"{case_name}_0000.nii.gz"
        if image_output_path.exists() and not overwrite:
            case_identifiers.append(case_name)
            case_mapping[case_name] = str(case_id).zfill(2)
            continue
        slices = sorted_slice_paths(image_case_dir)
        volume = load_stack(slices).astype(np.int16, copy=False)
        write_nifti(volume, spacing, image_output_path, np.int16)
        if label_case_dir is not None:
            if output_labels is None:
                raise ValueError("Label root provided but output label directory missing.")
            label_slices = sorted_slice_paths(label_case_dir)
            if len(label_slices) != len(slices):
                raise ValueError(
                    f"Mismatched slice count for case {case_id}: {len(slices)} images vs {len(label_slices)} labels"
                )
            label_volume = load_stack(label_slices).astype(np.uint8, copy=False)
            label_output_path = output_labels / f"{case_name}.nii.gz"
            write_nifti(label_volume, spacing, label_output_path, np.uint8)
        case_identifiers.append(case_name)
        case_mapping[case_name] = str(case_id).zfill(2)
    return case_identifiers, case_mapping


def parse_bbox_prompts(bbox_file: Optional[Path], output_json: Path) -> None:
    if not bbox_file or not bbox_file.exists():
        return
    prompts: Dict[str, Dict[str, Dict[str, Sequence[int]]]] = {}
    with bbox_file.open("r") as f:
        for raw_line in f:
            line = raw_line.strip()
            if not line:
                continue
            if not line.startswith("<") or ">:" not in line:
                continue
            key, value = line.split(">:")
            triplet = key.strip("<>").split(",")
            if len(triplet) != 3:
                continue
            case_id = triplet[0].strip().zfill(2)
            slice_idx = triplet[1].strip()
            organ_idx = triplet[2].strip()
            coords = eval(value.strip(), {"__builtins__": {}})
            prompts.setdefault(case_id, {}).setdefault(slice_idx, {})[organ_idx] = coords
    output_json.parent.mkdir(parents=True, exist_ok=True)
    with output_json.open("w") as f:
        json.dump(prompts, f, indent=2)


def generate_dataset_json_file(
    dataset_dir: Path,
    num_training_cases: int,
    labels: Dict[str, int],
    dataset_name: str,
    metadata: Dict[str, object],
) -> None:
    generate_dataset_json(
        str(dataset_dir),
        channel_names={0: "CT"},
        labels=labels,
        num_training_cases=num_training_cases,
        file_ending=".nii.gz",
        dataset_name=dataset_name,
        **metadata,
    )


def prepare_raw_dataset(cfg: PipelineConfig, dataset_dir: Path) -> Tuple[Dict[str, List[str]], Dict[str, str]]:
    spacing_map = parse_spacing_map(cfg.data_root / "spacing_mm.txt")
    train_ids = [p.name for p in (cfg.data_root / "train_images").iterdir() if p.is_dir()]
    val_ids = [p.name for p in (cfg.data_root / "val_images").iterdir() if p.is_dir()]
    test_ids = [p.name for p in (cfg.data_root / "test1_images").iterdir() if p.is_dir()]

    images_tr = dataset_dir / "imagesTr"
    labels_tr = dataset_dir / "labelsTr"
    images_ts = dataset_dir / "imagesTs"
    images_tr.mkdir(parents=True, exist_ok=True)
    labels_tr.mkdir(parents=True, exist_ok=True)
    images_ts.mkdir(parents=True, exist_ok=True)

    train_cases, train_map = convert_split_to_nnunet(
        train_ids,
        cfg.data_root / "train_images",
        cfg.data_root / "train_labels",
        images_tr,
        labels_tr,
        spacing_map,
        prefix="ct",
        overwrite=cfg.overwrite,
    )
    val_cases, val_map = convert_split_to_nnunet(
        val_ids,
        cfg.data_root / "val_images",
        cfg.data_root / "val_labels",
        images_tr,
        labels_tr,
        spacing_map,
        prefix="ct",
        overwrite=cfg.overwrite,
    )
    test_cases, test_map = convert_split_to_nnunet(
        test_ids,
        cfg.data_root / "test1_images",
        None,
        images_ts,
        None,
        spacing_map,
        prefix="ct",
        overwrite=cfg.overwrite,
    )

    metadata = {
        "training_cases": train_cases,
        "validation_cases": val_cases,
        "test_cases": test_cases,
        "spacing_file": str((cfg.data_root / "spacing_mm.txt").resolve()),
        "case_folder_map": {**train_map, **val_map, **test_map},
    }

    labels = {"background": 0}
    for organ_idx in range(1, 13):
        labels[f"organ_{organ_idx:02d}"] = organ_idx

    generate_dataset_json_file(
        dataset_dir=dataset_dir,
        num_training_cases=len(train_cases) + len(val_cases),
        labels=labels,
        dataset_name=dataset_dir.name,
        metadata=metadata,
    )

    splits = [{"train": train_cases, "val": val_cases}]
    splits_file = dataset_dir / "splits_final.json"
    if not splits_file.exists() or cfg.overwrite:
        with splits_file.open("w") as f:
            json.dump(splits, f, indent=2)

    if cfg.bounding_box_prompts:
        parse_bbox_prompts(cfg.bounding_box_prompts, dataset_dir / "test_bboxes.json")

    return {"train": train_cases, "val": val_cases, "test": test_cases}, {**train_map, **val_map, **test_map}

## Training & Inference Utilities

In [None]:

def run_planning_and_preprocessing(cfg: PipelineConfig, configurations: Sequence[str]) -> str:
    from nnunetv2.experiment_planning.plan_and_preprocess_api import (
        extract_fingerprints,
        plan_experiments,
        preprocess,
    )

    dataset_ids = [cfg.dataset_id]
    extract_fingerprints(
        dataset_ids,
        num_processes=cfg.num_processes_fingerprint,
        check_dataset_integrity=cfg.verify_dataset,
        clean=True,
        verbose=True,
    )
    resulting = plan_experiments(
        dataset_ids,
        experiment_planner_class_name=cfg.planner_class,
        preprocess_class_name=cfg.preprocessor_class,
        gpu_memory_target_in_gb=cfg.gpu_memory_target,
    )
    preprocess(
        dataset_ids,
        plans_identifier=resulting or cfg.plans_identifier,
        configurations=tuple(configurations),
        num_processes=tuple([cfg.num_processes_preprocess] * len(configurations)),
        verbose=True,
    )
    return resulting or cfg.plans_identifier


def build_model_output_dir(dataset_name: str, trainer_class: str, plans_identifier: str, configuration: str) -> Path:
    base = Path(os.environ["nnUNet_results"])
    return base / dataset_name / f"{trainer_class}__{plans_identifier}__{configuration}"


def run_training_stage(
    cfg: PipelineConfig,
    dataset_name: str,
    configuration: str,
    plans_identifier: str,
) -> None:
    import torch
    from nnunetv2.run.run_training import run_training

    torch_device = torch.device(cfg.device)
    run_training(
        dataset_name,
        configuration=configuration,
        fold=cfg.fold,
        trainer_class_name=cfg.trainer_class,
        plans_identifier=plans_identifier,
        num_gpus=cfg.num_gpus,
        device=torch_device,
        export_validation_probabilities=cfg.export_validation_probabilities,
    )


def run_inference(
    model_dir: Path,
    fold: str,
    inputs: Sequence[List[str]],
    output_dir: Path,
    device: str,
    checkpoint_name: str,
    save_probabilities: bool,
    overwrite: bool,
) -> None:
    import torch
    from nnunetv2.inference.predict_from_raw_data import nnUNetPredictor

    predictor = nnUNetPredictor(device=torch.device(device))
    predictor.initialize_from_trained_model_folder(str(model_dir), use_folds=(fold,), checkpoint_name=checkpoint_name)
    output_dir.mkdir(parents=True, exist_ok=True)
    predictor.predict_from_files(
        list(inputs),
        str(output_dir),
        save_probabilities=save_probabilities,
        overwrite=overwrite,
    )


def compute_validation_metrics(
    predictions_dir: Path,
    dataset_dir: Path,
    plans_identifier: str,
    output_filename: Optional[Path] = None,
) -> Dict[str, object]:
    from nnunetv2.evaluation.evaluate_predictions import compute_metrics_on_folder2

    dataset_json = dataset_dir / "dataset.json"
    plans_file = Path(os.environ["nnUNet_preprocessed"]) / dataset_dir.name / f"{plans_identifier}.json"
    gt_folder = dataset_dir / "labelsTr"
    return compute_metrics_on_folder2(
        str(gt_folder),
        str(predictions_dir),
        str(dataset_json),
        str(plans_file),
        output_file=str(output_filename) if output_filename else None,
    )


def build_inference_input_lists(image_dir: Path, case_ids: Sequence[str], file_ending: str) -> List[List[str]]:
    inputs: List[List[str]] = []
    for case in case_ids:
        image_file = image_dir / f"{case}_0000{file_ending}"
        if not image_file.exists():
            raise FileNotFoundError(f"Missing input volume for inference: {image_file}")
        inputs.append([str(image_file)])
    return inputs


def convert_nifti_to_png_slices(nifti_file: Path, output_dir: Path) -> None:
    img = sitk.ReadImage(str(nifti_file))
    data = sitk.GetArrayFromImage(img).astype(np.uint8, copy=False)
    output_dir.mkdir(parents=True, exist_ok=True)
    for idx, slice_arr in enumerate(data, start=1):
        slice_path = output_dir / f"{idx}.png"
        io.imsave(str(slice_path), slice_arr, check_contrast=False)


def export_predictions_to_png(
    predictions_dir: Path,
    output_root: Path,
    case_folder_map: Dict[str, str],
) -> None:
    prediction_files = sorted(predictions_dir.glob("*.nii.gz"))
    if not prediction_files:
        return
    output_root.mkdir(parents=True, exist_ok=True)
    for prediction_file in prediction_files:
        filename = prediction_file.name
        identifier = filename[:-7] if filename.endswith(".nii.gz") else prediction_file.stem
        folder_name = case_folder_map.get(identifier, identifier.split("_")[-1])
        folder_name = str(int(folder_name)).zfill(2) if folder_name.isdigit() else folder_name
        convert_nifti_to_png_slices(prediction_file, output_root / folder_name)

## Data Preparation
Run this cell to convert the PNG slices into the nnU-Net raw data structure (unless skipped).

In [None]:

configure_environment(cfg)
ensure_dependencies()

dataset_name = f"Dataset{cfg.dataset_id:03d}_{cfg.dataset_name}"
dataset_dir = Path(os.environ["nnUNet_raw"]) / dataset_name
dataset_dir.mkdir(parents=True, exist_ok=True)

if cfg.skip_conversion and (dataset_dir / "dataset.json").exists():
    with (dataset_dir / "dataset.json").open("r") as f:
        dataset_meta = json.load(f)
    case_splits = {
        "train": dataset_meta.get("training_cases", []),
        "val": dataset_meta.get("validation_cases", []),
        "test": dataset_meta.get("test_cases", []),
    }
    raw_map = dataset_meta.get("case_folder_map", {}) or {}
    case_folder_map = {k: str(v).zfill(2) for k, v in raw_map.items()}
    if cfg.log_to_stdout:
        print("Skipping dataset conversion (dataset.json already present).")
else:
    case_splits, case_folder_map = prepare_raw_dataset(cfg, dataset_dir)
    if cfg.log_to_stdout:
        print(f"Converted dataset stored at {dataset_dir}")

if not case_folder_map:
    case_folder_map = {identifier: identifier.split("_")[-1] for identifier in case_splits.get("test", [])}

active_configurations = list((cfg.only_configuration,) if cfg.only_configuration else cfg.configurations)
case_splits

## Training
Preprocess the dataset and train the requested configurations/folds.

In [None]:

if not cfg.skip_preprocessing:
    plans_identifier = run_planning_and_preprocessing(cfg, active_configurations)
else:
    plans_identifier = cfg.plans_identifier
    if cfg.log_to_stdout:
        print("Skipping planning & preprocessing.")

model_directories: Dict[str, Path] = {}
for configuration in active_configurations:
    model_dir = build_model_output_dir(dataset_name, cfg.trainer_class, plans_identifier, configuration)
    model_directories[configuration] = model_dir
    if cfg.skip_training:
        if cfg.log_to_stdout:
            print(f"Skipping training for configuration {configuration}.")
        continue
    if cfg.log_to_stdout:
        print(f"Starting training for configuration {configuration} (fold {cfg.fold})...")
    run_training_stage(cfg, dataset_name, configuration, plans_identifier)

model_directories

## Inference & Evaluation
Generate validation metrics and export test predictions (optionally as PNG slices).

In [None]:

results: Dict[Tuple[str, str], object] = {}
file_ending = ".nii.gz"

for configuration, model_dir in model_directories.items():
    fold_dir = model_dir / f"fold_{cfg.fold}"
    if not fold_dir.exists():
        raise FileNotFoundError(f"Expected trained fold directory does not exist: {fold_dir}")

    if not cfg.skip_validation_inference and case_splits.get("val"):
        val_inputs = build_inference_input_lists(dataset_dir / "imagesTr", case_splits["val"], file_ending)
        val_output_dir = (
            cfg.prediction_output / configuration / "val"
            if cfg.prediction_output
            else fold_dir / "pipeline_val_predictions"
        )
        run_inference(
            model_dir=model_dir,
            fold=cfg.fold,
            inputs=val_inputs,
            output_dir=val_output_dir,
            device=cfg.device,
            checkpoint_name=cfg.checkpoint_name,
            save_probabilities=cfg.save_probabilities,
            overwrite=cfg.overwrite,
        )
        summary_file = val_output_dir / "summary.json"
        summary = compute_validation_metrics(
            predictions_dir=val_output_dir,
            dataset_dir=dataset_dir,
            plans_identifier=plans_identifier,
            output_filename=summary_file,
        )
        results[(configuration, "validation")] = summary

    if not cfg.skip_test_inference and case_splits.get("test"):
        test_inputs = build_inference_input_lists(dataset_dir / "imagesTs", case_splits["test"], file_ending)
        test_output_dir = (
            cfg.prediction_output / configuration / "test"
            if cfg.prediction_output
            else fold_dir / "pipeline_test_predictions"
        )
        run_inference(
            model_dir=model_dir,
            fold=cfg.fold,
            inputs=test_inputs,
            output_dir=test_output_dir,
            device=cfg.device,
            checkpoint_name=cfg.checkpoint_name,
            save_probabilities=cfg.save_probabilities,
            overwrite=cfg.overwrite,
        )
        if cfg.export_test_pngs:
            png_root = cfg.png_output_root or (cfg.data_root / "test_labels")
            png_root = Path(png_root)
            if len(active_configurations) > 1 and cfg.png_output_root is None:
                png_root = png_root / configuration
            export_predictions_to_png(test_output_dir, png_root, case_folder_map)
            if cfg.log_to_stdout:
                print(f"Test PNG segmentations saved to {png_root}")
        if cfg.bounding_box_prompts:
            bbox_target = test_output_dir / "test_bboxes.json"
            if not bbox_target.exists():
                parse_bbox_prompts(cfg.bounding_box_prompts, bbox_target)
        results[(configuration, "test")] = str(test_output_dir)

results