In [1]:
import numpy as np
from monai.data import Dataset, DataLoader, CacheDataset

from monai.transforms import (
    Compose, LoadImaged, EnsureChannelFirstd, NormalizeIntensityd,
    Orientationd, CropForegroundd, GaussianSmoothd, ScaleIntensityd,
    RandSpatialCropd, RandRotate90d, RandFlipd, RandGaussianNoised,
    ToTensord, RandCropByLabelClassesd
)
import torch


import os

from monai.losses import DiceLoss
from monai.metrics import DiceMetric
# from monai.inferers import sliding_window_inference
# inputs = MetaTensor or Tensor
# outputs = sliding_window_inference(
#     inputs=inputs,
#     roi_size=(96, 96, 96),
#     sw_batch_size=4,
#     predictor=model.forward
# )
# output shape은 input의 shape과 동일하다.
# validation(task)의 chacing된 데이터셋을 사용해야함


In [2]:
TRAIN_IMG_DIR = "./datasets/train/images"
TRAIN_LABEL_DIR = "./datasets/train/labels"
VAL_IMG_DIR = "./datasets/val/images"
VAL_LABEL_DIR = "./datasets/val/labels"

train_list = os.listdir(TRAIN_IMG_DIR)
val_list = os.listdir(VAL_IMG_DIR)
train_files = []
valid_files = []


for name in train_list:
    image = np.load(os.path.join(TRAIN_IMG_DIR, f"{name}"))    
    label = np.load(os.path.join(TRAIN_LABEL_DIR, f"{name.replace("image", "label")}"))

    train_files.append({"image": image, "label": label})    

for name in val_list:
    image = np.load(os.path.join(VAL_IMG_DIR, f"{name}"))
    label = np.load(os.path.join(VAL_LABEL_DIR, f"{name.replace("image", "label")}"))

    valid_files.append({"image": image, "label": label})

In [3]:
# Non-random transforms to be cached
non_random_transforms = Compose([
    EnsureChannelFirstd(keys=["image", "label"], channel_dim="no_channel"),
    NormalizeIntensityd(keys="image"),
    Orientationd(keys=["image", "label"], axcodes="RAS")
])
# Create the cached dataset with non-random transforms
val_ds = CacheDataset(data=valid_files, transform=non_random_transforms, cache_rate=1.0)


Loading dataset: 100%|██████████| 4/4 [00:00<00:00,  9.53it/s]


In [4]:
from monai.transforms import Compose, RandCropByLabelClassesd, RandRotate90d, RandFlipd
from monai.data import Dataset, DataLoader
import torch

# Define the random transforms
random_transforms = Compose([
    RandCropByLabelClassesd(
        keys=["image", "label"],
        label_key="label",
        spatial_size=[3, 96, 96],
        num_classes=7,
        num_samples=8  # num_samples 값을 양의 정수로 설정
    ),
    RandRotate90d(keys=["image", "label"], prob=0.5, spatial_axes=[1,2]),
    RandFlipd(keys=["image", "label"], prob=0.5, spatial_axis=0),    
])

# Apply random transforms to the cached dataset
val_ds = Dataset(data=val_ds, transform=random_transforms)

train_loader = DataLoader(val_ds, batch_size=2, shuffle=True, num_workers=0)

In [12]:

batch = next(iter(train_loader))
images, labels = batch["image"], batch["label"]
print(images.shape, labels.shape)

torch.Size([16, 1, 3, 96, 96]) torch.Size([16, 1, 3, 96, 96])


In [6]:
out_img = images
print(out_img.shape)
out_meta = images.meta
center = out_img.shape[2] // 2
prev_image = out_img[:, :, 0:center, :, :]
image = out_img[:, :, center:center+1, :, :]
next_image = out_img[:, :, center+1:, :, :]

torch.Size([16, 1, 3, 96, 96])


In [13]:
labels = labels.squeeze(1)
labels = labels[:, center:center+1, :, :]
labels.shape

torch.Size([16, 1, 96, 96])