In [4]:
IMAGE_PATH = "/NAS/user_data/user/yb107/abdomen_1k/unzipped/AbdomenCT-1K/"
MASK_PATH = "/NAS/user_data/user/yb107/abdomen_1k/unzipped/Mask/"

import os
import json


def create_jsonl_dataset(image_path, mask_path, output_file):
    """Go through every .nii.gz file in image_path, find the corresponding mask file,
    and create a JSONL file with the image and mask paths.
    """
    dataset = []

    for filename in os.listdir(image_path):
        if filename.endswith(".nii.gz"):
            image_file = os.path.join(image_path, filename)
            mask_file = os.path.join(mask_path, filename.replace("_0000", ""))

            if os.path.exists(mask_file):
                dataset.append({"image": image_file, "mask": mask_file})
            else:
                print(f"Warning: Mask file for {filename} not found.")

    with open(output_file, "w") as f:
        for entry in dataset:
            f.write(json.dumps(entry) + "\n")


create_jsonl_dataset(
    IMAGE_PATH, MASK_PATH, "/home/yb107/cvpr2025/DukeDiffSeg/data/json/abdomen_1k.jsonl"
)

# Jsonl file format:
# {"image": "/path/to/image.nii.gz", "mask": "/path/to/mask.nii.gz"}


def divide_jsonl_train_val_test(
    jsonl_file, train_file, val_file, test_file, train_ratio=0.7, val_ratio=0.2
):
    """Divide the JSONL dataset into train, validation, and test sets."""
    with open(jsonl_file, "r") as f:
        entries = [json.loads(line) for line in f]

    total_entries = len(entries)
    train_end = int(total_entries * train_ratio)
    val_end = int(total_entries * (train_ratio + val_ratio))

    with open(train_file, "w") as f:
        for entry in entries[:train_end]:
            f.write(json.dumps(entry) + "\n")

    with open(val_file, "w") as f:
        for entry in entries[train_end:val_end]:
            f.write(json.dumps(entry) + "\n")

    with open(test_file, "w") as f:
        for entry in entries[val_end:]:
            f.write(json.dumps(entry) + "\n")


divide_jsonl_train_val_test(
    "/home/yb107/cvpr2025/DukeDiffSeg/data/json/abdomen_1k.jsonl",
    "/home/yb107/cvpr2025/DukeDiffSeg/data/json/abdomen_1k_train.jsonl",
    "/home/yb107/cvpr2025/DukeDiffSeg/data/json/abdomen_1k_val.jsonl",
    "/home/yb107/cvpr2025/DukeDiffSeg/data/json/abdomen_1k_test.jsonl",
)



In [7]:
# Given a JSONL file, go through every line and open the image and mask files, and then update the line with total_slices information.
# use multiprocessing to speed up the process.

import multiprocessing
import nibabel as nib
import monai

from monai.transforms import LoadImaged, EnsureChannelFirstd, Spacingd, Orientationd
import json
import time

transforms = monai.transforms.Compose(
    [
        LoadImaged(keys=["image", "mask"]),
        EnsureChannelFirstd(keys=["image", "mask"]),
        Spacingd(
            keys=["image", "mask"],
            pixdim=(1.0, 1.0, 3.0),
            mode=("bilinear", "nearest"),
        ),
        Orientationd(keys=["image", "mask"], axcodes="RAS"),
    ]
)

def process_entry(entry):
    entry_ = transforms({"image": entry["image"], "mask": entry["mask"]})
    image = entry_["image"]
    mask = entry_["mask"]

    # Check if they have same shape
    if image.shape != mask.shape:
        raise ValueError(
            f"Image and mask shapes do not match: {image.shape} vs {mask.shape}"
        )

    entry["shape"] = list(image.shape)
    return entry

def update_jsonl_with_slices(jsonl_file, output_file):
    """Update the JSONL file with the number of slices in each image."""

    with open(jsonl_file, "r") as f:
        entries = [json.loads(line) for line in f]

    with multiprocessing.Pool() as pool:
        updated_entries = pool.map(process_entry, entries)

    with open(output_file, "w") as f:
        for entry in updated_entries:
            f.write(json.dumps(entry) + "\n")


start = time.time()
update_jsonl_with_slices(
    "/home/yb107/cvpr2025/DukeDiffSeg/data/json/abdomen_1k_train.jsonl",
    "/home/yb107/cvpr2025/DukeDiffSeg/data/json/abdomen_1k_train_updated.jsonl",
)
end = time.time()
print(f"Time taken to update JSONL with slices: {end - start} seconds")

Time taken to update JSONL with slices: 631.5122039318085 seconds
