In [1]:
from typing import List, Tuple, Union
import numpy as np
import torch
from monai.data import DataLoader, Dataset, CacheDataset, decollate_batch
import os
from monai.transforms import (
    Compose, 
    EnsureChannelFirstd, 
    Orientationd,  
    AsDiscrete,  
    RandFlipd, 
    RandRotate90d, 
    NormalizeIntensityd,
    RandCropByLabelClassesd,
)

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
TRAIN_IMG_DIR = "./datasets/train/images"
TRAIN_LABEL_DIR = "./datasets/train/labels"
VAL_IMG_DIR = "./datasets/val/images"
VAL_LABEL_DIR = "./datasets/val/labels"

train_list = os.listdir(TRAIN_IMG_DIR)
val_list = os.listdir(VAL_IMG_DIR)
train_files = []
valid_files = []


for name in train_list:
    train_image = np.load(os.path.join(TRAIN_IMG_DIR, f"{name}"))    
    train_label = np.load(os.path.join(TRAIN_LABEL_DIR, f"{name.replace("image", "label")}"))

    train_files.append({"image": train_image, "label": train_label})    

for name in val_list:
    valid_image = np.load(os.path.join(VAL_IMG_DIR, f"{name}"))
    valid_label = np.load(os.path.join(VAL_LABEL_DIR, f"{name.replace("image", "label")}"))

    valid_files.append({"image": valid_image, "label": valid_label})

In [3]:
# Non-random transforms to be cached
non_random_transforms = Compose([
    EnsureChannelFirstd(keys=["image", "label"], channel_dim="no_channel"),
    NormalizeIntensityd(keys="image"),
    Orientationd(keys=["image", "label"], axcodes="ASR")
])

raw_train_ds = CacheDataset(data=train_files, transform=non_random_transforms, cache_rate=1.0)


my_num_samples = 1
train_batch_size = 1

xy_patch = 96
z_patch = 3
# Random transforms to be applied during training
random_transforms = Compose([
    RandCropByLabelClassesd(
        keys=["image", "label"],
        label_key="label",
        spatial_size=[z_patch, xy_patch, xy_patch],
        num_classes=7,
        num_samples=my_num_samples
    ),
    RandRotate90d(keys=["image", "label"], prob=0.5, spatial_axes=[0, 2]),
    RandFlipd(keys=["image", "label"], prob=0.5, spatial_axis=0),    
])


train_ds = Dataset(data=raw_train_ds, transform=random_transforms)


# DataLoader remains the same
train_loader = DataLoader(
    train_ds,
    batch_size=train_batch_size,
    shuffle=True,
    num_workers=0,
    pin_memory=torch.cuda.is_available()
)


Loading dataset: 100%|██████████| 24/24 [00:03<00:00,  6.64it/s]


In [4]:
# 데이터 검증 함수
def inspect_batch(loader):
    # 첫 번째 배치 가져오기
    batch = next(iter(loader))
    
    print("=== 배치 데이터 검증 ===")
    print(f"이미지 shape: {batch['image'].shape}")
    print(f"이미지 dtype: {batch['image'].dtype}")
    print(f"이미지 값 범위: [{batch['image'].min():.3f}, {batch['image'].max():.3f}]")
    print("\n")
    print(f"라벨 shape: {batch['label'].shape}")
    print(f"라벨 dtype: {batch['label'].dtype}")
    print(f"라벨 고유값: {torch.unique(batch['label'])}")

# 실행
inspect_batch(train_loader)

=== 배치 데이터 검증 ===
이미지 shape: torch.Size([1, 1, 96, 96, 3])
이미지 dtype: torch.float32
이미지 값 범위: [-9.597, 4.349]


라벨 shape: torch.Size([1, 1, 96, 96, 3])
라벨 dtype: torch.uint8
라벨 고유값: tensor([0, 1, 5, 6], dtype=torch.uint8)


# Validation DataLoader

In [5]:
import os
import numpy as np
import json
from pathlib import Path

valid_transforms = Compose([
    EnsureChannelFirstd(keys=["image", "label"], channel_dim="no_channel"),
    NormalizeIntensityd(keys="image")
])

def load_validation_patches(patch_dir):
    """
    저장된 validation 패치들을 로드하여 데이터셋용 리스트 생성
    Args:
        patch_dir: 패치가 저장된 디렉토리 (images, labels 서브디렉토리와 coordinates.json 포함)
    """
    patch_dir = Path(patch_dir)
    val_patched_data = []
    
    # coordinates.json 로드
    with open(patch_dir / "coordinates.json", 'r') as f:
        coordinates = json.load(f)
    
    # 각 패치에 대해
    for coord in coordinates:
        image = np.load(patch_dir / "images" / coord["patch_file"])
        label = np.load(patch_dir / "labels" / coord["patch_file"])
        
        val_patched_data.append({
            "image": image,      # shape: (11, 96, 96)
            "label": label,      # shape: (96, 96)
            "coords": coord      # 원본 위치 정보 (옵션)
        })
    
    return val_patched_data

# 패치 데이터 로드
val_patched_data = load_validation_patches("./datasets/val_patches")

# Dataset과 DataLoader 설정
valid_ds = CacheDataset(data=val_patched_data, transform=valid_transforms, cache_rate=1.0)

valid_batch_size = 1
valid_loader = DataLoader(
    valid_ds,
    batch_size=valid_batch_size,
    shuffle=False,
    num_workers=0,
    pin_memory=torch.cuda.is_available()
)


Loading dataset: 100%|██████████| 26496/26496 [00:15<00:00, 1699.15it/s]


In [6]:

# raw_valid_ds = CacheDataset(data=valid_files, transform=non_random_transforms, cache_rate=1.0)
# valid_ds = Dataset(data=raw_valid_ds, transform=random_transforms)
# valid_batch_size = 1

# # DataLoader remains the same
# valid_loader = DataLoader(
#     valid_ds,
#     batch_size=valid_batch_size,
#     shuffle=False,
#     num_workers=0,
#     pin_memory=torch.cuda.is_available()
# )


In [12]:
# 데이터 검증 함수
def inspect_batch(loader):
    # 첫 번째 배치 가져오기
    batch = next(iter(loader))
    
    print("=== 배치 데이터 검증 ===")
    print(f"이미지 shape: {batch['image'].shape}")
    print(f"이미지 dtype: {batch['image'].dtype}")
    print(f"이미지 값 범위: [{batch['image'].min():.3f}, {batch['image'].max():.3f}]")
    print("\n")
    print(f"라벨 shape: {batch['label'].shape}")
    print(f"라벨 dtype: {batch['label'].dtype}")
    print(f"라벨 고유값: {torch.unique(batch['label'])}")

# 실행
inspect_batch(valid_loader)

=== 배치 데이터 검증 ===
이미지 shape: torch.Size([1, 1, 3, 96, 96])
이미지 dtype: torch.float32
이미지 값 범위: [-4.900, 5.728]


라벨 shape: torch.Size([1, 1, 96, 96])
라벨 dtype: torch.uint8
라벨 고유값: tensor([0, 5], dtype=torch.uint8)
