In [2]:
import os
import pathlib

import SimpleITK as sitk
from monai.data import Dataset, DataLoader
from monai.utils import first
import matplotlib.pyplot as plt
import numpy as np
from monai.transforms import (
    AsDiscrete,
    AsDiscreted,
    EnsureChannelFirstd,
    Compose,
    CropForegroundd,
    LoadImaged,
    Orientationd,
    RandCropByPosNegLabeld,
    SaveImaged,
    SaveImage,
    ScaleIntensityRanged,
    Spacingd,
    Invertd,
    Rand3DElasticd,
    RandRotated,
    RandFlipd,
)
from monai.data import CacheDataset, decollate_batch
import torch
import glob
from monai.utils import set_determinism

In [5]:

image_path = "../data/multi_phase_select/final_version_crop_train/test/main/img_crop"
label_path = "../data/multi_phase_select/final_version_crop_train/test/main/label_crop_clip"
heart_path = "../data/multi_phase_select/final_version_crop_train/test/main/heart_connect_crop"

assert os.path.isdir(image_path), "img path not exist"
assert os.path.isdir(label_path), "label path not exist"
assert os.path.isdir(heart_path), "heart path not exist"

train_images = sorted(glob.glob(os.path.join(image_path, "*.nii.gz")))
train_labels = sorted(glob.glob(os.path.join(label_path, "*.nii.gz")))
train_heart = sorted(glob.glob(os.path.join(heart_path, "*.nii.gz")))
data_dicts = [
    {"image": image_name, "label": label_name, "heart": heart_seg}
    for image_name, label_name, heart_seg in zip(train_images, train_labels, train_heart)
]
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

from NoiseTransformD import GaussianNoiseD, GaussianBlurD
from IntensityTransformD import CTNormalizeD, BrightnessMultiplicativeD, ContrastAugmentationD, \
    SimulateLowResolutionD, GammaD
from SpatialTransformD import SpatialZooTransformD, CASTransformD
from AnatomyTransformD import HeartTransformD
set_determinism(seed=100)
save_dir = "./transform/test_transform/test_HeartTransformD"
pathlib.Path(save_dir).mkdir(parents=True, exist_ok=True)
save_transform = Compose(
    [
        LoadImaged(keys=["image", "label", "heart"]),
        EnsureChannelFirstd(keys=["image", "label", "heart"]),
        SaveImaged(keys=["image"], output_dir=save_dir, output_postfix='origin_image',
                   print_log=True, padding_mode="zeros"),
        SaveImaged(keys=["label"], output_dir=save_dir, output_postfix='origin_label',
                   print_log=True, padding_mode="zeros"),
        SaveImaged(keys=["heart"], output_dir=save_dir, output_postfix='origin_heart',
                   print_log=True, padding_mode="zeros"),
        # SaveImaged(keys=["image"], output_dir=save_dir, output_postfix='origin_image',
        #            print_log=False),
        # SaveImaged(keys=["label"], output_dir=save_dir, output_postfix='origin_label',
        #            print_log=False),
        # CTNormalizeD(keys=["image"],
        #              mean_intensity=236.97554812655147,
        #              std_intensity=149.1624262756288,
        #              lower_bound=-50.0,
        #              upper_bound=678.0, ),
        # SaveImaged(keys=["image"], output_dir=save_dir, output_postfix='norm_image',
        #            print_log=False, padding_mode="zeros"),
        # CropForegroundd(keys=["image", "label"], source_key="image"),
        # Orientationd(keys=["image", "label"], axcodes="RAI"),
        # Spacingd(keys=["image", "label"], pixdim=[0.359375, 0.359375, 0.5]
        #          , mode=(3, "nearest"), padding_mode=('mirror', "border")),
        # SpatialZooTransformD(keys=["label"], label_key="label", pos_ratio=1, patch_size=(128, 128, 64),
        #                      border_mode="constant", mode="nearest", border_cval=0, num_samples=4, scale=(0.9, 1.1),
        #                      p_el_per_sample=0, p_rot_per_sample=1, p_scale_per_sample=1, p_rot_per_axis=1),
        # SpatialZooTransformD(keys=["image", "label"], label_key="label", pos_ratio=1, patch_size=(160, 128, 64),
        #                      border_mode="constant", mode=("bilinear", "nearest"), border_cval=(0, 0), num_samples=4,
        #                      alpha=(0., 1000.), sigma=(7, 8),  # sigma=(10., 13.),
        #                      do_elastic_deform=True,
        #                      p_el_per_sample=1, p_rot_per_sample=0, p_scale_per_sample=0, p_rot_per_axis=1),
        # CASTransformD(keys=["image", "label", "heart"], label_key="label", heart_key="heart", p_anatomy_per_sample=1,
        #               dil_ranges=((-30, -40), (-300, -500)), directions_of_trans=((1, 1, 1), (1, 1, 1)), blur=[4, 32],
        #               mode=("bilinear", "nearest", "nearest"),),
        HeartTransformD(keys=["image", "label", "heart"], artery_key="label", heart_key="heart", p_anatomy_per_sample=1,
                        dil_ranges=((-300, -500), (30, 30)), directions_of_trans=((1, 1, 1), (1, 1, 1)), blur=(32, 4),
                        mode=("bilinear", "nearest", "nearest"),),

        SaveImaged(keys=["image"], output_dir=save_dir, output_postfix='spatial_transform_image',
                   print_log=True, padding_mode="zeros"),
        SaveImaged(keys=["label"], output_dir=save_dir, output_postfix='spatial_transform_label',
                   print_log=True, padding_mode="zeros"),
        SaveImaged(keys=["heart"], output_dir=save_dir, output_postfix='spatial_transform_heart',
                   print_log=True, padding_mode="zeros"),
    ]
)
def get_space_full_name(space):
    assert len(space) == 3
    full_name = []
    for s in space:
        if s == "L":
            full_name.append("left")
        elif s == "R":
            full_name.append("right")
        elif s == "A":
            full_name.append("anterior")
        elif s == "P":
            full_name.append("posterior")
        elif s == "S":
            full_name.append("superior")
        elif s == "I":
            full_name.append("inferior")
    # link with _ to form full name
    return "-".join(full_name)



In [6]:
import nrrd
for case in data_dicts[0:1]:
    d = save_transform(case)
    save_path = os.path.join(save_dir, pathlib.Path(d["image"].meta["filename_or_obj"]).name.split(".")[0])
    heart_fd = np.array(d["heart_df"].permute(1, 2, 3, 0))
    label_fd = np.array(d["label_df"].permute(1, 2, 3, 0))
    space = get_space_full_name(d["image"].meta["space"])
    spacing = np.array((d["image"].pixdim)).tolist()
    offset = np.array((d["image"].meta["original_affine"][:3, 3]))
    direction = np.array((d["image"].meta["original_affine"][:3, :3])).tolist()
    header = {
        'endian': 'little',
        'encoding': 'raw',
        'space': space,
        'space directions': direction+[None],
        'space origin': offset,
        'kinds': ['domain', 'domain', 'domain', 'vector'],
    }
    nrrd.write(os.path.join(save_path, 'heart_fd.nrrd'), heart_fd, header=header)
    nrrd.write(os.path.join(save_path, 'label_fd.nrrd'), label_fd, header=header)

2024-03-02 12:48:58,933 INFO image_writer.py:194 - writing: transform\test_transform\test_HeartTransformD\AI01_0043\AI01_0043_origin_image.nii.gz
2024-03-02 12:49:02,205 INFO image_writer.py:194 - writing: transform\test_transform\test_HeartTransformD\AI01_0043\AI01_0043_origin_label.nii.gz
2024-03-02 12:49:03,184 INFO image_writer.py:194 - writing: transform\test_transform\test_HeartTransformD\AI01_0043\AI01_0043_origin_heart.nii.gz


RuntimeError: applying transform <AnatomyTransformD.HeartTransformD object at 0x000002617A5B9808>