In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F

class ConvBlock(nn.Module):
    """Convolution block: Conv2d + BatchNorm + ReLU"""
    def __init__(self, in_channels, out_channels, kernel_size=3, padding=1):
        super().__init__()
        self.conv = nn.Sequential(
            nn.Conv2d(in_channels, out_channels, kernel_size, padding=padding),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True)
        )
    def forward(self, x):
        return self.conv(x)

class ResidualBlock(nn.Module):
    """Residual block with 2 conv blocks + dropout + skip connection"""
    def __init__(self, in_channels, out_channels, dropout=0.3):
        super().__init__()
        self.conv1 = ConvBlock(in_channels, out_channels)
        self.conv2 = ConvBlock(out_channels, out_channels)
        self.dropout = nn.Dropout2d(p=dropout)
        # 1x1 conv to match channels if needed
        self.skip_conv = nn.Conv2d(in_channels, out_channels, kernel_size=1) if in_channels != out_channels else nn.Identity()

    def forward(self, x):
        identity = self.skip_conv(x)
        out = self.conv1(x)
        out = self.dropout(out)
        out = self.conv2(out)
        out += identity
        return F.relu(out)

class AttentionBlock(nn.Module):
    """Attention block for gating skip connections"""
    def __init__(self, F_g, F_l, F_int):
        super().__init__()
        self.W_g = nn.Sequential(
            nn.Conv2d(F_g, F_int, kernel_size=1),
            nn.BatchNorm2d(F_int)
        )
        self.W_x = nn.Sequential(
            nn.Conv2d(F_l, F_int, kernel_size=1),
            nn.BatchNorm2d(F_int)
        )
        self.psi = nn.Sequential(
            nn.Conv2d(F_int, 1, kernel_size=1),
            nn.BatchNorm2d(1),
            nn.Sigmoid()
        )
        self.relu = nn.ReLU(inplace=True)

    def forward(self, g, x):
        # g: gating signal (decoder feature)
        # x: skip connection (encoder feature)
        g1 = self.W_g(g)
        x1 = self.W_x(x)
        psi = self.relu(g1 + x1)
        psi = self.psi(psi)
        return x * psi

class AER_UNet(nn.Module):
    def __init__(self, in_channels=3, out_channels=1):
        super().__init__()

        # Encoder filter sizes
        self.filters = [32, 64, 128]

        # Encoder
        self.encoder1 = ResidualBlock(in_channels, self.filters[0])
        self.pool1 = nn.MaxPool2d(kernel_size=2, stride=2)

        self.encoder2 = ResidualBlock(self.filters[0], self.filters[1])
        self.pool2 = nn.MaxPool2d(kernel_size=2, stride=2)

        self.encoder3 = ResidualBlock(self.filters[1], self.filters[2])
        self.pool3 = nn.MaxPool2d(kernel_size=2, stride=2)

        # Bottleneck
        self.bottleneck = ResidualBlock(self.filters[2], 256, dropout=0.3)

        # Decoder
        self.up3 = nn.ConvTranspose2d(256, self.filters[2], kernel_size=2, stride=2)
        self.att3 = AttentionBlock(F_g=self.filters[2], F_l=self.filters[2], F_int=self.filters[2]//2)
        self.decoder3 = ResidualBlock(self.filters[2]*2, self.filters[2])

        self.up2 = nn.ConvTranspose2d(self.filters[2], self.filters[1], kernel_size=2, stride=2)
        self.att2 = AttentionBlock(F_g=self.filters[1], F_l=self.filters[1], F_int=self.filters[1]//2)
        self.decoder2 = ResidualBlock(self.filters[1]*2, self.filters[1])

        self.up1 = nn.ConvTranspose2d(self.filters[1], self.filters[0], kernel_size=2, stride=2)
        self.att1 = AttentionBlock(F_g=self.filters[0], F_l=self.filters[0], F_int=self.filters[0]//2)
        self.decoder1 = ResidualBlock(self.filters[0]*2, self.filters[0])

        # Output
        self.final_conv = nn.Conv2d(self.filters[0], out_channels, kernel_size=1)
        self.sigmoid = nn.Sigmoid()

    def forward(self, x):
        # Encoder
        e1 = self.encoder1(x)
        p1 = self.pool1(e1)

        e2 = self.encoder2(p1)
        p2 = self.pool2(e2)

        e3 = self.encoder3(p2)
        p3 = self.pool3(e3)

        # Bottleneck
        b = self.bottleneck(p3)

        # Decoder level 3
        d3 = self.up3(b)
        e3_att = self.att3(g=d3, x=e3)
        d3 = torch.cat((e3_att, d3), dim=1)
        d3 = self.decoder3(d3)

        # Decoder level 2
        d2 = self.up2(d3)
        e2_att = self.att2(g=d2, x=e2)
        d2 = torch.cat((e2_att, d2), dim=1)
        d2 = self.decoder2(d2)

        # Decoder level 1
        d1 = self.up1(d2)
        e1_att = self.att1(g=d1, x=e1)
        d1 = torch.cat((e1_att, d1), dim=1)
        d1 = self.decoder1(d1)

        # Output
        out = self.final_conv(d1)
        return self.sigmoid(out)

In [None]:
# Exemple for the optimizer with L2 regularization (weight decay)

model = AER_UNet(in_channels=3, out_channels=1).cuda()
optimizer = torch.optim.Adam(model.parameters(), lr=1e-4, weight_decay=1e-4)