# StrassenNets: Deep Learning with a Multiplication Budget

This notebook provides a step-by-step implementation to reproduce part of the experiments in *StrassenNets: Deep Learning with a Multiplication Budget* by Michael Tschannen, Aran Khanna, and Anima Anandkumar (http://bit.ly/2BJZj5a).

#### Paper abstract

A large fraction of the arithmetic operations required to evaluate deep neural networks (DNNs) are due to matrix multiplications, both in convolutional and fully connected layers. Matrix multiplications can be cast as $2$-layer sum-product (SP) networks (arithmetic circuits), disentangling multiplications and additions. We leverage this observation for end-to-end learning of low-cost (in terms of multiplications) approximations of linear operations in DNN layers. Specifically, we propose to replace matrix multiplication operations by sum-product networks, with widths corresponding to the budget of multiplications we want to allocate to each layer, and learning the edges of the SP networks from data. Experiments on CIFAR-10 and ImageNet show that this method applied to ResNet yields significantly higher accuracy than existing methods for a given multiplication budget, or leads to the same or higher accuracy compared to existing methods while using significantly fewer multiplications. Furthermore, our approach allows fine-grained control of the tradeoff between arithmetic complexity and accuracy of DNN models. Finally, we demonstrate that the proposed framework can rediscover Strassen's algorithm, i.e., it can learn to multiply $2 \times 2$ matrices using only $7$ multiplications instead of $8$. 


#### Learning fast matrix multiplications via sum-product networks
$\renewcommand{\vec}{\mathrm{vec}}$Given square matrices $A, B \in \mathbb{R}^{n \times n}$, the product $C = A B$ can be represented as a $2$-layer sum-product (SP) network (or ``Strassen network'')

$$
    \label{eq:strnet}
    \vec(C) = W_C [ (W_B \vec(B)) \odot (W_A \vec(A))] \qquad \qquad (1)
$$

where $W_A, W_B \in \mathbb{K}^{r \times n^2}$ and $W_C \in \mathbb{K}^{n^2 \times r}$, $\mathbb{K}:= \{-1,0,1\}$ are fixed, and $\odot$ denotes the element-wise product.

<img src="files/spnetwork.png" alt="spnetwork" style="width: 400px;"/>

1. $A, B\in \mathbb{R}^{2 \times 2}$: Strassen's matrix multiplication algorithm (https://en.wikipedia.org/wiki/Strassen_algorithm) tells us ternary weights that satisfy (1) for $r=7$ (instead of $r = 8$).
2. General case $A \in \mathbb{R}^{k \times m}$, $B \in \mathbb{R}^{m \times n}$: $C = A B$ can be written in the form (1) if $r \geq nmk$.
3. If $A \in \mathbb{R}^{k \times m}$ is fixed and $B \in \mathbb{R}^{m \times n}$ concentrates on low-dimensional subspace of $\mathbb{R}^{k \times m}$: Can find $W_A, W_B$ s.t. (1) holds approximately even when $r \ll nmk$

*Idea:* Leverage item 3 to learn low-cost (in terms of multiplications) approximate linear operations in neural network layers by associating $A$ with (pretrained) weights/filters and $B$ with corresponding activations/feature maps, and by learning $W_A$, $W_B$, $W_C$ end-to-end. Please see the paper for a more in-depth description of the method.

In this notebook, we demonstrate this idea by compressing the 2D convolution layers of ResNet-20 trained on CIFAR-10.

#### Outline of the notebook:
1. Implement the ternary quantization operator to learn ternary weights $W_B$ and $W_C$
2. Implement the SP 2D convolution layer (i.e. a Strassen network for 2D convolutions)
3. Implement residual blocks with SP layers and use these to construct a SP-ResNet-20
4. Define iterators and the training loop
5. Train the network

#### Requirements

The code is based on Apache MXNet 0.12 (http://mxnet.incubator.apache.org/install/index.html) and the gluon package.

## 1. Implement the ternary quantization operator

Following the tutorial from [1] to define custom gluon operators, we define an operator implementing the ternarization method proposed in [2]. Using the notation from [2], this method quantizes the entries of a given input tensor to values $\{-\alpha, 0, \alpha\}$ ($\alpha$ corresponds to `scale` below) based on a threshold $\Delta$ (corresponding to `thresh` below) on the full precision entries. $\alpha$ and $\Delta$ are determinined by approximately solving an optimization problem, see [2]. The operator will be used to quanzie $W_B$ and $W_C$ (the factor $\alpha$ can be absorbed into $\tilde a = W_A \vec(A)$ after training to ensure that the entries of $W_B$ and $W_C$ are in $\{-1,0,1\}$). 

[1] https://github.com/apache/incubator-mxnet/blob/master/docs/tutorials/gluon/customop.md

[2] F. Li, B. Zhang, B. Liu, Ternary Weight Networks, 2016, https://arxiv.org/abs/1605.04711v2


In [1]:
import mxnet as mx
import numpy as np

# define the quantization operator
class TerQuant(mx.operator.CustomOp):
    # define forward pass as described in [1]
    def forward(self, is_train, req, in_data, out_data, aux):
        full_data = in_data[0]

        # compute $\Delta$ as in [1], Eq. (6)
        thresh = mx.nd.mean(mx.nd.abs(full_data), axis=()).asnumpy()[0] * 0.7

        # ternarize weights
        quant_data = mx.nd.greater(full_data, thresh*mx.nd.ones(full_data.shape, full_data.context))\
            - mx.nd.lesser(full_data, (-1.0)*thresh*mx.nd.ones(full_data.shape, full_data.context))

        # compute $\alpha$ as in [1], Eq. (5)
        scale = mx.nd.sum(mx.nd.multiply(full_data, quant_data), axis=()).asnumpy()[0]\
            / mx.nd.sum(mx.nd.abs(quant_data), axis=()).asnumpy()[0]

        # rescale ternary weights
        quant_data = scale * quant_data
        self.assign(out_data[0], req[0], quant_data)

    # define backward pass: simply pass on gradient from previous operation (derivative of the identity function)
    def backward(self, req, out_grad, in_data, out_data, in_grad, aux):
        self.assign(in_grad[0], req[0], out_grad[0])

# wrap the operator and provide functions to get information about input, output etc.
@mx.operator.register('ter_quant')
class TerQuantProp(mx.operator.CustomOpProp):
    def __init__(self):
        super(TerQuantProp, self).__init__(True)

    def list_arguments(self):
        return ['data']

    def list_outputs(self):
        return ['output']

    def infer_shape(self, in_shapes):
        data_shape = in_shapes[0]
        output_shape = data_shape
        return (data_shape,), (output_shape,), ()

    def create_operator(self, ctx, in_shapes, in_dtypes):
        return TerQuant()

## 2. Implement the SP 2DConvLayer
$\newcommand{\cin}{c_\mathrm{in}}\newcommand{\cout}{c_\mathrm{out}}$We are now ready to implement the SP 2D convolution layer (as a gluon `HybridBlock` object). Specifically, we compress the computation of $\cout \times p \times p$ output elements of the convolution from $\cin \times (p - 1 + k) \times (p - 1 + k)$ input elements by applying (see figure below)

1. a $p$-strided 2D convolution with a *ternary* filter of size $r \times \cin \times (p - 1 + k) \times (p - 1 + k)$ (corresponding to multiplication with $W_B$),
2. a channel-wise multiplication with an $r \times 1 \times 1$ *full-precision* filter $\tilde a = W_A \vec(A)$ (corresponding to the filters of the original convolution),
3. a $1/p$-strided transposed 2D convolution with a *ternary* filter of size $\cout \times r \times p \times p$ (corresponding to multiplication with $W_C$).

<img src="files/sp2dconv.png" alt="spnetwork" style="width: 600px;"/>

As we will pretrain the convolutions described in items 1 and 3 above with full-precision weights, we include a variable `mode` that determines whether quantization is active or not. Furthermore, we include the original convolution to selectively replace the SP convolutions with ordinary ones (not used in this notebook).

In [2]:
from mxnet import gluon, autograd
from mxnet.gluon.block import HybridBlock
from mxnet.base import numeric_types
from math import ceil

class SumProd2DConv(mx.gluon.HybridBlock):
    r"""
    nbr_mul : int
        Width of the sum-product network, corresponding to the number of multiplications r used to compute an
        $p_x$ x $p_y$ x channels activation element
    target_layer_shape:
        Shape of the filter to be replaced by the sum-product network
    target_layer_key:
        Name of the target layer in the original ResNet (to enable loading from a pretrained network)
    out_patch : int
        size ($p_x$, $p_y$) of the output patch for spatial compression
    kernel, stride, pad : specifics of the 2D convolution
    """
    
    def __init__(self, nbr_mul, target_layer_shape, target_layer_key, out_patch=(1,1), kernel=(3,3), stride=(1,1), pad=(1,1), prefix=None, **kwargs):
        # Use same prefix as parent layer for loading the shared paramters to be compressed
        super(SumProd2DConv, self).__init__(prefix=prefix, **kwargs)

        if isinstance(stride, numeric_types):
            stride = (stride,)*len(kernel)
        if isinstance(pad, numeric_types):
            pad = (pad,)*len(kernel)

        self._nbr_mul = nbr_mul
        # corresponds to $p$ in the description above
        self._out_patch = out_patch
        self._kernel = kernel
        self._stride = stride
        self._pad = pad

        # compute kernel shapes and stride for $W_B$
        self._sp_data_weight_kernel = ((self._out_patch[0]-1)*self._stride[0]+self._kernel[0], (self._out_patch[1]-1)*self._stride[1]+self._kernel[1])
        self._sp_data_weight_stride = (self._out_patch[0]*self._stride[0], self._out_patch[1]*self._stride[1])

        # filters for the original convolution
        self.filter_weights = self.params.get(target_layer_key, shape=target_layer_shape)

        # Assuming the default NCHW layout, determine shapes of $W_B$ and $W_C$
        filter_weights_shape = list(target_layer_shape)
        sp_data_weights_shape = filter_weights_shape.copy()
        sp_data_weights_shape[0] = nbr_mul
        sp_data_weights_shape[2] = self._sp_data_weight_kernel[0]
        sp_data_weights_shape[3] = self._sp_data_weight_kernel[1]
        sp_out_weights_shape = filter_weights_shape.copy()
        sp_out_weights_shape[0] = nbr_mul
        sp_out_weights_shape[1] = filter_weights_shape[0]
        sp_out_weights_shape[2] = self._out_patch[0]
        sp_out_weights_shape[3] = self._out_patch[1]

        ## SP network weights
        # $\tilde a = W_A \vec(A)$
        self.sp_in_weights = self.params.get('sp_in_weights', shape=(1,self._nbr_mul,1,1))
        # $W_B$
        self.sp_data_weights = self.params.get('sp_data_weights', shape=sp_data_weights_shape)
        # $W_C$
        self.sp_out_weights =  self.params.get('sp_out_weights', shape=sp_out_weights_shape)

        # SP network batchnorm parameters
        self.sp_batchnorm_gamma = self.params.get('sp_batchnorm_gamma', shape=(1,self._nbr_mul,1,1))
        self.sp_batchnorm_beta = self.params.get('sp_batchnorm_beta', shape=(1,self._nbr_mul,1,1))
        self.sp_batchnorm_running_mean = self.params.get('sp_batchnorm_running_mean', shape=(1,self._nbr_mul,1,1))
        self.sp_batchnorm_running_var = self.params.get('sp_batchnorm_running_var', shape=(1,self._nbr_mul,1,1))

        # mode: 1: original convolution; 2: SP network without ternary quantization; 3: SP network with ternary quantization
        self.sp_mode = self.params.get('sp_mode', shape=(1,), dtype=np.uint8)


    def forward(self, data):
        with data.context:
            mode = self.sp_mode.data().asnumpy()[0]
            filter_weights = self.filter_weights.data()

            if mode == 0:
                # perform original convolution
                conv_out = mx.nd.Convolution(data=data,
                                  weight=filter_weights,
                                  bias=None,
                                  no_bias=True,
                                  kernel=self._kernel,
                                  stride=self._stride,
                                  pad=self._pad,
                                  num_filter=filter_weights.shape[0])

                return conv_out
            else:
                sp_data_weights = self.sp_data_weights.data()
                sp_out_weights = self.sp_out_weights.data()

                # quantzie $W_B$ and $W_C$ if desired
                if mode == 2:
                    sp_data_weights = mx.nd.Custom(sp_data_weights, op_type='ter_quant')
                    sp_out_weights = mx.nd.Custom(sp_out_weights, op_type='ter_quant')

                pad_x = self._pad[0] + ceil((data.shape[2]%self._sp_data_weight_stride[0])/2)
                pad_y = self._pad[1] + ceil((data.shape[3]%self._sp_data_weight_stride[1])/2)

                # compute $W_B \vec(B)$
                sp_data_mul = mx.nd.Convolution(data=data,
                                  weight=sp_data_weights,
                                  bias=None,
                                  no_bias=True,
                                  kernel=self._sp_data_weight_kernel,
                                  stride=self._sp_data_weight_stride,
                                  pad=(pad_x, pad_y),
                                  num_filter = self._nbr_mul)

                # apply batchnorm to $W_B \vec(B)$
                sp_data_mul_norm = mx.nd.BatchNorm(data=sp_data_mul,
                                              gamma=self.sp_batchnorm_gamma.data(),
                                              beta=self.sp_batchnorm_beta.data(),
                                              moving_mean=self.sp_batchnorm_running_mean.data(),
                                              moving_var=self.sp_batchnorm_running_var.data(),
                                              fix_gamma=True)

                # apply non-negativity constraint to $\tilde a = W_A \vec(A)$
                with mx.autograd.pause():
                    self.sp_in_weights.set_data(mx.nd.clip(data=self.sp_in_weights.data(), a_min=0.0, a_max=100.0))
                
                # compute $(W_A \vec(A)) \odot (W_B \vec(B))$
                sp_mul = mx.nd.multiply(self.sp_in_weights.data(), sp_data_mul_norm)
                
                # compute $W_C[(W_A \vec(A)) \odot (W_B \vec(B))]$
                sp_out = mx.nd.Deconvolution(data = sp_mul,
                                  weight = sp_out_weights,
                                  bias = None,
                                  no_bias = True,
                                  kernel = self._out_patch,
                                  stride = self._out_patch,
                                  target_shape = (data.shape[2]//self._stride[0], data.shape[3]//self._stride[1]),
                                  num_filter = filter_weights.shape[0])

                return sp_out

## 3. Implement the sum-product residual unit

We continue by defining the residual unit as proposed in [3] using sum-product 2D convolutions. The code block below is adapted from [4] and looks lengthy, but there are only small modifications compared to [4]. Specifically, we replace ordinary 2D convolutions by sum-product 2D convolutions and add parameters for the sum-product convolutions.

[3] K. He, X. Zhang, S. Ren, J. Sun, Deep Residual Learning for Image Recognition, 2016, http://arxiv.org/abs/1512.03385

[4] https://github.com/apache/incubator-mxnet/blob/master/python/mxnet/gluon/model_zoo/vision/resnet.py

In [3]:
from mxnet.gluon import nn

class BasicBlockV1SumProd(HybridBlock):
    r"""
    Parameters
    ----------
    channels : int
        Number of output channels.
    stride : int
        Stride size.
    nbr_mul : int
        Width of the sum-product network, corresponding to the number of multiplications r used to compute an
        $p_x$ x $p_y$ x channels activation element
    out_patch : int
        size ($p_x$, $p_y$) of the output patch for spatial compression
    conv_idx : int
        index of the convolution, used for parameter names
    downsample : bool, default False
        Whether to downsample the input.
    in_channels : int, default 0
        Number of input channels. Default is 0, to infer from the graph.
    quant_down : bool
        Whether to quantize ResNet 1x1 convolution projection layers
    """
    def __init__(self, channels, stride, nbr_mul, out_patch=(1,1), conv_idx=0, downsample=False, in_channels=0, quant_down=True, **kwargs):
        super(BasicBlockV1SumProd, self).__init__(**kwargs)
        self.body = nn.HybridSequential(prefix='')

        prefix1 = 'conv%i_'%conv_idx
        # shape of the original convolution
        orig_shape1 = (channels, in_channels, 3, 3)
        # kernel of the original convolution
        kernel = (3,3)

        self.body.add(SumProd2DConv(nbr_mul=nbr_mul,
                                      target_layer_shape=orig_shape1,
                                      target_layer_key='weight',
                                      out_patch=out_patch,
                                      kernel=kernel,
                                      stride=stride,
                                      prefix=prefix1))

        self.body.add(nn.BatchNorm())
        self.body.add(nn.Activation('relu'))
        
        prefix2 = 'conv%i_'%(conv_idx+1)
        orig_shape2 = (channels, channels, 3, 3)

        self.body.add(SumProd2DConv(nbr_mul=nbr_mul,
                                      target_layer_shape=orig_shape2,
                                      target_layer_key='weight',
                                      out_patch=out_patch,
                                      kernel=kernel,
                                      stride=1,
                                      prefix=prefix2))

        self.body.add(nn.BatchNorm())
        
        if downsample:
            self.downsample = nn.HybridSequential(prefix='')

            # Quantize projection layer if required
            if quant_down:
                prefix_ds = 'conv%i_'%(conv_idx+2)
                orig_shape_ds = (channels, in_channels, 1, 1)

                self.downsample.add(SumProd2DConv(nbr_mul=nbr_mul,
                                              target_layer_shape=orig_shape_ds,
                                              target_layer_key='weight',
                                              out_patch=(1,1),
                                              kernel=(1,1),
                                              stride=stride,
                                              pad=(0,0),
                                              prefix=prefix_ds))
            else:
                self.downsample.add(nn.Conv2D(channels, kernel_size=1, strides=stride,
                                              use_bias=False, in_channels=in_channels))

            self.downsample.add(nn.BatchNorm())
        else:
            self.downsample = None

    def hybrid_forward(self, F, x):
        residual = x

        x = self.body(x)

        if self.downsample:
            residual = self.downsample(residual)

        x = F.Activation(residual+x, act_type='relu')

        return x

## 4. Implement the sum-product ResNet

Again, we modify the code to construct ResNet (as a gluon `HybridBlock` object) for various dephts from [4], replacing ordinary 2D convolutions by sum-product 2D convolutions, and the fully connected output layer by a sum-product network if desired. As before, we add in some parameters for the sum-product networks.

In [4]:
from mxnet.gluon.model_zoo.vision.resnet import _conv3x3

class ResNetV1SumProd(HybridBlock):
    r"""ResNet V1 model from
    `"Deep Residual Learning for Image Recognition"
    <http://arxiv.org/abs/1512.03385>`_ paper.

    Parameters
    ----------
    block : HybridBlock
        Class for the residual block. Options are BasicBlockV1, BottleneckV1.
    layers : list of int
        Numbers of layers in each block
    channels : list of int
        Numbers of channels in each block. Length should be one larger than layers list.
    nbr_mul : int
        Width of the sum-product network, corresponding to the number of multiplications r used to compute an
        $p_x$ x $p_y$ x channels activation element
    out_patch : int
        size ($p_x$, $p_y$) of the output patch for spatial compression
    classes : int, default 1000
        Number of classification classes.
    thumbnail : bool, default False
        Enable thumbnail.
    quant_down : bool
        Whether to quantize ResNet 1x1 convolution projection layers
    """
    def __init__(self, layers, channels, nbr_mul, out_patch, classes=1000, thumbnail=False, quant_down=True, prefix='resnetv10_', **kwargs):
        super(ResNetV1SumProd, self).__init__(prefix, **kwargs)
        
        with self.name_scope():
            self.features = nn.HybridSequential(prefix='')
            # Thumbnail determines the characteristics of the first layer. If thumbnail==True, 3x3 convolutions are used
            # and no pooling is performed (this setting is used for CIFAR). Otherwise, 7x7 filters with max-pooling are used.
            if thumbnail:
                # replace first layer with a sum-product layer if desired
                if nbr_mul[0] > 0:
                    orig_shape = (channels[0], 3, 3, 3)

                    self.features.add(SumProd2DConv(nbr_mul = nbr_mul[0],
                                                  target_layer_shape = orig_shape,
                                                  target_layer_key = 'weight',
                                                  kernel = (3, 3),
                                                  stride = (1, 1),
                                                  out_patch = out_patch[0],
                                                  prefix = 'conv0_'))
                else:
                    self.features.add(_conv3x3(channels[0], 1, 3))

            else:
                # replace first layer with a sum-product layer if desired
                if nbr_mul[0] > 0:
                    orig_shape = (channels[0], 3, 7, 7)

                    self.features.add(SumProd2DConv(nbr_mul = nbr_mul[0],
                                                  target_layer_shape = orig_shape,
                                                  target_layer_key = 'weight',
                                                  kernel = (7, 7),
                                                  stride = (2, 2),
                                                  pad = (3,3),
                                                  out_patch = out_patch[0],
                                                  prefix = 'conv0_'))
                else:
                    self.features.add(nn.Conv2D(channels[0], 7, 2, 3, use_bias=False,
                                                in_channels=3))
                self.features.add(nn.BatchNorm())
                self.features.add(nn.Activation('relu'))
                self.features.add(nn.MaxPool2D(3, 2, 1))

            for i, num_layer in enumerate(layers):
                stride = 1 if i == 0 else 2
                self.features.add(self._make_sumprod_layer(num_layer, channels[i+1], stride,
                                                   nbr_mul[i+1], out_patch[i+1], i+1, in_channels=channels[i], quant_down=quant_down))

            self.classifier = nn.HybridSequential(prefix='')
            self.classifier.add(nn.GlobalAvgPool2D())
            self.classifier.add(nn.Flatten())
            self.classifier.add(nn.Dense(classes, in_units=channels[-1]))

    # stack residual blocks as described in [2]
    def _make_sumprod_layer(self, layers, channels, stride, nbr_mul, out_patch, stage_index, in_channels=0, quant_down=True):
        layer = nn.HybridSequential(prefix='stage%d_'%stage_index)
        with layer.name_scope():
            layer.add(BasicBlockV1SumProd(channels, stride, nbr_mul, out_patch, 0, channels != in_channels, in_channels=in_channels, quant_down=quant_down, prefix=''))
            nbr_conv = 2 if channels == in_channels else 3
            for i in range(layers-1):
                layer.add(BasicBlockV1SumProd(channels, 1, nbr_mul, out_patch, nbr_conv, False, in_channels=channels, quant_down=quant_down, prefix=''))
                nbr_conv += 2
        return layer

    def hybrid_forward(self, F, x):
        x = self.features(x)
        x = self.classifier(x)

        return x

# 5. Training the network

We are now ready to train and evaluate the SP-ResNet-20 on CIFAR-10 as described in the paper.

We start by defining the SP-Resnet-20. The parameters below realize the configuration for $p=1$ and $r=c_\mathrm{out}$ in Table 1 (first row, first column) in the paper, and can be modified for other configurations. Furthermore, the context can be modified below (use `ctx = [mx.cpu(0)]` if you don't have a GPU available, or extend to a list of devices for parallel training).

In [5]:
# CIFAR-10 has 10 classes
classes = 10
# structure of ResNet-20
res_units = [3, 3, 3]
res_channels = [16, 16, 32, 64]
# choose the width of the sum-product convolution networks (i.e., $r$) to be equal to the number of output channels
# don't compress the last layer (encoded by nbr_mul=0)
nbr_mul = [16, 16, 32, 64, 0]
# don't do spatial compression (out. patch=1x1)
out_patch = [(1,1)]*4
# quantize ResNet 1x1 convolution projection layers
quant_down = True
# use 3x3 kernels for the first convolution layer, don't do pooling after the first convolution layer
thumbnail = True

# construct ResNet HybridBlock object
sp_resnet = ResNetV1SumProd(res_units, res_channels, nbr_mul, out_patch, classes, thumbnail, quant_down)

# define context
ctx = [mx.gpu(0)]

Next, we initialize the network parameters. As we first train all weights at full precision first, `sp_mode` is set to  1. `sp_in_weights` (corresponding to $\tilde a = W_A \mathrm{vec}(A)$) is initialized to 1s and Xavier initialization is used for `sp_data_weights` and `sp_out_weights` (corresponding to $W_B$ and $W_C$, respectively), and for the fully connected output layer. Furthermore, optimization is deactivated for the auxiliary batchnorm variables.

In [6]:
# helper function to filter and intialize paramters whose name contains a given substring
def initialize_by_key(net, key, init, context, force_reinit=False):
    for k, p in net.collect_params().items():
        if key in k:
            p.initialize(init=init, ctx=context, force_reinit=force_reinit)

# initialize SP networks
initialize_by_key(sp_resnet, 'sp_mode', mx.init.One(), ctx)
initialize_by_key(sp_resnet, 'sp_in_weights', mx.init.One(), ctx)
initialize_by_key(sp_resnet, 'sp_data_weights', mx.init.Xavier(magnitude=2), ctx)
initialize_by_key(sp_resnet, 'sp_out_weights', mx.init.Xavier(magnitude=2), ctx)

# intialize batchnorm parameters
initialize_by_key(sp_resnet, 'gamma', mx.init.One(), ctx)
initialize_by_key(sp_resnet, 'beta', mx.init.Zero(), ctx)
initialize_by_key(sp_resnet, 'running_mean', mx.init.Zero(), ctx)
initialize_by_key(sp_resnet, 'running_var', mx.init.Zero(), ctx)

# initialize fully connected layer
initialize_by_key(sp_resnet, 'dense0_weight', mx.init.Xavier(magnitude=2), ctx) #mx.init.Normal(1)
initialize_by_key(sp_resnet, 'dense0_bias', mx.init.Zero(), ctx)

# deactivate optimization for original convolutions as they are not used here
for k, p in sp_resnet.collect_params().items():
    if 'conv' in k and 'sp' not in k:
        p.initialize(init=mx.init.Xavier(magnitude=2), ctx=ctx)
        p.grad_req = 'null'

# deactivate optimization for batchnorm auxiliary variables
for k, p in sp_resnet.collect_params().items():
    if 'running_mean' in k or 'running_var' in k or 'sp_mode' in k:
        p.grad_req = 'null'

We continue by defining a testing routine and a training loop. This again looks lengthy but is standard and adapted from [5] with slight modifications.

[5] https://github.com/apache/incubator-mxnet/blob/master/example/gluon/image_classification.py

In [7]:
import time

# testing routine
def test(net, val_data, context):
    metric = mx.metric.Accuracy()
    val_data.reset()
    for batch in val_data:
        data = gluon.utils.split_and_load(batch.data[0], ctx_list=context, batch_axis=0)
        label = gluon.utils.split_and_load(batch.label[0], ctx_list=context, batch_axis=0)
        outputs = []
        for x in data:
            outputs.append(net(x))
        metric.update(label, outputs)
    return metric.get()


# training routine
def train(net, lr, mom, wd, epochs, scheduler, context, train_data, val_data, log_interval):
    # define trainer, metric, loss
    trainer = gluon.Trainer(net.collect_params(), 'sgd',
                            {'learning_rate': lr, 'wd': wd, 'momentum': mom, 'lr_scheduler': scheduler},
                            kvstore = 'device')
    metric = mx.metric.Accuracy()
    loss = gluon.loss.SoftmaxCrossEntropyLoss()

    for epoch in range(epochs):
        tic = time.time()
        train_data.reset()
        metric.reset()
        btic = time.time()
        # loop through batches
        for i, batch in enumerate(train_data):
            data = gluon.utils.split_and_load(batch.data[0], ctx_list=ctx, batch_axis=0)
            label = gluon.utils.split_and_load(batch.label[0], ctx_list=ctx, batch_axis=0)
            outputs = []
            Ls = []
            with autograd.record():
                # forward pass 
                for x, y in zip(data, label):
                    z = net(x)
                    L = loss(z, y)
                    Ls.append(L)
                    outputs.append(z)
                # backward pass
                for L in Ls:
                    L.backward()
            # SGD step
            trainer.step(batch.data[0].shape[0])
            metric.update(label, outputs)
            if log_interval and not (i+1)%log_interval:
                name, acc = metric.get()
                print('Epoch[%d] Batch [%d]\tSpeed: %f samples/sec\t%s=%f'%(
                               epoch, i, batch_size/(time.time()-btic), name, acc))
            btic = time.time()

        name, acc = metric.get()
        print('[Epoch %d] training: %s=%f'%(epoch, name, acc))
        print('[Epoch %d] time cost: %f'%(epoch, time.time()-tic))
        name, val_acc = test(net, val_data, context)
        print('[Epoch %d] validation: %s=%f'%(epoch, name, val_acc))

Now, we define the iterators, employing the same data augmentation procedures as described in [3], Sec. 4.2. The CIFAR-10 data set should be downloaded automatically if not available locally (otherwise, you can donwload it manually from http://data.mxnet.io/mxnet/data/cifar10.zip, create a subfolder 'data' in the folder containing this notebook, and extract it there).

In [8]:
mx.random.seed(321)
np.random.seed(321)

data_shape = (3, 32, 32)
batch_size = 128

# download CIFAR-10 if necessary
import os.path
if (not os.path.isfile("data/cifar/train.rec")) or (not os.path.isfile("data/cifar/test.rec")):
    zip_file_path = mx.test_utils.download('http://data.mxnet.io/mxnet/data/cifar10.zip', dirname='data')
    import zipfile
    with zipfile.ZipFile(zip_file_path) as zf:
        zf.extractall('data')

# training set iterator
train_data = mx.io.ImageRecordIter(
    path_imgrec   = "data/cifar/train.rec",
    data_shape    = data_shape,
    batch_size    = batch_size,
    mean_r        = 125.3,
    mean_g        = 123.0,
    mean_b        = 113.9,
    std_r         = 63.0,
    std_g         = 62.1,
    std_b         = 66.7,
    dtype         = 'float32',
    rand_crop     = True,
    max_crop_size = 32,
    min_crop_size = 32,
    pad           = 4,
    fill_value    = 0,
    shuffle       = True,
    rand_mirror   = True,
    shuffle_chunk_seed  = 123)

# validation set iterator
val_data = mx.io.ImageRecordIter(
    mean_r      = 125.3,
    mean_g      = 123.0,
    mean_b      = 113.9,
    std_r       = 63.0,
    std_g       = 62.1,
    std_b       = 66.7,
    dtype       = 'float32',
    path_imgrec = "data/cifar/test.rec",
    rand_crop   = False,
    rand_mirror = False,
    data_shape  = data_shape,
    batch_size  = batch_size)

Specify optimizer paramters, regularization, and learning rate schedule (150-50-50 epochs x0.1)... 

In [9]:
epochs = 250
learning_rate = 0.1
momentum = 0.9
weight_decay = 0.0001

train_set_size = 50000
# introduce auxiliary variable to translate steps into epochs
schedule_factor = ceil(train_set_size/batch_size)
scheduler = mx.lr_scheduler.MultiFactorScheduler([150*schedule_factor, 200*schedule_factor, 250*schedule_factor], factor=0.1)

log_interval = 50

... and train the SP-ResNet-20 from scratch with full precision entries for $W_B$, $W_C$.

In [10]:
train(sp_resnet, learning_rate, momentum, weight_decay, epochs, scheduler, ctx, train_data, val_data, log_interval)

Epoch[0] Batch [49]	Speed: 3995.050839 samples/sec	accuracy=0.201719
Epoch[0] Batch [99]	Speed: 3854.505270 samples/sec	accuracy=0.241641
Epoch[0] Batch [149]	Speed: 3657.756799 samples/sec	accuracy=0.273281
Epoch[0] Batch [199]	Speed: 3812.353803 samples/sec	accuracy=0.295234
Epoch[0] Batch [249]	Speed: 3524.694630 samples/sec	accuracy=0.314375
Epoch[0] Batch [299]	Speed: 3733.663291 samples/sec	accuracy=0.330885
Epoch[0] Batch [349]	Speed: 2852.707068 samples/sec	accuracy=0.346763
[Epoch 0] training: accuracy=0.358376
[Epoch 0] time cost: 16.123975
[Epoch 0] validation: accuracy=0.440961
Epoch[1] Batch [49]	Speed: 3640.888889 samples/sec	accuracy=0.485156
Epoch[1] Batch [99]	Speed: 3708.799027 samples/sec	accuracy=0.492422
Epoch[1] Batch [149]	Speed: 3799.484165 samples/sec	accuracy=0.505104
Epoch[1] Batch [199]	Speed: 3762.656724 samples/sec	accuracy=0.511836
Epoch[1] Batch [249]	Speed: 3774.295661 samples/sec	accuracy=0.518594
Epoch[1] Batch [299]	Speed: 3988.017560 samples/sec	acc

Epoch[21] Batch [149]	Speed: 3833.696887 samples/sec	accuracy=0.839792
Epoch[21] Batch [199]	Speed: 3585.856919 samples/sec	accuracy=0.840742
Epoch[21] Batch [249]	Speed: 2880.641470 samples/sec	accuracy=0.839781
Epoch[21] Batch [299]	Speed: 3884.176762 samples/sec	accuracy=0.839922
Epoch[21] Batch [349]	Speed: 3509.373077 samples/sec	accuracy=0.840223
[Epoch 21] training: accuracy=0.839423
[Epoch 21] time cost: 13.661896
[Epoch 21] validation: accuracy=0.792668
Epoch[22] Batch [49]	Speed: 3322.611644 samples/sec	accuracy=0.846250
Epoch[22] Batch [99]	Speed: 3803.279343 samples/sec	accuracy=0.845000
Epoch[22] Batch [149]	Speed: 3834.217096 samples/sec	accuracy=0.846250
Epoch[22] Batch [199]	Speed: 3686.844429 samples/sec	accuracy=0.843789
Epoch[22] Batch [249]	Speed: 3614.489117 samples/sec	accuracy=0.843375
Epoch[22] Batch [299]	Speed: 3867.305216 samples/sec	accuracy=0.844063
Epoch[22] Batch [349]	Speed: 3826.210585 samples/sec	accuracy=0.843616
[Epoch 22] training: accuracy=0.843470

[Epoch 34] training: accuracy=0.859034
[Epoch 34] time cost: 14.262979
[Epoch 34] validation: accuracy=0.842748
Epoch[35] Batch [49]	Speed: 3935.542106 samples/sec	accuracy=0.868594
Epoch[35] Batch [99]	Speed: 3804.600010 samples/sec	accuracy=0.865156
Epoch[35] Batch [149]	Speed: 3641.012350 samples/sec	accuracy=0.865833
Epoch[35] Batch [199]	Speed: 3784.085483 samples/sec	accuracy=0.862773
Epoch[35] Batch [249]	Speed: 3797.334239 samples/sec	accuracy=0.862156
Epoch[35] Batch [299]	Speed: 3832.930519 samples/sec	accuracy=0.862083
Epoch[35] Batch [349]	Speed: 3739.775225 samples/sec	accuracy=0.862634
[Epoch 35] training: accuracy=0.861533
[Epoch 35] time cost: 14.051878
[Epoch 35] validation: accuracy=0.837640
Epoch[36] Batch [49]	Speed: 3774.799873 samples/sec	accuracy=0.857656
Epoch[36] Batch [99]	Speed: 3826.537840 samples/sec	accuracy=0.859375
Epoch[36] Batch [149]	Speed: 3756.758978 samples/sec	accuracy=0.860625
Epoch[36] Batch [199]	Speed: 3480.908184 samples/sec	accuracy=0.861680

Epoch[48] Batch [199]	Speed: 3579.354175 samples/sec	accuracy=0.872227
Epoch[48] Batch [249]	Speed: 2934.233921 samples/sec	accuracy=0.873031
Epoch[48] Batch [299]	Speed: 3870.093005 samples/sec	accuracy=0.873516
Epoch[48] Batch [349]	Speed: 3782.032870 samples/sec	accuracy=0.873638
[Epoch 48] training: accuracy=0.873681
[Epoch 48] time cost: 14.680833
[Epoch 48] validation: accuracy=0.832971
Epoch[49] Batch [49]	Speed: 3800.344817 samples/sec	accuracy=0.875469
Epoch[49] Batch [99]	Speed: 3724.829928 samples/sec	accuracy=0.870625
Epoch[49] Batch [149]	Speed: 3921.427772 samples/sec	accuracy=0.870052
Epoch[49] Batch [199]	Speed: 3927.998010 samples/sec	accuracy=0.868242
Epoch[49] Batch [249]	Speed: 3784.805758 samples/sec	accuracy=0.868656
Epoch[49] Batch [299]	Speed: 3837.341320 samples/sec	accuracy=0.868620
Epoch[49] Batch [349]	Speed: 3914.850930 samples/sec	accuracy=0.871116
[Epoch 49] training: accuracy=0.870584
[Epoch 49] time cost: 13.562549
[Epoch 49] validation: accuracy=0.8383

[Epoch 61] validation: accuracy=0.825621
Epoch[62] Batch [49]	Speed: 3851.325418 samples/sec	accuracy=0.880469
Epoch[62] Batch [99]	Speed: 3926.733898 samples/sec	accuracy=0.877812
Epoch[62] Batch [149]	Speed: 3781.127230 samples/sec	accuracy=0.878646
Epoch[62] Batch [199]	Speed: 3224.721070 samples/sec	accuracy=0.879687
Epoch[62] Batch [249]	Speed: 3846.055677 samples/sec	accuracy=0.879906
Epoch[62] Batch [299]	Speed: 3935.022883 samples/sec	accuracy=0.880208
Epoch[62] Batch [349]	Speed: 3929.924472 samples/sec	accuracy=0.880469
[Epoch 62] training: accuracy=0.880874
[Epoch 62] time cost: 13.465948
[Epoch 62] validation: accuracy=0.814503
Epoch[63] Batch [49]	Speed: 3823.757956 samples/sec	accuracy=0.872656
Epoch[63] Batch [99]	Speed: 1714.590657 samples/sec	accuracy=0.878516
Epoch[63] Batch [149]	Speed: 3696.337969 samples/sec	accuracy=0.878542
Epoch[63] Batch [199]	Speed: 4061.941818 samples/sec	accuracy=0.878477
Epoch[63] Batch [249]	Speed: 3592.743937 samples/sec	accuracy=0.879031

Epoch[75] Batch [249]	Speed: 3058.832077 samples/sec	accuracy=0.879313
Epoch[75] Batch [299]	Speed: 3934.273135 samples/sec	accuracy=0.880495
Epoch[75] Batch [349]	Speed: 2674.578974 samples/sec	accuracy=0.881942
[Epoch 75] training: accuracy=0.881674
[Epoch 75] time cost: 13.902251
[Epoch 75] validation: accuracy=0.822516
Epoch[76] Batch [49]	Speed: 3825.229156 samples/sec	accuracy=0.877812
Epoch[76] Batch [99]	Speed: 4099.816052 samples/sec	accuracy=0.883281
Epoch[76] Batch [149]	Speed: 3941.725613 samples/sec	accuracy=0.882552
Epoch[76] Batch [199]	Speed: 3860.825221 samples/sec	accuracy=0.881484
Epoch[76] Batch [249]	Speed: 3878.032289 samples/sec	accuracy=0.880188
Epoch[76] Batch [299]	Speed: 2346.576592 samples/sec	accuracy=0.880885
Epoch[76] Batch [349]	Speed: 3434.654929 samples/sec	accuracy=0.881830
[Epoch 76] training: accuracy=0.882373
[Epoch 76] time cost: 13.659685
[Epoch 76] validation: accuracy=0.840445
Epoch[77] Batch [49]	Speed: 3807.163102 samples/sec	accuracy=0.88796

Epoch[89] Batch [49]	Speed: 3890.425311 samples/sec	accuracy=0.884062
Epoch[89] Batch [99]	Speed: 4026.270132 samples/sec	accuracy=0.883672
Epoch[89] Batch [149]	Speed: 2888.856728 samples/sec	accuracy=0.884062
Epoch[89] Batch [199]	Speed: 3869.256247 samples/sec	accuracy=0.884414
Epoch[89] Batch [249]	Speed: 3650.271028 samples/sec	accuracy=0.885906
Epoch[89] Batch [299]	Speed: 3933.293126 samples/sec	accuracy=0.886875
Epoch[89] Batch [349]	Speed: 3705.394557 samples/sec	accuracy=0.888103
[Epoch 89] training: accuracy=0.888827
[Epoch 89] time cost: 14.023120
[Epoch 89] validation: accuracy=0.850361
Epoch[90] Batch [49]	Speed: 4139.903086 samples/sec	accuracy=0.878437
Epoch[90] Batch [99]	Speed: 3574.492573 samples/sec	accuracy=0.885938
Epoch[90] Batch [149]	Speed: 3800.748382 samples/sec	accuracy=0.889323
Epoch[90] Batch [199]	Speed: 3845.229279 samples/sec	accuracy=0.888398
Epoch[90] Batch [249]	Speed: 3785.766552 samples/sec	accuracy=0.887125
Epoch[90] Batch [299]	Speed: 3865.439643

Epoch[102] Batch [249]	Speed: 3711.132008 samples/sec	accuracy=0.889906
Epoch[102] Batch [299]	Speed: 3824.956626 samples/sec	accuracy=0.890469
Epoch[102] Batch [349]	Speed: 4085.620121 samples/sec	accuracy=0.891094
[Epoch 102] training: accuracy=0.890205
[Epoch 102] time cost: 14.042636
[Epoch 102] validation: accuracy=0.854667
Epoch[103] Batch [49]	Speed: 3968.238418 samples/sec	accuracy=0.886094
Epoch[103] Batch [99]	Speed: 3803.521845 samples/sec	accuracy=0.890000
Epoch[103] Batch [149]	Speed: 3869.451458 samples/sec	accuracy=0.891979
Epoch[103] Batch [199]	Speed: 3888.819036 samples/sec	accuracy=0.890703
Epoch[103] Batch [249]	Speed: 3670.435376 samples/sec	accuracy=0.890656
Epoch[103] Batch [299]	Speed: 3002.130023 samples/sec	accuracy=0.890391
Epoch[103] Batch [349]	Speed: 3747.031400 samples/sec	accuracy=0.891272
[Epoch 103] training: accuracy=0.891286
[Epoch 103] time cost: 13.973863
[Epoch 103] validation: accuracy=0.841647
Epoch[104] Batch [49]	Speed: 1884.134413 samples/sec

[Epoch 115] training: accuracy=0.892703
[Epoch 115] time cost: 13.683660
[Epoch 115] validation: accuracy=0.820212
Epoch[116] Batch [49]	Speed: 3543.515273 samples/sec	accuracy=0.889531
Epoch[116] Batch [99]	Speed: 3857.579502 samples/sec	accuracy=0.889062
Epoch[116] Batch [149]	Speed: 3847.185662 samples/sec	accuracy=0.889323
Epoch[116] Batch [199]	Speed: 4005.841668 samples/sec	accuracy=0.889961
Epoch[116] Batch [249]	Speed: 2660.740488 samples/sec	accuracy=0.889312
Epoch[116] Batch [299]	Speed: 3964.165605 samples/sec	accuracy=0.890104
Epoch[116] Batch [349]	Speed: 3795.052606 samples/sec	accuracy=0.891473
[Epoch 116] training: accuracy=0.891664
[Epoch 116] time cost: 14.062855
[Epoch 116] validation: accuracy=0.846655
Epoch[117] Batch [49]	Speed: 3990.477872 samples/sec	accuracy=0.887969
Epoch[117] Batch [99]	Speed: 3890.566275 samples/sec	accuracy=0.886094
Epoch[117] Batch [149]	Speed: 3918.594164 samples/sec	accuracy=0.890104
Epoch[117] Batch [199]	Speed: 3743.008317 samples/sec	

Epoch[129] Batch [99]	Speed: 2373.728454 samples/sec	accuracy=0.890391
Epoch[129] Batch [149]	Speed: 3942.130819 samples/sec	accuracy=0.889479
Epoch[129] Batch [199]	Speed: 4059.423317 samples/sec	accuracy=0.890000
Epoch[129] Batch [249]	Speed: 3798.032698 samples/sec	accuracy=0.890875
Epoch[129] Batch [299]	Speed: 3896.552587 samples/sec	accuracy=0.891615
Epoch[129] Batch [349]	Speed: 3961.036108 samples/sec	accuracy=0.893013
[Epoch 129] training: accuracy=0.892703
[Epoch 129] time cost: 13.522790
[Epoch 129] validation: accuracy=0.852063
Epoch[130] Batch [49]	Speed: 3930.327255 samples/sec	accuracy=0.894375
Epoch[130] Batch [99]	Speed: 3778.067247 samples/sec	accuracy=0.893516
Epoch[130] Batch [149]	Speed: 3526.685839 samples/sec	accuracy=0.894583
Epoch[130] Batch [199]	Speed: 3646.948339 samples/sec	accuracy=0.894844
Epoch[130] Batch [249]	Speed: 3624.273702 samples/sec	accuracy=0.894312
Epoch[130] Batch [299]	Speed: 4002.556526 samples/sec	accuracy=0.894401
Epoch[130] Batch [349]	S

Epoch[142] Batch [249]	Speed: 3672.971594 samples/sec	accuracy=0.893750
Epoch[142] Batch [299]	Speed: 3899.269434 samples/sec	accuracy=0.893542
Epoch[142] Batch [349]	Speed: 2706.261749 samples/sec	accuracy=0.894643
[Epoch 142] training: accuracy=0.894501
[Epoch 142] time cost: 14.177687
[Epoch 142] validation: accuracy=0.853165
Epoch[143] Batch [49]	Speed: 3920.482781 samples/sec	accuracy=0.892188
Epoch[143] Batch [99]	Speed: 3745.306163 samples/sec	accuracy=0.892969
Epoch[143] Batch [149]	Speed: 3957.065554 samples/sec	accuracy=0.897135
Epoch[143] Batch [199]	Speed: 3847.158094 samples/sec	accuracy=0.897070
Epoch[143] Batch [249]	Speed: 3252.600052 samples/sec	accuracy=0.896344
Epoch[143] Batch [299]	Speed: 4002.496846 samples/sec	accuracy=0.896536
Epoch[143] Batch [349]	Speed: 3964.311964 samples/sec	accuracy=0.897232
[Epoch 143] training: accuracy=0.897736
[Epoch 143] time cost: 13.924367
[Epoch 143] validation: accuracy=0.853165
Epoch[144] Batch [49]	Speed: 4016.059964 samples/sec

[Epoch 155] training: accuracy=0.954963
[Epoch 155] time cost: 13.801093
[Epoch 155] validation: accuracy=0.907051
Epoch[156] Batch [49]	Speed: 3800.990562 samples/sec	accuracy=0.961406
Epoch[156] Batch [99]	Speed: 2300.336401 samples/sec	accuracy=0.957187
Epoch[156] Batch [149]	Speed: 3811.947770 samples/sec	accuracy=0.956406
Epoch[156] Batch [199]	Speed: 3990.922831 samples/sec	accuracy=0.956094
Epoch[156] Batch [249]	Speed: 3684.086765 samples/sec	accuracy=0.955688
Epoch[156] Batch [299]	Speed: 2220.750654 samples/sec	accuracy=0.955625
Epoch[156] Batch [349]	Speed: 3842.972270 samples/sec	accuracy=0.956272
[Epoch 156] training: accuracy=0.956522
[Epoch 156] time cost: 13.768610
[Epoch 156] validation: accuracy=0.907652
Epoch[157] Batch [49]	Speed: 3581.407638 samples/sec	accuracy=0.959063
Epoch[157] Batch [99]	Speed: 3952.986526 samples/sec	accuracy=0.956641
Epoch[157] Batch [149]	Speed: 4229.461398 samples/sec	accuracy=0.956562
Epoch[157] Batch [199]	Speed: 3959.838264 samples/sec	

Epoch[169] Batch [99]	Speed: 3813.599511 samples/sec	accuracy=0.966328
Epoch[169] Batch [149]	Speed: 3925.298394 samples/sec	accuracy=0.965990
Epoch[169] Batch [199]	Speed: 3793.175673 samples/sec	accuracy=0.965508
Epoch[169] Batch [249]	Speed: 3715.498197 samples/sec	accuracy=0.966031
Epoch[169] Batch [299]	Speed: 4053.660966 samples/sec	accuracy=0.966276
Epoch[169] Batch [349]	Speed: 4083.382736 samples/sec	accuracy=0.967121
[Epoch 169] training: accuracy=0.967172
[Epoch 169] time cost: 14.012477
[Epoch 169] validation: accuracy=0.909555
Epoch[170] Batch [49]	Speed: 2724.956411 samples/sec	accuracy=0.969063
Epoch[170] Batch [99]	Speed: 3710.593368 samples/sec	accuracy=0.966875
Epoch[170] Batch [149]	Speed: 3800.183415 samples/sec	accuracy=0.966823
Epoch[170] Batch [199]	Speed: 2592.627402 samples/sec	accuracy=0.966055
Epoch[170] Batch [249]	Speed: 3768.784657 samples/sec	accuracy=0.966344
Epoch[170] Batch [299]	Speed: 3929.090398 samples/sec	accuracy=0.966458
Epoch[170] Batch [349]	S

Epoch[182] Batch [249]	Speed: 3781.846379 samples/sec	accuracy=0.971531
Epoch[182] Batch [299]	Speed: 4023.373492 samples/sec	accuracy=0.971042
Epoch[182] Batch [349]	Speed: 4052.375867 samples/sec	accuracy=0.971339
[Epoch 182] training: accuracy=0.971407
[Epoch 182] time cost: 13.521734
[Epoch 182] validation: accuracy=0.911959
Epoch[183] Batch [49]	Speed: 3094.909823 samples/sec	accuracy=0.968906
Epoch[183] Batch [99]	Speed: 3712.851575 samples/sec	accuracy=0.969297
Epoch[183] Batch [149]	Speed: 2227.828038 samples/sec	accuracy=0.969792
Epoch[183] Batch [199]	Speed: 4015.128874 samples/sec	accuracy=0.970273
Epoch[183] Batch [249]	Speed: 3847.930162 samples/sec	accuracy=0.970625
Epoch[183] Batch [299]	Speed: 3795.159916 samples/sec	accuracy=0.970260
Epoch[183] Batch [349]	Speed: 4254.127670 samples/sec	accuracy=0.970804
[Epoch 183] training: accuracy=0.970913
[Epoch 183] time cost: 14.140009
[Epoch 183] validation: accuracy=0.909455
Epoch[184] Batch [49]	Speed: 4021.745962 samples/sec

[Epoch 195] training: accuracy=0.975803
[Epoch 195] time cost: 13.264545
[Epoch 195] validation: accuracy=0.904347
Epoch[196] Batch [49]	Speed: 3988.017560 samples/sec	accuracy=0.975938
Epoch[196] Batch [99]	Speed: 3109.068920 samples/sec	accuracy=0.975078
Epoch[196] Batch [149]	Speed: 3784.992541 samples/sec	accuracy=0.974323
Epoch[196] Batch [199]	Speed: 3894.630443 samples/sec	accuracy=0.974258
Epoch[196] Batch [249]	Speed: 3596.329870 samples/sec	accuracy=0.974156
Epoch[196] Batch [299]	Speed: 3972.025717 samples/sec	accuracy=0.973594
Epoch[196] Batch [349]	Speed: 3029.198519 samples/sec	accuracy=0.974464
[Epoch 196] training: accuracy=0.974864
[Epoch 196] time cost: 13.872098
[Epoch 196] validation: accuracy=0.910457
Epoch[197] Batch [49]	Speed: 3719.333490 samples/sec	accuracy=0.971250
Epoch[197] Batch [99]	Speed: 3968.385078 samples/sec	accuracy=0.972969
Epoch[197] Batch [149]	Speed: 3812.732846 samples/sec	accuracy=0.973802
Epoch[197] Batch [199]	Speed: 3874.841483 samples/sec	

Epoch[209] Batch [99]	Speed: 3399.466288 samples/sec	accuracy=0.984375
Epoch[209] Batch [149]	Speed: 2732.556863 samples/sec	accuracy=0.984896
Epoch[209] Batch [199]	Speed: 3999.395939 samples/sec	accuracy=0.984492
Epoch[209] Batch [249]	Speed: 3767.356546 samples/sec	accuracy=0.985156
Epoch[209] Batch [299]	Speed: 4041.181122 samples/sec	accuracy=0.984896
Epoch[209] Batch [349]	Speed: 4046.298006 samples/sec	accuracy=0.985134
[Epoch 209] training: accuracy=0.985454
[Epoch 209] time cost: 13.563575
[Epoch 209] validation: accuracy=0.913762
Epoch[210] Batch [49]	Speed: 3987.425168 samples/sec	accuracy=0.988125
Epoch[210] Batch [99]	Speed: 4028.143097 samples/sec	accuracy=0.986875
Epoch[210] Batch [149]	Speed: 4086.895283 samples/sec	accuracy=0.986250
Epoch[210] Batch [199]	Speed: 3962.117711 samples/sec	accuracy=0.985586
Epoch[210] Batch [249]	Speed: 2522.000761 samples/sec	accuracy=0.985375
Epoch[210] Batch [299]	Speed: 3856.415702 samples/sec	accuracy=0.985313
Epoch[210] Batch [349]	S

Epoch[222] Batch [249]	Speed: 3802.148072 samples/sec	accuracy=0.987469
Epoch[222] Batch [299]	Speed: 3451.059106 samples/sec	accuracy=0.987344
Epoch[222] Batch [349]	Speed: 3828.174955 samples/sec	accuracy=0.987366
[Epoch 222] training: accuracy=0.987432
[Epoch 222] time cost: 13.469655
[Epoch 222] validation: accuracy=0.912660
Epoch[223] Batch [49]	Speed: 3755.418770 samples/sec	accuracy=0.987344
Epoch[223] Batch [99]	Speed: 3261.472037 samples/sec	accuracy=0.988125
Epoch[223] Batch [149]	Speed: 2419.296442 samples/sec	accuracy=0.987396
Epoch[223] Batch [199]	Speed: 3783.312160 samples/sec	accuracy=0.986875
Epoch[223] Batch [249]	Speed: 3763.975714 samples/sec	accuracy=0.987062
Epoch[223] Batch [299]	Speed: 3804.465206 samples/sec	accuracy=0.987240
Epoch[223] Batch [349]	Speed: 3734.338522 samples/sec	accuracy=0.987433
[Epoch 223] training: accuracy=0.987440
[Epoch 223] time cost: 13.595052
[Epoch 223] validation: accuracy=0.912260
Epoch[224] Batch [49]	Speed: 3964.956072 samples/sec

[Epoch 235] training: accuracy=0.987732
[Epoch 235] time cost: 13.529035
[Epoch 235] validation: accuracy=0.911659
Epoch[236] Batch [49]	Speed: 3724.623195 samples/sec	accuracy=0.988750
Epoch[236] Batch [99]	Speed: 3720.467575 samples/sec	accuracy=0.988359
Epoch[236] Batch [149]	Speed: 2823.541014 samples/sec	accuracy=0.987969
Epoch[236] Batch [199]	Speed: 3250.965303 samples/sec	accuracy=0.988008
Epoch[236] Batch [249]	Speed: 2816.061769 samples/sec	accuracy=0.987844
Epoch[236] Batch [299]	Speed: 3859.437494 samples/sec	accuracy=0.987370
Epoch[236] Batch [349]	Speed: 3940.394809 samples/sec	accuracy=0.987857
[Epoch 236] training: accuracy=0.987972
[Epoch 236] time cost: 13.742738
[Epoch 236] validation: accuracy=0.913762
Epoch[237] Batch [49]	Speed: 3817.857304 samples/sec	accuracy=0.989219
Epoch[237] Batch [99]	Speed: 3997.311493 samples/sec	accuracy=0.988516
Epoch[237] Batch [149]	Speed: 3271.927257 samples/sec	accuracy=0.989219
Epoch[237] Batch [199]	Speed: 4053.201909 samples/sec	

Epoch[249] Batch [99]	Speed: 2727.517525 samples/sec	accuracy=0.990703
Epoch[249] Batch [149]	Speed: 3851.656984 samples/sec	accuracy=0.989896
Epoch[249] Batch [199]	Speed: 4098.814433 samples/sec	accuracy=0.989570
Epoch[249] Batch [249]	Speed: 3799.107752 samples/sec	accuracy=0.989500
Epoch[249] Batch [299]	Speed: 3805.894614 samples/sec	accuracy=0.989167
Epoch[249] Batch [349]	Speed: 3747.292938 samples/sec	accuracy=0.989420
[Epoch 249] training: accuracy=0.989170
[Epoch 249] time cost: 13.792157
[Epoch 249] validation: accuracy=0.913662


The training should have led to a similar validation accuracy as reported in [3], Sec. 4.2. We next activate quantization for $W_B$, $W_C$ (by setting `sp_mode` to 2) and continue training for 40 epochs.

In [11]:
initialize_by_key(sp_resnet, 'sp_mode', mx.init.Constant(2), ctx, True)

epochs = 40
learning_rate = 0.01
scheduler = mx.lr_scheduler.MultiFactorScheduler([10*schedule_factor, 20*schedule_factor, 30*schedule_factor, 40*schedule_factor], factor=0.1)

train(sp_resnet, learning_rate, momentum, weight_decay, epochs, scheduler, ctx, train_data, val_data, log_interval)

Epoch[0] Batch [49]	Speed: 1015.650668 samples/sec	accuracy=0.752188
Epoch[0] Batch [99]	Speed: 929.617626 samples/sec	accuracy=0.793750
Epoch[0] Batch [149]	Speed: 1117.950040 samples/sec	accuracy=0.815729
Epoch[0] Batch [199]	Speed: 1347.828049 samples/sec	accuracy=0.829961
Epoch[0] Batch [249]	Speed: 1122.817951 samples/sec	accuracy=0.841406
Epoch[0] Batch [299]	Speed: 1005.670031 samples/sec	accuracy=0.848385
Epoch[0] Batch [349]	Speed: 1155.652136 samples/sec	accuracy=0.854799
[Epoch 0] training: accuracy=0.859135
[Epoch 0] time cost: 48.579850
[Epoch 0] validation: accuracy=0.852264
Epoch[1] Batch [49]	Speed: 1139.712970 samples/sec	accuracy=0.900312
Epoch[1] Batch [99]	Speed: 851.494855 samples/sec	accuracy=0.899609
Epoch[1] Batch [149]	Speed: 1065.097493 samples/sec	accuracy=0.897292
Epoch[1] Batch [199]	Speed: 1061.093588 samples/sec	accuracy=0.895664
Epoch[1] Batch [249]	Speed: 1176.695632 samples/sec	accuracy=0.896156
Epoch[1] Batch [299]	Speed: 1138.280025 samples/sec	accur

Epoch[15] Batch [199]	Speed: 1047.945595 samples/sec	accuracy=0.943320
Epoch[15] Batch [249]	Speed: 779.633052 samples/sec	accuracy=0.943312
Epoch[15] Batch [299]	Speed: 1155.361158 samples/sec	accuracy=0.944010
Epoch[15] Batch [349]	Speed: 1056.688977 samples/sec	accuracy=0.944933
[Epoch 15] training: accuracy=0.945352
[Epoch 15] time cost: 47.566040
[Epoch 15] validation: accuracy=0.896134
Epoch[16] Batch [49]	Speed: 1096.899568 samples/sec	accuracy=0.951094
Epoch[16] Batch [99]	Speed: 1193.534494 samples/sec	accuracy=0.948828
Epoch[16] Batch [149]	Speed: 1115.413679 samples/sec	accuracy=0.948073
Epoch[16] Batch [199]	Speed: 1149.454064 samples/sec	accuracy=0.947031
Epoch[16] Batch [249]	Speed: 1061.750415 samples/sec	accuracy=0.946500
Epoch[16] Batch [299]	Speed: 1185.255392 samples/sec	accuracy=0.946328
Epoch[16] Batch [349]	Speed: 948.221898 samples/sec	accuracy=0.947701
[Epoch 16] training: accuracy=0.947997
[Epoch 16] time cost: 47.947451
[Epoch 16] validation: accuracy=0.893329

[Epoch 28] validation: accuracy=0.900441
Epoch[29] Batch [49]	Speed: 750.140300 samples/sec	accuracy=0.955156
Epoch[29] Batch [99]	Speed: 1207.390267 samples/sec	accuracy=0.950078
Epoch[29] Batch [149]	Speed: 1092.602325 samples/sec	accuracy=0.951979
Epoch[29] Batch [199]	Speed: 1128.612717 samples/sec	accuracy=0.951094
Epoch[29] Batch [249]	Speed: 987.579466 samples/sec	accuracy=0.952000
Epoch[29] Batch [299]	Speed: 1187.194919 samples/sec	accuracy=0.953281
Epoch[29] Batch [349]	Speed: 1133.554986 samples/sec	accuracy=0.953438
[Epoch 29] training: accuracy=0.953726
[Epoch 29] time cost: 48.473119
[Epoch 29] validation: accuracy=0.899840
Epoch[30] Batch [49]	Speed: 1127.799510 samples/sec	accuracy=0.953281
Epoch[30] Batch [99]	Speed: 1044.894380 samples/sec	accuracy=0.952109
Epoch[30] Batch [149]	Speed: 1052.502327 samples/sec	accuracy=0.952604
Epoch[30] Batch [199]	Speed: 1088.157738 samples/sec	accuracy=0.953477
Epoch[30] Batch [249]	Speed: 1216.649471 samples/sec	accuracy=0.954937
E

Finally, we fix $W_B$, $W_C$ and only train $\tilde a$ for another 10 epochs.

In [12]:
for k, p in sp_resnet.collect_params().items():
    if 'sp_data_weights' in k or 'sp_out_weights' in k:
        p.grad_req = 'null'

epochs = 10
learning_rate = 0.001
scheduler = None

train(sp_resnet, learning_rate, momentum, weight_decay, epochs, scheduler, ctx, train_data, val_data, log_interval)

Epoch[0] Batch [49]	Speed: 801.545118 samples/sec	accuracy=0.953125
Epoch[0] Batch [99]	Speed: 849.393277 samples/sec	accuracy=0.954844
Epoch[0] Batch [149]	Speed: 1027.907432 samples/sec	accuracy=0.955573
Epoch[0] Batch [199]	Speed: 1284.460067 samples/sec	accuracy=0.955156
Epoch[0] Batch [249]	Speed: 1283.637014 samples/sec	accuracy=0.955969
Epoch[0] Batch [299]	Speed: 1266.013885 samples/sec	accuracy=0.956536
Epoch[0] Batch [349]	Speed: 1211.157310 samples/sec	accuracy=0.957790
[Epoch 0] training: accuracy=0.958894
[Epoch 0] time cost: 52.003801
[Epoch 0] validation: accuracy=0.907953
Epoch[1] Batch [49]	Speed: 1138.062861 samples/sec	accuracy=0.957344
Epoch[1] Batch [99]	Speed: 1211.865394 samples/sec	accuracy=0.956172
Epoch[1] Batch [149]	Speed: 1195.919335 samples/sec	accuracy=0.956146
Epoch[1] Batch [199]	Speed: 1226.970852 samples/sec	accuracy=0.956211
Epoch[1] Batch [249]	Speed: 1302.910806 samples/sec	accuracy=0.956781
Epoch[1] Batch [299]	Speed: 1214.184067 samples/sec	accur

The accuracy of the model with quantized $W_A$, $W_B$ should have a validation accuracy slighlty below that of the unquantized model.

We finally save the paramters of the quantized model.

In [13]:
sp_resnet.collect_params().save('sp_resnet.params')

Done!

### References

[1] https://github.com/apache/incubator-mxnet/blob/master/docs/tutorials/gluon/customop.md

[2] F. Li, B. Zhang, B. Liu, Ternary Weight Networks, 2016, https://arxiv.org/abs/1605.04711v2

[3] K. He, X. Zhang, S. Ren, J. Sun, Deep Residual Learning for Image Recognition, 2016, http://arxiv.org/abs/1512.03385

[4] https://github.com/apache/incubator-mxnet/blob/master/python/mxnet/gluon/model_zoo/vision/resnet.py

[5] https://github.com/apache/incubator-mxnet/blob/master/example/gluon/image_classification.py