* ResNeXt
* 논문 : https://arxiv.org/pdf/1611.05431
* ResNet(Skip Connection) + GoogleNet(Inception Module) + AlexNet(Grouped Conv)/caedinality개념
* cardinality는 분할되는 group의 수
* 구현 참조 : https://pytorch.org/vision/0.8/_modules/torchvision/models/resnet.html

In [141]:
import torch
from torch import nn, Tensor
from torchinfo import summary

<img src = 'https://img1.daumcdn.net/thumb/R1280x0/?scode=mtistory2&fname=https%3A%2F%2Fblog.kakaocdn.net%2Fdn%2FvYT9N%2Fbtrh0EsRYzB%2Fxy8KN41PfUxn15kvwaiF2k%2Fimg.png'>
그림 중 (c) 구조를 사용함

In [181]:
class Bottleneck(nn.Module):
    expansion_factor = 4
    def __init__(self, in_channels, inner_channels, cardinality = 32, base_width = 64, stride = 1):
        '''
        논문과 블로그 구현은 1번째 conv에 stride = stride, 2번째 conv에 stride = 1을 두지만,
        torchvision구현은 1번째 conv에 stride = 1, 2번째 conv에 stride = stride를 둠!
        '''
        super().__init__()
        # in_channels는 64
        # out_channels는 128
        # 128 = 64 * (4/64) * 32
        # 결과적으로 2배
        width = int(inner_channels * (base_width / 64.)) * cardinality
        
        self.block = nn.Sequential(
            nn.Conv2d(in_channels, width, kernel_size = 1, stride = 1, bias = False),
            nn.BatchNorm2d(width),
            nn.ReLU(inplace = True),
            nn.Conv2d(width, width, kernel_size = 3, stride = stride, padding = 1, groups = cardinality, bias = False),
            nn.BatchNorm2d(width),
            nn.ReLU(inplace = True),
            nn.Conv2d(width, inner_channels * self.expansion_factor, kernel_size = 1, bias = False),
            nn.BatchNorm2d(inner_channels * self.expansion_factor)
        )
        if stride != 1 or in_channels != inner_channels * self.expansion_factor:
            self.residual = nn.Sequential(
                nn.Conv2d(in_channels, inner_channels * self.expansion_factor, kernel_size = 1, stride = stride, bias = False),
                nn.BatchNorm2d(inner_channels * self.expansion_factor),
            )
        else:
            self.residual = nn.Sequential()
        self.relu = nn.ReLU(inplace = True)
    
    def forward(self, x):
        resid = x
        out = self.block(x)
        out += self.residual(resid)
        out = self.relu(out)
        return out

block = Bottleneck(in_channels = 32, inner_channels = 64, cardinality = 32, base_width = 4)
summary(block, (1, 32,112,112))

Layer (type:depth-idx)                   Output Shape              Param #
Bottleneck                               [1, 256, 112, 112]        --
├─Sequential: 1-1                        [1, 256, 112, 112]        --
│    └─Conv2d: 2-1                       [1, 128, 112, 112]        4,096
│    └─BatchNorm2d: 2-2                  [1, 128, 112, 112]        256
│    └─ReLU: 2-3                         [1, 128, 112, 112]        --
│    └─Conv2d: 2-4                       [1, 128, 112, 112]        4,608
│    └─BatchNorm2d: 2-5                  [1, 128, 112, 112]        256
│    └─ReLU: 2-6                         [1, 128, 112, 112]        --
│    └─Conv2d: 2-7                       [1, 256, 112, 112]        32,768
│    └─BatchNorm2d: 2-8                  [1, 256, 112, 112]        512
├─Sequential: 1-2                        [1, 256, 112, 112]        --
│    └─Conv2d: 2-9                       [1, 256, 112, 112]        8,192
│    └─BatchNorm2d: 2-10                 [1, 256, 112, 112]        51

<img src = 'https://miro.medium.com/v2/resize:fit:852/1*ILkxvajbhiQrRkFrtFs6hg.png'>

In [174]:
class ResNext_blueprint(nn.Module):
    def __init__(self, block, block_lst, cardinality = 32, base_width = 4, num_classes = 1000):
        super().__init__()

        self.conv1 = nn.Sequential(
            nn.Conv2d(3, 64, kernel_size = 7, stride = 2, padding = 3, bias = False),
            nn.BatchNorm2d(64),
            nn.ReLU(inplace = True)
        )
        self.maxpool = nn.MaxPool2d(kernel_size = 3, stride = 2, padding = 1)
        self.in_channels = 64
        self.conv2 = self._make_layer(block, 64, block_lst[0], cardinality = cardinality, base_width = base_width, stride = 1)
        self.conv3 = self._make_layer(block, 128, block_lst[1], cardinality = cardinality, base_width = base_width, stride = 2)
        self.conv4 = self._make_layer(block, 256, block_lst[2], cardinality = cardinality, base_width = base_width, stride = 2)
        self.conv5 = self._make_layer(block, 512, block_lst[3], cardinality = cardinality, base_width = base_width, stride = 2)
        
        self.avg = nn.AdaptiveAvgPool2d((1,1))
        self.fc = nn.Linear(2048, num_classes)
        
        self._init_layer()
    
    def _make_layer(self, block, inner_channels, num_blocks, cardinality, base_width, stride):
        strides = [stride] + [1] * (num_blocks - 1)
        layers = []
        for stride in strides:
            layers.append(block(self.in_channels, inner_channels, cardinality = cardinality, base_width = base_width, stride = stride))
            self.in_channels = inner_channels * block.expansion_factor
        return nn.Sequential(*layers)

    def _init_layer(self):
        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                nn.init.kaiming_normal_(m.weight, mode = 'fan_out', nonlinearity = 'relu')
                if m.bias is not None:
                    nn.init.constant_(m.bias, 0)
            elif isinstance(m, nn.BatchNorm2d):
                nn.init.constant_(m.weight, 1)
                nn.init.constant_(m.bias, 0)
            elif isinstance(m, nn.Linear):
                nn.init.normal_(m.weight, 0, 0.01)
                nn.init.constant_(m.bias, 0)
        
    def forward(self, x):
        out = self.conv1(x)
        out = self.maxpool(out)
        out = self.conv2(out)
        out = self.conv3(out)
        out = self.conv4(out)
        out = self.conv5(out)
        out = self.avg(out)
        out = torch.flatten(out,1)
        out = self.fc(out)
        return out

In [175]:
class ResNeXt:
    def __init__(self, num_classes):
        self.num_classes = num_classes
    def ResNext50_32x4d(self):
        return ResNext_blueprint(Bottleneck, [3,4,6,3], cardinality = 32, base_width = 4)
    def ResNeXt101_32x8d(self):
        return ResNext_blueprint(Bottleneck, [3,4,23,3], cardinality = 32, base_width = 8)
    def ResNeXt101_64x4d(self):
        return ResNext_blueprint(Bottleneck, [3,4,23,3], cardinality = 64, base_width = 4)
    def ResNeXt152_32x4d(self):
        return ResNext_blueprint(Bottleneck, [3,8,36,3], cardinality = 32, base_width = 4)

* ResNeXt50_32x4d

In [176]:
model = ResNeXt(1000).ResNext50_32x4d()
summary(model, (1,3,224,224))

Layer (type:depth-idx)                   Output Shape              Param #
ResNext_blueprint                        [1, 1000]                 --
├─Sequential: 1-1                        [1, 64, 112, 112]         --
│    └─Conv2d: 2-1                       [1, 64, 112, 112]         9,408
│    └─BatchNorm2d: 2-2                  [1, 64, 112, 112]         128
│    └─ReLU: 2-3                         [1, 64, 112, 112]         --
├─MaxPool2d: 1-2                         [1, 64, 56, 56]           --
├─Sequential: 1-3                        [1, 256, 56, 56]          --
│    └─Bottleneck: 2-4                   [1, 256, 56, 56]          --
│    │    └─Sequential: 3-1              [1, 256, 56, 56]          46,592
│    │    └─Sequential: 3-2              [1, 256, 56, 56]          16,896
│    │    └─ReLU: 3-3                    [1, 256, 56, 56]          --
│    └─Bottleneck: 2-5                   [1, 256, 56, 56]          --
│    │    └─Sequential: 3-4              [1, 256, 56, 56]          71,168

In [177]:
from torchvision.models import resnext50_32x4d
model = resnext50_32x4d()
summary(model, (1,3,224,224))

Layer (type:depth-idx)                   Output Shape              Param #
ResNet                                   [1, 1000]                 --
├─Conv2d: 1-1                            [1, 64, 112, 112]         9,408
├─BatchNorm2d: 1-2                       [1, 64, 112, 112]         128
├─ReLU: 1-3                              [1, 64, 112, 112]         --
├─MaxPool2d: 1-4                         [1, 64, 56, 56]           --
├─Sequential: 1-5                        [1, 256, 56, 56]          --
│    └─Bottleneck: 2-1                   [1, 256, 56, 56]          --
│    │    └─Conv2d: 3-1                  [1, 128, 56, 56]          8,192
│    │    └─BatchNorm2d: 3-2             [1, 128, 56, 56]          256
│    │    └─ReLU: 3-3                    [1, 128, 56, 56]          --
│    │    └─Conv2d: 3-4                  [1, 128, 56, 56]          4,608
│    │    └─BatchNorm2d: 3-5             [1, 128, 56, 56]          256
│    │    └─ReLU: 3-6                    [1, 128, 56, 56]          --
│  

* ResNeXt101_32x8d

In [178]:
model = ResNeXt(1000).ResNeXt101_32x8d()
summary(model, (1,3,224,224))

Layer (type:depth-idx)                   Output Shape              Param #
ResNext_blueprint                        [1, 1000]                 --
├─Sequential: 1-1                        [1, 64, 112, 112]         --
│    └─Conv2d: 2-1                       [1, 64, 112, 112]         9,408
│    └─BatchNorm2d: 2-2                  [1, 64, 112, 112]         128
│    └─ReLU: 2-3                         [1, 64, 112, 112]         --
├─MaxPool2d: 1-2                         [1, 64, 56, 56]           --
├─Sequential: 1-3                        [1, 256, 56, 56]          --
│    └─Bottleneck: 2-4                   [1, 256, 56, 56]          --
│    │    └─Sequential: 3-1              [1, 256, 56, 56]          101,888
│    │    └─Sequential: 3-2              [1, 256, 56, 56]          16,896
│    │    └─ReLU: 3-3                    [1, 256, 56, 56]          --
│    └─Bottleneck: 2-5                   [1, 256, 56, 56]          --
│    │    └─Sequential: 3-4              [1, 256, 56, 56]          151,0

In [179]:
from torchvision.models import resnext101_32x8d
model = resnext101_32x8d()
summary(model, (1,3,224,224))

Layer (type:depth-idx)                   Output Shape              Param #
ResNet                                   [1, 1000]                 --
├─Conv2d: 1-1                            [1, 64, 112, 112]         9,408
├─BatchNorm2d: 1-2                       [1, 64, 112, 112]         128
├─ReLU: 1-3                              [1, 64, 112, 112]         --
├─MaxPool2d: 1-4                         [1, 64, 56, 56]           --
├─Sequential: 1-5                        [1, 256, 56, 56]          --
│    └─Bottleneck: 2-1                   [1, 256, 56, 56]          --
│    │    └─Conv2d: 3-1                  [1, 256, 56, 56]          16,384
│    │    └─BatchNorm2d: 3-2             [1, 256, 56, 56]          512
│    │    └─ReLU: 3-3                    [1, 256, 56, 56]          --
│    │    └─Conv2d: 3-4                  [1, 256, 56, 56]          18,432
│    │    └─BatchNorm2d: 3-5             [1, 256, 56, 56]          512
│    │    └─ReLU: 3-6                    [1, 256, 56, 56]          --
│

* ResNeXt101_64x4d

In [180]:
model = ResNeXt(1000).ResNeXt101_64x4d()
summary(model, (1,3,224,224))

Layer (type:depth-idx)                   Output Shape              Param #
ResNext_blueprint                        [1, 1000]                 --
├─Sequential: 1-1                        [1, 64, 112, 112]         --
│    └─Conv2d: 2-1                       [1, 64, 112, 112]         9,408
│    └─BatchNorm2d: 2-2                  [1, 64, 112, 112]         128
│    └─ReLU: 2-3                         [1, 64, 112, 112]         --
├─MaxPool2d: 1-2                         [1, 64, 56, 56]           --
├─Sequential: 1-3                        [1, 256, 56, 56]          --
│    └─Bottleneck: 2-4                   [1, 256, 56, 56]          --
│    │    └─Sequential: 3-1              [1, 256, 56, 56]          92,672
│    │    └─Sequential: 3-2              [1, 256, 56, 56]          16,896
│    │    └─ReLU: 3-3                    [1, 256, 56, 56]          --
│    └─Bottleneck: 2-5                   [1, 256, 56, 56]          --
│    │    └─Sequential: 3-4              [1, 256, 56, 56]          141,82

In [110]:
from torchvision.models import resnext101_64x4d
model = resnext101_64x4d()
summary(model, (1,3,224,224))

Layer (type:depth-idx)                   Output Shape              Param #
ResNet                                   [1, 1000]                 --
├─Conv2d: 1-1                            [1, 64, 112, 112]         9,408
├─BatchNorm2d: 1-2                       [1, 64, 112, 112]         128
├─ReLU: 1-3                              [1, 64, 112, 112]         --
├─MaxPool2d: 1-4                         [1, 64, 56, 56]           --
├─Sequential: 1-5                        [1, 256, 56, 56]          --
│    └─Bottleneck: 2-1                   [1, 256, 56, 56]          --
│    │    └─Conv2d: 3-1                  [1, 256, 56, 56]          16,384
│    │    └─BatchNorm2d: 3-2             [1, 256, 56, 56]          512
│    │    └─ReLU: 3-3                    [1, 256, 56, 56]          --
│    │    └─Conv2d: 3-4                  [1, 256, 56, 56]          9,216
│    │    └─BatchNorm2d: 3-5             [1, 256, 56, 56]          512
│    │    └─ReLU: 3-6                    [1, 256, 56, 56]          --
│ 