In [9]:
# Import your custom utilities (adjust paths as needed)
from utils.data import MaskToSDFd, sdf_to_mask
from utils.monai_transforms import (
    HarmonizeLabelsd,
    AddSpacingTensord,
    FilterAndRelabeld,
    EnsureAllTorchd,
    CropForegroundAxisd,
)
import monai

from monai.transforms import Transform


class ProbeTransform(Transform):
    def __init__(self, message="ProbeTransform called"):
        super().__init__()
        self.message = message

    def __call__(self, data):
        print(self.message)
        return data


import torch
from monai import transforms

ORGAN_NAMES = {
    1: "colon",
    2: "rectum",
    3: "small_bowel",
    4: "stomach",
    5: "liver",
    6: "spleen",
    7: "kidneys",
    9: "pancreas",
    10: "urinary_bladder",
    11: "duodenum",
    12: "gallbladder",
}
NAME_TO_INDEX = {v: k for k, v in ORGAN_NAMES.items()}


def get_conditioning_organs(generation_order, target_organ_index):
    """Get list of organs to condition on"""
    if target_organ_index not in generation_order:
        raise ValueError(f"Target organ {target_organ_index} not in order")
    pos = generation_order.index(target_organ_index)
    return generation_order[:pos]


# ============================================================================
# 3. BUILD PREPROCESSING TRANSFORM
# ============================================================================


def build_inference_transform(config, target_organ="liver", generation_order=None):
    """Simplified transform for single-sample inference"""

    target_organ_index = NAME_TO_INDEX.get(target_organ)
    if generation_order is None:
        generation_order = [5, 6, 7, 9, 3, 1, 2, 4, 10, 11, 12]  # default order

    conditioning_organs = get_conditioning_organs(generation_order, target_organ_index)

    data_keys = ["image", "label", "body_filled_channel"]

    transform = transforms.Compose(
        [
            transforms.LoadImaged(keys=data_keys),
            transforms.EnsureChannelFirstd(keys=data_keys),
            transforms.Spacingd(keys=data_keys, pixdim=config.pixdim, mode="nearest"),
            transforms.Orientationd(keys=data_keys, axcodes=config.orientation),
            ProbeTransform(message="üêî After Orientationd"),
            # transforms.KeepLargestConnectedComponentd(keys=data_keys),
            # ProbeTransform(message="üê∏ After KeepLargestConnectedComponentd"),
            HarmonizeLabelsd(
                keys=["image", "label"],
                kidneys_same_index=True,
                split_colon=True,
                split_colon_method="skeleton",
            ),
            # CropForegroundAxisd(
            #     keys=data_keys,
            #     source_key="image",
            #     axis=2,
            #     margin=5,
            # ),
            # transforms.CropForegroundd(
            #     keys=data_keys, source_key="body_filled_channel", margin=5
            # ),
            # ProbeTransform(message="üê¢ After CropForegroundd"),
            # transforms.Resized(
            #     keys=data_keys, spatial_size=config.roi_size, mode="nearest"
            # ),
            # AddSpacingTensord(ref_key="image"),
            # FilterAndRelabeld(
            #     image_key="image",
            #     label_key="label",
            #     conditioning_organs=conditioning_organs,
            #     target_organ=target_organ_index,
            # ),
            # ProbeTransform(message="üêç After FilterAndRelabeld"),
            # MaskToSDFd(
            #     keys=data_keys,
            #     spacing_key="spacing_tensor",
            #     device=torch.device("cpu"),
            # ),
            # ProbeTransform(message="üêô After MaskToSDFd"),
            # EnsureAllTorchd(print_changes=False),
            # transforms.EnsureTyped(
            #     keys=data_keys + ["spacing_tensor"],
            #     track_meta=True,
            # ),
        ]
    )

    return transform

In [2]:
class InferenceConfig:
    # Model params
    spatial_dims = 3
    in_channels = 1  # image SDF + conditioning
    out_channels = 1  # target organ SDF
    features = [32, 64, 64, 128, 256]  # adjust based on your trained model
    attention_levels = [False, False, False, False, False]
    num_head_channels = [0, 0, 0, 64, 64]
    with_conditioning = True
    cross_attention_dim = 128  # adjust based on your trained model
    volume_embedding_dim = 128

    # Diffusion params
    diffusion_steps = 1000
    ddim_steps = 20
    beta_schedule = "scaled_linear_beta"
    model_mean_type = "sample"  # or "sample"
    guidance_scale = 1.0  # CFG scale
    condition_drop_prob = 0.1

    # Data params
    pixdim = (1.5, 1.5, 2.0)
    orientation = "RAS"
    roi_size = (128, 128, 128)

    # Paths
    checkpoint_path = None
    # checkpoint_path = "/home/yb107/cvpr2025/DukeDiffSeg/outputs/diffunet-binary-iterative/7.2/DiffUnet-binary-iterative_liver_latest_checkpoint_97.pt"
    device = "cuda:1"

In [3]:
test_jsonl_path = "/home/yb107/cvpr2025/DukeDiffSeg/data/mobina_mixed_colon_dataset/mobina_mixed_colon_dataset_with_body_filled_test.jsonl"
import json


def load_jsonl_inference(jsonl_path):
    data = []
    with open(jsonl_path, "r") as f:
        for line in f:
            data.append(json.loads(line))
    return data


test_data = load_jsonl_inference(test_jsonl_path)
test_data = test_data[0]

In [8]:
config = InferenceConfig()

print("üì¶ Preprocessing data...")
transform = build_inference_transform(
    config, "liver", [5, 12, 6, 7, 4, 9, 11, 10, 2, 1, 3]
)

data_dict = {
    "image": test_data["mask"],
    "label": test_data["mask"],
    "body_filled_channel": test_data["body_filled_mask"],
}
print("üèãÔ∏è‚Äç‚ôÄÔ∏è Applying transforms...")
data_dict = transform(data_dict)

üì¶ Preprocessing data...
üèãÔ∏è‚Äç‚ôÄÔ∏è Applying transforms...




üêî After Orientationd
üê¢ After CropForegroundd


In [None]:
colon_101 = data_dict["label"].clone()
colon_101[colon_101 == 101] = 0
colon_102 = data_dict["label"].clone()
colon_102[colon_102 == 102] = 0
colon_103 = data_dict["label"].clone()
colon_103[colon_103 == 103] = 0

In [7]:
label = data_dict["label"]
monai.transforms.SaveImage(
    output_dir="tmp/",
    output_postfix="_colon_101",
    separate_folder=False,
)(colon_101)

2025-11-07 14:58:59,841 INFO image_writer.py:197 - writing: tmp/Patient_00074_Study_78614_Series_03__colon_101.nii.gz


metatensor([[[[0., 0., 0.,  ..., 0., 0., 0.],
          [0., 0., 0.,  ..., 0., 0., 0.],
          [0., 0., 0.,  ..., 0., 0., 0.],
          ...,
          [0., 0., 0.,  ..., 0., 0., 0.],
          [0., 0., 0.,  ..., 0., 0., 0.],
          [0., 0., 0.,  ..., 0., 0., 0.]],

         [[0., 0., 0.,  ..., 0., 0., 0.],
          [0., 0., 0.,  ..., 0., 0., 0.],
          [0., 0., 0.,  ..., 0., 0., 0.],
          ...,
          [0., 0., 0.,  ..., 0., 0., 0.],
          [0., 0., 0.,  ..., 0., 0., 0.],
          [0., 0., 0.,  ..., 0., 0., 0.]],

         [[0., 0., 0.,  ..., 0., 0., 0.],
          [0., 0., 0.,  ..., 0., 0., 0.],
          [0., 0., 0.,  ..., 0., 0., 0.],
          ...,
          [0., 0., 0.,  ..., 0., 0., 0.],
          [0., 0., 0.,  ..., 0., 0., 0.],
          [0., 0., 0.,  ..., 0., 0., 0.]],

         ...,

         [[0., 0., 0.,  ..., 0., 0., 0.],
          [0., 0., 0.,  ..., 0., 0., 0.],
          [0., 0., 0.,  ..., 0., 0., 0.],
          ...,
          [0., 0., 0.,  ..., 0., 