In [1]:
import mxnet as mx
from mxnet import ndarray as nd
from mxnet.gluon import nn
import gluoncv as gcv

In [None]:
class Residual(nn.HybridBlock):
    def __init__(self, num_channels, use_1x1conv=False, strides=1, **kwargs):
        super(Residual, self).__init__(**kwargs)
        self.conv1 = nn.Conv2D(num_channels, kernel_size=3, padding=1, strides=strides)
        self.conv2 = nn.Conv2D(num_channels, kernel_size=3, padding=1)
        if use_1x1conv:
            self.conv3 = nn.Conv2D(num_channels, kernel_size=1, strides=strides)
        else:
            self.conv3 = None
        self.bn1 = nn.BatchNorm()
        self.bn2 = nn.BatchNorm()
    
    def forward(self, X):
        Y = nd.relu(self.bn1(self.conv1(X)))
        Y = self.bn2(self.conv2(Y))
        if self.conv3:
            X = self.conv3(X)
        return nd.relu(Y + X)

In [None]:
blk = Residual(3)
blk.initialize()
X = nd.random.uniform(shape=(4,3,6,6))
blk(X).shape

In [None]:
blk = Residual(6, use_1x1conv=True, strides=2)
blk.initialize()
blk(X).shape

In [None]:
#Resnet
net = nn.HybridSequential()
net.add(nn.Conv2D(64, kernel_size=7, strides=2, padding=3),
       nn.BatchNorm(), nn.Activation('relu'), nn.MaxPool2D(pool_size=3, strides=2, padding=1))

In [None]:
 def resnet_block(num_channels, num_residuals, first_block=False):
        blk = nn.HybridSequential()
        for i in range(num_residuals):
            if i == 0 and not first_block:
                blk.add(Residual(num_channels, use_1x1conv=True, strides=2))
            else:
                blk.add(Residual(num_channels))
        return blk

In [None]:
net.add(resnet_block(64, 2, first_block=True),
       resnet_block(128, 2),
       resnet_block(256, 2),
       resnet_block(512, 2))

In [None]:
net.add(nn.GlobalAvgPool2D(), nn.Dense(10))

In [None]:
X = nd.random.uniform(shape=(1,1,224,224))
net.initialize(mx.init.Xavier())
for layer in net:
    X = layer(X)
    print(layer.name, 'output shape:\t', X.shape)

In [8]:
#Plot Resnet50 Residual

class ResidualHybridBlock(nn.HybridBlock):
    def __init__(self, num_channels, kernel_size=1, strides=1, first_block=False, **kwargs):
        super(ResidualHybridBlock, self).__init__(**kwargs)
        self.conv1 = nn.Conv2D(num_channels, kernel_size=kernel_size, padding=1, strides=strides)
        self.conv2 = nn.Conv2D(num_channels, kernel_size=3, padding=1)
        self.conv3 = nn.Conv2D(256, kernel_size=1, padding=1)
        
        self.bn0 = nn.BatchNorm()
        self.bn1 = nn.BatchNorm()
        self.bn2 = nn.BatchNorm()
        
        self.first_block = first_block
        
    def forward(self, X):
        if self.first_block:
            X = nn.Activation('relu')(self.bn0(X))
        Y = nn.Activation('relu')(self.bn1(self.conv1(X)))
        Y = self.bn2(self.conv2(Y))
        Y = self.conv3(nn.Activation('relu')(Y))
        if self.first_block:
            X = nn.Conv2D(256, kernel_size=1, padding=2)(X)
        return (Y + X)

In [9]:
def get_net():
    net = nn.HybridSequential()
    net.add(ResidualHybridBlock(num_channels=64, first_block=True))
    net.initialize()
    return net
t = get_net()

In [10]:
gcv.utils.viz.plot_network(t)

In [None]:
t.summary