In [7]:

from pathlib import Path
import zarr

def find_all_zarr_files():
    """각 실험 디렉토리의 모든 zarr 파일 찾기"""
    base_path = Path('./kaggle/input/czii-cryo-et-object-identification/train')
    experiments = {}
    
    for exp_dir in base_path.glob("**/VoxelSpacing10.000"):
        exp_name = exp_dir.parent.name  # 실험 이름 (예: TS_5_4)
        zarr_files = list(exp_dir.glob("*.zarr"))
        
        if zarr_files:
            experiments[exp_name] = {
                'path': str(exp_dir),
                'files': [f.name for f in zarr_files]
            }
    
    return experiments

def test_all_zarr_files():
    """모든 실험의 zarr 파일 정보 출력"""
    experiments = find_all_zarr_files()
    
    for exp_name, info in experiments.items():
        print(f"\n실험: {exp_name}")
        print(f"경로: {info['path']}")
        print("zarr 파일들:")
        for zarr_file in info['files']:
            zarr_path = Path(info['path']) / zarr_file
            try:
                root = zarr.open(str(zarr_path), mode='r')
                sample = root[0]
                print(f"- {zarr_file}: shape={sample.shape}")
            except Exception as e:
                print(f"- {zarr_file}: 오류 발생 ({str(e)})")

if __name__ == "__main__":
    test_all_zarr_files()


실험: TS_5_4
경로: kaggle\input\czii-cryo-et-object-identification\train\static\ExperimentRuns\TS_5_4\VoxelSpacing10.000
zarr 파일들:
- ctfdeconvolved.zarr: shape=(184, 630, 630)
- denoised.zarr: shape=(184, 630, 630)
- isonetcorrected.zarr: shape=(184, 630, 630)
- wbp.zarr: shape=(184, 630, 630)

실험: TS_69_2
경로: kaggle\input\czii-cryo-et-object-identification\train\static\ExperimentRuns\TS_69_2\VoxelSpacing10.000
zarr 파일들:
- ctfdeconvolved.zarr: shape=(184, 630, 630)
- denoised.zarr: shape=(184, 630, 630)
- isonetcorrected.zarr: shape=(184, 630, 630)
- wbp.zarr: shape=(184, 630, 630)

실험: TS_6_4
경로: kaggle\input\czii-cryo-et-object-identification\train\static\ExperimentRuns\TS_6_4\VoxelSpacing10.000
zarr 파일들:
- ctfdeconvolved.zarr: shape=(184, 630, 630)
- denoised.zarr: shape=(184, 630, 630)
- isonetcorrected.zarr: shape=(184, 630, 630)
- wbp.zarr: shape=(184, 630, 630)

실험: TS_6_6
경로: kaggle\input\czii-cryo-et-object-identification\train\static\ExperimentRuns\TS_6_6\VoxelSpacing10.000
zarr

In [12]:
from pathlib import Path
import numpy as np
import torch
from torch.utils.data import Dataset, DataLoader
from typing import Dict, List, Tuple
from monai.transforms import (
    Compose,
    EnsureChannelFirstd,
    NormalizeIntensityd,
    RandFlipd,
    RandRotate90d
)

class CryoET_2_5D_Dataset(Dataset):
    def __init__(self, 
                 data_dir: str, 
                 mode: str = 'train',
                 slice_depth: int = 3,
                 crop_size: Tuple[int, int] = (96, 96),
                 transform = None):
        """
        Args:
            data_dir: 데이터셋 기본 경로
            mode: 'train' 또는 'val'
            slice_depth: 중심 기준 슬라이스 수 (항상 홀수)
            crop_size: (height, width) crop 크기
            transform: MONAI transform
        """
        self.data_dir = Path(data_dir) / mode
        self.slice_depth = slice_depth
        self.crop_size = crop_size
        self.transform = transform
        self.pad_size = slice_depth // 2
        
        # 이미지와 라벨 파일 리스트
        self.image_files = sorted(list((self.data_dir / 'images').glob('*.npy')))
        self.label_files = sorted(list((self.data_dir / 'labels').glob('*.npy')))
        
        assert len(self.image_files) == len(self.label_files), "이미지와 라벨 수가 다릅니다"
        
    def __len__(self):
        return len(self.image_files)
        
    def __getitem__(self, idx) -> Dict:
        # 이미지와 라벨 로드
        image = np.load(self.image_files[idx])  # (D, H, W)
        label = np.load(self.label_files[idx])  # (D, H, W)
        
        # 원본 shape 저장
        original_shape = image.shape
        
        # 중심 슬라이스 선택
        center_idx = image.shape[0] // 2
        start_idx = center_idx - self.pad_size
        end_idx = center_idx + self.pad_size + 1
        
        # 패딩 처리
        if start_idx < 0 or end_idx > image.shape[0]:
            padded_image = np.pad(
                image,
                ((self.pad_size, self.pad_size), (0, 0), (0, 0)),
                mode='edge'
            )
            start_idx += self.pad_size
            end_idx += self.pad_size
            slices = padded_image[start_idx:end_idx]
        else:
            slices = image[start_idx:end_idx]
        
        # 중심 슬라이스의 라벨만 사용
        center_label = label[center_idx]
        
        # 중앙 crop 좌표 계산
        start_h = (image.shape[1] - self.crop_size[0]) // 2
        start_w = (image.shape[2] - self.crop_size[1]) // 2
        
        # crop 적용
        slices_cropped = slices[:, 
                              start_h:start_h + self.crop_size[0], 
                              start_w:start_w + self.crop_size[1]]
        label_cropped = center_label[start_h:start_h + self.crop_size[0],
                                   start_w:start_w + self.crop_size[1]]
        
        data_dict = {
            'image': slices_cropped[None, ...],  # (1, D, H, W)
            'label': label_cropped[None, ...],  # (1, H, W)
            'crop_info': {
                'z_start': start_idx,
                'z_center': center_idx,
                'y_start': start_h,
                'x_start': start_w,
                'crop_size': self.crop_size,
                'original_shape': original_shape,
                'file_name': self.image_files[idx].stem
            }
        }
        
        if self.transform:
            data_dict = self.transform(data_dict)
            
        return data_dict

# Transform 설정 수정
train_transforms = Compose([
    EnsureChannelFirstd(
        keys=['image', 'label'],
        channel_dim=0,  # 첫 번째 차원을 채널로 사용
        allow_missing_keys=True
    ),
    NormalizeIntensityd(
        keys=['image'],
        allow_missing_keys=True
    ),
    RandFlipd(
        keys=['image', 'label'],
        spatial_axis=[1, 2],  # H, W 축에서 뒤집기
        prob=0.5,
        allow_missing_keys=True
    ),
    RandRotate90d(
        keys=['image', 'label'],
        spatial_axes=[1, 2],  # H, W 평면에서 회전
        prob=0.5,
        allow_missing_keys=True
    )
])

def create_dataloaders(data_dir: str, batch_size: int = 8):
    train_ds = CryoET_2_5D_Dataset(
        data_dir=data_dir,
        mode='train',
        transform=train_transforms,
        slice_depth=3,
        crop_size=(96, 96)
    )
    
    val_ds = CryoET_2_5D_Dataset(
        data_dir=data_dir,
        mode='val',
        transform=None,
        slice_depth=3,
        crop_size=(96, 96)
    )
    
    train_loader = DataLoader(
        train_ds, 
        batch_size=batch_size,
        shuffle=True,
        num_workers=0,
        pin_memory=torch.cuda.is_available()
    )
    
    val_loader = DataLoader(
        val_ds,
        batch_size=batch_size,
        num_workers=0,
        pin_memory=torch.cuda.is_available()
    )
    
    return train_loader, val_loader

# 테스트
if __name__ == "__main__":
    data_dir = "./datasets"
    train_loader, val_loader = create_dataloaders(data_dir)
    
    # 데이터 확인
    batch = next(iter(train_loader))
    print("이미지 shape:", batch['image'].shape)
    print("라벨 shape:", batch['label'].shape)
    print("Crop 정보:", batch['crop_info'])

RuntimeError: applying transform <monai.transforms.spatial.dictionary.RandFlipd object at 0x00000299EAA2ED50>