In [1]:
# ============================================================================
# MINIMAL INFERENCE SETUP FOR JUPYTER NOTEBOOK
# ============================================================================

import torch
import torch.nn as nn
import numpy as np
import logging
from pathlib import Path

import monai
from monai import transforms
from monai.networks.nets import DiffusionModelUNet
from monai.networks.schedulers import DDIMScheduler, DDPMScheduler

# Import your custom utilities (adjust paths as needed)
from utils.data import MaskToSDFd, sdf_to_mask
from utils.monai_transforms import (
    HarmonizeLabelsd,
    AddSpacingTensord,
    FilterAndRelabeld,
    EnsureAllTorchd,
    CropForegroundAxisd,
)

from monai.transforms import Transform


class ProbeTransform(Transform):
    def __init__(self, message="ProbeTransform called"):
        super().__init__()
        self.message = message

    def __call__(self, data):
        print(self.message)
        return data


# ============================================================================
# 1. CONFIGURATION
# ============================================================================
class VolumeSpacingEmbedding(nn.Module):
    def __init__(self, embed_dim=256):
        super().__init__()
        self.mlp = nn.Sequential(
            nn.Linear(4, 64),  # 4 = volume (1) + spacing (3)
            nn.SiLU(),
            nn.Linear(64, 128),
            nn.SiLU(),
            nn.Linear(128, embed_dim),
            nn.SiLU(),
            nn.Linear(embed_dim, embed_dim),
        )

    def forward(self, volume, spacing):
        """
        Args:
            volume: (B,) or (B, 1) - organ volume in ml
            spacing: (B, 3) - [spacing_x, spacing_y, spacing_z] in mm
        Returns:
            embedding: (B, embed_dim)
        """
        if volume.dim() == 1:
            volume = volume.unsqueeze(-1)  # (B, 1)
        # Concatenate volume and spacing
        volume_spacing = torch.cat([volume, spacing], dim=-1)  # (B, 4)
        return self.mlp(volume_spacing)


class InferenceConfig:
    # Model params
    spatial_dims = 3
    in_channels = 1  # image SDF + conditioning
    out_channels = 1  # target organ SDF
    features = [32, 64, 64, 128, 256]  # adjust based on your trained model
    attention_levels = [False, False, False, False, False]
    num_head_channels = [0, 0, 0, 64, 64]
    with_conditioning = True
    cross_attention_dim = 128  # adjust based on your trained model
    volume_embedding_dim = 128

    # Diffusion params
    diffusion_steps = 1000
    ddim_steps = 20
    beta_schedule = "scaled_linear_beta"
    model_mean_type = "sample"  # or "sample"
    guidance_scale = 1.0  # CFG scale
    condition_drop_prob = 0.1

    # Data params
    pixdim = (1.5, 1.5, 2.0)
    orientation = "RAS"
    roi_size = (128, 128, 128)

    # Paths
    checkpoint_path = None
    # checkpoint_path = "/home/yb107/cvpr2025/DukeDiffSeg/outputs/diffunet-binary-iterative/7.2/DiffUnet-binary-iterative_liver_latest_checkpoint_97.pt"
    device = "cuda:1"


config = InferenceConfig()

# ============================================================================
# 2. ORGAN MAPPING (from your script)
# ============================================================================

ORGAN_NAMES = {
    1: "colon",
    2: "rectum",
    3: "small_bowel",
    4: "stomach",
    5: "liver",
    6: "spleen",
    7: "kidneys",
    9: "pancreas",
    10: "urinary_bladder",
    11: "duodenum",
    12: "gallbladder",
}
NAME_TO_INDEX = {v: k for k, v in ORGAN_NAMES.items()}


def get_conditioning_organs(generation_order, target_organ_index):
    """Get list of organs to condition on"""
    if target_organ_index not in generation_order:
        raise ValueError(f"Target organ {target_organ_index} not in order")
    pos = generation_order.index(target_organ_index)
    return generation_order[:pos]


# ============================================================================
# 3. BUILD PREPROCESSING TRANSFORM
# ============================================================================


def build_inference_transform(config, target_organ="liver", generation_order=None):
    """Simplified transform for single-sample inference"""

    target_organ_index = NAME_TO_INDEX.get(target_organ)
    if generation_order is None:
        generation_order = [5, 6, 7, 9, 3, 1, 2, 4, 10, 11, 12]  # default order

    conditioning_organs = get_conditioning_organs(generation_order, target_organ_index)

    data_keys = ["image", "label", "body_filled_channel"]

    transform = transforms.Compose(
        [
            transforms.LoadImaged(keys=data_keys),
            transforms.EnsureChannelFirstd(keys=data_keys),
            transforms.Spacingd(keys=data_keys, pixdim=config.pixdim, mode="nearest"),
            transforms.Orientationd(keys=data_keys, axcodes=config.orientation),
            ProbeTransform(message="üêî After Orientationd"),
            # transforms.KeepLargestConnectedComponentd(keys=data_keys),
            # ProbeTransform(message="üê∏ After KeepLargestConnectedComponentd"),
            HarmonizeLabelsd(keys=["image", "label"], kidneys_same_index=True),
            CropForegroundAxisd(
                keys=data_keys,
                source_key="image",
                axis=2,
                margin=5,
            ),
            transforms.CropForegroundd(
                keys=data_keys, source_key="body_filled_channel", margin=5
            ),
            ProbeTransform(message="üê¢ After CropForegroundd"),
            transforms.Resized(
                keys=data_keys, spatial_size=config.roi_size, mode="nearest"
            ),
            AddSpacingTensord(ref_key="image"),
            FilterAndRelabeld(
                image_key="image",
                label_key="label",
                conditioning_organs=conditioning_organs,
                target_organ=target_organ_index,
            ),
            ProbeTransform(message="üêç After FilterAndRelabeld"),
            MaskToSDFd(
                keys=data_keys,
                spacing_key="spacing_tensor",
                device=torch.device("cpu"),
            ),
            ProbeTransform(message="üêô After MaskToSDFd"),
            EnsureAllTorchd(print_changes=False),
            transforms.EnsureTyped(
                keys=data_keys + ["spacing_tensor"],
                track_meta=True,
            ),
        ]
    )

    return transform


# ============================================================================
# 4. BUILD MODEL
# ============================================================================


def build_model(config):
    """Create and load model"""

    model = DiffusionModelUNet(
        spatial_dims=config.spatial_dims,
        in_channels=config.in_channels
        + config.out_channels,  # concat image during inference
        out_channels=config.out_channels,
        channels=config.features,
        attention_levels=config.attention_levels,
        num_res_blocks=1,
        transformer_num_layers=0,
        num_head_channels=config.num_head_channels,
        with_conditioning=config.with_conditioning,
        cross_attention_dim=config.cross_attention_dim,
    )

    vol_embed = VolumeSpacingEmbedding(embed_dim=config.volume_embedding_dim)

    # Load checkpoint
    checkpoint = torch.load(config.checkpoint_path, map_location="cpu")
    model.load_state_dict(checkpoint["model"])
    vol_embed.load_state_dict(checkpoint["vol_embed_net"])

    model = model.to(config.device)
    vol_embed = vol_embed.to(config.device)
    model.eval()
    vol_embed.eval()

    print(f"‚úì Model loaded from {config.checkpoint_path}")
    print(f"‚úì Parameters: {sum(p.numel() for p in model.parameters())/1e6:.2f}M")

    return model, vol_embed


def build_scheduler(config):
    """Create DDIM scheduler for inference"""

    scheduler = DDIMScheduler(
        num_train_timesteps=config.diffusion_steps,
        beta_start=0.0001,
        beta_end=0.02,
        schedule=config.beta_schedule,
        clip_sample=False,
        prediction_type=config.model_mean_type,
    )
    scheduler.set_timesteps(num_inference_steps=config.ddim_steps)

    return scheduler


@torch.no_grad()
def run_inference(
    model, vol_embed, scheduler, image_sdf, body_filled_sdf, volume, spacing, config
):
    """
    Run DDIM sampling to generate organ mask

    Args:
        model: DiffusionModelUNet
        scheduler: DDIMScheduler
        image_sdf: conditioning image SDF [B, 1, H, W, D]
        config: InferenceConfig

    Returns:
        pred_mask: binary mask [B, 1, H, W, D]
        pred_sdf: signed distance field [B, 1, H, W, D]
    """

    device = config.device
    image_sdf = image_sdf.to(device).float()
    pred = torch.randn_like(image_sdf)

    # Initialize with random noise
    if body_filled_sdf is not None:
        body_filled_sdf = body_filled_sdf.to(device).float()
        image = torch.cat([image_sdf, body_filled_sdf], dim=1)
    else:
        image = image_sdf

    vol_context = vol_embed(volume.to(device), spacing.to(device))

    # Get all timesteps
    all_next_timesteps = torch.cat(
        [scheduler.timesteps[1:], torch.tensor([0], dtype=scheduler.timesteps.dtype)]
    )

    # DDIM sampling loop
    for i, (t, next_t) in enumerate(zip(scheduler.timesteps, all_next_timesteps)):
        pred_before = pred.clone()
        # Concatenate conditioning
        model_input = torch.cat([pred, image], dim=1)

        # Predict
        # t_tensor = torch.full((image.shape[0],), t, device=device).long()
        # model_output = model(x=model_input, timesteps=t_tensor, context=vol_context)
        model_output = model(
            x=model_input,
            timesteps=torch.Tensor((t,)).to(image.device),
            context=vol_embed,
        )

        # Classifier-free guidance (if scale != 1.0)
        if config.guidance_scale != 1.0:
            image_uncond = torch.zeros_like(image)
            model_input_uncond = torch.cat([pred, image_uncond], dim=1)
            uncond_output = model(
                x=model_input_uncond,
                timesteps=torch.Tensor((t,)).to(image.device),
                context=vol_embed,
            )
            model_output = uncond_output + config.guidance_scale * (
                model_output - uncond_output
            )

        # DDIM step
        pred, _ = scheduler.step(model_output, t, pred)

        change = (pred - pred_before).abs().mean().item()
        print(f"Step {i}: t={t}, change={change:.6f}, pred_mean={pred.mean():.4f}")

        if (i + 1) % 5 == 0:
            print(f"  Step {i+1}/{len(scheduler.timesteps)}")

    # Convert SDF to mask
    pred_sdf = pred.clone()
    pred_mask = sdf_to_mask(pred * 10.0)  # scale factor from your training

    return pred_mask, pred_sdf

In [3]:
test_jsonl_path = "/home/yb107/cvpr2025/DukeDiffSeg/data/mobina_mixed_colon_dataset/mobina_mixed_colon_dataset_with_body_filled_test.jsonl"
import json


def load_jsonl_inference(jsonl_path):
    data = []
    with open(jsonl_path, "r") as f:
        for line in f:
            data.append(json.loads(line))
    return data


test_data = load_jsonl_inference(test_jsonl_path)
test_data = test_data[0]

In [4]:
# 1. Preprocess
config = InferenceConfig()
# config.checkpoint_path = "/home/yb107/cvpr2025/DukeDiffSeg/outputs/diffunet-binary-iterative/7.1/checkpoints/final_unet.pth"
config.checkpoint_path = "/home/yb107/cvpr2025/DukeDiffSeg/outputs/diffunet-binary-iterative/7.1/checkpoints/liver/DiffUnet-binary-iterative_liver_latest_checkpoint_2000.pt"
config.in_channels = 2

print("üì¶ Preprocessing data...")
transform = build_inference_transform(
    config, "liver", [5, 12, 6, 7, 4, 9, 11, 10, 2, 1, 3]
)

data_dict = {
    "image": test_data["mask"],
    "label": test_data["mask"],
    "body_filled_channel": test_data["body_filled_mask"],
}
print("üèãÔ∏è‚Äç‚ôÄÔ∏è Applying transforms...")
data_dict = transform(data_dict)

# SAVE data_dict as .pt for caching
# torch.save(data_dict, "tmp/data_dict.pt")

# load data_dict from .pt
data_dict = torch.load("tmp/data_dict.pt", weights_only=False)

üì¶ Preprocessing data...
üèãÔ∏è‚Äç‚ôÄÔ∏è Applying transforms...
üêî After Orientationd
üê¢ After CropForegroundd
üêç After FilterAndRelabeld
üêô After MaskToSDFd


In [5]:
# 2. Build model & scheduler
print("üèóÔ∏è  Building model...")
model, vol_net = build_model(config)
scheduler = build_scheduler(config)

üèóÔ∏è  Building model...
‚úì Model loaded from /home/yb107/cvpr2025/DukeDiffSeg/outputs/diffunet-binary-iterative/7.1/checkpoints/liver/DiffUnet-binary-iterative_liver_latest_checkpoint_2000.pt
‚úì Parameters: 28.98M


In [18]:
image_sdf = data_dict["image"].unsqueeze(0)  # [1, 1, H, W, D]
label_sdf = data_dict["label"].unsqueeze(0)
label_mask = sdf_to_mask(label_sdf)
body_filled_sdf = data_dict["body_filled_channel"].unsqueeze(0)
zero_filled_sdf = torch.zeros_like(body_filled_sdf)
spacing = data_dict["spacing_tensor"].unsqueeze(0)

voxel_volume = spacing.prod(dim=1)
label_volume = label_mask.sum(dim=(1, 2, 3, 4)) * voxel_volume
label_volume / 1000.0  # convert to ml

metatensor([1685.4402])

In [37]:
print("üé® Running DDIM sampling...")
pred_mask, pred_sdf = run_inference(
    model,
    vol_net,
    scheduler,
    image_sdf,
    zero_filled_sdf,
    label_volume + 5.4e3,
    spacing,
    config,
)

üé® Running DDIM sampling...
Step 0: t=950, change=0.009661, pred_mean=-0.0089
Step 1: t=900, change=0.013250, pred_mean=-0.0210
Step 2: t=850, change=0.017306, pred_mean=-0.0369
Step 3: t=800, change=0.021589, pred_mean=-0.0567
Step 4: t=750, change=0.026396, pred_mean=-0.0802
  Step 5/20
Step 5: t=700, change=0.031898, pred_mean=-0.1071
Step 6: t=650, change=0.038110, pred_mean=-0.1368
Step 7: t=600, change=0.044929, pred_mean=-0.1683
Step 8: t=550, change=0.051992, pred_mean=-0.2004
Step 9: t=500, change=0.058836, pred_mean=-0.2321
  Step 10/20
Step 10: t=450, change=0.064867, pred_mean=-0.2623
Step 11: t=400, change=0.069573, pred_mean=-0.2898
Step 12: t=350, change=0.072536, pred_mean=-0.3142
Step 13: t=300, change=0.073518, pred_mean=-0.3349
Step 14: t=250, change=0.072468, pred_mean=-0.3518
  Step 15/20
Step 15: t=200, change=0.069479, pred_mean=-0.3651
Step 16: t=150, change=0.064911, pred_mean=-0.3752
Step 17: t=100, change=0.059776, pred_mean=-0.3828
Step 18: t=50, change=0.

In [38]:
# Calculate pred_mask original volume
pred_mask_volume = pred_mask.sum(dim=(1, 2, 3, 4)).cpu() * voxel_volume
pred_mask_volume / 1000.0  # convert to ml

metatensor([1619.9487])

In [30]:
monai.transforms.SaveImage(
    output_dir="tmp/",
    output_postfix="_pred_71_smaller",
    separate_folder=False,
)(pred_mask.squeeze(0))
# monai.transforms.SaveImage(
#     output_dir="tmp/",
#     output_postfix="_pred_sdf",
#     separate_folder=False,
# )(pred_sdf.squeeze(0))
# monai.transforms.SaveImage(
#     output_dir="tmp/",
#     output_postfix="_img_sdf",
#     separate_folder=False,
# )(label_sdf.squeeze(0))
# monai.transforms.SaveImage(
#     output_dir="tmp/",
#     output_postfix="_body_filled_sdf",
#     separate_folder=False,
# )(body_filled_sdf.squeeze(0))

2025-11-07 14:12:53,489 INFO image_writer.py:197 - writing: tmp/Patient_00074_Study_78614_Series_03__pred_71_smaller.nii.gz


metatensor([[[[0., 0., 0.,  ..., 0., 0., 0.],
          [0., 0., 0.,  ..., 0., 0., 0.],
          [0., 0., 0.,  ..., 0., 0., 0.],
          ...,
          [0., 0., 0.,  ..., 0., 0., 0.],
          [0., 0., 0.,  ..., 0., 0., 0.],
          [0., 0., 0.,  ..., 0., 0., 0.]],

         [[0., 0., 0.,  ..., 0., 0., 0.],
          [0., 0., 0.,  ..., 0., 0., 0.],
          [0., 0., 0.,  ..., 0., 0., 0.],
          ...,
          [0., 0., 0.,  ..., 0., 0., 0.],
          [0., 0., 0.,  ..., 0., 0., 0.],
          [0., 0., 0.,  ..., 0., 0., 0.]],

         [[0., 0., 0.,  ..., 0., 0., 0.],
          [0., 0., 0.,  ..., 0., 0., 0.],
          [0., 0., 0.,  ..., 0., 0., 0.],
          ...,
          [0., 0., 0.,  ..., 0., 0., 0.],
          [0., 0., 0.,  ..., 0., 0., 0.],
          [0., 0., 0.,  ..., 0., 0., 0.]],

         ...,

         [[0., 0., 0.,  ..., 0., 0., 0.],
          [0., 0., 0.,  ..., 0., 0., 0.],
          [0., 0., 0.,  ..., 0., 0., 0.],
          ...,
          [0., 0., 0.,  ..., 0., 