### resnet50 pytorch 实现   
2222 2， 3463 2， 3463 3， 34233 3，
![resnet](https://pic3.zhimg.com/v2-4f6be1baeda43656e98ab58a17edaa1c_1440w.jpg?source=172ae18b)

In [3]:
import torch
import torch.nn as nn
import torch.nn.functional as F

### ResNet的Bottleneck 
1x1conv->bn->relu->3x3conv->bn->relu->1x1conv->bn 残差相加输出 -> relu

In [95]:
class BasicBlock(nn.Module):
    expansion = 1
    def __init__(self, inplanes, planes, stride=1, downsample=None):
        super().__init__()
        self.conv1 = nn.Conv2d(inplanes, planes, 3, stride, padding=1, bias=False)
        self.bn1 = nn.BatchNorm2d(planes)
        self.conv2 = nn.Conv2d(planes, planes, 3, stride, padding=1, bias=False)
        self.bn2 = nn.BatchNorm2d(planes)
#         self.conv3 = nn.Conv2d(planes, planes*self.expansion, 1, 1)
#         self.bn3 = nn.BatchNorm2d(planes*self.expansion)
        self.relu = nn.ReLU()
        self.downsample = downsample
        self.stride = stride
        
    def forward(self, x):
        residual = x
        out = self.relu(self.bn1(self.conv1(x)))
        out = self.bn2(self.conv2(out))
        # out = self.bn3(self.conv3(out)) 
        
        if self.downsample:
            residual = self.downsample(x)
        out += residual
        out = self.relu(out)
        return out

In [96]:
class ResnetBottleneck(nn.Module):
    expansion = 4
    def __init__(self, inplanes, planes, stride=1, downsample=None):
        super().__init__()
        self.conv1 = nn.Conv2d(inplanes, planes, 1, 1, bias=False)
        self.bn1 = nn.BatchNorm2d(planes)
        self.conv2 = nn.Conv2d(planes, planes, 3, stride, padding=1, bias=False)
        self.bn2 = nn.BatchNorm2d(planes)
        self.conv3 = nn.Conv2d(planes, planes*self.expansion, 1, 1)
        self.bn3 = nn.BatchNorm2d(planes*self.expansion)
        self.relu = nn.ReLU()
        self.downsample = downsample
        self.stride = stride
        
    def forward(self, x):
        residual = x
        out = self.relu(self.bn1(self.conv1(x)))
        out = self.relu(self.bn2(self.conv2(out)))
        out = self.bn3(self.conv3(out)) 
        
        if self.downsample:
            residual = self.downsample(x)
        out += residual
        out = self.relu(out)
        return out

### ResNext的Bottleneck
resnext再resnet基础上再3x3卷积的部分采用了分组卷积，这样做的受Inception启发论文将Residual部分分成若干个支路，这个支路的数量就是cardinality的含义。 右图是ResNeXt的一个32x4d的基本结构，32指的是cardinality是32，即利用1x1卷积降维，并分成32条支路； 4d指的是每个支路中transform的3x3卷积的滤波器数量为4。
![resnext](https://pic4.zhimg.com/80/v2-e84925a5b925673384c320dec5979227_720w.jpg)
![block](https://pic3.zhimg.com/80/v2-9914acfdf1ebec89eb690ba96de74f42_720w.jpg)

In [106]:
class ResnextBottleneck(nn.Module):
    expansion = 4
    def __init__(self, inplanes, planes, stride=1, downsample=None, group=32):
        super().__init__()
        self.conv1 = nn.Conv2d(inplanes, planes, 1, 1, bias=False)
        self.bn1 = nn.BatchNorm2d(planes)
        self.conv2 = nn.Conv2d(planes, planes, 3, stride, padding=1, groups=group, bias=False)
        self.bn2 = nn.BatchNorm2d(planes)
        self.conv3 = nn.Conv2d(planes, planes*self.expansion, 1, 1)
        self.bn3 = nn.BatchNorm2d(planes*self.expansion)
        self.relu = nn.ReLU()
        self.downsample = downsample
        self.stride = stride
        
    def forward(self, x):
        residual = x
        out = self.relu(self.bn1(self.conv1(x)))
        out = self.relu(self.bn2(self.conv2(out)))
        out = self.bn3(self.conv3(out)) 
        
        if self.downsample:
            residual = self.downsample(x)
        out += residual
        out = self.relu(out)
        return out

In [107]:
class Resnet(nn.Module):
    def __init__(self, block, layers, num_classes=100):
        super().__init__()
        self.inplanes = 64
        self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3, bias=False)
        self.bn1 = nn.BatchNorm2d(64)
        self.relu = nn.ReLU()
        self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
        
        # 四个layer
        self.layer1 = self._make_layer(block, 64, layers[0])
        self.layer2 = self._make_layer(block, 128, layers[1], stride=2)
        self.layer3 = self._make_layer(block, 256, layers[2], stride=2)
        self.layer4 = self._make_layer(block, 512, layers[3], stride=2)
        self.avgpool = nn.AdaptiveAvgPool2d(1)
        self.fc = nn.Linear(512 * block.expansion, num_classes)
        
    def _make_layer(self, block, planes, blocks, stride=1):
        downsample = None
        if stride != 1 or self.inplanes != planes * block.expansion:
            downsample = nn.Sequential(
                nn.Conv2d(self.inplanes, planes*block.expansion, 1, stride),
                nn.BatchNorm2d(planes*block.expansion)
            )
        layers = []
        layers.append(block(self.inplanes, planes, stride, downsample))
        self.inplanes = planes*block.expansion
        for _ in range(1, blocks):
            layers.append(block(self.inplanes, planes))
        return nn.Sequential(*layers)
    def forward(self, x):
        x = self.maxpool(self.relu(self.bn1(self.conv1(x))))
        x = self.layer1(x)
        x = self.layer2(x)
        x = self.layer3(x)
        x = self.layer4(x)
        
        x = self.avgpool(x)
        x = x.view(x.size(0), -1)
        x = self.fc(x)
        return x

In [108]:
Resnet_model = Resnet(ResnetBottleneck, [3,4,23,3])
# resnet 18（2，2，2，2） 和 34（3，4，6，3） 用的basicblock resnet50（3，4，6，3） resnet101（3，4，23，3） 
Resnext_model = Resnet(ResnextBottleneck, [3,4,23,3])
test = torch.rand(4,3,224,224)
Resnet_model(test).shape

torch.Size([4, 100])

In [86]:
a,b,c = test.chunk(3, dim=1)
nn.

### Resnest Block
![resenest](https://camo.qiitausercontent.com/9ef6a58f7d75fb7dc8e5f57020d5afe448a34310/68747470733a2f2f71696974612d696d6167652d73746f72652e73332e61702d6e6f727468656173742d312e616d617a6f6e6177732e636f6d2f302f343031302f37323437376562352d336563392d326136322d303531332d3366303936386533316262322e706e67)

In [4]:
class ResnestBottleneck(nn.Module):
    def __init__(self, inplanes, k=2, r=3, downsample=None):
        super().__init__()
        self.k = k
        self.conv1 = nn.Conv2d(inplanes//k, inplanes//(k*r), kernel_size=1, groups=r, bias=False)
        self.bn1 = nn.BatchNorm2d(inplanes//k)
        self.conv2 = nn.Conv2d(inplanes//(k*r), in_channels//k, kernel_size=3, padding=1, groups=r, bias=False)
        self.relu = nn.ReLU()
        self.globalavgpool = nn.AdaptiveAvgPool2d(1)
    def forward(self, x):
        '''x [bs, c, w, h]'''
        r_xs = torch.split(x, self.k, dim=1)
        r_xs = list(r_xs)
        for i in range(len(r_xs)):
            r_xs[i] = self.bn2(self.conv2(self.relu(self.bn1(self.conv1(r_xs[i])))))
        r_xs = torch.concat(r_xs, dim=1)
        r_xs = self.globalavgpool(r_xs)
        r_xs = r_xs.view(r_xs.size(0), -1)              