<img src="imgs/ResNet34.png">

<img src="imgs/ResNet.png">

In [9]:
import torch as t
import torch.nn as nn

In [37]:
class BlockA(nn.Module):
    def __init__(self, in1, mid1, out1, strides=1):
        super(BlockA, self).__init__()
        self.strides = strides
        self.conv1 = nn.Conv2d(in_channels=in1, out_channels=mid1, kernel_size=3, stride=strides, padding=1)
        self.conv2 = nn.Conv2d(in_channels=mid1, out_channels=out1, kernel_size=3, padding=1)
    def forward(self, x):
        output = self.conv2(self.conv1(x))
        if self.strides == 1:
            output = output + x
        return output

    
class BlockB(nn.Module):
    def __init__(self, in1, mid1, mid2, out1, strides=1):
        super(BlockB, self).__init__()
        self.strides = strides
        self.in1 = in1
        self.conv1 = nn.Conv2d(in_channels=in1, out_channels=mid1, kernel_size=1, stride=strides)
        self.conv2 = nn.Conv2d(in_channels=mid1, out_channels=mid2, kernel_size=3, padding=1)
        self.conv3 = nn.Conv2d(in_channels=mid2, out_channels=out1, kernel_size=1)
    def forward(self, x):
        output = self.conv3(self.conv2(self.conv1(x)))
        if self.strides == 1 and self.in1 == output.size():
            output = output + x
        return output

In [38]:
net = BlockB(256,64,64,256,1)
input_data = t.randn(10, 256,  56, 56)
output = net(input_data)
print(output.size())

torch.Size([10, 256, 56, 56])


In [12]:
class ResNet18(nn.Module):
    def __init__(self):
        super(ResNet18, self).__init__()
        self.conv1 = nn.Conv2d(in_channels=3, out_channels=64, kernel_size=7, stride=2, padding=3)
        self.pool2 = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
        
        self.block3 = nn.Sequential(BlockA(64,64,64,1),
                                    BlockA(64,64,64,1) 
        )
        
        self.block4 = nn.Sequential(BlockA(64,128,128,2),
                                    BlockA(128,128,128,1))
        
        self.block5 = nn.Sequential(BlockA(128,256,256,2),
                                    BlockA(256,256,256,1))
        
        self.block6 = nn.Sequential(BlockA(256,512,512,2),
                                    BlockA(512,512,512,1))
        
        self.pool7 = nn.AvgPool2d(kernel_size=3, stride=1, padding=1)
        self.fc8 = nn.Linear(in_features=512*7*7,out_features=1000)
        
        
    def forward(self, x):
        x = self.conv1(x)
        x = self.pool2(x)
        x = self.block3(x)
        x = self.block4(x)
        x = self.block5(x)
        x = self.block6(x)
        x = self.pool7(x)
        x = x.view(x.size()[0], -1)
        x = self.fc8(x)
        return x

In [13]:
net = ResNet18()
input_data = t.randn(10, 3,  224, 224)
output = net(input_data)
print(output.size())

torch.Size([10, 1000])


In [14]:
class ResNet34(nn.Module):
    def __init__(self):
        super(ResNet34, self).__init__()
        self.conv1 = nn.Conv2d(in_channels=3, out_channels=64, kernel_size=7, stride=2, padding=3)
        self.pool2 = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
        
        self.block3 = nn.Sequential(BlockA(64,64,64,1),
                                    BlockA(64,64,64,1),
                                    BlockA(64,64,64,1) 
        )
        
        self.block4 = nn.Sequential(BlockA(64,128,128,2),
                                    BlockA(128,128,128,1),
                                    BlockA(128,128,128,1),
                                    BlockA(128,128,128,1))
        
        self.block5 = nn.Sequential(BlockA(128,256,256,2),
                                    BlockA(256,256,256,1),
                                    BlockA(256,256,256,1),
                                    BlockA(256,256,256,1),
                                    BlockA(256,256,256,1),
                                    BlockA(256,256,256,1))
        
        self.block6 = nn.Sequential(BlockA(256,512,512,2),
                                    BlockA(512,512,512,1),
                                    BlockA(512,512,512,1))
        
        self.pool7 = nn.AvgPool2d(kernel_size=3, stride=1, padding=1)
        self.fc8 = nn.Linear(in_features=512*7*7,out_features=1000)
        
        
    def forward(self, x):
        x = self.conv1(x)
        x = self.pool2(x)
        x = self.block3(x)
        x = self.block4(x)
        x = self.block5(x)
        x = self.block6(x)
        x = self.pool7(x)
        x = x.view(x.size()[0], -1)
        x = self.fc8(x)
        return x

In [15]:
net = ResNet34()
input_data = t.randn(10, 3,  224, 224)
output = net(input_data)
print(output.size())

torch.Size([10, 1000])


In [41]:
class ResNet50(nn.Module):
    def __init__(self):
        super(ResNet50, self).__init__()
        self.conv1 = nn.Conv2d(in_channels=3, out_channels=64, kernel_size=7, stride=2, padding=3)
        self.pool2 = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
        
        self.block3 = nn.Sequential(BlockB(64,64,64,256,1),
                                    BlockB(256,64,64,256,1),
                                    BlockB(256,64,64,256,1)) 
        
        self.block4 = nn.Sequential(BlockB(256,128,128,512,2),
                                    BlockB(512,128,128,512,1),
                                    BlockB(512,128,128,512,1),
                                    BlockB(512,128,128,512,1))
        
        self.block5 = nn.Sequential(BlockB(512,256,256,1024,2),
                                    BlockB(1024,256,256,1024,1),
                                    BlockB(1024,256,256,1024,1),
                                    BlockB(1024,256,256,1024,1),
                                    BlockB(1024,256,256,1024,1),
                                    BlockB(1024,256,256,1024,1))
        
        self.block6 = nn.Sequential(BlockB(1024,512,512,2048,2),
                                    BlockB(2048,512,512,2048,1),
                                    BlockB(2048,512,512,2048,1))
        
        self.pool7 = nn.AvgPool2d(kernel_size=3, stride=1, padding=1)
        self.fc8 = nn.Linear(in_features=2048*7*7,out_features=1000)
        
        
    def forward(self, x):
        x = self.conv1(x)
        x = self.pool2(x)
        x = self.block3(x)
        x = self.block4(x)
        x = self.block5(x)
        x = self.block6(x)
        x = self.pool7(x)
        x = x.view(x.size()[0], -1)
        x = self.fc8(x)
        return x

In [42]:
net = ResNet50()
input_data = t.randn(10, 3,  224, 224)
output = net(input_data)
print(output.size())

torch.Size([10, 1000])
