In [1]:
import numpy as np
import torch
from monai.data import DataLoader, Dataset, CacheDataset, decollate_batch
import os
from monai.transforms import (
    Compose, 
    EnsureChannelFirstd, 
    Orientationd,  
    AsDiscrete,  
    RandFlipd, 
    RandRotate90d, 
    NormalizeIntensityd,
    RandCropByLabelClassesd,
)

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
TRAIN_IMG_DIR = "./datasets/train/images"
TRAIN_LABEL_DIR = "./datasets/train/labels"
VAL_IMG_DIR = "./datasets/val/images"
VAL_LABEL_DIR = "./datasets/val/labels"

train_list = os.listdir(TRAIN_IMG_DIR)
val_list = os.listdir(VAL_IMG_DIR)
train_files = []
valid_files = []


for name in train_list:
    image = np.load(os.path.join(TRAIN_IMG_DIR, f"{name}"))    
    label = np.load(os.path.join(TRAIN_LABEL_DIR, f"{name.replace("image", "label")}"))

    train_files.append({"image": image, "label": label})    

for name in val_list:
    image = np.load(os.path.join(VAL_IMG_DIR, f"{name}"))
    label = np.load(os.path.join(VAL_LABEL_DIR, f"{name.replace("image", "label")}"))

    valid_files.append({"image": image, "label": label})

In [3]:
# Non-random transforms to be cached
non_random_transforms = Compose([
    EnsureChannelFirstd(keys=["image", "label"], channel_dim="no_channel"),
    NormalizeIntensityd(keys="image"),
    Orientationd(keys=["image", "label"], axcodes="ASR")
])

raw_train_ds = CacheDataset(data=train_files, transform=non_random_transforms, cache_rate=1.0)


my_num_samples = 1
train_batch_size = 1


# Random transforms to be applied during training
random_transforms = Compose([
    RandCropByLabelClassesd(
        keys=["image", "label"],
        label_key="label",
        spatial_size=[11, 96, 96],
        num_classes=7,
        num_samples=my_num_samples
    ),
    RandRotate90d(keys=["image", "label"], prob=0.5, spatial_axes=[0, 2]),
    RandFlipd(keys=["image", "label"], prob=0.5, spatial_axis=0),    
])


train_ds = Dataset(data=raw_train_ds, transform=random_transforms)


# DataLoader remains the same
train_loader = DataLoader(
    train_ds,
    batch_size=train_batch_size,
    shuffle=True,
    num_workers=0,
    pin_memory=torch.cuda.is_available()
)


Loading dataset: 100%|██████████| 24/24 [00:03<00:00,  6.74it/s]


In [4]:
# 데이터 검증 함수
def inspect_batch(loader):
    # 첫 번째 배치 가져오기
    batch = next(iter(loader))
    
    print("=== 배치 데이터 검증 ===")
    print(f"이미지 shape: {batch['image'].shape}")
    print(f"이미지 dtype: {batch['image'].dtype}")
    print(f"이미지 값 범위: [{batch['image'].min():.3f}, {batch['image'].max():.3f}]")
    print("\n")
    print(f"라벨 shape: {batch['label'].shape}")
    print(f"라벨 dtype: {batch['label'].dtype}")
    print(f"라벨 고유값: {torch.unique(batch['label'])}")

# 실행
inspect_batch(train_loader)

=== 배치 데이터 검증 ===
이미지 shape: torch.Size([1, 1, 11, 96, 96])
이미지 dtype: torch.float32
이미지 값 범위: [-10.437, 2.000]


라벨 shape: torch.Size([1, 1, 11, 96, 96])
라벨 dtype: torch.uint8
라벨 고유값: tensor([0, 4, 5], dtype=torch.uint8)


In [None]:

raw_valid_ds = CacheDataset(data=valid_files, transform=non_random_transforms, cache_rate=1.0)
valid_ds = Dataset(data=raw_valid_ds, transform=random_transforms)
valid_batch_size = 1

# DataLoader remains the same
valid_loader = DataLoader(
    valid_ds,
    batch_size=valid_batch_size,
    shuffle=False,
    num_workers=4,
    pin_memory=torch.cuda.is_available()
)


Loading dataset: 100%|██████████| 4/4 [00:02<00:00,  1.55it/s]


In [15]:
# 첫 번째 배치를 가져옵니다.
for batch in valid_loader:
    # 배치의 입력 데이터와 레이블을 가져옵니다.
    x, y = batch['image'], batch['label']
    
    # 입력 데이터와 레이블의 모양을 출력합니다.
    print(f'입력 데이터의 모양: {x.shape}')
    print(f'레이블의 모양: {y.shape}')
    
    # 첫 번째 배치만 확인하면 되므로 break로 루프를 종료합니다.
    break

> collate dict key "image" out of 2 keys
>> collate/stack a list of tensors
>> E: stack expects each tensor to be equal size, but got [1, 7, 96, 96] at entry 0 and [1, 96, 96, 7] at entry 3, shape [torch.Size([1, 7, 96, 96]), torch.Size([1, 7, 96, 96]), torch.Size([1, 7, 96, 96]), torch.Size([1, 96, 96, 7]), torch.Size([1, 7, 96, 96]), torch.Size([1, 7, 96, 96]), torch.Size([1, 7, 96, 96]), torch.Size([1, 96, 96, 7]), torch.Size([1, 96, 96, 7]), torch.Size([1, 96, 96, 7]), torch.Size([1, 7, 96, 96]), torch.Size([1, 7, 96, 96]), torch.Size([1, 7, 96, 96]), torch.Size([1, 7, 96, 96]), torch.Size([1, 7, 96, 96]), torch.Size([1, 7, 96, 96])] in collate([metatensor([[[[-5.1154e+00, -5.2550e+00, -3.4445e+00,  ..., -1.5089e-02,
           -9.2563e-01, -1.0091e+00],
          [-4.8709e+00, -4.5988e+00, -3.7460e+00,  ...,  2.4953e-01,
           -4.1819e-01, -2.6483e-01],
          [-4.8062e+00, -3.6248e+00, -3.2961e+00,  ...,  4.9671e-01,
           -8.3564e-01, -8.8224e-01],
          ...,
  

RuntimeError: Caught RuntimeError in DataLoader worker process 0.
Original Traceback (most recent call last):
  File "/Users/seungwoo/anaconda3/envs/dust/lib/python3.12/site-packages/monai/data/utils.py", line 519, in list_data_collate
    ret = collate_fn(data)
          ^^^^^^^^^^^^^^^^
  File "/Users/seungwoo/anaconda3/envs/dust/lib/python3.12/site-packages/torch/utils/data/_utils/collate.py", line 277, in default_collate
    return collate(batch, collate_fn_map=default_collate_fn_map)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/Users/seungwoo/anaconda3/envs/dust/lib/python3.12/site-packages/torch/utils/data/_utils/collate.py", line 129, in collate
    return elem_type({key: collate([d[key] for d in batch], collate_fn_map=collate_fn_map) for key in elem})
                           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/Users/seungwoo/anaconda3/envs/dust/lib/python3.12/site-packages/torch/utils/data/_utils/collate.py", line 121, in collate
    return collate_fn_map[elem_type](batch, collate_fn_map=collate_fn_map)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/Users/seungwoo/anaconda3/envs/dust/lib/python3.12/site-packages/monai/data/utils.py", line 456, in collate_meta_tensor_fn
    collated = collate_tensor_fn(batch)
               ^^^^^^^^^^^^^^^^^^^^^^^^
  File "/Users/seungwoo/anaconda3/envs/dust/lib/python3.12/site-packages/torch/utils/data/_utils/collate.py", line 174, in collate_tensor_fn
    return torch.stack(batch, 0, out=out)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/Users/seungwoo/anaconda3/envs/dust/lib/python3.12/site-packages/monai/data/meta_tensor.py", line 282, in __torch_function__
    ret = super().__torch_function__(func, types, args, kwargs)
          ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/Users/seungwoo/anaconda3/envs/dust/lib/python3.12/site-packages/torch/_tensor.py", line 1418, in __torch_function__
    ret = func(*args, **kwargs)
          ^^^^^^^^^^^^^^^^^^^^^
RuntimeError: stack expects each tensor to be equal size, but got [1, 7, 96, 96] at entry 0 and [1, 96, 96, 7] at entry 3

The above exception was the direct cause of the following exception:

Traceback (most recent call last):
  File "/Users/seungwoo/anaconda3/envs/dust/lib/python3.12/site-packages/torch/utils/data/_utils/worker.py", line 308, in _worker_loop
    data = fetcher.fetch(index)
           ^^^^^^^^^^^^^^^^^^^^
  File "/Users/seungwoo/anaconda3/envs/dust/lib/python3.12/site-packages/torch/utils/data/_utils/fetch.py", line 54, in fetch
    return self.collate_fn(data)
           ^^^^^^^^^^^^^^^^^^^^^
  File "/Users/seungwoo/anaconda3/envs/dust/lib/python3.12/site-packages/monai/data/utils.py", line 532, in list_data_collate
    raise RuntimeError(re_str) from re
RuntimeError: stack expects each tensor to be equal size, but got [1, 7, 96, 96] at entry 0 and [1, 96, 96, 7] at entry 3

MONAI hint: if your transforms intentionally create images of different shapes, creating your `DataLoader` with `collate_fn=pad_list_data_collate` might solve this problem (check its documentation).



           -1.4228e+00,  3.2713e-02],
          [-8.4190e-01,  2.9665e-01,  2.3930e+00,  ...,  1.5405e+00,
           -3.2178e-01, -4.3048e-01],
          [-3.5362e-01,  9.1296e-01,  1.3811e+00,  ...,  1.3989e+00,
           -4.7794e-01, -1.2878e-01]],

         [[ 7.7764e-01,  7.3444e-01,  3.5155e-01,  ...,  5.0070e+00,
            1.9922e+00, -4.2962e-01],
          [ 1.2200e+00,  7.2342e-01, -3.3823e-01,  ..., -3.8807e-01,
            2.2631e-01,  2.5034e+00],
          [ 7.6491e-01,  4.5611e-01,  5.5405e-01,  ...,  2.0146e-01,
            3.1584e-01,  2.1275e-01],
          ...,
          [ 5.5214e-01,  1.9965e+00,  5.9690e-02,  ...,  5.2932e-01,
           -4.7120e-02, -1.0298e+00],
          [-8.1410e-01, -1.0441e-02,  2.8859e+00,  ...,  6.7769e-01,
            9.4744e-01, -7.2454e-01],
          [ 1.8791e-01, -1.3932e-01,  1.1908e+00,  ..., -2.0222e-01,
            1.6572e-01, -7.0603e-01]],

         [[ 1.3471e+00,  5.0151e-01,  8.2111e-02,  ...,  4.4096e+00,
            2.174

In [None]:
# import torch
# from PIL import Image
# import numpy as np
# from pathlib import Path
# from torch.utils.data import Dataset, DataLoader
# from collections import defaultdict


# import torch
# from torch.utils.data import Dataset
# import numpy as np
# from pathlib import Path


# class NumpyCryoETDataset(Dataset):
#     def __init__(self, data_dir, num_channels=11, slice_size=(224, 224), stride=112, transform=None):
#         self.data_dir = Path(data_dir)
#         self.num_channels = num_channels
#         self.slice_size = slice_size
#         self.stride = stride
#         self.transform = transform

#         # Collect all numpy files
#         self.image_files = sorted(list(self.data_dir.glob("*_image.npy")))
#         self.label_files = sorted(list(self.data_dir.glob("*_label.npy")))
#         assert len(self.image_files) == len(self.label_files), "Mismatch between image and label files!"

#         # Generate indices for individual patches
#         self.data_indices = self._generate_indices()

#     def _generate_indices(self):
#         """Generate indices for individual patches."""
#         indices = []
#         for file_idx, (image_file, label_file) in enumerate(zip(self.image_files, self.label_files)):
#             image = np.load(image_file)
#             label = np.load(label_file)
#             D, H, W = image.shape

#             for z in range(self.num_channels // 2, D - self.num_channels // 2):
#                 for y in range(0, H - self.slice_size[0] + 1, self.stride):
#                     for x in range(0, W - self.slice_size[1] + 1, self.stride):
#                         indices.append((file_idx, z, y, x))
#         return indices

#     def __len__(self):
#         return len(self.data_indices)

#     def __getitem__(self, idx):
#         # Retrieve patch index
#         file_idx, z, y, x = self.data_indices[idx]

#         # Load numpy arrays
#         image = np.load(self.image_files[file_idx])
#         label = np.load(self.label_files[file_idx])

#         # Extract 2.5D patch
#         input_patch = image[z - self.num_channels // 2:z + self.num_channels // 2 + 1, y:y + self.slice_size[0], x:x + self.slice_size[1]]
#         label_patch = label[z, y:y + self.slice_size[0], x:x + self.slice_size[1]]

#         # Apply transformations if provided
#         if self.transform:
#             input_patch, label_patch = self.transform(input_patch, label_patch)

#         # Convert to PyTorch tensors
#         input_patch = torch.tensor(input_patch, dtype=torch.float32)
#         label_patch = torch.tensor(label_patch, dtype=torch.long)

#         return input_patch, label_patch

# # Example usage
# data_dir = "./datasets/numpy"
# dataset = NumpyCryoETDataset(data_dir, num_channels=11, slice_size=(224, 224), stride=112)
# dataloader = torch.utils.data.DataLoader(dataset, batch_size=1, shuffle=True)

# for inputs, labels in dataloader:
#     print(f"Inputs shape: {inputs.shape}")  # (B, num_channels, H, W)
#     print(f"Labels shape: {labels.shape}")  # (B, H, W)
#     break

Inputs shape: torch.Size([1, 11, 224, 224])
Labels shape: torch.Size([1, 224, 224])


In [8]:
import torch
import torch.nn as nn
from monai.networks.nets import UNet

class UNet2_5D_v2(nn.Module):
    def __init__(self, out_channels=6):
        super().__init__()
        
        # 초기 3D 처리 레이어
        self.init_3d = nn.Sequential(
            nn.Conv3d(1, 64, kernel_size=(11, 3, 3), padding=(0, 1, 1)),
            nn.BatchNorm3d(64),
            nn.ReLU(inplace=True)
        )
        
        # 2D UNet
        self.unet = UNet(
            spatial_dims=2,
            in_channels=64,  # 3D 컨볼루션 출력 채널
            out_channels=out_channels,
            channels=(128, 256, 512, 1024),
            strides=(2, 2, 2, 2),
            num_res_units=2
        )

    def forward(self, x):
        # x shape: (batch, 1, 11, H, W)
        # 3D 처리
        x = x.unsqueeze(1)
        x = self.init_3d(x)  # (batch, 64, 1, H, W)
        x = x.squeeze(2)     # (batch, 64, H, W)
        
        # 2D UNet
        return self.unet(x)

# 테스트 코드

model = UNet2_5D_v2(out_channels=7)
x = torch.randn(2, 11, 224, 224)
output = model(x)
print(f"Output shape: {output.shape}")  # Expected: (8, 6, 256, 256)



Output shape: torch.Size([2, 7, 224, 224])


In [None]:
for inputs, targets in dataloader:
    print(f"Max label value: {targets.max().item()}")
    print(f"Num classes (model output): {output.shape[1]}")
    break

Max label value: 6
Num classes (model output): 7


In [None]:
import torch
from monai.losses import DiceLoss
from torch import optim

# Loss and Optimizer
criterion = DiceLoss(to_onehot_y=True, softmax=True)  # Dice Loss with softmax
optimizer = optim.Adam(model.parameters(), lr=1e-4)

# Training Loop
num_epochs = 5
for epoch in range(num_epochs):
    model.train()
    running_loss = 0.0

    for inputs, targets in dataloader:
        # Ensure targets are long integers
        targets = targets.long()  # 라벨: (B, H, W)

        # Add channel dimension to targets: (B, H, W) -> (B, 1, H, W)
        targets = targets.unsqueeze(1)
        print(f"Modified Targets shape: {targets.shape}")  # (B, 1, H, W)

        # Forward pass
        outputs = model(inputs)  # 모델 출력: (B, 7, H, W)
        print(f"Outputs shape: {outputs.shape}, Targets shape: {targets.shape}")

        # Compute loss
        loss = criterion(outputs, targets)  # Dice Loss
        print(f"Loss: {loss.item()}")

        # Backpropagation
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        running_loss += loss.item()

    # Print epoch loss
    epoch_loss = running_loss / len(dataloader)
    print(f"Epoch [{epoch + 1}/{num_epochs}], Loss: {epoch_loss:.4f}")

Modified Targets shape: torch.Size([1, 1, 224, 224])
Outputs shape: torch.Size([1, 7, 224, 224]), Targets shape: torch.Size([1, 1, 224, 224])
Loss: 0.9693675637245178
Modified Targets shape: torch.Size([1, 1, 224, 224])
Outputs shape: torch.Size([1, 7, 224, 224]), Targets shape: torch.Size([1, 1, 224, 224])
Loss: 0.9682437181472778
Modified Targets shape: torch.Size([1, 1, 224, 224])
Outputs shape: torch.Size([1, 7, 224, 224]), Targets shape: torch.Size([1, 1, 224, 224])
Loss: 0.9593340754508972
Modified Targets shape: torch.Size([1, 1, 224, 224])
Outputs shape: torch.Size([1, 7, 224, 224]), Targets shape: torch.Size([1, 1, 224, 224])
Loss: 0.9628776907920837
Modified Targets shape: torch.Size([1, 1, 224, 224])
Outputs shape: torch.Size([1, 7, 224, 224]), Targets shape: torch.Size([1, 1, 224, 224])
Loss: 0.957891047000885
Modified Targets shape: torch.Size([1, 1, 224, 224])
Outputs shape: torch.Size([1, 7, 224, 224]), Targets shape: torch.Size([1, 1, 224, 224])
Loss: 0.9588175415992737

In [None]:
# import torch
# import torch.nn as nn
# import torch.optim as optim
# from torch.utils.data import DataLoader, Dataset
# import numpy as np

# # Define the model (U-Net)
# class DoubleConv(nn.Module):
#     def __init__(self, in_channels, out_channels):
#         super(DoubleConv, self).__init__()
#         self.double_conv = nn.Sequential(
#             nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1),
#             nn.ReLU(inplace=True),
#             nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1),
#             nn.ReLU(inplace=True)
#         )

#     def forward(self, x):
#         return self.double_conv(x)


# class UNet(nn.Module):
#     def __init__(self, in_channels, out_channels):
#         super(UNet, self).__init__()
#         self.enc1 = DoubleConv(in_channels, 64)
#         self.pool1 = nn.MaxPool2d(2)
#         self.enc2 = DoubleConv(64, 128)
#         self.pool2 = nn.MaxPool2d(2)
#         self.enc3 = DoubleConv(128, 256)
#         self.pool3 = nn.MaxPool2d(2)
#         self.bridge = DoubleConv(256, 512)
#         self.upconv3 = nn.ConvTranspose2d(512, 256, kernel_size=2, stride=2)
#         self.dec3 = DoubleConv(512, 256)
#         self.upconv2 = nn.ConvTranspose2d(256, 128, kernel_size=2, stride=2)
#         self.dec2 = DoubleConv(256, 128)
#         self.upconv1 = nn.ConvTranspose2d(128, 64, kernel_size=2, stride=2)
#         self.dec1 = DoubleConv(128, 64)
#         self.out_conv = nn.Conv2d(64, out_channels, kernel_size=1)

#     def forward(self, x):
#         enc1 = self.enc1(x)
#         enc2 = self.enc2(self.pool1(enc1))
#         enc3 = self.enc3(self.pool2(enc2))
#         bridge = self.bridge(self.pool3(enc3))
#         dec3 = self.dec3(torch.cat([self.upconv3(bridge), enc3], dim=1))
#         dec2 = self.dec2(torch.cat([self.upconv2(dec3), enc2], dim=1))
#         dec1 = self.dec1(torch.cat([self.upconv1(dec2), enc1], dim=1))
#         return self.out_conv(dec1)


# # Instantiate model, loss, and optimizer
# model = UNet(in_channels=11, out_channels=2)  # 2 classes for segmentation
# criterion = nn.CrossEntropyLoss()
# optimizer = optim.Adam(model.parameters(), lr=1e-4)


# # Training Loop
# for epoch in range(5):  # 5 epochs for demonstration
#     model.train()
#     running_loss = 0.0
#     for inputs, targets in dataloader:
#         optimizer.zero_grad()
#         outputs = model(inputs)  # Forward pass
#         loss = criterion(outputs, targets)  # Compute loss
#         loss.backward()  # Backpropagation
#         optimizer.step()  # Update weights
#         running_loss += loss.item()

#     print(f"Epoch [{epoch+1}/5], Loss: {running_loss/len(dataloader):.4f}")