In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
import torchvision.transforms as transforms
from torchsummary import summary


In [None]:
# Define ResNeXt Block with Grouped Convolutions
class ResNeXtBlock(nn.Module):
    def __init__(self, in_channels, out_channels, cardinality=32, stride=1, downsample=None):
        super(ResNeXtBlock, self).__init__()
        D = out_channels // 2  # Width per group
        self.conv1 = nn.Conv2d(in_channels, D * cardinality, kernel_size=1, bias=False)
        self.bn1 = nn.BatchNorm2d(D * cardinality)
        self.conv2 = nn.Conv2d(D * cardinality, D * cardinality, kernel_size=3, stride=stride, padding=1, groups=cardinality, bias=False)
        self.bn2 = nn.BatchNorm2d(D * cardinality)
        self.conv3 = nn.Conv2d(D * cardinality, out_channels, kernel_size=1, bias=False)
        self.bn3 = nn.BatchNorm2d(out_channels)
        self.relu = nn.ReLU(inplace=True)
        self.downsample = downsample

    def forward(self, x):
        identity = x
        if self.downsample is not None:
            identity = self.downsample(x)
        out = self.conv1(x)
        out = self.bn1(out)
        out = self.relu(out)
        out = self.conv2(out)
        out = self.bn2(out)
        out = self.relu(out)
        out = self.conv3(out)
        out = self.bn3(out)
        out += identity
        out = self.relu(out)
        return out

In [None]:
# Define ResNeXt Network
class ResNeXt(nn.Module):
    def __init__(self, num_classes=10, cardinality=32):
        super(ResNeXt, self).__init__()
        self.layer1 = self._make_layer(3, 64, 3, cardinality)
        self.layer2 = self._make_layer(64, 128, 4, cardinality, stride=2)
        self.layer3 = self._make_layer(128, 256, 6, cardinality, stride=2)
        self.layer4 = self._make_layer(256, 512, 3, cardinality, stride=2)
        self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
        self.fc = nn.Linear(512, num_classes)

    def _make_layer(self, in_channels, out_channels, blocks, cardinality, stride=1):
        downsample = None
        if stride != 1 or in_channels != out_channels:
            downsample = nn.Sequential(
                nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=stride, bias=False),
                nn.BatchNorm2d(out_channels)
            )
        layers = [ResNeXtBlock(in_channels, out_channels, cardinality, stride, downsample)]
        for _ in range(1, blocks):
            layers.append(ResNeXtBlock(out_channels, out_channels, cardinality))
        return nn.Sequential(*layers)

    def forward(self, x):
        x = self.layer1(x)
        x = self.layer2(x)
        x = self.layer3(x)
        x = self.layer4(x)
        x = self.avgpool(x)
        x = torch.flatten(x, 1)
        x = self.fc(x)
        return x


In [None]:
# Instantiate and print model summary
model = ResNeXt()
summary(model, (3, 64, 64))

----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
            Conv2d-1           [-1, 64, 64, 64]             192
       BatchNorm2d-2           [-1, 64, 64, 64]             128
            Conv2d-3         [-1, 1024, 64, 64]           3,072
       BatchNorm2d-4         [-1, 1024, 64, 64]           2,048
              ReLU-5         [-1, 1024, 64, 64]               0
            Conv2d-6         [-1, 1024, 64, 64]         294,912
       BatchNorm2d-7         [-1, 1024, 64, 64]           2,048
              ReLU-8         [-1, 1024, 64, 64]               0
            Conv2d-9           [-1, 64, 64, 64]          65,536
      BatchNorm2d-10           [-1, 64, 64, 64]             128
             ReLU-11           [-1, 64, 64, 64]               0
     ResNeXtBlock-12           [-1, 64, 64, 64]               0
           Conv2d-13         [-1, 1024, 64, 64]          65,536
      BatchNorm2d-14         [-1, 1024,