In [14]:
import monai
import json, os
import numpy as np

import torch

from monai import transforms

jsonl_path = "/home/yb107/cvpr2025/DukeDiffSeg/data/mobina_mixed_colon_dataset/mobina_mixed_colon_dataset_with_body_filled_train.jsonl"
# Load the JSONL file and create a list of dictionaries
with open(jsonl_path, "r") as f:
    files = [json.loads(line) for line in f]
print(f"Number of entries in JSONL: {len(files)}")
file = files[7]
file, files[6]

Number of entries in JSONL: 356


({'mask': '/data/usr/yb107/colon_data/refined_by_mobina/colon_refined_by_mobina/masks/Patient_00534_Study_65512_Series_03.nii.gz',
  'body_filled_mask': '/data/usr/yb107/colon_data/refined_by_mobina/Body_filled_all/Patient_00534_Study_65512_Series_03_Body_filled.nii.gz'},
 {'mask': '/data/usr/yb107/colon_data/refined_by_mobina/male_cases_refined_by_md/masks/Patient_00543_Study_02713_Series_03.nii.gz',
  'body_filled_mask': '/data/usr/yb107/colon_data/refined_by_mobina/Body_filled_all/Patient_00543_Study_02713_Series_03_Body_filled.nii.gz'})

In [19]:
import numpy as np
import torch
from monai.transforms import MapTransform


class CropForegroundAxisd(MapTransform):
    """
    Crop the tensors in `keys` along a single spatial axis based on the foreground
    of `source_key`. Other axes are left untouched.
    """

    def __init__(self, keys, source_key, axis=0, select_fn=lambda x: x > 0, margin=5):
        if not isinstance(keys, (list, tuple)):
            keys = [keys]
        super().__init__(keys)
        if axis not in (0, 1, 2):
            raise ValueError(f"`axis` must be 0, 1, or 2; got {axis}")
        if margin < 0:
            raise ValueError("`margin` must be >= 0")
        self.keys = list(keys)
        self.source_key = source_key
        self.axis = axis
        self.select_fn = select_fn
        self.margin = margin

    def _to_tensor(self, x):
        return x if isinstance(x, torch.Tensor) else torch.as_tensor(x)

    def _get_spatial_axis_index(self, arr_ndim: int) -> int:
        if arr_ndim < 3:
            raise ValueError(
                f"Input must have at least 3 dims (D,H,W). Got ndim={arr_ndim}"
            )
        # spatial dims are the last 3 dims
        return arr_ndim - 3 + self.axis

    def _compute_crop_indices(self, src):
        t = self._to_tensor(src)

        # Reduce all non-spatial dims to a 3D spatial volume (D,H,W)
        if t.ndim == 3:
            spatial = t
        else:
            n_spatial = 3
            reduce_dims = tuple(range(t.ndim - n_spatial))  # e.g., (0,) for C,D,H,W
            spatial = t.any(dim=reduce_dims).to(t.dtype)

        mask = self.select_fn(spatial)
        mask = mask if isinstance(mask, torch.Tensor) else torch.as_tensor(mask)
        mask = mask.bool()

        if mask.ndim != 3:
            raise ValueError(f"Foreground mask must be 3D; got {tuple(mask.shape)}")

        axis = self.axis
        other = tuple(d for d in (0, 1, 2) if d != axis)
        # ↓↓↓ fix: reduce both non-axis dims at once to get a 1D presence vector
        presence_1d = mask.any(dim=other)

        if not presence_1d.any():
            return None

        idxs = presence_1d.nonzero(as_tuple=False).squeeze(-1)
        start = int(idxs.min().item())
        end_inclusive = int(idxs.max().item())
        size_axis = mask.shape[axis]

        start = max(0, start - self.margin)
        end = min(size_axis, end_inclusive + 1 + self.margin)  # [start, end)

        # Safety: never empty
        if end <= start:
            center = int((idxs.float().mean().round().item()))
            start = max(0, min(center, size_axis - 1))
            end = start + 1

        return start, end

    def __call__(self, data):
        d = dict(data)

        if self.source_key not in d:
            return d

        crop_range = self._compute_crop_indices(d[self.source_key])
        if crop_range is None:
            return d  # nothing to crop

        start, end = crop_range

        def _safe_crop(arr):
            arr_ndim = arr.ndim if hasattr(arr, "ndim") else np.asarray(arr).ndim
            gaxis = self._get_spatial_axis_index(arr_ndim)
            slicers = [slice(None)] * arr_ndim
            slicers[gaxis] = slice(start, end)
            out = arr[tuple(slicers)]
            # --------- NEW: safety net, avoid 0-size dim ----------
            if out.shape[gaxis] == 0:
                return arr  # fallback to no crop for this key
            # ------------------------------------------------------
            return out

        for key in self.keys:
            if key not in d:
                continue
            d[key] = _safe_crop(d[key])

            meta_key = f"{key}_meta_dict"
            if meta_key in d and isinstance(d[meta_key], dict):
                d[meta_key]["spatial_shape"] = np.asarray(
                    d[key].shape[-3:], dtype=np.int64
                )

        # Also crop source_key itself if it's not already included
        if self.source_key not in self.keys and self.source_key in d:
            d[self.source_key] = _safe_crop(d[self.source_key])
            meta_key = f"{self.source_key}_meta_dict"
            if meta_key in d and isinstance(d[meta_key], dict):
                d[meta_key]["spatial_shape"] = np.asarray(
                    d[self.source_key].shape[-3:], dtype=np.int64
                )

        return d


def remove_labels(x: torch.Tensor, labels: list, relabel: bool = False) -> torch.Tensor:
    """Remove the specified labels from the label tensor."""
    for label in labels:
        x[x == label] = 0

    if relabel:
        # get unique values in tensor x
        unique_values = x.unique()
        # Sort the unique values
        sorted_uv = sorted(unique_values)
        # Remap the labels
        for new_label, old_label in enumerate(sorted_uv):
            x[x == old_label] = new_label

    return x


def transform_labels(x: torch.Tensor, label_map: dict) -> torch.Tensor:
    """Transform labels in the tensor according to the provided label_map."""
    label_map_items = label_map.items()
    # sort it with old_labels
    sorted_items = sorted(label_map_items, key=lambda x: x[0])
    for old_label, new_label in sorted_items:
        mask = x == old_label
        if mask.any():
            x[mask] = new_label
    return x


def dataset_depended_transform_labels(x, kidneys_same_index=False):
    """
    Apply the transform_labels function to the dependent dataset.
    Resulting label map:
      "1": "colon",
      "2": "rectum",
      "3": "small_bowel",
      "4": "stomach",
      "5": "liver",
      "6": "spleen",
      "7": "kidney_left",
      "8": "kidney_right",
      "9": "pancreas",
      "10": "urinary_bladder",
      "11": "duodenum",
      "12": "gallbladder",
    """
    pathname = str(x.meta["filename_or_obj"])

    if "colon_refined_by_mobina" in pathname:
        label_map = {
            0: 30,
            1: 30,
            2: 35,
            3: 30,
            4: 30,
            5: 37,
            6: 38,
            7: 36,
            8: 39,
            9: 42,
            10: 40,
            11: 34,
            12: 30,
            13: 33,
            14: 41,
            15: 31,
            16: 32,
            17: 30,
        }
        x = transform_labels(x, label_map)
        x.sub_(30)

    elif "female_cases_refined_by_md" in pathname:

        label_map = {
            0: 30,
            1: 30,
            2: 42,
            3: 37,
            4: 38,
            5: 35,
            6: 39,
            7: 30,
            8: 36,
            9: 34,
            10: 30,
            11: 30,
            12: 40,
            13: 31,
            14: 32,
            15: 41,
            16: 33,
            17: 30,
            18: 30,
            19: 30,
            20: 30,
            21: 30,  # uterus
            22: 30,  # portal vein and splenic vein
            23: 30,  # portal vein and splenic vein
            24: 30,  # portal vein and splenic vein
        }

        x = transform_labels(x, label_map)
        x.sub_(30)

    elif "male_cases_refined_by_md" in pathname:
        label_map = {
            0: 30,  # background
            1: 30,
            2: 42,
            3: 37,
            4: 38,
            5: 35,
            6: 39,
            7: 30,
            8: 36,
            9: 34,
            10: 30,
            11: 30,
            12: 40,
            13: 31,
            14: 32,
            15: 41,
            16: 33,
            17: 30,
            18: 30,
            19: 30,
            20: 30,
            21: 30,
            22: 30,
            23: 30,
        }
        x = transform_labels(x, label_map)
        x.sub_(30)

    elif ("a_grade_colons_not_in_refined_by_md" in pathname) or (
        "c_grade_colons/masks/" in pathname
    ):
        label_map = {
            13: 0,
            14: 0,
            15: 0,
            16: 0,
            17: 0,
            18: 0,
            19: 0,
            20: 0,
            21: 0,
            22: 0,
            23: 0,
        }
        x = transform_labels(x, label_map)
    else:
        raise ValueError(f"Unknown dataset for {pathname}")

    if kidneys_same_index:
        # Map kidney_right (8) to kidney_left (7)
        kidney_merge_map = {8: 7}
        x = transform_labels(x, kidney_merge_map)

    return x

In [24]:
import functools


trans = monai.transforms.Compose(
    [
        transforms.LoadImaged(keys=["mask"]),
        # Get same orientation, spacing, and shape
        transforms.EnsureChannelFirstd(keys=["mask"]),
        transforms.Spacingd(
            keys=["mask"],
            pixdim=(2, 2, 2),
            mode=("nearest"),
        ),
        transforms.Orientationd(keys=["mask"], axcodes="RAS"),
        # transforms.CropForegroundd(
        #     keys=["mask", "body_filled_mask"], source_key="mask"
        # ),
        transforms.Lambdad(
            keys=["mask"],
            func=functools.partial(
                dataset_depended_transform_labels,
                kidneys_same_index=False,
            ),
        ),
        transforms.Lambdad(
            keys=["mask"],
            func=functools.partial(
                remove_labels,
                labels=[
                    1,
                    2,
                    3,
                    4,
                    6,
                    7,
                    8,
                    9,
                    10,
                    11,
                    12,
                    13,
                    14,
                    15,
                    16,
                    17,
                    18,
                    19,
                    20,
                    21,
                    22,
                ],
            ),  # Assuming labels 17 and 14 are the organs,
        ),
        # CropForegroundAxisd(
        #     keys=["mask", "body_filled_mask"], source_key="mask", axis=2
        # ),
        transforms.CropForegroundd(keys=["mask"], source_key="mask"),
        # transforms.SaveImaged(
        #     keys=["mask"],
        #     output_dir="tmp",
        #     output_postfix="",
        #     separate_folder=False,
        # ),
    ]
)
# transformed = trans(files[7])

for file in files:
    transformed = trans(file)
    # Check its shape, print if any dim is larger than 96
    mask = transformed["mask"]
    shape = mask.shape
    if any(dim > 96 for dim in shape[1:]):
        print(f"File: {file['mask']}, Shape: {shape}")
    else:
        print(f"Shape: {shape}")

File: /data/usr/yb107/colon_data/refined_by_mobina/male_cases_refined_by_md/masks/Patient_01799_Study_07874_Series_03.nii.gz, Shape: torch.Size([1, 97, 99, 71])
File: /data/usr/yb107/colon_data/refined_by_mobina/a_grade_colons_not_in_refined_by_md/masks/Patient_00370_Study_26660_Series_03.nii.gz, Shape: torch.Size([1, 114, 86, 86])
File: /data/usr/yb107/colon_data/refined_by_mobina/a_grade_colons_not_in_refined_by_md/masks/Patient_00862_Study_65668_Series_03.nii.gz, Shape: torch.Size([1, 111, 97, 72])
File: /data/usr/yb107/colon_data/refined_by_mobina/a_grade_colons_not_in_refined_by_md/masks/Patient_02032_Study_41100_Series_03.nii.gz, Shape: torch.Size([1, 112, 100, 73])
File: /data/usr/yb107/colon_data/refined_by_mobina/a_grade_colons_not_in_refined_by_md/masks/Patient_01470_Study_10344_Series_03.nii.gz, Shape: torch.Size([1, 108, 108, 73])
File: /data/usr/yb107/colon_data/refined_by_mobina/female_cases_refined_by_md/masks/Patient_02225_Study_60027_Series_03.nii.gz, Shape: torch.Size

KeyboardInterrupt: 