In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F

In [None]:
class SEBlock(nn.Module):
    def __init__(self, in_channels, reduction=16):
        super(SEBlock, self).__init__()
        # Adaptive Average Pooling to squeeze channel-wise statistics
        self.avg_pool = nn.AdaptiveAvgPool2d(1)
        # Excitation: two fully connected layers (1x1 convolutions)
        self.fc1 = nn.Conv2d(in_channels, in_channels // reduction, kernel_size=1, bias=False)
        self.fc2 = nn.Conv2d(in_channels // reduction, in_channels, kernel_size=1, bias=False)
        # Sigmoid activation to generate the attention weights
        self.sigmoid = nn.Sigmoid()

    def forward(self, x):
        # Squeeze: global average pooling
        squeeze = self.avg_pool(x)
        # Excitation: pass through two fully connected layers (1x1 convolutions)
        excitation = self.fc1(squeeze)
        excitation = F.relu(excitation)
        excitation = self.fc2(excitation)
        # Apply sigmoid to scale the output between 0 and 1 (channel-wise attention)
        excitation = self.sigmoid(excitation)
        return x * excitation