In [3]:
import torch
import torch.nn as nn

In [None]:
            # check if resnet18 matches official version
            # from torchsummary import summary
            # print(summary(model, (3, 224, 224)))
            # import torchvision.models as md
            # print(summary(md.resnet18(False), (3, 224, 224)))

In [10]:
class BottleneckResidualBlock(nn.Module):
    def __init__(self, in_channels, out_channels, use_shortcut, mid_stride, shortcut_stride):
        super().__init__()

        self.use_shortcut = use_shortcut

        self.stride = 2 if in_channels != out_channels else 1

        self.blocks = nn.Sequential(

            # first conv layer
            nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=1, padding=0, bias=False),
            nn.BatchNorm2d(out_channels),
            #             nn.ReLU(inplace=True),

            # second conv layer
            nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=mid_stride, padding=1, bias=False),
            nn.BatchNorm2d(out_channels),
            #             nn.ReLU(inplace=True),

            # third conv layer
            nn.Conv2d(out_channels, out_channels * 4, kernel_size=1, stride=1, padding=0, bias=False),
            nn.BatchNorm2d(out_channels * 4),
            nn.ReLU(inplace=True)
        )

        #         # shortcut
        #         if in_channels == out_channels:
        #             self.shortcut = nn.Identity()
        #         else:

        if self.use_shortcut:
            self.shortcut = nn.Sequential(
                nn.Conv2d(in_channels, out_channels * 4, kernel_size=1, stride=shortcut_stride, bias=False),
                nn.BatchNorm2d(out_channels * 4)
            )

    def forward(self, x):

        # blocks
        out = self.blocks(x)

        if self.use_shortcut:
            # shortcut
            shortcut = self.shortcut(x)

            # combine
            activate = nn.ReLU(inplace=True)
            out = activate(out + shortcut)

        return out


class Resnet152(nn.Module):
    def __init__(self, in_features, num_classes):
        super().__init__()

        self.layers = nn.Sequential(
            # conv1
            nn.Conv2d(in_channels=in_features, out_channels=64, kernel_size=7, stride=2, padding=3, bias=False),
            nn.BatchNorm2d(64),
            nn.ReLU(inplace=True),

            nn.MaxPool2d(kernel_size=3, stride=2, padding=1),

            # conv2..x
            BottleneckResidualBlock(64, 64, use_shortcut=True, mid_stride=1, shortcut_stride=1),
            BottleneckResidualBlock(256, 64, use_shortcut=False, mid_stride=1, shortcut_stride=1),
            BottleneckResidualBlock(256, 64, use_shortcut=False, mid_stride=1, shortcut_stride=1),

            # conv3..x
            BottleneckResidualBlock(256, 128, use_shortcut=True, mid_stride=2, shortcut_stride=2),
            BottleneckResidualBlock(512, 128, use_shortcut=False, mid_stride=1, shortcut_stride=1),
            BottleneckResidualBlock(512, 128, use_shortcut=False, mid_stride=1, shortcut_stride=1),
            BottleneckResidualBlock(512, 128, use_shortcut=False, mid_stride=1, shortcut_stride=1),
            BottleneckResidualBlock(512, 128, use_shortcut=False, mid_stride=1, shortcut_stride=1),
            BottleneckResidualBlock(512, 128, use_shortcut=False, mid_stride=1, shortcut_stride=1),
            BottleneckResidualBlock(512, 128, use_shortcut=False, mid_stride=1, shortcut_stride=1),
            BottleneckResidualBlock(512, 128, use_shortcut=False, mid_stride=1, shortcut_stride=1),

            # conv4..x
            BottleneckResidualBlock(512, 256, use_shortcut=True, mid_stride=2, shortcut_stride=2),
            BottleneckResidualBlock(1024, 256, use_shortcut=False, mid_stride=1, shortcut_stride=1),
            BottleneckResidualBlock(1024, 256, use_shortcut=False, mid_stride=1, shortcut_stride=1),
            BottleneckResidualBlock(1024, 256, use_shortcut=False, mid_stride=1, shortcut_stride=1),
            BottleneckResidualBlock(1024, 256, use_shortcut=False, mid_stride=1, shortcut_stride=1),
            BottleneckResidualBlock(1024, 256, use_shortcut=False, mid_stride=1, shortcut_stride=1),
            BottleneckResidualBlock(1024, 256, use_shortcut=False, mid_stride=1, shortcut_stride=1),
            BottleneckResidualBlock(1024, 256, use_shortcut=False, mid_stride=1, shortcut_stride=1),
            BottleneckResidualBlock(1024, 256, use_shortcut=False, mid_stride=1, shortcut_stride=1),
            BottleneckResidualBlock(1024, 256, use_shortcut=False, mid_stride=1, shortcut_stride=1),
            
            BottleneckResidualBlock(1024, 256, use_shortcut=False, mid_stride=1, shortcut_stride=1),
            BottleneckResidualBlock(1024, 256, use_shortcut=False, mid_stride=1, shortcut_stride=1),
            BottleneckResidualBlock(1024, 256, use_shortcut=False, mid_stride=1, shortcut_stride=1),
            BottleneckResidualBlock(1024, 256, use_shortcut=False, mid_stride=1, shortcut_stride=1),
            BottleneckResidualBlock(1024, 256, use_shortcut=False, mid_stride=1, shortcut_stride=1),
            BottleneckResidualBlock(1024, 256, use_shortcut=False, mid_stride=1, shortcut_stride=1),
            BottleneckResidualBlock(1024, 256, use_shortcut=False, mid_stride=1, shortcut_stride=1),
            BottleneckResidualBlock(1024, 256, use_shortcut=False, mid_stride=1, shortcut_stride=1),
            BottleneckResidualBlock(1024, 256, use_shortcut=False, mid_stride=1, shortcut_stride=1),
            BottleneckResidualBlock(1024, 256, use_shortcut=False, mid_stride=1, shortcut_stride=1),
            
            BottleneckResidualBlock(1024, 256, use_shortcut=False, mid_stride=1, shortcut_stride=1),
            BottleneckResidualBlock(1024, 256, use_shortcut=False, mid_stride=1, shortcut_stride=1),
            BottleneckResidualBlock(1024, 256, use_shortcut=False, mid_stride=1, shortcut_stride=1),
            BottleneckResidualBlock(1024, 256, use_shortcut=False, mid_stride=1, shortcut_stride=1),
            BottleneckResidualBlock(1024, 256, use_shortcut=False, mid_stride=1, shortcut_stride=1),
            BottleneckResidualBlock(1024, 256, use_shortcut=False, mid_stride=1, shortcut_stride=1),
            BottleneckResidualBlock(1024, 256, use_shortcut=False, mid_stride=1, shortcut_stride=1),
            BottleneckResidualBlock(1024, 256, use_shortcut=False, mid_stride=1, shortcut_stride=1),
            BottleneckResidualBlock(1024, 256, use_shortcut=False, mid_stride=1, shortcut_stride=1),
            BottleneckResidualBlock(1024, 256, use_shortcut=False, mid_stride=1, shortcut_stride=1),
            
            BottleneckResidualBlock(1024, 256, use_shortcut=False, mid_stride=1, shortcut_stride=1),
            BottleneckResidualBlock(1024, 256, use_shortcut=False, mid_stride=1, shortcut_stride=1),
            BottleneckResidualBlock(1024, 256, use_shortcut=False, mid_stride=1, shortcut_stride=1),
            BottleneckResidualBlock(1024, 256, use_shortcut=False, mid_stride=1, shortcut_stride=1),
            BottleneckResidualBlock(1024, 256, use_shortcut=False, mid_stride=1, shortcut_stride=1),
            BottleneckResidualBlock(1024, 256, use_shortcut=False, mid_stride=1, shortcut_stride=1),

            # conv5..x
            BottleneckResidualBlock(1024, 512, use_shortcut=True, mid_stride=2, shortcut_stride=2),
            BottleneckResidualBlock(2048, 512, use_shortcut=False, mid_stride=1, shortcut_stride=1),
            BottleneckResidualBlock(2048, 512, use_shortcut=False, mid_stride=1, shortcut_stride=1),

            # summary
            nn.AdaptiveAvgPool2d((1, 1)),
            nn.Flatten(),
        )
        # decoding layer
        self.linear = nn.Sequential(
            nn.Linear(2048, num_classes))

    def forward(self, x, return_embedding=False):
        embedding = self.layers(x)

        if return_embedding:
            return embedding
        else:
            return self.linear(embedding)


In [14]:
model = Resnet152(3,1000)
model
# from torchsummary import summary
# print(summary(model, (3, 224, 224)))
# print()
# import torchvision.models as md
# # print('theirs')
# # print(summary(md.resnet152(False), (3, 224, 224)))
# md.resnet152(False)

Resnet152(
  (layers): Sequential(
    (0): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
    (1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (2): ReLU(inplace=True)
    (3): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
    (4): BottleneckResidualBlock(
      (blocks): Sequential(
        (0): Conv2d(64, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)
        (1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        (3): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (4): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
        (5): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (6): ReLU(inplace=True)
      )
      (shortcut): Sequential(
        (0): Conv2d(64, 