In [None]:
# import cv2
# import numpy as np
# import os
# import matplotlib.pyplot as plt

# path = './datasets/labels/val/'
# # Get list of all png files in the directory
# png_files = [f for f in os.listdir(path) if f.endswith('.png')]
# unique_values = set()
# for file in png_files:
#     img = cv2.imread(os.path.join(path, file), cv2.IMREAD_GRAYSCALE)
#     unique_values.update(np.unique(img))

# print(unique_values)

# path = './datasets/labels/train/'
# # Get list of all png files in the directory
# png_files = [f for f in os.listdir(path) if f.endswith('.png')]
# unique_values = set()
# for file in png_files:
#     img = cv2.imread(os.path.join(path, file), cv2.IMREAD_GRAYSCALE)
#     unique_values.update(np.unique(img))

# print(unique_values)


{0, 128, 102, 42, 170, 204, 51, 85, 213, 153, 255}
{0, 128, 64, 102, 42, 170, 204, 51, 85, 213, 153, 255}


In [None]:
# import torch
# import torch.nn as nn
# from monai.networks.nets import UNet

# class UNet2_5D_v2(nn.Module):
#     def __init__(self, out_channels=6):
#         super().__init__()
        
#         # 초기 3D 처리 레이어
#         self.init_3d = nn.Sequential(
#             nn.Conv3d(1, 64, kernel_size=(11, 3, 3), padding=(0, 1, 1)),
#             nn.BatchNorm3d(64),
#             nn.ReLU(inplace=True)
#         )
        
#         # 2D UNet
#         self.unet = UNet(
#             spatial_dims=2,
#             in_channels=64,  # 3D 컨볼루션 출력 채널
#             out_channels=out_channels,
#             channels=(64, 128, 256, 512),
#             strides=(2, 2, 2, 2),
#             num_res_units=2
#         )

#     def forward(self, x):
#         # x shape: (batch, 1, 11, H, W)
#         # 3D 처리
#         x = x.unsqueeze(1)
#         x = self.init_3d(x)  # (batch, 64, 1, H, W)
#         x = x.squeeze(2)     # (batch, 64, H, W)
        
#         # 2D UNet
#         return self.unet(x)

# # 테스트 코드
# if __name__ == "__main__":
#     model = UNet2_5D_v2(out_channels=6)
#     x = torch.randn(8, 11, 256, 256)
#     output = model(x)
#     print(f"Output shape: {output.shape}")  # Expected: (8, 6, 256, 256)

  from .autonotebook import tqdm as notebook_tqdm


Output shape: torch.Size([8, 6, 256, 256])


In [None]:
# import torch
# from monai.losses import DiceLoss
# from monai.metrics import DiceMetric
# from monai.networks.utils import one_hot

# class SegmentationTrainer:
#     def __init__(self, model, device='cuda'):
#         self.model = model.to(device)
#         self.device = device
        
#         # Loss와 Metric 초기화
#         self.loss_function = DiceLoss(
#             include_background=True, 
#             to_onehot_y=True, 
#             softmax=True
#         )
#         self.dice_metric = DiceMetric(
#             include_background=True,
#             reduction="mean",
#             get_not_nans=False
#         )
        
#     def train_step(self, images, labels):
#         """
#         images: (B, 11, H, W)
#         labels: (B, H, W) with class indices
#         """
#         images = images.to(self.device)
#         labels = labels.to(self.device)
        
#         # Forward pass
#         outputs = self.model(images)  # (B, num_classes, H, W)
        
#         # Loss 계산
#         loss = self.loss_function(outputs, labels)
        
#         # Metric 계산 (예측값을 클래스 인덱스로 변환)
#         preds = torch.argmax(outputs, dim=1)  # (B, H, W)
#         self.dice_metric(y_pred=preds, y=labels)
        
#         return loss

# # 사용 예시
# if __name__ == "__main__":
#     model = UNet2_5D_v2(out_channels=6)  # 6개 클래스
#     trainer = SegmentationTrainer(model)
    
#     # 예시 데이터
#     images = torch.randn(8, 11, 256, 256)  # (B, slice, H, W)
#     labels = torch.randint(0, 6, (8, 256, 256))  # (B, H, W)
    
#     # 학습 단계
#     loss = trainer.train_step(images, labels)
#     dice_score = trainer.dice_metric.aggregate().item()
#     trainer.dice_metric.reset()
    
#     print(f"Loss: {loss.item():.4f}")
#     print(f"Dice Score: {dice_score:.4f}")

In [None]:
import torch
from torch.utils.data import Dataset, DataLoader
import zarr
import numpy as np

class CryoETDataset2_5D(Dataset):
    def __init__(self, zarr_path, crop_size=(126, 126), slice_depth=11, transform=None):
        """
        Cryo-ET 데이터셋을 2.5D 입력으로 준비하는 Dataset 클래스.
        
        Args:
            zarr_path (str): Zarr 파일 경로.
            crop_size (tuple): (height, width)로 크기를 정의.
            slice_depth (int): 병합할 슬라이스 수 (e.g., 11).
            transform (callable, optional): 슬라이스에 적용할 Transform.
        """
        self.dataset = zarr.open(zarr_path, mode='r')  # (D, H, W)
        self.crop_size = crop_size
        self.slice_depth = slice_depth
        self.transform = transform

        self.depth, self.height, self.width = self.dataset.shape
        self.pad_size = slice_depth // 2

    def __len__(self):
        # 모든 슬라이스를 대상으로 하되, 경계를 고려하여 슬라이스 개수 반환
        return self.depth

    def __getitem__(self, idx):
        # 경계 복제를 위해 padding 적용
        padded_data = np.pad(
            self.dataset,
            pad_width=((self.pad_size, self.pad_size), (0, 0), (0, 0)),
            mode='edge'
        )  # (D + 2*pad_size, H, W)

        # 11개의 슬라이스 추출 (idx 중심으로)
        slices = padded_data[idx:idx + self.slice_depth]  # (11, H, W)

        # 중앙 crop (H, W -> crop_size)
        start_h = (slices.shape[1] - self.crop_size[0]) // 2
        start_w = (slices.shape[2] - self.crop_size[1]) // 2
        slices_cropped = slices[:, start_h:start_h + self.crop_size[0], start_w:start_w + self.crop_size[1]]

        # Transform 적용
        if self.transform:
            slices_cropped = self.transform(slices_cropped)

        # PyTorch Tensor로 변환 (C=11, H, W)
        slices_tensor = torch.tensor(slices_cropped, dtype=torch.float32)

        return slices_tensor  # (11, H, W)

# Dataset 및 DataLoader 준비
zarr_file_path = "path/to/cryoet_dataset.zarr"
dataset = CryoETDataset2_5D(zarr_file_path)

# DataLoader
dataloader = DataLoader(dataset, batch_size=8, shuffle=True)

# 데이터 확인
for batch in dataloader:
    print("Batch shape:", batch.shape)  # (B, 11, 126, 126)
    break


In [5]:
import torch
from torch.utils.data import Dataset, DataLoader
import zarr
import numpy as np

class CryoETDataset2_5D(Dataset):
    def __init__(self, zarr_path, crop_size=(126, 126), slice_depth=11, transform=None):
        """
        Cryo-ET 데이터셋을 2.5D 입력으로 준비하는 Dataset 클래스.
        
        Args:
            zarr_path (str): Zarr 파일 경로.
            crop_size (tuple): (height, width)로 크기를 정의.
            slice_depth (int): 병합할 슬라이스 수 (e.g., 11).
            transform (callable, optional): 슬라이스에 적용할 Transform.
        """
        self.dataset = zarr.open(zarr_path, mode='r')  # (D, H, W)
        self.crop_size = crop_size
        self.slice_depth = slice_depth
        self.transform = transform

        self.depth, self.height, self.width = self.dataset.shape
        self.pad_size = slice_depth // 2

    def __len__(self):
        # 모든 슬라이스를 대상으로 하되, 경계를 고려하여 슬라이스 개수 반환
        return self.depth

    def __getitem__(self, idx):
        # 경계 복제를 위해 padding 적용
        padded_data = np.pad(
            self.dataset,
            pad_width=((self.pad_size, self.pad_size), (0, 0), (0, 0)),
            mode='edge'
        )  # (D + 2*pad_size, H, W)

        # 11개의 슬라이스 추출 (idx 중심으로)
        slices = padded_data[idx:idx + self.slice_depth]  # (11, H, W)

        # 중앙 crop (H, W -> crop_size)
        start_h = (slices.shape[1] - self.crop_size[0]) // 2
        start_w = (slices.shape[2] - self.crop_size[1]) // 2
        slices_cropped = slices[:, start_h:start_h + self.crop_size[0], start_w:start_w + self.crop_size[1]]

        # Transform 적용
        if self.transform:
            slices_cropped = self.transform(slices_cropped)

        # PyTorch Tensor로 변환 (C=11, H, W)
        slices_tensor = torch.tensor(slices_cropped, dtype=torch.float32)

        return slices_tensor  # (11, H, W)

# Dataset 및 DataLoader 준비
zarr_file_path = './kaggle/input/czii-cryo-et-object-identification/train/static/ExperimentRuns/TS_5_4/VoxelSpacing10.000/ctfdeconvolved.zarr'

dataset = CryoETDataset2_5D(zarr_file_path)

# DataLoader
dataloader = DataLoader(dataset, batch_size=8, shuffle=True)

# 데이터 확인
for batch in dataloader:
    print("Batch shape:", batch.shape)  # (B, 11, 126, 126)
    break


AttributeError: 

In [1]:
import pickle
with open('data_dicts.pkl', 'rb') as f:
    data_dicts = pickle.load(f)

In [None]:
import pickle
import numpy as np
from torch.utils.data import Dataset, DataLoader
from pathlib import Path
import cv2
from tqdm import tqdm
import torch
from monai.transforms import (
    Compose, 
    EnsureChannelFirstd,
    NormalizeIntensityd,
    RandFlipd,
    RandRotate90d
)

def save_data_as_png(data_dicts, output_dir='./data'):
    """데이터를 PNG 파일로 저장"""
    output_dir = Path(output_dir)
    (output_dir / 'images').mkdir(parents=True, exist_ok=True)
    (output_dir / 'labels').mkdir(parents=True, exist_ok=True)
    
    for idx, data in tqdm(enumerate(data_dicts), desc='Saving data'):
        volume = data['image']  # (D, H, W)
        label = data['label']   # (D, H, W)
        
        for slice_idx in range(volume.shape[0]):
            img_slice = volume[slice_idx]
            img_slice = ((img_slice - img_slice.min()) / (img_slice.max() - img_slice.min()) * 255).astype(np.uint8)
            cv2.imwrite(str(output_dir / 'images' / f'volume_{idx:03d}_slice_{slice_idx:03d}.png'), img_slice)
            
            label_slice = label[slice_idx].astype(np.uint8)
            cv2.imwrite(str(output_dir / 'labels' / f'volume_{idx:03d}_slice_{slice_idx:03d}.png'), label_slice)

class CryoET_2_5D_Dataset(Dataset):
    def __init__(self, data_dir, transform=None, slice_thickness=11):
        self.data_dir = Path(data_dir)
        self.transform = transform
        self.slice_thickness = slice_thickness
        self.half_thickness = slice_thickness // 2
        
        self.image_files = sorted(list((self.data_dir / 'images').glob('*.png')))
        self.volume_slices = self._group_slices()
        
    def _group_slices(self):
        groups = {}
        for img_path in self.image_files:
            volume_id = img_path.stem.split('_slice_')[0]
            if volume_id not in groups:
                groups[volume_id] = []
            groups[volume_id].append(img_path)
        return groups
    
    def __len__(self):
        return sum(len(slices) - self.slice_thickness + 1 for slices in self.volume_slices.values())
    
    def __getitem__(self, idx):
        for volume_id, slices in self.volume_slices.items():
            if idx < len(slices) - self.slice_thickness + 1:
                center_idx = idx + self.half_thickness
                break
            idx -= len(slices) - self.slice_thickness + 1
        
        input_slices = []
        for i in range(center_idx - self.half_thickness, center_idx + self.half_thickness + 1):
            img = cv2.imread(str(slices[i]), cv2.IMREAD_GRAYSCALE)
            img = img.astype(np.float32) / 255.0
            input_slices.append(img)
        
        label_path = str(slices[center_idx]).replace('images', 'labels')
        label = cv2.imread(label_path, cv2.IMREAD_GRAYSCALE)
        
        data_dict = {
            'image': np.stack(input_slices)[None, ...],  # (1, 11, H, W)
            'label': label[None, ...].astype(np.int64)   # (1, H, W)
        }
        
        if self.transform:
            data = self.transform(data_dict)
            return data['image'].squeeze(0), data['label'].squeeze(0)
        
        return torch.from_numpy(data_dict['image']), torch.from_numpy(data_dict['label'])

transforms = Compose([
    EnsureChannelFirstd(
        keys=['image', 'label'],
        channel_dim=None,
        allow_missing_keys=True
    ),
    NormalizeIntensityd(
        keys=['image'],
        allow_missing_keys=True
    ),
    RandFlipd(
        keys=['image', 'label'],
        spatial_axis=[1, 2],
        prob=0.5,
        allow_missing_keys=True
    ),
    RandRotate90d(
        keys=['image', 'label'],
        spatial_axes=[1, 2],
        prob=0.5,
        allow_missing_keys=True
    ),
])

def create_dataloader(data_dir, batch_size=8, train_ratio=0.8):
    dataset = CryoET_2_5D_Dataset(
        data_dir=data_dir,
        transform=transforms,
        slice_thickness=11
    )
    
    train_size = int(train_ratio * len(dataset))
    val_size = len(dataset) - train_size
    train_dataset, val_dataset = torch.utils.data.random_split(dataset, [train_size, val_size])
    
    train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
    val_loader = DataLoader(val_dataset, batch_size=batch_size)
    
    return train_loader, val_loader

def validate_data_format(loader):
    for batch in loader:
        images, labels = batch
        print("데이터 형식 검증:")
        print(f"이미지 shape: {images.shape}")
        print(f"이미지 dtype: {images.dtype}")
        print(f"라벨 shape: {labels.shape}")
        print(f"라벨 dtype: {labels.dtype}")
        print(f"라벨 고유값: {torch.unique(labels)}")
        break

if __name__ == "__main__":
    with open('data_dicts.pkl', 'rb') as f:
        data_dicts = pickle.load(f)
    
    save_data_as_png(data_dicts)
    train_loader, val_loader = create_dataloader('./data', batch_size=8)
    validate_data_format(train_loader)

Saving data: 7it [00:08,  1.20s/it]


RuntimeError: applying transform <monai.transforms.spatial.dictionary.RandFlipd object at 0x000002957D74BE90>