In [44]:
## Imports
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
import torchvision.transforms as transforms
import random
from torchsummary import summary
import numpy as np

In [41]:
# Define Stochastic Depth Residual Block
# This block is a standard residual block but is used within a stochastic depth framework
class StochasticDepthResidualBlock(nn.Module):
    def __init__(self, in_channels, out_channels, stride=1, downsample=None):
        super(StochasticDepthResidualBlock, self).__init__()
        self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=stride, padding=1, bias=False)
        self.bn1 = nn.BatchNorm2d(out_channels)
        self.relu = nn.ReLU(inplace=True)
        self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1, bias=False)
        self.bn2 = nn.BatchNorm2d(out_channels)
        self.downsample = downsample

    def forward(self, x):
        identity = x
        out = self.conv1(x)
        out = self.bn1(out)
        out = self.relu(out)
        out = self.conv2(out)
        out = self.bn2(out)
        if self.downsample is not None:
            identity = self.downsample(x)
        out += identity
        out = self.relu(out)
        return out

In [42]:
# Define ResNet with Stochastic Depth
# The model consists of 6 residual blocks, but only 4 of them are selected randomly per forward pass, in sequence
class StochasticDepthResNet(nn.Module):
    def __init__(self, num_classes=10):
        super(StochasticDepthResNet, self).__init__()
        self.init = nn.Sequential(nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1, bias=False),
            nn.BatchNorm2d(64),
            nn.ReLU(inplace=True))
        self.residual_blocks = nn.ModuleList([
            StochasticDepthResidualBlock(64, 64),
            StochasticDepthResidualBlock(64, 128, stride=2, downsample=nn.Sequential(nn.Conv2d(64, 128, kernel_size=1, stride=2, bias=False), nn.BatchNorm2d(128))),
            StochasticDepthResidualBlock(128, 128),
            StochasticDepthResidualBlock(128, 128),
            StochasticDepthResidualBlock(128, 128),
            StochasticDepthResidualBlock(128, 128)
        ])
        self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
        self.fc = nn.Linear(128, num_classes)  # Adjusted final layer input size

    def forward(self, x):
        x = self.init(x)
        # Randomly select a starting point, ensuring sequential execution of 4 blocks
        layer_idx = np.random.choice([2,3,4,5], size=2, replace=False)
        layer_idx.sort()
        all_layers = [0,1] + list(layer_idx)
        print(f"Selected layer indicies: {all_layers}")
        selected_blocks = [self.residual_blocks[int(idx)] for idx in  all_layers]
        for block in selected_blocks:
            x = block(x)
        x = self.avgpool(x)
        x = torch.flatten(x, 1)
        x = self.fc(x)
        return x


In [43]:
# Instantiate and print model summary
model = StochasticDepthResNet()
summary(model, (3, 32, 32))

Selected layer indicies: [0, 1, 3, 5]
----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
            Conv2d-1           [-1, 64, 32, 32]           1,728
       BatchNorm2d-2           [-1, 64, 32, 32]             128
              ReLU-3           [-1, 64, 32, 32]               0
            Conv2d-4           [-1, 64, 32, 32]          36,864
       BatchNorm2d-5           [-1, 64, 32, 32]             128
              ReLU-6           [-1, 64, 32, 32]               0
            Conv2d-7           [-1, 64, 32, 32]          36,864
       BatchNorm2d-8           [-1, 64, 32, 32]             128
              ReLU-9           [-1, 64, 32, 32]               0
StochasticDepthResidualBlock-10           [-1, 64, 32, 32]               0
           Conv2d-11          [-1, 128, 16, 16]          73,728
      BatchNorm2d-12          [-1, 128, 16, 16]             256
             ReLU-13          [-1, 128, 16, 16]       