In [None]:
## Imports
import torch
import torch.nn as nn

In [4]:
# 1. Original Residual Block (Standard ResNet)
class OriginalResidualBlock(nn.Module):
    """Original ResNet design with BatchNorm and ReLU applied after each convolution."""
    def __init__(self, in_channels, out_channels, stride=1, downsample=None):
        super(OriginalResidualBlock, self).__init__()
        self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=stride, padding=1, bias=False)
        self.bn1 = nn.BatchNorm2d(out_channels)
        self.relu = nn.ReLU(inplace=True)
        self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1, bias=False)
        self.bn2 = nn.BatchNorm2d(out_channels)
        self.downsample = downsample

    def forward(self, x):
        identity = x
        if self.downsample is not None:
            identity = self.downsample(x)
        out = self.conv1(x)
        out = self.bn1(out)
        out = self.relu(out)
        out = self.conv2(out)
        out = self.bn2(out)
        out += identity
        out = self.relu(out)
        return out

In [5]:
# 2. BatchNorm After Addition
class BatchNormAfterAdditionResidualBlock(nn.Module):
    """Batch Normalization is applied after the residual addition instead of before the activation."""
    def __init__(self, in_channels, out_channels, stride=1, downsample=None):
        super(BatchNormAfterAdditionResidualBlock, self).__init__()
        self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=stride, padding=1, bias=False)
        self.bn1 = nn.BatchNorm2d(out_channels)
        self.relu = nn.ReLU(inplace=True)
        self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1, bias=False)
        self.bn2 = nn.BatchNorm2d(out_channels)
        self.bn_out = nn.BatchNorm2d(out_channels)
        self.downsample = downsample

    def forward(self, x):
        identity = x
        if self.downsample is not None:
            identity = self.downsample(x)
        out = self.conv1(x)
        out = self.bn1(out)
        out = self.relu(out)
        out = self.conv2(out)
        out = self.bn2(out)
        out += identity
        out = self.bn_out(out)
        out = self.relu(out)
        return out

In [6]:
# 3. ReLU Before Addition
class ReLUBeforeAdditionResidualBlock(nn.Module):
    """ReLU is applied before the residual addition instead of after."""
    def __init__(self, in_channels, out_channels, stride=1, downsample=None):
        super(ReLUBeforeAdditionResidualBlock, self).__init__()
        self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=stride, padding=1, bias=False)
        self.bn1 = nn.BatchNorm2d(out_channels)
        self.relu = nn.ReLU(inplace=True)
        self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1, bias=False)
        self.bn2 = nn.BatchNorm2d(out_channels)
        self.downsample = downsample

    def forward(self, x):
        identity = x
        if self.downsample is not None:
            identity = self.downsample(x)
        out = self.conv1(x)
        out = self.bn1(out)
        out = self.relu(out)
        out = self.conv2(out)
        out = self.bn2(out)
        out = self.relu(out)
        out += identity
        return out

In [7]:
# 4. ReLU-Only Pre-activation
class ReLUOnlyPreActivationResidualBlock(nn.Module):
    """ReLU is applied before both convolutions without BatchNorm pre-activation."""
    def __init__(self, in_channels, out_channels, stride=1, downsample=None):
        super(ReLUOnlyPreActivationResidualBlock, self).__init__()
        self.relu = nn.ReLU(inplace=True)
        self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=stride, padding=1, bias=False)
        self.bn1 = nn.BatchNorm2d(out_channels)
        self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1, bias=False)
        self.bn2 = nn.BatchNorm2d(out_channels)
        self.downsample = downsample

    def forward(self, x):
        identity = x
        x = self.relu(x)
        if self.downsample is not None:
            identity = self.downsample(x)
        out = self.conv1(x)
        out = self.bn1(out)
        out = self.conv2(out)
        out = self.bn2(out)
        out += identity
        return out

In [8]:
# 5. Full Pre-Activation Residual Block
class FullPreActivationResidualBlock(nn.Module):
    """Both BatchNorm and ReLU are applied before each convolution."""
    def __init__(self, in_channels, out_channels, stride=1, downsample=None):
        super(FullPreActivationResidualBlock, self).__init__()
        self.bn1 = nn.BatchNorm2d(in_channels)
        self.relu1 = nn.ReLU(inplace=True)
        self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=stride, padding=1, bias=False)
        self.bn2 = nn.BatchNorm2d(out_channels)
        self.relu2 = nn.ReLU(inplace=True)
        self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1, bias=False)
        self.downsample = downsample

    def forward(self, x):
        identity = x
        out = self.bn1(x)
        out = self.relu1(out)
        if self.downsample is not None:
            identity = self.downsample(out)
        out = self.conv1(out)
        out = self.bn2(out)
        out = self.relu2(out)
        out = self.conv2(out)
        out += identity
        return out


In [9]:
# Example Usage
example_input = torch.randn(1, 64, 32, 32)
blocks = [OriginalResidualBlock(64, 64), BatchNormAfterAdditionResidualBlock(64, 64),
          ReLUBeforeAdditionResidualBlock(64, 64), ReLUOnlyPreActivationResidualBlock(64, 64),
          FullPreActivationResidualBlock(64, 64)]
block_names = ["Original", "Batch-Norm-After-Addition", "ReLU-Before-Addition", "ReLU-nly-Pre-Activation", "Full-Pre-Activation"]

for i, block in enumerate(blocks):
    output = block(example_input)
    print(f"{block_names[i]} Output Shape: {output.shape}")

Original Output Shape: torch.Size([1, 64, 32, 32])
Batch-Norm-After-Addition Output Shape: torch.Size([1, 64, 32, 32])
ReLU-Before-Addition Output Shape: torch.Size([1, 64, 32, 32])
ReLU-nly-Pre-Activation Output Shape: torch.Size([1, 64, 32, 32])
Full-Pre-Activation Output Shape: torch.Size([1, 64, 32, 32])
