In [17]:
import numpy as np
from monai.data import Dataset, DataLoader, CacheDataset

from monai.transforms import (
    Compose, LoadImaged, EnsureChannelFirstd, NormalizeIntensityd,
    Orientationd, CropForegroundd, GaussianSmoothd, ScaleIntensityd,
    RandSpatialCropd, RandRotate90d, RandFlipd, RandGaussianNoised,
    ToTensord, RandCropByLabelClassesd
)
import torch


import os

from monai.losses import DiceLoss
from monai.metrics import DiceMetric
from monai.inferers import sliding_window_inference


In [2]:
TRAIN_IMG_DIR = "./datasets/train/images"
TRAIN_LABEL_DIR = "./datasets/train/labels"
VAL_IMG_DIR = "./datasets/val/images"
VAL_LABEL_DIR = "./datasets/val/labels"

train_list = os.listdir(TRAIN_IMG_DIR)
val_list = os.listdir(VAL_IMG_DIR)
train_files = []
valid_files = []


for name in train_list:
    image = np.load(os.path.join(TRAIN_IMG_DIR, f"{name}"))    
    label = np.load(os.path.join(TRAIN_LABEL_DIR, f"{name.replace("image", "label")}"))

    train_files.append({"image": image, "label": label})    

for name in val_list:
    image = np.load(os.path.join(VAL_IMG_DIR, f"{name}"))
    label = np.load(os.path.join(VAL_LABEL_DIR, f"{name.replace("image", "label")}"))

    valid_files.append({"image": image, "label": label})

In [27]:
# Non-random transforms to be cached
non_random_transforms = Compose([
    EnsureChannelFirstd(keys=["image", "label"], channel_dim="no_channel"),
    NormalizeIntensityd(keys="image"),
    
    
])
# Create the cached dataset with non-random transforms
train_ds = CacheDataset(data=train_files, transform=non_random_transforms, cache_rate=1.0)


Loading dataset: 100%|██████████| 24/24 [00:02<00:00,  9.35it/s]


In [None]:
from monai.transforms import Compose, RandCropByLabelClassesd, RandRotate90d, RandFlipd
from monai.data import Dataset, DataLoader
import torch

# Define the random transforms
random_transforms = Compose([
    RandCropByLabelClassesd(
        keys=["image", "label"],
        label_key="label",
        spatial_size=[96, 96, 96],  # 3D 데이터의 경우 3개의 요소를 가져야 합니다.
        num_classes=8,
        num_samples=16
    ),
    RandRotate90d(keys=["image", "label"], prob=0.5, spatial_axes=[0, 2]),
    RandFlipd(keys=["image", "label"], prob=0.5, spatial_axis=0),    
])

# Apply random transforms to the cached dataset
train_ds = Dataset(data=train_files, transform=random_transforms)

train_loader = DataLoader(train_ds, batch_size=2, shuffle=True, num_workers=2)

device = torch.device("cuda:0")

batch = next(iter(train_loader))
images, labels = batch["image"].to(device), batch["label"].to(device)