In [2]:
import torch
import torch.nn as nn
from typing import Sequence
import math
from monai.networks.nets import DiffusionModelUNet


class FiLMLayer(nn.Module):
    """Applies Feature-wise Linear Modulation: out = gamma * x + beta"""

    def forward(
        self, x: torch.Tensor, gamma: torch.Tensor, beta: torch.Tensor
    ) -> torch.Tensor:
        """
        Args:
            x: (B, C, D, H, W) feature maps
            gamma: (B, C) scale parameters
            beta: (B, C) shift parameters
        Returns:
            modulated features (B, C, D, H, W)
        """
        # Reshape for broadcasting: (B, C) -> (B, C, 1, 1, 1)
        gamma = gamma.view(gamma.shape[0], gamma.shape[1], 1, 1, 1)
        beta = beta.view(beta.shape[0], beta.shape[1], 1, 1, 1)

        return gamma * x + beta


class FiLMAdapter(nn.Module):
    """
    Generates FiLM parameters (gamma, beta) from volume and spacing information.
    Outputs modulation parameters for each U-Net resolution level.
    """

    def __init__(
        self,
        unet_channels: Sequence[int] = (32, 64, 64, 64),
        embed_dim: int = 256,
        volume_mean: float = 150.0,
        volume_std: float = 100.0,
        use_log_volume: bool = False,
    ):
        """
        Args:
            unet_channels: Channel dimensions at each U-Net level (should match your U-Net)
            embed_dim: Dimension of the intermediate embedding
            volume_mean: Mean volume for normalization (compute from your dataset)
            volume_std: Std volume for normalization (compute from your dataset)
            use_log_volume: If True, use log normalization instead of standard normalization
        """
        super().__init__()
        self.unet_channels = list(unet_channels)
        self.volume_mean = volume_mean
        self.volume_std = volume_std
        self.use_log_volume = use_log_volume

        # Base embedding network for volume + spacing
        self.embedding_net = nn.Sequential(
            nn.Linear(4, 64),  # 4 = volume (1) + spacing (3)
            nn.SiLU(),
            nn.Linear(64, 128),
            nn.SiLU(),
            nn.Linear(128, embed_dim),
            nn.SiLU(),
        )

        # FiLM parameter generators for each U-Net level
        # Each generates both gamma and beta (2 * channels)
        self.film_generators = nn.ModuleList(
            [
                nn.Sequential(
                    nn.Linear(embed_dim, embed_dim),
                    nn.SiLU(),
                    nn.Linear(embed_dim, 2 * channels),
                )
                for channels in self.unet_channels
            ]
        )

        # Learnable scale factors for each level (helps with initialization)
        self.gamma_scales = nn.ParameterList(
            [
                nn.Parameter(torch.ones(channels) * 0.1)
                for channels in self.unet_channels
            ]
        )
        self.beta_scales = nn.ParameterList(
            [
                nn.Parameter(torch.ones(channels) * 0.1)
                for channels in self.unet_channels
            ]
        )

    def normalize_volume(self, volume: torch.Tensor) -> torch.Tensor:
        """Normalize volume to a reasonable range"""
        if self.use_log_volume:
            # Log normalization (better for volumes with large variance)
            volume = torch.log(volume + 1.0)
            log_mean = torch.log(torch.tensor(self.volume_mean + 1.0))
            log_std = torch.log(torch.tensor(self.volume_std + 1.0))
            return (volume - log_mean) / log_std
        else:
            # Standard normalization
            return (volume - self.volume_mean) / self.volume_std

    def forward(
        self, volume: torch.Tensor, spacing: torch.Tensor
    ) -> list[tuple[torch.Tensor, torch.Tensor]]:
        """
        Args:
            volume: (B,) or (B, 1) - organ volumes in ml
            spacing: (B, 3) - [spacing_x, spacing_y, spacing_z] in mm

        Returns:
            List of (gamma, beta) tuples, one for each U-Net level
            Each gamma and beta has shape (B, C) where C is channels at that level
        """
        # Normalize volume
        if volume.dim() == 1:
            volume = volume.unsqueeze(-1)  # (B, 1)

        volume_normalized = self.normalize_volume(volume)

        # Concatenate volume and spacing
        volume_spacing = torch.cat([volume_normalized, spacing], dim=-1)  # (B, 4)

        # Generate base embedding
        base_embed = self.embedding_net(volume_spacing)  # (B, embed_dim)

        # Generate FiLM parameters for each level
        film_params = []
        for i, (film_gen, gamma_scale, beta_scale) in enumerate(
            zip(self.film_generators, self.gamma_scales, self.beta_scales)
        ):
            params = film_gen(base_embed)  # (B, 2*C)
            gamma_raw, beta_raw = params.chunk(2, dim=-1)  # Each (B, C)

            # Scale and shift to initialize near identity transform
            # gamma starts near 1.0, beta starts near 0.0
            gamma = gamma_raw * gamma_scale + 1.0
            beta = beta_raw * beta_scale

            film_params.append((gamma, beta))

        return film_params


class DiffusionModelUNetFiLM(DiffusionModelUNet):
    """
    DiffusionModelUNet with FiLM (Feature-wise Linear Modulation) conditioning.
    Applies FiLM after each down block, middle block, and up block.
    """

    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)

        # FiLM layers for each resolution level
        # down_blocks + middle + up_blocks
        num_levels = len(self.block_out_channels)

        # FiLM for down blocks
        self.film_down = nn.ModuleList([FiLMLayer() for _ in range(num_levels)])

        # FiLM for middle block
        self.film_mid = FiLMLayer()

        # FiLM for up blocks
        self.film_up = nn.ModuleList([FiLMLayer() for _ in range(num_levels)])

    def forward(
        self,
        x: torch.Tensor,
        timesteps: torch.Tensor,
        film_params: list[tuple[torch.Tensor, torch.Tensor]] | None = None,
        context: torch.Tensor | None = None,
        class_labels: torch.Tensor | None = None,
        down_block_additional_residuals: tuple[torch.Tensor] | None = None,
        mid_block_additional_residual: torch.Tensor | None = None,
    ) -> torch.Tensor:
        """
        Args:
            x: input tensor (N, C, SpatialDims)
            timesteps: timestep tensor (N,)
            film_params: List of (gamma, beta) tuples for FiLM conditioning.
                        Should have length = len(channels) for down blocks.
                        If None, no FiLM conditioning is applied.
            context: context tensor for cross-attention (N, 1, ContextDim)
            class_labels: class labels (N,)
            down_block_additional_residuals: additional residuals for controlnet
            mid_block_additional_residual: additional residual for controlnet
        """
        # 1. Time embedding
        t_emb = get_timestep_embedding(timesteps, self.block_out_channels[0])
        t_emb = t_emb.to(dtype=x.dtype)
        emb = self.time_embed(t_emb)

        # 2. Class embedding
        if self.num_class_embeds is not None:
            if class_labels is None:
                raise ValueError(
                    "class_labels should be provided when num_class_embeds > 0"
                )
            class_emb = self.class_embedding(class_labels)
            class_emb = class_emb.to(dtype=x.dtype)
            emb = emb + class_emb

        # 3. Initial convolution
        h = self.conv_in(x)

        # 4. Down blocks with FiLM
        if context is not None and self.with_conditioning is False:
            raise ValueError(
                "model should have with_conditioning = True if context is provided"
            )

        down_block_res_samples: list[torch.Tensor] = [h]
        for i, downsample_block in enumerate(self.down_blocks):
            h, res_samples = downsample_block(
                hidden_states=h, temb=emb, context=context
            )

            # Apply FiLM modulation after the down block
            if film_params is not None and i < len(film_params):
                gamma, beta = film_params[i]
                h = self.film_down[i](h, gamma, beta)

                # Also apply FiLM to residual connections
                res_samples = [
                    self.film_down[i](res, gamma, beta) for res in res_samples
                ]

            for residual in res_samples:
                down_block_res_samples.append(residual)

        # Additional residuals for ControlNet
        if down_block_additional_residuals is not None:
            new_down_block_res_samples: list[torch.Tensor] = []
            for down_block_res_sample, down_block_additional_residual in zip(
                down_block_res_samples, down_block_additional_residuals
            ):
                down_block_res_sample = (
                    down_block_res_sample + down_block_additional_residual
                )
                new_down_block_res_samples.append(down_block_res_sample)
            down_block_res_samples = new_down_block_res_samples

        # 5. Middle block with FiLM
        h = self.middle_block(hidden_states=h, temb=emb, context=context)

        if film_params is not None:
            # Use the last down block's FiLM params for middle block
            gamma, beta = film_params[-1]
            h = self.film_mid(h, gamma, beta)

        # Additional residual for ControlNet
        if mid_block_additional_residual is not None:
            h = h + mid_block_additional_residual

        # 6. Up blocks with FiLM
        for i, upsample_block in enumerate(self.up_blocks):
            idx: int = -len(upsample_block.resnets)  # type: ignore
            res_samples = down_block_res_samples[idx:]
            down_block_res_samples = down_block_res_samples[:idx]

            h = upsample_block(
                hidden_states=h,
                res_hidden_states_list=res_samples,
                temb=emb,
                context=context,
            )

            # Apply FiLM modulation after up block
            # Use corresponding down block's parameters (mirrored)
            if film_params is not None:
                film_idx = len(self.up_blocks) - 1 - i  # Mirror the indices
                if film_idx < len(film_params):
                    gamma, beta = film_params[film_idx]
                    h = self.film_up[i](h, gamma, beta)

        # 7. Output block
        output: torch.Tensor = self.out(h)

        return output


def get_timestep_embedding(
    timesteps: torch.Tensor, embedding_dim: int, max_period: int = 10000
) -> torch.Tensor:
    """
    Create sinusoidal timestep embeddings following the implementation in Ho et al. "Denoising Diffusion Probabilistic
    Models" https://arxiv.org/abs/2006.11239.

    Args:
        timesteps: a 1-D Tensor of N indices, one per batch element.
        embedding_dim: the dimension of the output.
        max_period: controls the minimum frequency of the embeddings.
    """
    if timesteps.ndim != 1:
        raise ValueError("Timesteps should be a 1d-array")

    half_dim = embedding_dim // 2
    exponent = -math.log(max_period) * torch.arange(
        start=0, end=half_dim, dtype=torch.float32, device=timesteps.device
    )
    freqs = torch.exp(exponent / half_dim)

    args = timesteps[:, None].float() * freqs[None, :]
    embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)

    # zero pad
    if embedding_dim % 2 == 1:
        embedding = torch.nn.functional.pad(embedding, (0, 1, 0, 0))

    return embedding


class FiLMAdapter(nn.Module):
    """
    Generates FiLM parameters (gamma, beta) from volume and spacing information.
    Outputs modulation parameters for each U-Net resolution level.
    """

    def __init__(
        self,
        unet_channels: Sequence[int] = (32, 64, 64, 64),
        embed_dim: int = 256,
        volume_mean: float = 150.0,
        volume_std: float = 100.0,
        use_log_volume: bool = False,
    ):
        """
        Args:
            unet_channels: Channel dimensions at each U-Net level (should match your U-Net)
            embed_dim: Dimension of the intermediate embedding
            volume_mean: Mean volume for normalization (compute from your dataset)
            volume_std: Std volume for normalization (compute from your dataset)
            use_log_volume: If True, use log normalization instead of standard normalization
        """
        super().__init__()
        self.unet_channels = list(unet_channels)
        self.volume_mean = volume_mean
        self.volume_std = volume_std
        self.use_log_volume = use_log_volume

        # Base embedding network for volume + spacing
        self.embedding_net = nn.Sequential(
            nn.Linear(4, 64),  # 4 = volume (1) + spacing (3)
            nn.SiLU(),
            nn.Linear(64, 128),
            nn.SiLU(),
            nn.Linear(128, embed_dim),
            nn.SiLU(),
        )

        # FiLM parameter generators for each U-Net level
        # Each generates both gamma and beta (2 * channels)
        self.film_generators = nn.ModuleList(
            [
                nn.Sequential(
                    nn.Linear(embed_dim, embed_dim),
                    nn.SiLU(),
                    nn.Linear(embed_dim, 2 * channels),
                )
                for channels in self.unet_channels
            ]
        )

        # Learnable scale factors for each level (helps with initialization)
        self.gamma_scales = nn.ParameterList(
            [
                nn.Parameter(torch.ones(channels) * 0.1)
                for channels in self.unet_channels
            ]
        )
        self.beta_scales = nn.ParameterList(
            [
                nn.Parameter(torch.ones(channels) * 0.1)
                for channels in self.unet_channels
            ]
        )

    def normalize_volume(self, volume: torch.Tensor) -> torch.Tensor:
        """Normalize volume to a reasonable range"""
        if self.use_log_volume:
            # Log normalization (better for volumes with large variance)
            volume = torch.log(volume + 1.0)
            log_mean = torch.log(torch.tensor(self.volume_mean + 1.0))
            log_std = torch.log(torch.tensor(self.volume_std + 1.0))
            return (volume - log_mean) / log_std
        else:
            # Standard normalization
            return (volume - self.volume_mean) / self.volume_std

    def forward(
        self, volume: torch.Tensor, spacing: torch.Tensor
    ) -> list[tuple[torch.Tensor, torch.Tensor]]:
        """
        Args:
            volume: (B,) or (B, 1) - organ volumes in ml
            spacing: (B, 3) - [spacing_x, spacing_y, spacing_z] in mm

        Returns:
            List of (gamma, beta) tuples, one for each U-Net level
            Each gamma and beta has shape (B, C) where C is channels at that level
        """
        # Normalize volume
        if volume.dim() == 1:
            volume = volume.unsqueeze(-1)  # (B, 1)

        volume_normalized = self.normalize_volume(volume)

        # Concatenate volume and spacing
        volume_spacing = torch.cat([volume_normalized, spacing], dim=-1)  # (B, 4)

        # Generate base embedding
        base_embed = self.embedding_net(volume_spacing)  # (B, embed_dim)

        # Generate FiLM parameters for each level
        film_params = []
        for i, (film_gen, gamma_scale, beta_scale) in enumerate(
            zip(self.film_generators, self.gamma_scales, self.beta_scales)
        ):
            params = film_gen(base_embed)  # (B, 2*C)
            gamma_raw, beta_raw = params.chunk(2, dim=-1)  # Each (B, C)

            # Scale and shift to initialize near identity transform
            # gamma starts near 1.0, beta starts near 0.0
            gamma = gamma_raw * gamma_scale + 1.0
            beta = beta_raw * beta_scale

            film_params.append((gamma, beta))

        return film_params

In [3]:
# ============================================================================
# MINIMAL INFERENCE SETUP FOR JUPYTER NOTEBOOK
# ============================================================================

import torch
import torch.nn as nn
import numpy as np
import logging
from pathlib import Path

import monai
from monai import transforms
from monai.networks.nets import DiffusionModelUNet
from monai.networks.schedulers import DDIMScheduler, DDPMScheduler

# Import your custom utilities (adjust paths as needed)
from utils.data import MaskToSDFd, sdf_to_mask
from utils.monai_transforms import (
    HarmonizeLabelsd,
    AddSpacingTensord,
    FilterAndRelabeld,
    EnsureAllTorchd,
    CropForegroundAxisd,
)

from monai.transforms import Transform


class ProbeTransform(Transform):
    def __init__(self, message="ProbeTransform called"):
        super().__init__()
        self.message = message

    def __call__(self, data):
        print(self.message)
        return data


# ============================================================================
# 1. CONFIGURATION
# ============================================================================
class VolumeSpacingEmbedding(nn.Module):
    def __init__(self, embed_dim=256):
        super().__init__()
        self.mlp = nn.Sequential(
            nn.Linear(4, 64),  # 4 = volume (1) + spacing (3)
            nn.SiLU(),
            nn.Linear(64, 128),
            nn.SiLU(),
            nn.Linear(128, embed_dim),
            nn.SiLU(),
            nn.Linear(embed_dim, embed_dim),
        )

    def forward(self, volume, spacing):
        """
        Args:
            volume: (B,) or (B, 1) - organ volume in ml
            spacing: (B, 3) - [spacing_x, spacing_y, spacing_z] in mm
        Returns:
            embedding: (B, embed_dim)
        """
        if volume.dim() == 1:
            volume = volume.unsqueeze(-1)  # (B, 1)
        # Concatenate volume and spacing
        volume_spacing = torch.cat([volume, spacing], dim=-1)  # (B, 4)
        return self.mlp(volume_spacing)


class InferenceConfig:
    # Model params
    spatial_dims = 3
    in_channels = 1  # image SDF + conditioning
    out_channels = 1  # target organ SDF
    features = [32, 64, 64, 128, 256]  # adjust based on your trained model
    attention_levels = [False, False, False, False, False]
    num_head_channels = [0, 0, 0, 64, 64]
    with_conditioning = True
    cross_attention_dim = 256  # adjust based on your trained model
    volume_embedding_dim = 128

    # Diffusion params
    diffusion_steps = 1000
    ddim_steps = 20
    beta_schedule = "scaled_linear_beta"
    model_mean_type = "sample"  # or "sample"
    guidance_scale = 1.0  # CFG scale
    condition_drop_prob = 0.1

    # Data params
    pixdim = (1.5, 1.5, 2.0)
    orientation = "RAS"
    roi_size = (128, 128, 128)

    # Paths
    checkpoint_path = None
    # checkpoint_path = "/home/yb107/cvpr2025/DukeDiffSeg/outputs/diffunet-binary-iterative/7.2/DiffUnet-binary-iterative_liver_latest_checkpoint_97.pt"
    device = "cuda:1"


config = InferenceConfig()

# ============================================================================
# 2. ORGAN MAPPING (from your script)
# ============================================================================

ORGAN_NAMES = {
    1: "colon",
    2: "rectum",
    3: "small_bowel",
    4: "stomach",
    5: "liver",
    6: "spleen",
    7: "kidneys",
    9: "pancreas",
    10: "urinary_bladder",
    11: "duodenum",
    12: "gallbladder",
}
NAME_TO_INDEX = {v: k for k, v in ORGAN_NAMES.items()}


def get_conditioning_organs(generation_order, target_organ_index):
    """Get list of organs to condition on"""
    if target_organ_index not in generation_order:
        raise ValueError(f"Target organ {target_organ_index} not in order")
    pos = generation_order.index(target_organ_index)
    return generation_order[:pos]


# ============================================================================
# 3. BUILD PREPROCESSING TRANSFORM
# ============================================================================


def build_inference_transform(config, target_organ="liver", generation_order=None):
    """Simplified transform for single-sample inference"""

    target_organ_index = NAME_TO_INDEX.get(target_organ)
    if generation_order is None:
        generation_order = [5, 6, 7, 9, 3, 1, 2, 4, 10, 11, 12]  # default order

    conditioning_organs = get_conditioning_organs(generation_order, target_organ_index)

    data_keys = ["image", "label", "body_filled_channel"]

    transform = transforms.Compose(
        [
            transforms.LoadImaged(keys=data_keys),
            transforms.EnsureChannelFirstd(keys=data_keys),
            transforms.Spacingd(keys=data_keys, pixdim=config.pixdim, mode="nearest"),
            transforms.Orientationd(keys=data_keys, axcodes=config.orientation),
            ProbeTransform(message="üêî After Orientationd"),
            # transforms.KeepLargestConnectedComponentd(keys=data_keys),
            # ProbeTransform(message="üê∏ After KeepLargestConnectedComponentd"),
            HarmonizeLabelsd(keys=["image", "label"], kidneys_same_index=True),
            CropForegroundAxisd(
                keys=data_keys,
                source_key="image",
                axis=2,
                margin=5,
            ),
            transforms.CropForegroundd(
                keys=data_keys, source_key="body_filled_channel", margin=5
            ),
            ProbeTransform(message="üê¢ After CropForegroundd"),
            transforms.Resized(
                keys=data_keys, spatial_size=config.roi_size, mode="nearest"
            ),
            AddSpacingTensord(ref_key="image"),
            FilterAndRelabeld(
                image_key="image",
                label_key="label",
                conditioning_organs=conditioning_organs,
                target_organ=target_organ_index,
            ),
            ProbeTransform(message="üêç After FilterAndRelabeld"),
            MaskToSDFd(
                keys=data_keys,
                spacing_key="spacing_tensor",
                device=torch.device("cpu"),
            ),
            ProbeTransform(message="üêô After MaskToSDFd"),
            EnsureAllTorchd(print_changes=False),
            transforms.EnsureTyped(
                keys=data_keys + ["spacing_tensor"],
                track_meta=True,
            ),
        ]
    )

    return transform


# ============================================================================
# 4. BUILD MODEL
# ============================================================================


def build_model(config):
    """Create and load model"""

    model = DiffusionModelUNetFiLM(
        spatial_dims=config.spatial_dims,
        in_channels=config.in_channels
        + config.out_channels,  # concat image during inference
        out_channels=config.out_channels,
        channels=config.features,
        attention_levels=config.attention_levels,
        num_res_blocks=1,
        transformer_num_layers=0,
        num_head_channels=config.num_head_channels,
        with_conditioning=False,
        cross_attention_dim=None,
    )
    vol_embed = FiLMAdapter(
        unet_channels=config.features,
        embed_dim=256,
        volume_mean=1625.45,
        volume_std=535.51,
        use_log_volume=True,
    )

    # vol_embed = VolumeSpacingEmbedding(embed_dim=config.volume_embedding_dim)

    # Load checkpoint
    checkpoint = torch.load(config.checkpoint_path, map_location="cpu")
    model.load_state_dict(checkpoint["model"])
    vol_embed.load_state_dict(checkpoint["vol_embed_net"])

    model = model.to(config.device)
    vol_embed = vol_embed.to(config.device)
    model.eval()
    vol_embed.eval()

    print(f"‚úì Model loaded from {config.checkpoint_path}")
    print(f"‚úì Parameters: {sum(p.numel() for p in model.parameters())/1e6:.2f}M")

    return model, vol_embed


def build_scheduler(config):
    """Create DDIM scheduler for inference"""

    scheduler = DDIMScheduler(
        num_train_timesteps=config.diffusion_steps,
        beta_start=0.0001,
        beta_end=0.02,
        schedule=config.beta_schedule,
        clip_sample=False,
        prediction_type=config.model_mean_type,
    )
    scheduler.set_timesteps(num_inference_steps=config.ddim_steps)

    return scheduler


@torch.no_grad()
def run_inference(
    model, vol_embed, scheduler, image_sdf, body_filled_sdf, volume, spacing, config
):
    """
    Run DDIM sampling to generate organ mask

    Args:
        model: DiffusionModelUNet
        scheduler: DDIMScheduler
        image_sdf: conditioning image SDF [B, 1, H, W, D]
        config: InferenceConfig

    Returns:
        pred_mask: binary mask [B, 1, H, W, D]
        pred_sdf: signed distance field [B, 1, H, W, D]
    """

    device = config.device
    image_sdf = image_sdf.to(device).float()
    pred = torch.randn_like(image_sdf)

    # Initialize with random noise
    if body_filled_sdf is not None:
        body_filled_sdf = body_filled_sdf.to(device).float()
        image = torch.cat([image_sdf, body_filled_sdf], dim=1)
    else:
        image = image_sdf

    # vol_context = vol_embed(volume.to(device), spacing.to(device))
    flim_params = vol_embed(volume.to(device), spacing.to(device))

    # Get all timesteps
    all_next_timesteps = torch.cat(
        [scheduler.timesteps[1:], torch.tensor([0], dtype=scheduler.timesteps.dtype)]
    )

    # DDIM sampling loop
    for i, (t, next_t) in enumerate(zip(scheduler.timesteps, all_next_timesteps)):
        pred_before = pred.clone()
        # Concatenate conditioning
        model_input = torch.cat([pred, image], dim=1)

        # Predict
        # t_tensor = torch.full((image.shape[0],), t, device=device).long()
        # model_output = model(x=model_input, timesteps=t_tensor, context=vol_context)
        model_output = model(
            x=model_input,
            timesteps=torch.Tensor((t,)).to(image.device),
            context=None,
            film_params=flim_params,
        )

        print(f"Model output mean at step {i}: {model_output.mean().item():.4f}")

        # Classifier-free guidance (if scale != 1.0)
        if config.guidance_scale != 1.0:
            image_uncond = torch.zeros_like(image)
            model_input_uncond = torch.cat([pred, image_uncond], dim=1)
            uncond_output = model(
                x=model_input_uncond,
                timesteps=torch.Tensor((t,)).to(image.device),
                context=vol_embed,
            )
            model_output = uncond_output + config.guidance_scale * (
                model_output - uncond_output
            )

        # DDIM step
        pred, _ = scheduler.step(model_output, t, pred)

        change = (pred - pred_before).abs().mean().item()
        print(f"Step {i}: t={t}, change={change:.6f}, pred_mean={pred.mean():.4f}")

        if (i + 1) % 5 == 0:
            print(f"  Step {i+1}/{len(scheduler.timesteps)}")

    # Convert SDF to mask
    pred_sdf = pred.clone()
    pred_mask = sdf_to_mask(pred * 10.0)  # scale factor from your training

    return pred_mask, pred_sdf

In [4]:
test_jsonl_path = "/home/yb107/cvpr2025/DukeDiffSeg/data/mobina_mixed_colon_dataset/mobina_mixed_colon_dataset_with_body_filled_test.jsonl"
import json


def load_jsonl_inference(jsonl_path):
    data = []
    with open(jsonl_path, "r") as f:
        for line in f:
            data.append(json.loads(line))
    return data


test_data = load_jsonl_inference(test_jsonl_path)
test_data = test_data[0]

In [57]:
# 1. Preprocess
config = InferenceConfig()
# config.checkpoint_path = "/home/yb107/cvpr2025/DukeDiffSeg/outputs/diffunet-binary-iterative/7.1/checkpoints/final_unet.pth"
config.checkpoint_path = "/home/yb107/cvpr2025/DukeDiffSeg/outputs/diffunet-binary-iterative/7.5/checkpoints/liver/DiffUnet-binary-iterative_liver_latest_checkpoint_65.pt"
config.in_channels = 1

print("üì¶ Preprocessing data...")
transform = build_inference_transform(
    config, "liver", [5, 12, 6, 7, 4, 9, 11, 10, 2, 1, 3]
)

data_dict = {
    "image": test_data["mask"],
    "label": test_data["mask"],
    "body_filled_channel": test_data["body_filled_mask"],
}
print("üèãÔ∏è‚Äç‚ôÄÔ∏è Applying transforms...")
# data_dict = transform(data_dict)

# SAVE data_dict as .pt for caching
# torch.save(data_dict, "tmp/data_dict.pt")

# load data_dict from .pt
data_dict = torch.load("tmp/data_dict.pt", weights_only=False)

üì¶ Preprocessing data...
üèãÔ∏è‚Äç‚ôÄÔ∏è Applying transforms...




In [60]:
# 2. Build model & scheduler
print("üèóÔ∏è  Building model...")
model, vol_net = build_model(config)
scheduler = build_scheduler(config)

üèóÔ∏è  Building model...
‚úì Model loaded from /home/yb107/cvpr2025/DukeDiffSeg/outputs/diffunet-binary-iterative/7.5/checkpoints/liver/DiffUnet-binary-iterative_liver_latest_checkpoint_65.pt
‚úì Parameters: 29.11M


In [61]:
image_sdf = data_dict["image"].unsqueeze(0)  # [1, 1, H, W, D]
label_sdf = data_dict["label"].unsqueeze(0)
label_mask = sdf_to_mask(label_sdf)
body_filled_sdf = data_dict["body_filled_channel"].unsqueeze(0)
zero_filled_sdf = torch.zeros_like(body_filled_sdf)
spacing = data_dict["spacing_tensor"].unsqueeze(0)

voxel_volume = spacing.prod(dim=1)
label_volume = label_mask.sum(dim=(1, 2, 3, 4)) * voxel_volume
label_volume = label_volume / 1000.0  # convert to ml
label_volume

metatensor([1685.4402])

In [67]:
print("üé® Running DDIM sampling...")
pred_mask, pred_sdf = run_inference(
    model,
    vol_net,
    scheduler,
    image_sdf,
    None,
    label_volume,
    spacing,
    config,
)

üé® Running DDIM sampling...
Model output mean at step 0: -0.5166
Step 0: t=950, change=0.012404, pred_mean=-0.0127
Model output mean at step 1: -0.5168
Step 1: t=900, change=0.016991, pred_mean=-0.0296
Model output mean at step 2: -0.5169
Step 2: t=850, change=0.022214, pred_mean=-0.0517
Model output mean at step 3: -0.5166
Step 3: t=800, change=0.027751, pred_mean=-0.0792
Model output mean at step 4: -0.5162
Step 4: t=750, change=0.033291, pred_mean=-0.1121
  Step 5/20
Model output mean at step 5: -0.5159
Step 5: t=700, change=0.038782, pred_mean=-0.1497
Model output mean at step 6: -0.5156
Step 6: t=650, change=0.044451, pred_mean=-0.1910
Model output mean at step 7: -0.5156
Step 7: t=600, change=0.050418, pred_mean=-0.2346
Model output mean at step 8: -0.5155
Step 8: t=550, change=0.056538, pred_mean=-0.2789
Model output mean at step 9: -0.5155
Step 9: t=500, change=0.062423, pred_mean=-0.3222
  Step 10/20
Model output mean at step 10: -0.5156
Step 10: t=450, change=0.067571, pred

In [68]:
# Calculate pred_mask original volume
pred_mask_volume = pred_mask.sum(dim=(1, 2, 3, 4)).cpu() * voxel_volume
pred_mask_volume / 1000.0  # convert to ml

# pred_sdf_logits = torch.sigmoid(pred_sdf * 10)  # scale factor from your training
# pred_sdf_logits_volume = (pred_sdf_logits).sum(dim=(1, 2, 3, 4)).cpu() * voxel_volume
# pred_sdf_logits_volume / 1000.0  # convert to ml

# pred_sdf_logits.mean(), pred_sdf_logits.std(), pred_sdf_logits.min(), pred_sdf_logits.max()

metatensor([861.6476])

In [56]:
monai.transforms.SaveImage(
    output_dir="tmp/",
    output_postfix="_pred_75",
    separate_folder=False,
    # )(pred_mask.squeeze(0))
)(pred_sdf_logits.squeeze(0))
# monai.transforms.SaveImage(
#     output_dir="tmp/",
#     output_postfix="_pred_sdf",
#     separate_folder=False,
# )(pred_sdf.squeeze(0))
# monai.transforms.SaveImage(
#     output_dir="tmp/",
#     output_postfix="_img_sdf",
#     separate_folder=False,
# )(label_sdf.squeeze(0))
# monai.transforms.SaveImage(
#     output_dir="tmp/",
#     output_postfix="_body_filled_sdf",
#     separate_folder=False,
# )(body_filled_sdf.squeeze(0))

2025-11-08 21:24:47,350 INFO image_writer.py:197 - writing: tmp/Patient_00074_Study_78614_Series_03__pred_75.nii.gz


metatensor([[[[0.1673, 0.0775, 0.0510,  ..., 0.0462, 0.0639, 0.1503],
          [0.0795, 0.0242, 0.0133,  ..., 0.0121, 0.0191, 0.0709],
          [0.0564, 0.0151, 0.0079,  ..., 0.0074, 0.0122, 0.0528],
          ...,
          [0.0508, 0.0127, 0.0065,  ..., 0.0067, 0.0113, 0.0496],
          [0.0716, 0.0211, 0.0120,  ..., 0.0126, 0.0197, 0.0710],
          [0.1565, 0.0726, 0.0508,  ..., 0.0529, 0.0719, 0.1590]],

         [[0.0782, 0.0226, 0.0120,  ..., 0.0097, 0.0156, 0.0627],
          [0.0243, 0.0041, 0.0018,  ..., 0.0014, 0.0026, 0.0186],
          [0.0150, 0.0023, 0.0010,  ..., 0.0008, 0.0014, 0.0120],
          ...,
          [0.0123, 0.0017, 0.0007,  ..., 0.0006, 0.0012, 0.0105],
          [0.0202, 0.0034, 0.0016,  ..., 0.0015, 0.0026, 0.0181],
          [0.0701, 0.0209, 0.0123,  ..., 0.0123, 0.0189, 0.0685]],

         [[0.0551, 0.0134, 0.0066,  ..., 0.0051, 0.0085, 0.0420],
          [0.0149, 0.0022, 0.0009,  ..., 0.0007, 0.0012, 0.0106],
          [0.0091, 0.0013, 0.0005,  ..