#### [CSPNet: A NEW BACKBONE THAT CAN ENHANCE LEARNING CAPABILITY OF CNN](https://openaccess.thecvf.com/content_CVPRW_2020/papers/w28/Wang_CSPNet_A_New_Backbone_That_Can_Enhance_Learning_Capability_of_CVPRW_2020_paper.pdf)

Why need to implement CSPNet?
- Reduces the computational load by 10-20% and has superior accuracy compared to ResNet, DenseNet, ResNext and DenseNet. 
- easy to integrate within any existing frameworks. 
- evenly distribute the amount of computation at each layer in CNN so that we can effectively upgrade the utilization rate of each computation unit and thus reduce unnecessary energy consumption. On yolov3, it achieved 80% less computational bottleneck. 
- Overall required memory cost reduces. 
- improve the inference speed. 

The goal of this blogpost is to build
- ResNet10
- CSPResNet10 

and understand the changes. 

In [1]:
import torch 
import torch.nn as nn
from functools import partial
from typing import Union, Type, List, Tuple

from timm.models.layers import ConvNormAct, DropPath, create_act_layer, ConvNormActAa

## minimal basic implementation of Resnet10

In [2]:
class ResNetBlock(nn.Module):
    expansion = 1
    def __init__(self, in_chs, out_chs, stride=1, downsample=None):
        super().__init__()
        block_kwargs = {"norm_layer":nn.BatchNorm2d, "act_layer": nn.ReLU}
        self.conv1 = ConvNormAct(in_chs, out_chs, kernel_size=3, stride=stride, padding=1, bias=False, **block_kwargs)
        self.conv2 = ConvNormAct(out_chs, out_chs, kernel_size=3, padding=1, bias=False, apply_act=False, **block_kwargs)
        self.downsample = downsample
        self.relu = nn.ReLU(inplace=True)
        if self.downsample is None:
            assert stride == 1, "stride cannot be more than 1 when downsample is None"
    
    def forward(self, x: torch.Tensor)-> torch.Tensor:
        residual=x
        out: torch.Tensor = self.conv1(x)
        
        out = self.conv2(out)
        
        if self.downsample is not None:
            residual = self.downsample(x)

        out += residual
        out = self.relu(out)

        return out

In [3]:
resblock = ResNetBlock(16, 16, 1)
resblock

ResNetBlock(
  (conv1): ConvNormAct(
    (conv): Conv2d(16, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
    (bn): BatchNormAct2d(
      16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True
      (drop): Identity()
      (act): ReLU(inplace=True)
    )
  )
  (conv2): ConvNormAct(
    (conv): Conv2d(16, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
    (bn): BatchNormAct2d(
      16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True
      (drop): Identity()
      (act): Identity()
    )
  )
  (relu): ReLU(inplace=True)
)

In [4]:
x = torch.randn((2, 16, 24, 24))
resblock(x).shape

torch.Size([2, 16, 24, 24])

### ResNetBottleneck 
- we have 3 conv layer 
- the feature maps extend from in_chs to out_chs* self.expansion
- incase we use stride>1, use downsample to reduce the size of skip connection.

In [5]:
class ResNetBottleneck(nn.Module):
    expansion = 4

    def __init__(self, in_chs, out_chs, stride=1, downsample=None)-> None:
        super().__init__()
        block_kwargs = {"norm_layer":nn.BatchNorm2d, "act_layer": nn.ReLU}
        self.conv1 = ConvNormAct(in_chs, out_chs, kernel_size=1, bias=False, **block_kwargs)
        self.conv2 = ConvNormAct(out_chs, out_chs, kernel_size=3, stride=stride, padding=1, bias=False, **block_kwargs)
        self.conv3 = ConvNormAct(out_chs, out_chs*self.expansion, kernel_size=1, bias=False, **block_kwargs, apply_act=False )
        self.relu = nn.ReLU(inplace=True)
        self.downsample = downsample
        self.stride = stride
        if self.downsample is None:
            self.downsample = ConvNormAct(in_chs, out_chs*self.expansion, kernel_size=1, stride=self.stride, **block_kwargs, apply_act=False)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        residual = x

        out: torch.Tensor = self.conv1(x)
        out = self.conv2(out)
        out = self.conv3(out)

        if self.downsample is not None:
            residual = self.downsample(x)

        out += residual
        out = self.relu(out)

        return out

In [6]:
x = torch.randn((2, 16, 24, 24))
ResNetBottleneck(16, 64, 1)(x).shape

[W NNPACK.cpp:51] Could not initialize NNPACK! Reason: Unsupported hardware.


torch.Size([2, 256, 24, 24])

### CSPStage
- This is an extension to both blocks discussed before
- First we take a tensor
    - conv_down = if stride>1, we reduce the size of the feature_map using convlayer else use nn.Identity
    - conv_exp = increase the number channels based on expand_ratio.
    - next we split the tensor into two parts `xs, xb = x.split(self.expand_chs // 2, dim=1)`
    - we pass xb through `block` layers. the number of block layers is defined by depth.
    - pass through `conv_transition_b` layer and reduce the feature channels original xb.
    - concat xs and xb
    - pass through `conv_transition` conv layer to expand to out_chs layers. 

In [7]:
class CSPStage(nn.Module):
    def __init__(self, 
                 in_chs, 
                 out_chs, 
                 stride, 
                 depth, 
                 block_ratio=1., 
                 expand_ratio=1.0, 
                 bottle_ratio=1.0, 
                 block_fn=Type[Union[ResNetBlock, ResNetBottleneck]]):
        super().__init__()
        block_kwargs = {"norm_layer":nn.BatchNorm2d, "act_layer": nn.ReLU}
        self.in_chs = in_chs
        self.expand_chs = exp_chs = int(round(out_chs * expand_ratio))
        block_out_chs = int(round(out_chs * block_ratio))
        conv_kwargs = dict(act_layer=block_kwargs.get('act_layer'), norm_layer=block_kwargs.get('norm_layer'))
        aa_layer = block_kwargs.pop('aa_layer', None)
        
        if stride != 1:
            self.conv_down = nn.Sequential(
                nn.AvgPool2d(2) if stride == 2 else nn.Identity(),
                ConvNormAct(in_chs, out_chs, kernel_size=3, stride=1, **conv_kwargs))
            prev_chs = out_chs
        else:
            self.conv_down = nn.Identity()
            prev_chs = in_chs
        
        self.conv_exp = ConvNormAct(prev_chs, exp_chs, kernel_size=1, apply_act=True, **conv_kwargs)
        prev_chs = exp_chs // 2  # output of conv_exp is always split in two
        self.in_chs=prev_chs
        self.blocks = self._make_layer(block_fn, prev_chs, depth, 1)
        # transition convs
        self.conv_transition_b = ConvNormAct(self.in_chs, exp_chs // 2, kernel_size=1, **conv_kwargs)
        self.conv_transition = ConvNormAct(exp_chs, out_chs, kernel_size=1, **conv_kwargs)
        
        
    def _make_layer(self, block: Type[ResNetBlock], planes: int, blocks: int, stride: int = 1.) -> nn.Sequential:
        downsample: Union[nn.Module, partial, None] = None
        if stride != 1 or self.in_chs != planes * block.expansion:
            downsample = nn.Sequential(
                    nn.Conv2d(self.in_chs, planes * block.expansion, kernel_size=1, stride=stride),
                    nn.BatchNorm2d(planes * block.expansion),
                )
        layers = [
            block(
                in_chs=self.in_chs, out_chs=planes, stride=stride, downsample=downsample
            )
        ]

        self.in_chs = planes * block.expansion
        for _i in range(1, blocks):
            layers.append(block(self.in_chs, planes))
        
        return nn.Sequential(*layers)
    
    def forward(self, x):
        x = self.conv_down(x)
        x = self.conv_exp(x)
        xs, xb = x.split(self.expand_chs // 2, dim=1)
        xb = self.blocks(xb)
        xb = self.conv_transition_b(xb).contiguous()
        out = self.conv_transition(torch.cat([xs, xb], dim=1))
        return out

In [8]:
x2= torch.randn((2,16, 64, 64)).float()
tt = CSPStage(16, 64, 2, 2, block_fn=ResNetBottleneck)

In [9]:
tt(x2).shape

torch.Size([2, 64, 32, 32])

### Build CSPResNet and ResNet

In [10]:
class ResNet(nn.Module):
    def __init__(
        self,
        block: Type[Union[ResNetBlock, ResNetBottleneck]],
        layers: List[int],
        block_inplanes: List[int],
        in_chs: int = 3,
        conv1_t_size: int = 7,
        conv1_t_stride: int = 1,
        no_max_pool: bool = False,
        num_classes: int = 400,
        csp: bool=False 
    ) -> None:

        super().__init__()

        self.in_planes = block_inplanes[0]
        self.no_max_pool = no_max_pool

        # conv-bn-relu
        block_kwargs = {"norm_layer":nn.BatchNorm2d, "act_layer": nn.ReLU}
        self.conv1 = ConvNormAct(in_chs, 
                                self.in_planes, 
                                kernel_size=conv1_t_size, 
                                stride=conv1_t_stride, 
                                padding=tuple(k // 2 for k in (conv1_t_size, conv1_t_size)), 
                                bias=False, 
                                **block_kwargs)
        self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
        
        if csp:
            self.layer1 = CSPStage(self.in_planes, block_inplanes[0], 1, layers[0], block_fn=block)
            self.layer2 = CSPStage(block_inplanes[0], block_inplanes[1], 2, layers[1], block_fn=block)
            self.layer3 = CSPStage(block_inplanes[1], block_inplanes[2], 2, layers[2], block_fn=block)
            self.layer4 = CSPStage(block_inplanes[2], block_inplanes[3], 2, layers[3], block_fn=block)
        else:
            
            self.layer1 = self._make_layer(block, block_inplanes[0], layers[0],)
            self.layer2 = self._make_layer(block, block_inplanes[1], layers[1], stride=2)
            self.layer3 = self._make_layer(block, block_inplanes[2], layers[2], stride=2)
            self.layer4 = self._make_layer(block, block_inplanes[3], layers[3], stride=2)
        #classifier.
        self.avgpool = nn.AdaptiveAvgPool2d([1, 1])
        self.fc = nn.Linear(block_inplanes[3] * block.expansion, num_classes)

        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                nn.init.kaiming_normal_(torch.as_tensor(m.weight), mode="fan_out", nonlinearity="relu")
            elif isinstance(m, nn.BatchNorm2d):
                nn.init.constant_(torch.as_tensor(m.weight), 1)
                nn.init.constant_(torch.as_tensor(m.bias), 0)
            elif isinstance(m, nn.Linear):
                nn.init.constant_(torch.as_tensor(m.bias), 0)

    def _make_layer(self, block: Type[ResNetBlock], planes: int, blocks: int, stride: int = 1) -> nn.Sequential:
        downsample: Union[nn.Module, partial, None] = None
        if stride != 1 or self.in_planes != planes * block.expansion:
            downsample = nn.Sequential(
                    nn.Conv2d(self.in_planes, planes * block.expansion, kernel_size=1, stride=stride),
                    nn.BatchNorm2d(planes * block.expansion),
                )
        layers = [
            block(
                in_chs=self.in_planes, out_chs=planes, stride=stride, downsample=downsample
            )
        ]

        self.in_planes = planes * block.expansion
        for _i in range(1, blocks):
            layers.append(block(self.in_planes, planes))

        return nn.Sequential(*layers)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        x = self.conv1(x)
        if not self.no_max_pool:
            x = self.maxpool(x)

        x = self.layer1(x)
        x = self.layer2(x)
        x = self.layer3(x)
        x = self.layer4(x)
        
        x = self.avgpool(x)

        x = x.view(x.size(0), -1)
        if self.fc is not None:
            x = self.fc(x)

        return x

In [11]:
x = torch.randn((2, 3, 128, 128))
x.shape

torch.Size([2, 3, 128, 128])

In [12]:
resnet10 = ResNet(ResNetBlock, [1, 1, 1, 1], [64, 128, 256, 512])

In [13]:
%%time
resnet10(x).shape

CPU times: user 181 ms, sys: 19.7 ms, total: 201 ms
Wall time: 213 ms


torch.Size([2, 400])

In [14]:
csp_resnet10 = ResNet(ResNetBlock, [1, 1, 1, 1], [64, 128, 256, 512], csp=True)

In [15]:
%%time
csp_resnet10(x).shape

CPU times: user 169 ms, sys: 16.8 ms, total: 186 ms
Wall time: 186 ms


torch.Size([2, 400])

As seen on MAC Air M1 the wall time for `csp_resnet` is 186 ms and normal `resnet` is 213 ms, So `csp_resnet` is ~12% faster. 

for 
- resnet18 use `ResNetBlock & [2, 2, 2, 2]`
- resnet34 use `ResNetBlock & [3, 4, 6, 3]`
- resnet50 use `ResNetBottleneck & [3, 4, 6, 3]`
- resnet101 use `ResNetBottleneck & [3, 4, 23, 3]`
- resnet200 use `ResNetBottleneck & [3, 24, 36, 3]`

### Parameters count 

In [16]:
res10 = sum([i.numel() for name, i in resnet10.named_parameters()])
res10

5111888

In [17]:
cspres10 = sum([i.numel() for name, i in csp_resnet10.named_parameters()])
cspres10

4121616

In [18]:
((res10- cspres10)/ res10)*100

19.371942421273705

> we have ~20% reduction in the number of params. 