In [1]:
import torch

In [2]:
import torch.nn as nn
import torch.nn.functional as F

In [None]:
class ResBlock(nn.Module):
    def __init__(self,in_channels,out_channels,stride=1):
        super(ResBlock,self).__init__()
        self.layer=nn.Sequential(
            nn.Conv2d(in_channels,out_channels,kernel_size=3,stride=stride,padding=1),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(),                   
            nn.Conv2d(out_channels,out_channels,kernel_size=3,stride=1,padding=1),
            nn.BatchNorm2d(out_channels),
        )
        self.shortcut=nn.Sequential()
        if in_channels!=out_channels or stride>1:
            #如果输入输出通道数不一致或者步长不为1，则需要添加shortcut
            #shortcut的作用是将输入直接加到输出上
            #如果输入输出通道数不一致，则需要通过卷积调整通道数
            #如果步长不为1，则需要通过卷积调整步长
            #这里的卷积核大小为3，步长为stride，padding为1
            #这样可以保证输出的大小与输入的大小一致
            #如果步长为1，则padding为1
            #如果步长不为2，则padding为0
            #如果步长不为3，则padding为1
            #如果步长不为4，则padding为2
            self.shortcut=nn.Sequential(
                nn.Conv2d(in_channels,out_channels,kernel_size=3,stride=stride,padding=1),
                nn.BatchNorm2d(out_channels),
            )
        
    def forward(self,x):
        out1=self.layer(x)
        out2=self.shortcut(x)
        out=out1+out2
        out=F.relu(out)
        return out
        

In [None]:
class ResNet(nn.Module):
    def make_layer(self,block,out_channel,stride,num_block):
        layers_list=[]
        for i in range(num_block):
           if i==0:
               in_stride=stride
           else:
               in_stride=1     
           layers_list.append(block(self.in_channel,out_channel,in_stride)) 
           self.in_channel=out_channel 
        return nn.Sequential(*layers_list)
    def __init__(self,ResBlock):
        super(ResNet,self).__init__()
        self.in_channel=32
        self.conv1=nn.Sequential(
            nn.Conv2d(3,32,kernel_size=3,stride=1,padding=1),
            nn.BatchNorm2d(32),
            nn.ReLU()
        )
        # self.layer1=ResBlock(in_channel=32,
        #                      out_channels=64,
        #                      stride=2)
        # self.layer2=ResBlock(in_channel=64,
        #                      out_channels=64,
        #                      stride=2)
        # self.layer3=ResBlock(in_channel=64,
        #                      out_channels=128,
        #                      stride=2)
        # self.layer4=ResBlock(in_channel=64,
        #                      out_channels=128,
        #                      stride=2)
        self.layer1=self.make_layer(ResBlock,64,2,2)
        self.layer2=self.make_layer(ResBlock,128,2,2)
        self.layer3=self.make_layer(ResBlock,256,2,2)
        self.layer4=self.make_layer(ResBlock,512,2,2)
        self.fc=nn.Linear(512,10)  # 假设输出10个类别
    def forward(self,x):
        out=self.conv1(x)
        out=self.layer1(out)
        out=self.layer2(out)
        out=self.layer3(out)
        out=self.layer4(out)
        out=F.avg_pool2d(out,2)  # 全局
        out=out.view(out.size(0),-1)  # 展平
        out=self.fc(out)  # 全连接层
        return out
    
     

In [5]:
def resnet():
    return ResNet(ResBlock)