In [1]:
import monai
import json, os
import numpy as np

import torch

from monai import transforms

jsonl_path = "/home/yb107/cvpr2025/DukeDiffSeg/data/mobina_mixed_colon_dataset/mobina_mixed_colon_dataset_with_body_filled.jsonl_train.jsonl"
# Load the JSONL file and create a list of dictionaries
with open(jsonl_path, "r") as f:
    files = [json.loads(line) for line in f]
print(f"Number of entries in JSONL: {len(files)}")
file = files[0]
file

Number of entries in JSONL: 356


{'mask': '/data/usr/yb107/colon_data/refined_by_mobina/male_cases_refined_by_md/masks/Patient_01799_Study_07874_Series_03.nii.gz',
 'body_filled_mask': '/data/usr/yb107/colon_data/refined_by_mobina/Body_filled_all/Patient_01799_Study_07874_Series_03_Body_filled.nii.gz'}

In [17]:
import numpy as np
import torch
from monai.transforms import MapTransform


class CropForegroundAxisd(MapTransform):
    """
    Crop the tensors in `keys` along a single spatial axis based on the foreground
    of `source_key`. Other axes are left untouched.
    """

    def __init__(self, keys, source_key, axis=0, select_fn=lambda x: x > 0, margin=5):
        if not isinstance(keys, (list, tuple)):
            keys = [keys]
        super().__init__(keys)
        if axis not in (0, 1, 2):
            raise ValueError(f"`axis` must be 0, 1, or 2; got {axis}")
        if margin < 0:
            raise ValueError("`margin` must be >= 0")
        self.keys = list(keys)
        self.source_key = source_key
        self.axis = axis
        self.select_fn = select_fn
        self.margin = margin

    def _to_tensor(self, x):
        return x if isinstance(x, torch.Tensor) else torch.as_tensor(x)

    def _get_spatial_axis_index(self, arr_ndim: int) -> int:
        if arr_ndim < 3:
            raise ValueError(
                f"Input must have at least 3 dims (D,H,W). Got ndim={arr_ndim}"
            )
        # spatial dims are the last 3 dims
        return arr_ndim - 3 + self.axis

    def _compute_crop_indices(self, src):
        t = self._to_tensor(src)

        # Reduce all non-spatial dims to a 3D spatial volume (D,H,W)
        if t.ndim == 3:
            spatial = t
        else:
            n_spatial = 3
            reduce_dims = tuple(range(t.ndim - n_spatial))  # e.g., (0,) for C,D,H,W
            spatial = t.any(dim=reduce_dims).to(t.dtype)

        mask = self.select_fn(spatial)
        mask = mask if isinstance(mask, torch.Tensor) else torch.as_tensor(mask)
        mask = mask.bool()

        if mask.ndim != 3:
            raise ValueError(f"Foreground mask must be 3D; got {tuple(mask.shape)}")

        axis = self.axis
        other = tuple(d for d in (0, 1, 2) if d != axis)
        # ↓↓↓ fix: reduce both non-axis dims at once to get a 1D presence vector
        presence_1d = mask.any(dim=other)

        if not presence_1d.any():
            return None

        idxs = presence_1d.nonzero(as_tuple=False).squeeze(-1)
        start = int(idxs.min().item())
        end_inclusive = int(idxs.max().item())
        size_axis = mask.shape[axis]

        start = max(0, start - self.margin)
        end = min(size_axis, end_inclusive + 1 + self.margin)  # [start, end)

        # Safety: never empty
        if end <= start:
            center = int((idxs.float().mean().round().item()))
            start = max(0, min(center, size_axis - 1))
            end = start + 1

        return start, end

    def __call__(self, data):
        d = dict(data)

        if self.source_key not in d:
            return d

        crop_range = self._compute_crop_indices(d[self.source_key])
        if crop_range is None:
            return d  # nothing to crop

        start, end = crop_range

        def _safe_crop(arr):
            arr_ndim = arr.ndim if hasattr(arr, "ndim") else np.asarray(arr).ndim
            gaxis = self._get_spatial_axis_index(arr_ndim)
            slicers = [slice(None)] * arr_ndim
            slicers[gaxis] = slice(start, end)
            out = arr[tuple(slicers)]
            # --------- NEW: safety net, avoid 0-size dim ----------
            if out.shape[gaxis] == 0:
                return arr  # fallback to no crop for this key
            # ------------------------------------------------------
            return out

        for key in self.keys:
            if key not in d:
                continue
            d[key] = _safe_crop(d[key])

            meta_key = f"{key}_meta_dict"
            if meta_key in d and isinstance(d[meta_key], dict):
                d[meta_key]["spatial_shape"] = np.asarray(
                    d[key].shape[-3:], dtype=np.int64
                )

        # Also crop source_key itself if it's not already included
        if self.source_key not in self.keys and self.source_key in d:
            d[self.source_key] = _safe_crop(d[self.source_key])
            meta_key = f"{self.source_key}_meta_dict"
            if meta_key in d and isinstance(d[meta_key], dict):
                d[meta_key]["spatial_shape"] = np.asarray(
                    d[self.source_key].shape[-3:], dtype=np.int64
                )

        return d


def remove_labels(x: torch.Tensor, labels: list, relabel: bool = False) -> torch.Tensor:
    """Remove the specified labels from the label tensor."""
    for label in labels:
        x[x == label] = 0

    if relabel:
        # get unique values in tensor x
        unique_values = x.unique()
        # Sort the unique values
        sorted_uv = sorted(unique_values)
        # Remap the labels
        for new_label, old_label in enumerate(sorted_uv):
            x[x == old_label] = new_label

    return x

In [20]:
import functools


trans = monai.transforms.Compose(
    [
        transforms.LoadImaged(keys=["mask", "body_filled_mask"]),
        # Get same orientation, spacing, and shape
        transforms.EnsureChannelFirstd(keys=["mask", "body_filled_mask"]),
        transforms.Spacingd(
            keys=["mask", "body_filled_mask"],
            pixdim=(1.0, 1.0, 2.0),
            mode=("nearest", "nearest"),
        ),
        transforms.Orientationd(keys=["mask", "body_filled_mask"], axcodes="RAS"),
        # transforms.CropForegroundd(
        #     keys=["mask", "body_filled_mask"], source_key="mask"
        # ),
        transforms.Lambdad(
            keys=["mask"],
            func=functools.partial(
                remove_labels, labels=[16, 17, 14]
            ),  # Assuming labels 17 and 14 are the organs,
        ),
        CropForegroundAxisd(
            keys=["mask", "body_filled_mask"], source_key="mask", axis=2
        ),
        transforms.CropForegroundd(keys=["mask"], source_key="mask"),
        transforms.SaveImaged(
            keys=["mask", "body_filled_mask"],
            output_dir="tmp",
            output_postfix="",
            separate_folder=False,
        ),
    ]
)
transformed = trans(file)

2025-09-30 21:45:14,021 INFO image_writer.py:197 - writing: tmp/Patient_01799_Study_07874_Series_03.nii.gz
2025-09-30 21:45:14,456 INFO image_writer.py:197 - writing: tmp/Patient_01799_Study_07874_Series_03_Body_filled.nii.gz


## Dialation

In [7]:
# MONAI transform: grow non-largest components until connected, with debug snapshots
from typing import Optional, Dict, Any, List
import numpy as np
import torch
from monai.transforms import MapTransform
from monai.config import KeysCollection
from monai.utils import convert_to_numpy
import cc3d
from scipy.ndimage import binary_dilation, generate_binary_structure


def _to_numpy(x):
    if isinstance(x, torch.Tensor):
        return x.detach().cpu().numpy()
    return np.asarray(x)


def _to_like_type(arr_np, like):
    if isinstance(like, torch.Tensor):
        return torch.from_numpy(arr_np).to(like.device, dtype=like.dtype)
    return arr_np.astype(like.dtype, copy=False)


def _ensure_binary(x, thr=0.5):
    if x.dtype == np.bool_:
        return x
    if np.issubdtype(x.dtype, np.integer):
        return x != 0
    return x > thr


class ConnectMinorByDilationRollbackd(MapTransform):
    """
    Grow ONLY the non-largest components until they connect to the largest,
    then (optionally) roll back dilation. Can export per-iteration snapshots.

    New params vs earlier:
      - cc_diagonal_ok: if True, use 26/8 connectivity for CC counting (diagonal touch counts as connected).
      - dilate_diagonals: if True, dilation uses 26/8 neighborhood (grows diagonally too).
      - debug: if True, write snapshots & logs to dict keys <debug_prefix>*.
      - debug_every: keep every k-th dilation step (and the final one) to limit memory.
      - debug_prefix: key prefix for debug outputs.
      - alt_grow_anchor: if True and not connected after half max_iters, also dilate the largest
                         every other iteration to help bridging wide gaps.
      - if connection never achieved, we DO NOT roll back (you’ll get the last dilation).
    """

    def __init__(
        self,
        keys: KeysCollection,
        min_ratio: float = 0.05,
        binarize_thr: float = 0.5,
        max_iters: int = 512,
        # connectivity controls
        cc_diagonal_ok: bool = False,  # False => 6(3D)/4(2D) CC; True => 26/8 CC
        dilate_diagonals: bool = False,  # False => face-only dilation; True => allow diagonal growth
        # rollback
        rollback_iters: int = 1,
        rollback_mode: str = "partial",  # 'partial' | 'keep' | 'full'
        # strategy
        alt_grow_anchor: bool = False,
        # debug
        debug: bool = False,
        debug_every: int = 1,
        debug_prefix: str = "post_debug",
        allow_missing_keys: bool = False,
    ):
        super().__init__(keys, allow_missing_keys)
        self.min_ratio = float(min_ratio)
        self.binarize_thr = float(binarize_thr)
        self.max_iters = int(max_iters)
        self.cc_diagonal_ok = bool(cc_diagonal_ok)
        self.dilate_diagonals = bool(dilate_diagonals)
        self.rollback_iters = int(max(0, rollback_iters))
        assert rollback_mode in ("partial", "keep", "full")
        self.rollback_mode = rollback_mode
        self.alt_grow_anchor = bool(alt_grow_anchor)
        self.debug = bool(debug)
        self.debug_every = max(1, int(debug_every))
        self.debug_prefix = str(debug_prefix)

    # --- helpers ---
    def _cc_rank(self, ndim):
        # rank 1 ~ face-only (6/4), rank 3/2 ~ full (26/8)
        if ndim == 2:
            return 2 if self.cc_diagonal_ok else 1
        return 3 if self.cc_diagonal_ok else 1

    def _dilate_rank(self, ndim):
        if ndim == 2:
            return 2 if self.dilate_diagonals else 1
        return 3 if self.dilate_diagonals else 1

    def _label(self, mask_bool):
        return cc3d.connected_components(mask_bool.astype(np.uint8), connectivity=6)

    def _num_cc(self, mask_bool) -> int:
        labels = self._label(mask_bool)
        counts = np.bincount(labels.ravel())
        if len(counts) == 0:
            return 0
        counts[0] = 0
        return int((counts > 0).sum())

    def _largest_and_others(self, mask_bool):
        labels = self._label(mask_bool)
        counts = np.bincount(labels.ravel())
        counts[0] = 0
        if (counts > 0).sum() == 0:
            return np.zeros_like(mask_bool, bool), np.zeros_like(mask_bool, bool)
        largest_id = int(np.argmax(counts))
        L = labels == largest_id
        O = mask_bool & (~L)
        return L, O

    def _prune_small(self, mask_bool):
        labels = self._label(mask_bool)
        counts = np.bincount(labels.ravel())
        if len(counts) <= 1:
            return mask_bool
        counts[0] = 0
        if (counts > 0).sum() <= 1:
            return mask_bool
        largest_id = int(np.argmax(counts))
        largest_sz = int(counts[largest_id])
        keep_ids = {
            i
            for i, c in enumerate(counts)
            if (i != 0 and c >= self.min_ratio * largest_sz)
        }
        keep_ids.add(largest_id)
        return np.isin(labels, list(keep_ids))

    # --- core per-channel ---
    def _process_one_channel(
        self, mask_np: np.ndarray, dbg_store: Dict[str, Any]
    ) -> np.ndarray:
        bin_mask = _ensure_binary(mask_np, self.binarize_thr)

        # nothing to do?
        if self._num_cc(bin_mask) <= 1:
            return mask_np

        pruned = self._prune_small(bin_mask)
        if self._num_cc(pruned) <= 1:
            return pruned.astype(mask_np.dtype)

        L, O0 = self._largest_and_others(pruned)
        if not O0.any():
            return pruned.astype(mask_np.dtype)

        st = generate_binary_structure(pruned.ndim, self._dilate_rank(pruned.ndim))
        O_hist: List[np.ndarray] = [O0]
        L_hist: List[np.ndarray] = [L]  # stored only when alt_grow_anchor=True
        cc_log = [self._num_cc(L | O0)]

        O = O0.copy()
        anchor = L.copy()
        connected = False

        for it in range(1, self.max_iters + 1):
            # optionally alternate-grow the anchor to help crossing big gaps
            if self.alt_grow_anchor and (it >= self.max_iters // 2) and (it % 2 == 0):
                anchor = binary_dilation(anchor, structure=st)
                L_hist.append(anchor)

            O = binary_dilation(O, structure=st)
            O_hist.append(O)

            ncc = self._num_cc(anchor | O)
            if (it % self.debug_every) == 0:
                cc_log.append(ncc)

            if ncc <= 1:
                connected = True
                break

        # decide rollback
        if not connected:
            # no rollback if never connected: return last dilation so you can "see the dilation"
            O_final = O_hist[-1]
        else:
            if self.rollback_mode == "keep":
                O_final = O_hist[-1]
            elif self.rollback_mode == "full":
                O_final = O_hist[0]
            else:  # partial
                idx = max(0, len(O_hist) - 1 - self.rollback_iters)
                O_final = O_hist[idx]

        out_bool = anchor | O_final

        # --- debug export ---
        if self.debug:
            # compact stacks: keep every k-th snapshot and always the last
            sel = list(range(0, len(O_hist), self.debug_every))
            if sel[-1] != len(O_hist) - 1:
                sel.append(len(O_hist) - 1)
            O_stack = np.stack([O_hist[i] for i in sel], axis=0).astype(
                np.uint8
            )  # [T,...]
            dbg_store["O_stack"] = O_stack  # others’ dilation snapshots
            dbg_store["O_stack_iters"] = np.array(sel, dtype=np.int32)
            dbg_store["connected"] = np.array([int(connected)], dtype=np.int32)
            dbg_store["cc_log"] = np.array(cc_log, dtype=np.int32)
            if self.alt_grow_anchor:
                # Anchor snapshots at the same cadence (best-effort)
                L_sel = list(range(0, len(L_hist), self.debug_every))
                if L_sel and L_sel[-1] != len(L_hist) - 1:
                    L_sel.append(len(L_hist) - 1)
                if L_hist:
                    L_stack = np.stack(
                        [L_hist[min(i, len(L_hist) - 1)] for i in L_sel], axis=0
                    ).astype(np.uint8)
                    dbg_store["L_stack"] = L_stack
                    dbg_store["L_stack_iters"] = np.array(L_sel, dtype=np.int32)

        return out_bool.astype(mask_np.dtype)

    def _process_any(self, arr: np.ndarray, dbg_root: Dict[str, Any]):
        # channel-first support
        if arr.ndim >= 4 and arr.shape[0] > 1:
            outs = []
            for c in range(arr.shape[0]):
                dbg_store = {}
                outc = self._process_one_channel(arr[c], dbg_store)
                if self.debug and dbg_store:
                    dbg_root[f"{self.debug_prefix}_ch{c}_O_stack"] = dbg_store.get(
                        "O_stack"
                    )
                    dbg_root[f"{self.debug_prefix}_ch{c}_O_stack_iters"] = (
                        dbg_store.get("O_stack_iters")
                    )
                    if "L_stack" in dbg_store:
                        dbg_root[f"{self.debug_prefix}_ch{c}_L_stack"] = dbg_store.get(
                            "L_stack"
                        )
                        dbg_root[f"{self.debug_prefix}_ch{c}_L_stack_iters"] = (
                            dbg_store.get("L_stack_iters")
                        )
                    dbg_root[f"{self.debug_prefix}_ch{c}_connected"] = dbg_store.get(
                        "connected"
                    )
                    dbg_root[f"{self.debug_prefix}_ch{c}_cc_log"] = dbg_store.get(
                        "cc_log"
                    )
                outs.append(outc)
            return np.stack(outs, axis=0)
        if arr.ndim >= 4 and arr.shape[0] == 1:
            dbg_store = {}
            out = self._process_one_channel(arr[0], dbg_store)
            if self.debug and dbg_store:
                dbg_root[f"{self.debug_prefix}_ch0_O_stack"] = dbg_store.get("O_stack")
                dbg_root[f"{self.debug_prefix}_ch0_O_stack_iters"] = dbg_store.get(
                    "O_stack_iters"
                )
                if "L_stack" in dbg_store:
                    dbg_root[f"{self.debug_prefix}_ch0_L_stack"] = dbg_store.get(
                        "L_stack"
                    )
                    dbg_root[f"{self.debug_prefix}_ch0_L_stack_iters"] = dbg_store.get(
                        "L_stack_iters"
                    )
                dbg_root[f"{self.debug_prefix}_ch0_connected"] = dbg_store.get(
                    "connected"
                )
                dbg_root[f"{self.debug_prefix}_ch0_cc_log"] = dbg_store.get("cc_log")
            return out[None]
        # no channel dim
        dbg_store = {}
        out = self._process_one_channel(arr, dbg_store)
        if self.debug and dbg_store:
            dbg_root[f"{self.debug_prefix}_O_stack"] = dbg_store.get("O_stack")
            dbg_root[f"{self.debug_prefix}_O_stack_iters"] = dbg_store.get(
                "O_stack_iters"
            )
            if "L_stack" in dbg_store:
                dbg_root[f"{self.debug_prefix}_L_stack"] = dbg_store.get("L_stack")
                dbg_root[f"{self.debug_prefix}_L_stack_iters"] = dbg_store.get(
                    "L_stack_iters"
                )
            dbg_root[f"{self.debug_prefix}_connected"] = dbg_store.get("connected")
            dbg_root[f"{self.debug_prefix}_cc_log"] = dbg_store.get("cc_log")
        return out

    # --- MONAI entrypoint ---
    def __call__(self, data: Dict[str, Any]):
        d = dict(data)
        for key in self.keys:
            if key not in d:
                if not self.allow_missing_keys:
                    raise KeyError(f"Missing key: {key}")
                continue
            arr = convert_to_numpy(d[key])
            # store debug arrays back into the same dict
            processed = self._process_any(_to_numpy(arr), d if self.debug else {})
            d[key] = _to_like_type(processed, d[key])
        return d

In [30]:
import numpy as np
from typing import Optional, Union, Tuple
from monai.transforms import Transform
from monai.config import KeysCollection
from scipy import ndimage
from scipy.ndimage import (
    binary_dilation,
    binary_erosion,
    label,
    find_objects,
    distance_transform_edt,
    binary_closing,
)
import torch


class ColonMaskPostProcessing(Transform):
    """
    Post-processing transform for colon segmentation masks with natural appearance.

    Steps:
    1. Detect and handle thin unnatural connections (bottlenecks)
    2. Identify connected components
    3. Remove small components below volume ratio threshold
    4. Join disjoint components using morphological operations
    5. Smooth boundaries for natural appearance

    Args:
        volume_ratio_threshold: Minimum volume ratio compared to largest component (default: 0.1)
        max_dilation_iterations: Maximum iterations for component joining (default: 10)
        structuring_element_size: Size of structuring element for morphological ops (default: 3)
        preserve_largest_only: If True, only keep largest component when single component expected (default: False)
        min_neck_thickness: Minimum thickness for natural connections in voxels (default: 3)
        smooth_iterations: Number of smoothing iterations for final result (default: 2)
    """

    def __init__(
        self,
        volume_ratio_threshold: float = 0.1,
        max_dilation_iterations: int = 10,
        structuring_element_size: int = 3,
        preserve_largest_only: bool = False,
        min_neck_thickness: int = 3,
        smooth_iterations: int = 2,
    ):
        self.volume_ratio_threshold = volume_ratio_threshold
        self.max_dilation_iterations = max_dilation_iterations
        self.structuring_element_size = structuring_element_size
        self.preserve_largest_only = preserve_largest_only
        self.min_neck_thickness = min_neck_thickness
        self.smooth_iterations = smooth_iterations

        # Create structuring element for morphological operations
        self.struct_element = self._create_structuring_element()

    def _create_structuring_element(self):
        """Create a spherical/ball structuring element"""
        size = self.structuring_element_size
        struct = ndimage.generate_binary_structure(3, 1)
        struct = ndimage.iterate_structure(struct, size // 2)
        return struct

    def _detect_thin_connections(self, mask: np.ndarray) -> Tuple[np.ndarray, bool]:
        """
        Detect and break thin unnatural connections (bottlenecks).
        Uses distance transform to find regions where the mask is very thin.

        Returns:
            mask_cleaned: Mask with thin connections removed
            has_thin_connections: Boolean indicating if thin connections were found
        """
        # Compute distance transform (distance to nearest background voxel)
        dist_transform = distance_transform_edt(mask)

        # Find voxels that are part of thin structures
        # These are voxels where the distance to background is very small
        thin_voxels = (dist_transform > 0) & (
            dist_transform < self.min_neck_thickness / 2
        )

        # Check if removing these thin voxels would disconnect the mask
        test_mask = mask & ~thin_voxels
        labeled_test, num_components_test = label(test_mask)

        # Original connectivity check
        labeled_original, num_components_original = label(mask)

        # If removing thin voxels creates more components, we found thin connections
        if num_components_test > num_components_original:
            # These thin connections are likely artifacts, break them
            return test_mask, True
        else:
            # No problematic thin connections, or they're essential to single component
            return mask, False

    def _get_connected_components(self, mask: np.ndarray):
        """
        Identify and analyze connected components

        Returns:
            labeled_mask: Array with labeled components
            component_sizes: Dictionary mapping label to size
            num_components: Number of components
        """
        labeled_mask, num_components = label(mask)

        if num_components == 0:
            return labeled_mask, {}, 0

        # Calculate size of each component
        component_sizes = {}
        for i in range(1, num_components + 1):
            component_sizes[i] = np.sum(labeled_mask == i)

        return labeled_mask, component_sizes, num_components

    def _remove_small_components(self, labeled_mask: np.ndarray, component_sizes: dict):
        """
        Remove components below volume ratio threshold

        Returns:
            filtered_mask: Binary mask with small components removed
            kept_labels: List of labels that were kept
        """
        if not component_sizes:
            return labeled_mask.astype(bool), []

        # Find largest component size
        max_size = max(component_sizes.values())
        threshold_size = max_size * self.volume_ratio_threshold

        # Keep only components above threshold
        kept_labels = [
            label for label, size in component_sizes.items() if size >= threshold_size
        ]

        # Create filtered mask
        filtered_mask = np.zeros_like(labeled_mask, dtype=bool)
        for kept_label in kept_labels:
            filtered_mask |= labeled_mask == kept_label

        return filtered_mask, kept_labels

    def _smooth_boundaries(self, mask: np.ndarray) -> np.ndarray:
        """
        Smooth mask boundaries using morphological closing for natural appearance

        Returns:
            smoothed_mask: Mask with smoothed boundaries
        """
        if self.smooth_iterations == 0:
            return mask

        smoothed = mask.copy()

        # Create a smaller structuring element for gentle smoothing
        smooth_struct = ndimage.generate_binary_structure(3, 1)

        for _ in range(self.smooth_iterations):
            # Morphological closing: dilation followed by erosion
            # This fills small holes and smooths boundaries
            smoothed = binary_closing(smoothed, structure=smooth_struct)

        return smoothed

    def _join_components_naturally(self, mask: np.ndarray):
        """
        Join disjoint components using iterative dilation until connection,
        then restore natural boundaries using controlled erosion and smoothing

        Returns:
            joined_mask: Binary mask with components joined naturally
        """
        # Check if components are already connected
        labeled_initial, num_initial = label(mask)
        if num_initial <= 1:
            return mask

        # Store original mask for boundary restoration
        original_mask = mask.copy()

        # Get bounding boxes of components to estimate required dilation
        component_props = find_objects(labeled_initial)

        # Iteratively dilate until components join
        dilated_mask = mask.copy()
        iteration = 0
        dilation_history = [mask.copy()]

        while iteration < self.max_dilation_iterations:
            # Dilate with structuring element
            dilated_mask = binary_dilation(dilated_mask, structure=self.struct_element)
            dilation_history.append(dilated_mask.copy())

            # Check if components have joined
            labeled_temp, num_components = label(dilated_mask)
            if num_components == 1:
                break

            iteration += 1

        # If components didn't join after max iterations, return best attempt
        if num_components > 1:
            return dilated_mask

        # Strategy for natural restoration:
        # Use gradual erosion but stop before disconnection
        restored_mask = dilated_mask.copy()

        # Erode gradually, checking connectivity at each step
        for i in range(iteration):
            eroded_temp = binary_erosion(restored_mask, structure=self.struct_element)

            # Check if erosion would disconnect components
            labeled_check, num_check = label(eroded_temp)
            if num_check > 1:
                # Stop erosion to preserve connectivity
                break

            restored_mask = eroded_temp

        # Alternative: Blend with original to preserve natural boundaries
        # where possible while maintaining connections
        bridge_region = dilated_mask & ~original_mask
        natural_region = original_mask.copy()

        # Combine: use original where it exists, bridges where needed
        final_mask = natural_region | (restored_mask & bridge_region)

        # Verify connectivity one more time
        labeled_final, num_final = label(final_mask)
        if num_final > 1:
            # If still disconnected, use the restored mask
            final_mask = restored_mask

        return final_mask

    def __call__(
        self, mask: Union[np.ndarray, torch.Tensor]
    ) -> Union[np.ndarray, torch.Tensor]:
        """
        Apply post-processing to colon segmentation mask

        Args:
            mask: Binary segmentation mask (can be numpy array or torch tensor)

        Returns:
            Processed mask in same format as input
        """
        # Handle torch tensors
        is_torch = isinstance(mask, torch.Tensor)
        if is_torch:
            device = mask.device
            dtype = mask.dtype
            mask_np = mask.detach().cpu().numpy()
        else:
            mask_np = mask.copy()

        # Ensure binary mask
        mask_np = mask_np.astype(bool)

        # Remove channel dimension if present
        squeeze_dim = False
        if mask_np.ndim == 4 and mask_np.shape[0] == 1:
            mask_np = mask_np[0]
            squeeze_dim = True

        # Step 0: Detect and handle thin unnatural connections
        mask_np, had_thin_connections = self._detect_thin_connections(mask_np)

        # Step 1: Get connected components
        labeled_mask, component_sizes, num_components = self._get_connected_components(
            mask_np
        )

        # If no components, return empty mask
        if num_components == 0:
            result = mask_np

        # If single component, just smooth it
        elif num_components == 1:
            result = self._smooth_boundaries(mask_np)

        # Multiple components: process
        else:
            # Step 2: Remove small components
            filtered_mask, kept_labels = self._remove_small_components(
                labeled_mask, component_sizes
            )

            # If only one component remains after filtering, smooth and return
            filtered_labeled, filtered_num = label(filtered_mask)
            if filtered_num <= 1:
                result = self._smooth_boundaries(filtered_mask)
            else:
                # Step 3: Join remaining components naturally
                joined_mask = self._join_components_naturally(filtered_mask)

                # Step 4: Final smoothing for natural appearance
                result = self._smooth_boundaries(joined_mask)

        # Restore dimensions
        if squeeze_dim:
            result = result[np.newaxis, ...]

        # Convert back to torch if needed
        if is_torch:
            result = torch.from_numpy(result.astype(np.float32)).to(
                device=device, dtype=dtype
            )

        return result


class ColonMaskPostProcessingd(Transform):
    """
    Dictionary-based version for use with MONAI dictionary transforms

    Args:
        keys: Keys to apply the transform to
        volume_ratio_threshold: Minimum volume ratio compared to largest component
        max_dilation_iterations: Maximum iterations for component joining
        structuring_element_size: Size of structuring element
        preserve_largest_only: If True, only keep largest component
        min_neck_thickness: Minimum thickness for natural connections in voxels
        smooth_iterations: Number of smoothing iterations for final result
        allow_missing_keys: If True, don't raise error for missing keys
    """

    def __init__(
        self,
        keys: KeysCollection,
        volume_ratio_threshold: float = 0.1,
        max_dilation_iterations: int = 10,
        structuring_element_size: int = 3,
        preserve_largest_only: bool = False,
        min_neck_thickness: int = 4,
        smooth_iterations: int = 10,
        allow_missing_keys: bool = False,
    ):
        self.keys = keys if isinstance(keys, (list, tuple)) else [keys]
        self.transform = ColonMaskPostProcessing(
            volume_ratio_threshold=volume_ratio_threshold,
            max_dilation_iterations=max_dilation_iterations,
            structuring_element_size=structuring_element_size,
            preserve_largest_only=preserve_largest_only,
            min_neck_thickness=min_neck_thickness,
            smooth_iterations=smooth_iterations,
        )
        self.allow_missing_keys = allow_missing_keys

    def __call__(self, data: dict) -> dict:
        """Apply transform to dictionary data"""
        d = dict(data)
        for key in self.keys:
            if key in d:
                d[key] = self.transform(d[key])
            elif not self.allow_missing_keys:
                raise KeyError(f"Key '{key}' not found in data dictionary")
        return d

In [64]:
import numpy as np
from typing import Union
from monai.transforms import Transform
from monai.config import KeysCollection
from scipy.ndimage import binary_dilation, binary_erosion, generate_binary_structure
import torch


class SmoothColonMask(Transform):
    """
    Simple smoothing transform for colon segmentation masks.

    How it works:
    1. DILATE: Expands the mask outward (fills gaps, connects close parts)
    2. ERODE: Shrinks back by same amount (returns to original size but smoothed)

    This removes weird appendages and smooths irregular boundaries.

    Args:
        iterations: Number of dilation/erosion iterations (default: 3)
                   Higher = more aggressive smoothing
                   Typical range: 2-5
        connectivity: Structuring element connectivity (1, 2, or 3)
                     1 = face connectivity (6-connected in 3D)
                     2 = face+edge connectivity (18-connected in 3D)
                     3 = face+edge+corner connectivity (26-connected in 3D)
                     Default: 2 (good balance)
    """

    def __init__(self, iterations: int = 3, connectivity: int = 2):
        self.iterations = iterations
        self.connectivity = connectivity

        # Create structuring element (defines the shape of dilation/erosion)
        # This is a 3D ball-like structure
        self.struct_element = generate_binary_structure(3, connectivity)

    def __call__(
        self, mask: Union[np.ndarray, torch.Tensor]
    ) -> Union[np.ndarray, torch.Tensor]:
        """
        Apply smoothing to mask

        Args:
            mask: Binary segmentation mask

        Returns:
            Smoothed mask in same format as input
        """
        # Handle torch tensors
        is_torch = isinstance(mask, torch.Tensor)
        if is_torch:
            device = mask.device
            dtype = mask.dtype
            mask_np = mask.detach().cpu().numpy()
        else:
            mask_np = mask.copy()

        # Ensure binary
        mask_np = mask_np.astype(bool)

        # Handle channel dimension
        squeeze_dim = False
        if mask_np.ndim == 4 and mask_np.shape[0] == 1:
            mask_np = mask_np[0]
            squeeze_dim = True

        # STEP 1: Dilate (expand outward) - fills gaps, smooths bumps outward
        dilated = mask_np.copy()
        for _ in range(self.iterations):
            dilated = binary_dilation(dilated, structure=self.struct_element)

        # STEP 2: Erode (shrink back) - returns to approximately original size
        smoothed = dilated.copy()
        for _ in range(self.iterations):
            smoothed = binary_erosion(smoothed, structure=self.struct_element)

        # Restore dimensions
        if squeeze_dim:
            smoothed = smoothed[np.newaxis, ...]

        # Convert back to torch if needed
        if is_torch:
            smoothed = torch.from_numpy(smoothed.astype(np.float32)).to(
                device=device, dtype=dtype
            )

        return smoothed


class SmoothColonMaskd(Transform):
    """
    Dictionary-based version for MONAI pipelines

    Args:
        keys: Keys to apply transform to
        iterations: Number of dilation/erosion iterations (default: 3)
        connectivity: Structuring element connectivity (1, 2, or 3)
        allow_missing_keys: Don't raise error for missing keys
    """

    def __init__(
        self,
        keys: KeysCollection,
        iterations: int = 3,
        connectivity: int = 2,
        allow_missing_keys: bool = False,
    ):
        self.keys = keys if isinstance(keys, (list, tuple)) else [keys]
        self.transform = SmoothColonMask(
            iterations=iterations, connectivity=connectivity
        )
        self.allow_missing_keys = allow_missing_keys

    def __call__(self, data: dict) -> dict:
        """Apply transform to dictionary"""
        d = dict(data)
        for key in self.keys:
            if key in d:
                d[key] = self.transform(d[key])
            elif not self.allow_missing_keys:
                raise KeyError(f"Key '{key}' not found in data")
        return d

In [65]:
import monai
import json, os
import numpy as np

import torch

from monai import transforms

nii_gz_file = {
    "img": "/home/yb107/cvpr2025/DukeDiffSeg/outputs/diffunet-binary-colon/5.1/inference_c_grade_550_gs_3.0/Patient_00135_Study_03256_Series_03_pred.nii.gz"
}
transforms = monai.transforms.Compose(
    [
        transforms.LoadImaged(keys="img", image_only=True),
        transforms.EnsureChannelFirstd(keys="img"),
        # ColonMaskPostProcessingd(
        #     keys="img",
        # ),
        SmoothColonMaskd(keys="img", iterations=6, connectivity=2),
        transforms.SaveImaged(
            keys="img", output_dir="tmp", output_postfix="", separate_folder=False
        ),
    ]
)
transforms(nii_gz_file)

2025-10-06 23:14:59,280 INFO image_writer.py:197 - writing: tmp/0.nii.gz


{'img': tensor([[[[0., 0., 0.,  ..., 0., 0., 0.],
           [0., 0., 0.,  ..., 0., 0., 0.],
           [0., 0., 0.,  ..., 0., 0., 0.],
           ...,
           [0., 0., 0.,  ..., 0., 0., 0.],
           [0., 0., 0.,  ..., 0., 0., 0.],
           [0., 0., 0.,  ..., 0., 0., 0.]],
 
          [[0., 0., 0.,  ..., 0., 0., 0.],
           [0., 0., 0.,  ..., 0., 0., 0.],
           [0., 0., 0.,  ..., 0., 0., 0.],
           ...,
           [0., 0., 0.,  ..., 0., 0., 0.],
           [0., 0., 0.,  ..., 0., 0., 0.],
           [0., 0., 0.,  ..., 0., 0., 0.]],
 
          [[0., 0., 0.,  ..., 0., 0., 0.],
           [0., 0., 0.,  ..., 0., 0., 0.],
           [0., 0., 0.,  ..., 0., 0., 0.],
           ...,
           [0., 0., 0.,  ..., 0., 0., 0.],
           [0., 0., 0.,  ..., 0., 0., 0.],
           [0., 0., 0.,  ..., 0., 0., 0.]],
 
          ...,
 
          [[0., 0., 0.,  ..., 0., 0., 0.],
           [0., 0., 0.,  ..., 0., 0., 0.],
           [0., 0., 0.,  ..., 0., 0., 0.],
           ...,
