In [5]:
import torch
import torch.nn as nn
import torch.nn.functional as F

def get_activation(activation_type):
    """Helper function to get activation layer"""
    if activation_type.lower() == 'elu':
        return nn.ELU()
    elif activation_type.lower() == 'relu':
        return nn.ReLU()
    elif activation_type.lower() == 'leaky_relu':
        return nn.LeakyReLU()
    elif activation_type.lower() == 'sigmoid':
        return nn.Sigmoid()
    elif activation_type.lower() == 'tanh':
        return nn.Tanh()
    else:
        raise ValueError(f"Unsupported activation type: {activation_type}")

class AttentionBlock(nn.Module):
    def __init__(self, F_g, F_l, F_int, batch_norm=False):
        super(AttentionBlock, self).__init__()
        self.W_g = nn.Sequential(
            nn.Conv2d(F_g, F_int, kernel_size=1, stride=1, padding=0, bias=True),
            nn.BatchNorm2d(F_int) if batch_norm else nn.Identity()
        )

        self.W_x = nn.Sequential(
            nn.Conv2d(F_l, F_int, kernel_size=1, stride=1, padding=0, bias=True),
            nn.BatchNorm2d(F_int) if batch_norm else nn.Identity()
        )

        self.psi = nn.Sequential(
            nn.Conv2d(F_int, 1, kernel_size=1, stride=1, padding=0, bias=True),
            nn.BatchNorm2d(1) if batch_norm else nn.Identity(),
            nn.Sigmoid()
        )

        self.relu = nn.ReLU(inplace=True)

    def forward(self, g, x):
        g1 = self.W_g(g)
        x1 = self.W_x(x)
        psi = self.relu(g1 + x1)
        psi = self.psi(psi)
        return x * psi

class Att_YNet(nn.Module):
    def __init__(self, image_shape, activation='elu', feature_maps=[16, 32, 64, 128, 256],
                 drop_values=[0.1, 0.1, 0.2, 0.2, 0.3], spatial_dropout=False, batch_norm=False,
                 n_classes=1):
        super(Att_YNet, self).__init__()

        self.depth = len(feature_maps) - 1
        self.activation = get_activation(activation)  # 使用辅助函数获取激活层
        self.spatial_dropout = spatial_dropout
        self.batch_norm = batch_norm

        # Encoder
        self.encoder_blocks = nn.ModuleList()
        in_channels = image_shape[0]  # 使用image_shape的第一个维度作为输入通道
        for i in range(self.depth):
            block = nn.Sequential(
                nn.Conv2d(in_channels if i == 0 else feature_maps[i-1], feature_maps[i], 3, padding=1),
                nn.BatchNorm2d(feature_maps[i]) if batch_norm else nn.Identity(),
                self.activation,
                nn.Dropout2d(drop_values[i]) if spatial_dropout else nn.Dropout(drop_values[i]),
                nn.Conv2d(feature_maps[i], feature_maps[i], 3, padding=1),
                nn.BatchNorm2d(feature_maps[i]) if batch_norm else nn.Identity(),
                self.activation
            )
            self.encoder_blocks.append(block)

        # Bottleneck
        self.bottleneck = nn.Sequential(
            nn.Conv2d(feature_maps[-2], feature_maps[-1], 3, padding=1),
            nn.BatchNorm2d(feature_maps[-1]) if batch_norm else nn.Identity(),
            self.activation,
            nn.Dropout2d(drop_values[-1]) if spatial_dropout else nn.Dropout(drop_values[-1]),
            nn.Conv2d(feature_maps[-1], feature_maps[-1], 3, padding=1),
            nn.BatchNorm2d(feature_maps[-1]) if batch_norm else nn.Identity(),
            self.activation
        )

        # Decoder (UNet)
        self.unet_decoder_blocks = nn.ModuleList()
        self.attention_blocks = nn.ModuleList()
        for i in range(self.depth-1, -1, -1):
            self.attention_blocks.append(AttentionBlock(feature_maps[i], feature_maps[i], feature_maps[i]//2, batch_norm))
            block = nn.Sequential(
                nn.Conv2d(feature_maps[i]*2, feature_maps[i], 3, padding=1),
                nn.BatchNorm2d(feature_maps[i]) if batch_norm else nn.Identity(),
                self.activation,
                nn.Dropout2d(drop_values[i]) if spatial_dropout else nn.Dropout(drop_values[i]),
                nn.Conv2d(feature_maps[i], feature_maps[i], 3, padding=1),
                nn.BatchNorm2d(feature_maps[i]) if batch_norm else nn.Identity(),
                self.activation
            )
            self.unet_decoder_blocks.append(block)

        # Decoder (AutoEncoder)
        self.ae_decoder_blocks = nn.ModuleList()
        for i in range(self.depth-1, -1, -1):
            block = nn.Sequential(
                nn.Conv2d(feature_maps[i+1] if i != self.depth-1 else feature_maps[-1], feature_maps[i], 3, padding=1),
                nn.BatchNorm2d(feature_maps[i]) if batch_norm else nn.Identity(),
                self.activation,
                nn.Dropout2d(drop_values[i]) if spatial_dropout else nn.Dropout(drop_values[i]),
                nn.Conv2d(feature_maps[i], feature_maps[i], 3, padding=1),
                nn.BatchNorm2d(feature_maps[i]) if batch_norm else nn.Identity(),
                self.activation
            )
            self.ae_decoder_blocks.append(block)

        self.final_conv_mask = nn.Conv2d(feature_maps[0], n_classes, 1)
        self.final_conv_img = nn.Conv2d(feature_maps[0], image_shape[0], 1)  # 输出通道数应与输入图像通道数相同

    def forward(self, x):
        # Encoder
        encoder_features = []
        for block in self.encoder_blocks:
            x = block(x)
            encoder_features.append(x)
            x = F.max_pool2d(x, 2)

        # Bottleneck
        x = self.bottleneck(x)

        # UNet Decoder
        unet_x = x
        for i, block in enumerate(self.unet_decoder_blocks):
            unet_x = F.interpolate(unet_x, scale_factor=2, mode='bilinear', align_corners=True)
            attn = self.attention_blocks[i](unet_x, encoder_features[-(i+1)])
            unet_x = torch.cat([unet_x, attn], dim=1)
            unet_x = block(unet_x)

        # AutoEncoder Decoder
        ae_x = x
        for block in self.ae_decoder_blocks:
            ae_x = F.interpolate(ae_x, scale_factor=2, mode='bilinear', align_corners=True)
            ae_x = block(ae_x)

        mask = torch.sigmoid(self.final_conv_mask(unet_x))
        img = self.final_conv_img(ae_x)

        return img, mask


In [6]:
# 测试代码
image_shape = (3, 256, 256)  # (C, H, W)
model = Att_YNet(image_shape)
x = torch.randn(1, *image_shape)
img, mask = model(x)
print(f"Image output shape: {img.shape}")
print(f"Mask output shape: {mask.shape}")

RuntimeError: Given groups=1, weight of size [64, 128, 1, 1], expected input[1, 256, 32, 32] to have 128 channels, but got 256 channels instead