In [1]:
from gluoncv.model_zoo import get_model
import mxnet as mx

from mxboard import SummaryWriter
from mxnet import gluon, nd, image
from mxnet.gluon.data.vision import transforms
from mxnet import autograd as ag
from mxnet.gluon import nn

In [2]:
def vgg_block(num_convs, channels):
    out = nn.HybridSequential()
    for _ in range(num_convs):
        out.add(
            nn.Conv2D(channels=channels, kernel_size=3, padding=1),
            nn.BatchNorm(axis=1),
            nn.Activation(activation='relu'),

            nn.Conv2D(channels=channels, kernel_size=1, padding=0),
            nn.BatchNorm(axis=1),
            nn.Activation(activation='relu'),
        )
    out.add(nn.MaxPool2D(pool_size=2, strides=2))
    return out


def vgg_stack(architecture):
    out = nn.HybridSequential()
    for (num_convs, channels) in architecture:
        out.add(vgg_block(num_convs, channels))
    return out


n_classes = 10
architecture = ((2, 64), (2, 256), (2, 512))

class Xt(nn.HybridBlock):
    def __init__(self, **kwargs):
        super().__init__(**kwargs)
        widenet = get_model('cifar_wideresnet40_8', classes = 10, pretrained = True, ctx = d.ctx)
        self.cifar_wideresnet = widenet.features[:8]
        self.cifar_wideresnet.add(nn.AvgPool2D(pool_size=(2, 2), strides=2, padding=0))
        
        self.vgg = vgg_stack(architecture)
        self.net = nn.HybridSequential()
        self.net.add(
            nn.Conv2D(channels=128, kernel_size=1, padding=0),
            nn.BatchNorm(axis=1),
            nn.Activation(activation='relu'),

            nn.Conv2D(channels=n_classes, kernel_size=1, padding=0),
            nn.GlobalAvgPool2D(),
            nn.Flatten()
        )
    def hybrid_forward(self, F, x):
        out1 = self.net(self.cifar_wideresnet(x)) 
        out2 = self.net(self.vgg(x))
        out = out1 + out2
        return out