In [38]:
import numpy as np
from monai.data import Dataset
from monai.transforms import (
    Compose, LoadImaged, EnsureChannelFirstd, NormalizeIntensityd,
    Orientationd, CropForegroundd, GaussianSmoothd, ScaleIntensityd,
    RandSpatialCropd, RandRotate90d, RandFlipd, RandGaussianNoised,
    ToTensord
)
import torch
from torch.utils.data import DataLoader

import os

from monai.losses import DiceLoss
from monai.metrics import DiceMetric
from monai.inferers import sliding_window_inference


In [33]:
TRAIN_IMG_DIR = "./datasets/train/images"
TRAIN_LABEL_DIR = "./datasets/train/labels"
VAL_IMG_DIR = "./datasets/val/images"
VAL_LABEL_DIR = "./datasets/val/labels"

train_list = os.listdir(TRAIN_IMG_DIR)
val_list = os.listdir(VAL_IMG_DIR)
train_files = []
valid_files = []


for name in train_list:
    train_image = np.load(os.path.join(TRAIN_IMG_DIR, f"{name}"))    
    train_label = np.load(os.path.join(TRAIN_LABEL_DIR, f"{name.replace("image", "label")}"))

    train_files.append({"image": train_image, "label": train_label})    

for name in val_list:
    valid_image = np.load(os.path.join(VAL_IMG_DIR, f"{name}"))
    valid_label = np.load(os.path.join(VAL_LABEL_DIR, f"{name.replace("image", "label")}"))

    valid_files.append({"image": valid_image, "label": valid_label})


class CryoETDataset(Dataset):
    def __init__(self, data, transforms, slice_thickness=5, voxel_size=(10, 10, 10), origin=(0, 0, 0)):
        """
        Args:
            data: 리스트 형태의 데이터셋 ({"image": np.ndarray, "label": np.ndarray} 형태).
            transforms: MONAI 변환 객체.
            slice_thickness: 슬라이스 두께.
            voxel_size: 각 축의 voxel 크기 (z, x, y).
            origin: 데이터 원점 (z, x, y).
        """
        super().__init__(data, transforms)
        self.slice_thickness = slice_thickness
        self.half_thickness = slice_thickness // 2
        self.voxel_size = voxel_size
        self.origin = origin
        self.slices = []  # 전체 슬라이스 저장

        # 모든 슬라이스를 미리 생성
        for data_dict in data:
            image = data_dict["image"]
            label = data_dict.get("label", None)

            # Z 축에 패딩 추가
            image = np.pad(image, ((self.half_thickness, self.half_thickness), (0, 0), (0, 0)), mode="constant", constant_values=0)
            if label is not None:
                label = np.pad(label, ((self.half_thickness, self.half_thickness), (0, 0), (0, 0)), mode="constant", constant_values=0)

            depth = image.shape[0]

            for center_idx in range(self.half_thickness, depth - self.half_thickness):
                self.slices.append({
                    "image": image[center_idx - self.half_thickness:center_idx + self.half_thickness + 1],
                    "label": label[center_idx] if label is not None else None,
                    "original_index": center_idx - self.half_thickness,  # 패딩 이전의 원래 인덱스
                    "real_position": self.compute_voxel_position(center_idx - self.half_thickness, 0, 0),
                })

    def compute_voxel_position(self, z, x, y):
        dz, dx, dy = self.voxel_size
        oz, ox, oy = self.origin
        z_real = z * dz + oz
        x_real = x * dx + ox
        y_real = y * dy + oy
        return z_real, x_real, y_real

    def __len__(self):
        return len(self.slices)

    def __getitem__(self, index):
        slice_data = self.slices[index]
        return {
            "image": slice_data["image"],  # 슬라이스 (slice_thickness, H, W)
            "label": slice_data["label"],  # 중앙 라벨 (H, W) 또는 None
            "original_index": slice_data["original_index"],  # 중앙 슬라이스 인덱스
            "real_position": slice_data["real_position"],  # 실제 좌표
        }

In [30]:
# 배치 데이터 확인
def inspect_batch(loader):
    batch = next(iter(loader))
    print("=== 배치 데이터 확인 ===")
    print(f"Batch image shape: {batch['image'].shape}")
    print(f"Batch label shape: {batch['label'].shape if batch['label'] is not None else 'None'}")
    print(f"Image dtype: {batch['image'].dtype}")
    if batch["label"] is not None:
        print(f"Label dtype: {batch['label'].dtype}")
        print(f"Label unique values: {torch.unique(batch['label'])}")

In [11]:
train_files[0]['image'].shape

(184, 630, 630)

In [None]:
train_transforms = Compose([
    EnsureChannelFirstd(keys=["image", "label"], channel_dim="no_channel"),
    Orientationd(keys=["image", "label"], axcodes="SRA")
])

# 데이터셋 생성
train_dataset = CryoETDataset(data=train_files, transforms=train_transforms)
train_loader = DataLoader(train_dataset, batch_size=1, shuffle=True)


        
for batch in train_loader:
    print(f"Batch image shape: {batch['image'].shape}")
    print(f"Batch label shape: {batch['label'].shape if batch['label'] is not None else 'None'}")
    print(f"Batch original indices: {batch['original_index']}")
    print(f"Batch real positions: {batch['real_position']}")
    break

Batch image shape: torch.Size([1, 5, 630, 630])
Batch label shape: torch.Size([1, 630, 630])
Batch original indices: tensor([113])
Batch real positions: [tensor([1130]), tensor([0]), tensor([0])]


# Validation DataLoader

In [None]:
val_transforms = Compose([
    EnsureChannelFirstd(keys=["image", "label"], channel_dim="no_channel"),
    CropForegroundd(keys=["image", "label"], source_key="image"),
    NormalizeIntensityd(keys="image"),
    ScaleIntensityd(keys="image", minv=0.0, maxv=1.0),
    ToTensord(keys=["image", "label"]),
])
val_dataset = CryoETDataset(data=valid_files, transforms=val_transforms)

Loading dataset: 100%|██████████| 26496/26496 [00:07<00:00, 3439.67it/s]


# Task

In [None]:
no_label_transforms = Compose([
    EnsureChannelFirstd(keys=["image"], channel_dim="no_channel"),
    CropForegroundd(keys=["image"], source_key="image"),
    NormalizeIntensityd(keys="image"),
    ScaleIntensityd(keys="image", minv=0.0, maxv=1.0),
    ToTensord(keys=["image"]),
])
# no_label_dataset = CryoETDataset(data=no_label_data, transforms=no_label_transforms)

In [41]:
import torch
import torch.nn as nn
from monai.networks.nets import UNet

class UNet2_5D_v2(nn.Module):
    def __init__(self, init_ch=3 ,out_channels=6):
        super().__init__()
        
        # 초기 3D 처리 레이어
        self.init_3d = nn.Sequential(
            nn.Conv3d(1, 64, kernel_size=(init_ch, 3, 3), padding=(0, 1, 1)),
            nn.BatchNorm3d(64),
            nn.ReLU(inplace=True)
        )
        
        # 2D UNet
        self.unet = UNet(
            spatial_dims=2,
            in_channels=64,  # 3D 컨볼루션 출력 채널
            out_channels=out_channels,
            channels=(128, 256, 512, 1024),
            strides=(2, 2, 2, 2),
            num_res_units=2
        )

    def forward(self, x):
        # x shape: (batch, 1, 11, H, W)
        # 3D 처리
        
        x = self.init_3d(x)  # (batch, 64, 1, H, W)
        x = x.squeeze(2)     # (batch, 64, H, W)
        
        # 2D UNet
        return self.unet(x)

# 테스트 코드

model = UNet2_5D_v2(init_ch=3,out_channels=7)





In [43]:
x = torch.randn(2, 1, 3, 96, 96)
print(f"Input Shape: {x.shape}")

Input Shape: torch.Size([2, 1, 3, 96, 96])


In [45]:
output = model(x)
print(f"Output shape: {output.shape}")  # Expected: (8, 6, 256, 256)

Output shape: torch.Size([2, 7, 96, 96])


In [None]:
import torch
from monai.losses import DiceLoss
from torch import optim

# Loss and Optimizer
criterion = DiceLoss(to_onehot_y=True, softmax=True)  # Dice Loss with softmax
optimizer = optim.Adam(model.parameters(), lr=1e-4)

# Training Loop
num_epochs = 5
for epoch in range(num_epochs):
    model.train()
    running_loss = 0.0

    for batch in train_loader:
        
        # Ensure targets are long integers
        targets = batch['label'].long()  # 라벨: (B, H, W)

        
        print(f"Modified Targets shape: {targets.shape}")  # (B, 1, H, W)

        # Forward pass
        outputs = model(batch['image'])  # 모델 출력: (B, 7, H, W)
        print(f"Outputs shape: {outputs.shape}, Targets shape: {targets.shape}")

        # Compute loss
        loss = criterion(outputs, targets)  # Dice Loss
        print(f"Loss: {loss.item()}")

        # Backpropagation
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        running_loss += loss.item()

    # Print epoch loss
    epoch_loss = running_loss / len(train_loader)
    print(f"Epoch [{epoch + 1}/{num_epochs}], Loss: {epoch_loss:.4f}")

Modified Targets shape: torch.Size([1, 1, 3, 96, 96])
Outputs shape: torch.Size([1, 7, 96, 96]), Targets shape: torch.Size([1, 1, 3, 96, 96])


AssertionError: ground truth has different shape (torch.Size([1, 7, 3, 96, 96])) from input (torch.Size([1, 7, 96, 96]))