In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F

In [None]:
class SEBlock(nn.Module):
    def __init__(self, in_channels, reduction=16):
        super(SEBlock, self).__init__()
        self.in_channels = in_channels
        self.reduction = reduction
        
        # Squeeze operation: global average pooling
        self.avg_pool = nn.AdaptiveAvgPool2d(1)
        
        # Excitation operation: fully connected layers (bottleneck structure)
        self.fc1 = nn.Linear(in_channels, in_channels // reduction, bias=False)
        self.fc2 = nn.Linear(in_channels // reduction, in_channels, bias=False)
        
    def forward(self, x):
        # Squeeze: global average pooling
        b, c, _, _ = x.size()
        squeeze = self.avg_pool(x).view(b, c)
        
        # Excitation: fully connected layers
        excitation = F.relu(self.fc1(squeeze))
        excitation = torch.sigmoid(self.fc2(excitation)).view(b, c, 1, 1)
        
        # Scale the input feature map with the excitation values
        return x * excitation.expand_as(x)

# Memory Module
class MemoryModule(nn.Module):
    def __init__(self, in_channels, memory_size=1):
        super(MemoryModule, self).__init__()
        self.memory_size = memory_size
        self.memory = None

    def forward(self, x):
        # Store the important features from the current input
        if self.memory is None:
            self.memory = x
        else:
            self.memory = torch.cat([self.memory, x], dim=0)
            if self.memory.size(0) > self.memory_size:
                self.memory = self.memory[1:]

        return self.memory.mean(dim=0)

In [None]:
class SReNBlock(nn.Module):
    def __init__(self, in_channels, out_channels, memory_size=1, reduction=16):
        super(SReNBlock, self).__init__()
        
        self.conv = nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1)
        self.se_block = SEBlock(out_channels, reduction)
        self.memory_module = MemoryModule(out_channels, memory_size)
        self.relu = nn.ReLU(inplace=True)

    def forward(self, x):
        # Convolution layer
        x = self.conv(x)
        
        # Apply SE block (squeeze-and-excitation)
        x = self.se_block(x)
        
        # Apply memory module (remember important features)
        memory = self.memory_module(x)
        
        # Combine current features with memory (e.g., by adding or concatenating)
        x = x + memory
        
        # Apply ReLU activation
        return self.relu(x)

In [None]:
class SReNNetwork(nn.Module):
    def __init__(self, in_channels, num_classes, memory_size=3):
        super(SReNNetwork, self).__init__()
        
        self.conv1 = SReNBlock(in_channels, 64, memory_size)
        self.conv2 = SReNBlock(64, 128, memory_size)
        self.conv3 = SReNBlock(128, 256, memory_size)
        
        self.fc = nn.Linear(256, num_classes)
    
    def forward(self, x):
        # Pass through SReN blocks
        x = self.conv1(x)
        x = self.conv2(x)
        x = self.conv3(x)
        
        # Global Average Pooling
        x = F.adaptive_avg_pool2d(x, (1, 1))
        x = torch.flatten(x, 1)
        
        # Final Fully Connected Layer
        x = self.fc(x)
        
        return x