In [2]:
# Define a basic convolutional layer
import torch.nn as nn
import torch.nn.functional as F

# Define a basic convolutional layer
def conv3x3(in_channels, out_channels, stride=1):
    return nn.Conv2d(in_channels, out_channels, kernel_size=3, 
                     stride=stride, padding=1, bias=False)

# Define the ResNet model without using ConvBlock
class ResNet(nn.Module):
    def __init__(self, num_classes=10):
        super(ResNet, self).__init__()
        self.conv = conv3x3(3, 16)
        self.bn = nn.BatchNorm2d(16)
        
        # Layer 1
        self.layer1_conv1 = conv3x3(16, 16)
        self.layer1_bn1 = nn.BatchNorm2d(16)
        self.layer1_conv2 = conv3x3(16, 16)
        self.layer1_bn2 = nn.BatchNorm2d(16)
        self.layer1_downsample = nn.Sequential(
            nn.Conv2d(16, 16, kernel_size=1, stride=1, bias=False),
            nn.BatchNorm2d(16)
        )
        self.layer1_extra_conv1 = conv3x3(16, 16)
        self.layer1_extra_bn1 = nn.BatchNorm2d(16)
        self.layer1_extra_conv2 = conv3x3(16, 16)
        self.layer1_extra_bn2 = nn.BatchNorm2d(16)
        
        # Layer 2
        self.layer2_conv1 = conv3x3(16, 32, stride=2)
        self.layer2_bn1 = nn.BatchNorm2d(32)
        self.layer2_conv2 = conv3x3(32, 32)
        self.layer2_bn2 = nn.BatchNorm2d(32)
        self.layer2_downsample = nn.Sequential(
            nn.Conv2d(16, 32, kernel_size=1, stride=2, bias=False),
            nn.BatchNorm2d(32)
        )
        self.layer2_extra_conv1 = conv3x3(32, 32)
        self.layer2_extra_bn1 = nn.BatchNorm2d(32)
        self.layer2_extra_conv2 = conv3x3(32, 32)
        self.layer2_extra_bn2 = nn.BatchNorm2d(32)
        
        # Layer 3
        self.layer3_conv1 = conv3x3(32, 64, stride=2)
        self.layer3_bn1 = nn.BatchNorm2d(64)
        self.layer3_conv2 = conv3x3(64, 64)
        self.layer3_bn2 = nn.BatchNorm2d(64)
        self.layer3_downsample = nn.Sequential(
            nn.Conv2d(32, 64, kernel_size=1, stride=2, bias=False),
            nn.BatchNorm2d(64)
        )
        self.layer3_extra_conv1 = conv3x3(64, 64)
        self.layer3_extra_bn1 = nn.BatchNorm2d(64)
        self.layer3_extra_conv2 = conv3x3(64, 64)
        self.layer3_extra_bn2 = nn.BatchNorm2d(64)
        
        self.avg_pool = nn.AvgPool2d(8)
        self.fc = nn.Linear(64, num_classes)

    def forward(self, x):
        out = self.conv(x)
        out = self.bn(out)
        out = F.relu(out)
        
        # Layer 1
        residual = out
        out = self.layer1_conv1(out)
        out = self.layer1_bn1(out)
        out = F.relu(out)
        out = self.layer1_conv2(out)
        out = self.layer1_bn2(out)
        residual = self.layer1_downsample(residual)
        out += residual
        out = F.relu(out)
        # Extra layers
        residual = out
        out = self.layer1_extra_conv1(out)
        out = self.layer1_extra_bn1(out)
        out = F.relu(out)
        out = self.layer1_extra_conv2(out)
        out = self.layer1_extra_bn2(out)
        residual = self.layer1_downsample(residual)
        out += residual
        out = F.relu(out)
        
        # Layer 2
        residual = out
        out = self.layer2_conv1(out)
        out = self.layer2_bn1(out)
        out = F.relu(out)
        out = self.layer2_conv2(out)
        out = self.layer2_bn2(out)
        residual = self.layer2_downsample(residual)
        out += residual
        out = F.relu(out)
        # Extra layers
        residual = out
        out = self.layer2_extra_conv1(out)
        out = self.layer2_extra_bn1(out)
        out = F.relu(out)
        out = self.layer2_extra_conv2(out)
        out = self.layer2_extra_bn2(out)
        residual = self.layer2_downsample(residual)
        out += residual
        out = F.relu(out)
        
        # Layer 3
        residual = out
        out = self.layer3_conv1(out)
        out = self.layer3_bn1(out)
        out = F.relu(out)
        out = self.layer3_conv2(out)
        out = self.layer3_bn2(out)
        residual = self.layer3_downsample(residual)
        out += residual
        out = F.relu(out)
        # Extra layers
        residual = out
        out = self.layer3_extra_conv1(out)
        out = self.layer3_extra_bn1(out)
        out = F.relu(out)
        out = self.layer3_extra_conv2(out)
        out = self.layer3_extra_bn2(out)
        residual = self.layer3_downsample(residual)
        out += residual
        out = F.relu(out)
        
        out = self.avg_pool(out)
        out = out.view(out.size(0), -1)
        out = self.fc(out)
        return out

# Instantiate the ResNet model
model = ResNet()
print(model)



ResNet(
  (conv): Conv2d(3, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
  (bn): BatchNorm2d(16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (layer1_conv1): Conv2d(16, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
  (layer1_bn1): BatchNorm2d(16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (layer1_conv2): Conv2d(16, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
  (layer1_bn2): BatchNorm2d(16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (layer1_downsample): Sequential(
    (0): Conv2d(16, 16, kernel_size=(1, 1), stride=(1, 1), bias=False)
    (1): BatchNorm2d(16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  )
  (layer1_extra_conv1): Conv2d(16, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
  (layer1_extra_bn1): BatchNorm2d(16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (layer1_extra_conv2): Conv2d