In [None]:
!pip install einops
!pip install timm
!pip install thop
!pip install mmcv==1.4.0

In [17]:
from model.SUNet_detail import *

In [24]:
img = torch.rand(4, 3, 256, 256)
img.shape

torch.Size([4, 3, 256, 256])

In [29]:
conv = torch.nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1)
feature_map = conv(img)  # Shape: (4, 64, 256, 256)

context_block = ContextBlock(
    inplanes=64,
    ratio=0.25,
    pooling_type='att',
    fusion_types=('channel_add', 'channel_mul')
)
output = context_block(feature_map)
print("Input shape:", feature_map.shape)
print("Output shape:", output.shape)

Input shape: torch.Size([4, 64, 256, 256])
Output shape: torch.Size([4, 64, 256, 256])


In [30]:
class StridedConvolutionDownsampling(nn.Module):
    def __init__(self, in_channels, out_channels, norm_layer=nn.BatchNorm2d):
        super().__init__()
        self.conv_down = nn.Conv2d(
            in_channels, out_channels, kernel_size=3, stride=2, padding=1
        )
        self.norm = norm_layer(out_channels)
        self.activation = nn.ReLU(inplace=True)

    def forward(self, x):
        x = self.conv_down(x)
        x = self.norm(x)
        x = self.activation(x)
        return x

In [38]:
x = output
downsample = StridedConvolutionDownsampling(64, 128)
x_downsampled = downsample(x)
print("x shape:", x.shape)
print("x_downsampled shape:", x_downsampled.shape)


x shape: torch.Size([4, 64, 256, 256])
x_downsampled shape: torch.Size([4, 128, 128, 128])


In [33]:
class AttentionDownsampling(nn.Module):
    def __init__(self, in_channels, out_channels, kernel_size=3, stride=2, padding=1):
        super(AttentionDownsampling, self).__init__()
        # Channel attention mechanism
        self.channel_attention = nn.Sequential(
            nn.AdaptiveAvgPool2d(1),  # Global average pooling
            nn.Conv2d(in_channels, in_channels // 4, 1),
            nn.ReLU(inplace=True),
            nn.Conv2d(in_channels // 4, in_channels, 1),
            nn.Sigmoid()
        )
        self.conv = nn.Conv2d(
            in_channels, out_channels, kernel_size=kernel_size, stride=stride, padding=padding
        )

        # Residual connection
        self.residual = nn.Sequential(
            nn.Conv2d(
                in_channels,
                out_channels,
                kernel_size=1,
                stride=stride
            ),
            nn.BatchNorm2d(out_channels)
        ) if in_channels != out_channels else nn.Identity()

    def forward(self, x):
        attention_map = self.channel_attention(x)  # (B, 1, H, W)
        x_attended = x * attention_map
        x_downsampled = self.conv(x_attended)
        res = self.residual(x)
        return F.relu(x_downsampled + res)

In [39]:
output.shape

torch.Size([4, 64, 256, 256])

In [41]:
import torch.nn.functional as F
x = output
downsample = AttentionDownsampling(64, 128)
x_downsampled = downsample(x)
print("x shape:", x.shape)
print("x_downsampled shape:", x_downsampled.shape)

x shape: torch.Size([4, 64, 256, 256])
x_downsampled shape: torch.Size([4, 128, 128, 128])
