In [1]:
import torch
import torch.nn as nn

In [16]:
class Block(nn.Module):

    def __init__(self, in_ch, out_ch, identity_downsample=None, stride=1):
        super(Block, self).__init__()
        self.expansion = 4
        self.identity_downsample = identity_downsample
        self.conv1 = nn.Conv2d(in_ch, out_ch, kernel_size=1, stride=1, padding=0)
        self.bn1 = nn.BatchNorm2d(out_ch)
        self.relu = nn.ReLU()

        self.conv2 = nn.Conv2d(out_ch, out_ch, kernel_size=3, stride=stride, padding=1)
        self.bn2 = nn.BatchNorm2d(out_ch)

        self.conv3 = nn.Conv2d(out_ch, out_ch * 4, kernel_size=1, stride=1, padding=0)
        self.bn3 = nn.BatchNorm2d(out_ch * 4)

    def forward(self, x):
        x_ = x

        x = self.conv1(x)
        x = self.bn1(x)
        x = self.relu(x)

        x = self.conv2(x)
        x = self.bn2(x)
        x = self.relu(x)

        x = self.conv3(x)
        x = self.bn3(x)
        
        if self.identity_downsample is not None:
            x_ = self.identity_downsample(x_)

        return self.relu(x + x_)


class ResnetBlock(nn.Module):

    def __init__(self, inp, out, num_repeat, stride):
        super(ResnetBlock, self).__init__()
        self.num_repeat = num_repeat
        self.layers = []
        if stride != 1 or inp != out * 4:
            self.identity_downsample = nn.Sequential(
                nn.Conv2d(inp, out*4, kernel_size=1, stride=stride),
                nn.BatchNorm2d(out*4)
            )
        self.layers.append(Block(inp, out, self.identity_downsample, stride))

        inp = out * 4
        for i in range(self.num_repeat - 1):
            self.layers.append(Block(inp, out))

        self.layers = nn.Sequential(*self.layers)

    def forward(self, x):
        return self.layers(x)


class Resnet(nn.Module):

    def __init__(self, layers, img_ch, num_class):
        super(Resnet, self).__init__()
        self.conv1 = nn.Conv2d(img_ch, 64, kernel_size=7, stride=2, padding=3)
        self.bn1 = nn.BatchNorm2d(64)
        self.relu = nn.ReLU()
        self.maxpool1 = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)

        self.layer1 = ResnetBlock(64, 64, layers[0], 1)
        self.layer2 = ResnetBlock(64*4, 128, layers[1], 2)
        self.layer3 = ResnetBlock(128*4, 256, layers[2], 2)
        self.layer4 = ResnetBlock(256*4, 512, layers[3], 2)

        self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
        self.flat = nn.Flatten()
        self.fc1 = nn.Linear(512*4, num_class)

    def forward(self, x):
        x = self.conv1(x)
        x = self.bn1(x)
        x = self.relu(x)
        x = self.maxpool1(x)

        x = self.layer1(x)
        #print(x.shape)
        x = self.layer2(x)
        #print(x.shape)
        x = self.layer3(x)
        #print(x.shape)
        x = self.layer4(x)
        #print(x.shape)
        x = self.avgpool(x)
        #print(x.shape)
        x = self.flat(x)
        x = self.fc1(x)
        return x

In [17]:
model = Resnet([3,4,6,3], 3, 1000)

In [18]:
inp = torch.ones((3, 3, 224, 224))
out = model(inp)
out.shape

torch.Size([3, 256, 56, 56])
torch.Size([3, 512, 28, 28])
torch.Size([3, 1024, 14, 14])
torch.Size([3, 2048, 7, 7])
torch.Size([3, 2048, 1, 1])


torch.Size([3, 1000])

In [15]:
num_para = sum(p.numel() for p in model.parameters() if p.requires_grad)
num_para

25583592