In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F

In [None]:
class SEBlock(nn.Module):
    def __init__(self, in_channels, reduction=16):
        super(SEBlock, self).__init__()
        self.in_channels = in_channels
        self.reduction = reduction
        
        # Squeeze operation: global average pooling
        self.avg_pool = nn.AdaptiveAvgPool2d(1)
        
        # Excitation operation: fully connected layers (bottleneck structure)
        self.fc1 = nn.Linear(in_channels, in_channels // reduction, bias=False)
        self.fc2 = nn.Linear(in_channels // reduction, in_channels, bias=False)
        
    def forward(self, x):
        # Squeeze: global average pooling
        b, c, _, _ = x.size()
        squeeze = self.avg_pool(x).view(b, c)
        
        # Excitation: fully connected layers
        excitation = F.relu(self.fc1(squeeze))
        excitation = torch.sigmoid(self.fc2(excitation)).view(b, c, 1, 1)
        
        # Scale the input feature map with the excitation values
        return x * excitation.expand_as(x)

# Memory Module
class MemoryModule(nn.Module):
    def __init__(self, in_channels, memory_size=1):
        super(MemoryModule, self).__init__()
        self.memory_size = memory_size
        self.memory = None

    def forward(self, x):
        # Store the important features from the current input
        if self.memory is None:
            self.memory = x
        else:
            self.memory = torch.cat([self.memory, x], dim=0)
            if self.memory.size(0) > self.memory_size:
                self.memory = self.memory[1:]

        return self.memory.mean(dim=0)