In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F

class DoubleConv(nn.Module):
    """2D Double Convolution Block"""
    def __init__(self, in_channels, out_channels):
        super(DoubleConv, self).__init__()
        self.conv = nn.Sequential(
            nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True),
            nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True)
        )

    def forward(self, x):
        return self.conv(x)


class Encoder(nn.Module):
    """U-Net Encoder"""
    def __init__(self, in_channels, features):
        super(Encoder, self).__init__()
        self.pool = nn.MaxPool2d(kernel_size=2, stride=2)
        self.conv = DoubleConv(in_channels, features)

    def forward(self, x):
        x = self.conv(x)
        skip = x
        x = self.pool(x)
        return x, skip


class Decoder(nn.Module):
    """U-Net Decoder"""
    def __init__(self, in_channels, skip_channels, out_channels):
        """
        Args:
            in_channels: Up-sampled feature channels
            skip_channels: Skip connection feature channels
            out_channels: Output feature channels after concatenation
        """
        super(Decoder, self).__init__()
        self.upconv = nn.ConvTranspose2d(in_channels, out_channels, kernel_size=2, stride=2)
        self.conv = DoubleConv(out_channels + skip_channels, out_channels)  # Concatenated channels

    def forward(self, x, skip):
        # Up-sample
        x = self.upconv(x)
        # Ensure spatial dimensions match
        if x.shape[-2:] != skip.shape[-2:]:
            diffY = skip.size(2) - x.size(2)
            diffX = skip.size(3) - x.size(3)
            x = F.pad(x, [diffX // 2, diffX - diffX // 2, diffY // 2, diffY - diffY // 2])
        # Concatenate with skip connection
        x = torch.cat((x, skip), dim=1)
        # Apply convolution
        x = self.conv(x)
        return x


class UNet2_5D(nn.Module):
    """2.5D U-Net"""
    def __init__(self, in_channels, out_channels, depth_slices=11, features=[64, 128, 256, 512]):
        super(UNet2_5D, self).__init__()
        self.depth_slices = depth_slices

        # Initial 3D Convolution to merge depth slices
        self.init_conv3d = nn.Conv3d(in_channels, features[0], kernel_size=(depth_slices, 3, 3), padding=(0, 1, 1))
        self.init_bn3d = nn.BatchNorm3d(features[0])

        # U-Net Encoder
        self.encoders = nn.ModuleList()
        for i, feature in enumerate(features):
            if i == 0:
                self.encoders.append(Encoder(features[i], features[i]))
            else:
                self.encoders.append(Encoder(features[i - 1], features[i]))
        
        # Bottleneck
        self.bottleneck = DoubleConv(features[-1], features[-1] * 2)

        # U-Net Decoder
        self.decoders = nn.ModuleList()
        reversed_features = list(reversed(features))
        for i in range(len(reversed_features) - 1):
            self.decoders.append(
                Decoder(
                    in_channels=reversed_features[i] * 2,  # From previous layer
                    skip_channels=reversed_features[i + 1],  # From encoder skip connection
                    out_channels=reversed_features[i]
                )
            )
        self.decoders.append(
            Decoder(
                in_channels=reversed_features[-1] * 2,  # Last decoder input
                skip_channels=reversed_features[-1] // 2,  # Last skip connection
                out_channels=reversed_features[-1] // 2
            )
        )

        # Final Convolution
        self.final_conv = nn.Conv2d(reversed_features[-1] // 2, out_channels, kernel_size=1)

    def forward(self, x):
        # 3D Convolution for depth slices
        x = F.relu(self.init_bn3d(self.init_conv3d(x)))  # B x F x H x W
        x = x.squeeze(2)  # Collapse depth axis for 2D U-Net: B x F x H x W

        # Encoder path
        skips = []
        for encoder in self.encoders:
            x, skip = encoder(x)
            skips.append(skip)

        # Bottleneck
        x = self.bottleneck(x)

        # Decoder path
        for decoder, skip in zip(self.decoders, reversed(skips)):
            x = decoder(x, skip)

        # Final Convolution
        return self.final_conv(x)



In [2]:
import torch
from torch.utils.data import Dataset, DataLoader

class CryoETDataset(Dataset):
    def __init__(self, images, masks):
        self.images = images  # Shape: (N, 1, 11, H, W)
        self.masks = masks    # Shape: (N, H, W) with class labels

    def __len__(self):
        return len(self.images)

    def __getitem__(self, idx):
        image = self.images[idx]
        mask = self.masks[idx]
        return image, mask


# 데이터 예제 (더미 데이터)
images = torch.randn(100, 1, 11, 256, 256)  # 100개의 샘플
masks = torch.randint(0, 6, (100, 256, 256))  # 6개 클래스 라벨
dataset = CryoETDataset(images, masks)
dataloader = DataLoader(dataset, batch_size=8, shuffle=True)


In [3]:

# 모델 초기화
model = UNet2_5D(in_channels=1, out_channels=6, depth_slices=11).to('cuda')
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)
criterion = nn.CrossEntropyLoss()

# 학습 루프
for epoch in range(10):  # 10 epochs
    model.train()
    for images, masks in dataloader:
        images, masks = images.to('cuda'), masks.to('cuda')
        optimizer.zero_grad()
        outputs = model(images)
        loss = criterion(outputs, masks)
        loss.backward()
        optimizer.step()

    print(f"Epoch [{epoch+1}/10], Loss: {loss.item():.4f}")


RuntimeError: Given groups=1, weight of size [512, 768, 3, 3], expected input[8, 1024, 32, 32] to have 768 channels, but got 1024 channels instead