In [3]:
import torch
import torch.nn as nn
from monai.networks.nets import UNet

class UNet2_5D_v2(nn.Module):
    def __init__(self, out_channels=6):
        super().__init__()
        
        # 초기 3D 처리 레이어
        self.init_3d = nn.Sequential(
            nn.Conv3d(1, 64, kernel_size=(11, 3, 3), padding=(0, 1, 1)),
            nn.BatchNorm3d(64),
            nn.ReLU(inplace=True)
        )
        
        # 2D UNet
        self.unet = UNet(
            spatial_dims=2,
            in_channels=64,  # 3D 컨볼루션 출력 채널
            out_channels=out_channels,
            channels=(64, 128, 256, 512),
            strides=(2, 2, 2, 2),
            num_res_units=2
        )

    def forward(self, x):
        # x shape: (batch, 1, 11, H, W)
        # 3D 처리
        x = self.init_3d(x)  # (batch, 64, 1, H, W)
        x = x.squeeze(2)     # (batch, 64, H, W)
        
        # 2D UNet
        return self.unet(x)

# 테스트 코드
if __name__ == "__main__":
    model = UNet2_5D_v2(out_channels=6)
    x = torch.randn(8, 1, 11, 256, 256)
    output = model(x)
    print(f"Output shape: {output.shape}")  # Expected: (8, 6, 256, 256)

Output shape: torch.Size([8, 6, 256, 256])


In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F

class DoubleConv(nn.Module):
    """2D Double Convolution Block"""
    def __init__(self, in_channels, out_channels):
        super(DoubleConv, self).__init__()
        self.conv = nn.Sequential(
            nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True),
            nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True)
        )

    def forward(self, x):
        return self.conv(x)


class Encoder(nn.Module):
    """U-Net Encoder"""
    def __init__(self, in_channels, features):
        super(Encoder, self).__init__()
        self.pool = nn.MaxPool2d(kernel_size=2, stride=2)
        self.conv = DoubleConv(in_channels, features)

    def forward(self, x):
        x = self.conv(x)
        skip = x
        x = self.pool(x)
        return x, skip


class Decoder(nn.Module):
    """U-Net Decoder"""
    def __init__(self, in_channels, skip_channels, out_channels):
        """
        Args:
            in_channels: Up-sampled feature channels
            skip_channels: Skip connection feature channels
            out_channels: Output feature channels after concatenation
        """
        super(Decoder, self).__init__()
        self.upconv = nn.ConvTranspose2d(in_channels, out_channels, kernel_size=2, stride=2)
        self.conv = DoubleConv(out_channels + skip_channels, out_channels)  # Concatenated channels

    def forward(self, x, skip):
        # Up-sample
        x = self.upconv(x)
        # Ensure spatial dimensions match
        if x.shape[-2:] != skip.shape[-2:]:
            diffY = skip.size(2) - x.size(2)
            diffX = skip.size(3) - x.size(3)
            x = F.pad(x, [diffX // 2, diffX - diffX // 2, diffY // 2, diffY - diffY // 2])
        # Concatenate with skip connection
        x = torch.cat((x, skip), dim=1)
        # Apply convolution
        x = self.conv(x)
        return x


class UNet2_5D(nn.Module):
    def __init__(self, in_channels, out_channels, depth_slices=11, features=[64, 128, 256, 512]):
        super(UNet2_5D, self).__init__()
        self.depth_slices = depth_slices

        # Initial 3D Convolution to merge depth slices
        self.init_conv3d = nn.Conv3d(in_channels, features[0], kernel_size=(depth_slices, 3, 3), padding=(0, 1, 1))
        self.init_bn3d = nn.BatchNorm3d(features[0])

        # Encoder
        self.encoders = nn.ModuleList()
        for i in range(len(features)):
            if i == 0:
                self.encoders.append(Encoder(features[0], features[0]))
            else:
                self.encoders.append(Encoder(features[i-1], features[i]))

        # Bottleneck
        self.bottleneck = DoubleConv(features[-1], features[-1] * 2)

        # Decoder
        self.decoders = nn.ModuleList()
        reversed_features = list(reversed(features))
        
        # 첫 번째 디코더는 bottleneck의 출력을 처리
        self.decoders.append(
            Decoder(
                in_channels=features[-1] * 2,  # bottleneck의 출력
                skip_channels=features[-1],    # 마지막 인코더의 skip connection
                out_channels=features[-2]      # 다음 레벨의 특성 수
            )
        )
        
        # 중간 디코더들
        for i in range(len(features)-2):
            self.decoders.append(
                Decoder(
                    in_channels=features[-2-i],     # 이전 디코더의 출력
                    skip_channels=features[-2-i],   # 해당 레벨의 skip connection
                    out_channels=features[-3-i]     # 다음 레벨의 특성 수
                )
            )
            
        # 마지막 디코더
        self.decoders.append(
            Decoder(
                in_channels=features[0],     # 이전 디코더의 출력
                skip_channels=features[0],   # 첫 번째 인코더의 skip connection
                out_channels=features[0]     # 최종 특성 수
            )
        )

        # Final Convolution
        self.final_conv = nn.Conv2d(features[0], out_channels, kernel_size=1)

    def forward(self, x):
        # 3D Convolution for depth slices
        x = F.relu(self.init_bn3d(self.init_conv3d(x)))
        x = x.squeeze(2)

        # Encoder path
        skips = []
        for encoder in self.encoders:
            x, skip = encoder(x)
            skips.append(skip)

        # Bottleneck
        x = self.bottleneck(x)

        # Decoder path with corrected skip connections
        skips = skips[::-1]  # 스킵 커넥션 순서 뒤집기
        for decoder, skip in zip(self.decoders, skips):
            x = decoder(x, skip)

        return self.final_conv(x)


In [None]:
import torch
from torch.utils.data import Dataset, DataLoader

class CryoETDataset(Dataset):
    def __init__(self, images, masks):
        self.images = images  # Shape: (N, 1, 11, H, W)
        self.masks = masks    # Shape: (N, H, W) with class labels

    def __len__(self):
        return len(self.images)

    def __getitem__(self, idx):
        image = self.images[idx]
        mask = self.masks[idx]
        return image, mask


# 데이터 예제 (더미 데이터)
images = torch.randn(100, 1, 11, 256, 256)  # 100개의 샘플
masks = torch.randint(0, 6, (100, 256, 256))  # 6개 클래스 라벨
dataset = CryoETDataset(images, masks)
dataloader = DataLoader(dataset, batch_size=8, shuffle=True)


In [None]:

# 모델 초기화
model = UNet2_5D(in_channels=1, out_channels=6, depth_slices=11).to('cuda')
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)
criterion = nn.CrossEntropyLoss()

# 학습 루프
for epoch in range(10):  # 10 epochs
    model.train()
    for images, masks in dataloader:
        images, masks = images.to('cuda'), masks.to('cuda')
        optimizer.zero_grad()
        outputs = model(images)
        loss = criterion(outputs, masks)
        loss.backward()
        optimizer.step()

    print(f"Epoch [{epoch+1}/10], Loss: {loss.item():.4f}")
