# Fast Linear Algebra in Stacked Strassen Networks

This notebook provides a step-by-step implementation to reproduce the experiment in *Fast Linear Algebra in Stacked Strassen Networks* by Michael Tschannen, Aran Khanna, and Anima Anandkumar, submitted to the Machine Learning on the Phone and other Consumer Devices (MLPCD) NIPS 2018 workshop.

#### Abstract

Matrix multiplications can be cast as $2$-layer sum-product (SP) networks with ternary weight matrices, disentangling multiplications and additions. We leverage this observation for end-to-end learning of low cost (in terms of multiplications) approximations for linear operations in neural network layers. Specifically, we propose to replace matrix multiplication operations by SP networks, with widths corresponding to the budget of multiplications we want to allocate to each layer, and to learn them end-to-end. We showcase this approach by compressing the convolution layers of ResNet and evaluating it on CIFAR-10. We obtain a $98.96\%$ reduction in the number of multiplications with an accuracy loss of only $0.82\%$. 


#### Learning fast matrix multiplications via sum-product networks
$\renewcommand{\vec}{\mathrm{vec}}$Given square matrices $A, B \in \mathbb{R}^{n \times n}$, the product $C = A B$ can be represented as a $2$-layer sum-product (SP) network (or ``Strassen network'')

$$
    \label{eq:strnet}
    \vec(C) = W_C [ (W_B \vec(B)) \odot (W_A \vec(A))] \qquad \qquad (1)
$$

where $W_A, W_B \in \mathbb{K}^{r \times n^2}$ and $W_C \in \mathbb{K}^{n^2 \times r}$, $\mathbb{K}:= \{-1,0,1\}$, and $\odot$ denotes the element-wise product.

<img src="files/spnetwork.png" alt="spnetwork" style="width: 400px;"/>

1. $A, B\in \mathbb{R}^{2 \times 2}$: Strassen's algorithm tells us ternary weights that satisfy (1) for $r=7$ (instead of $r = 8$).
2. General case $A \in \mathbb{R}^{k \times m}$, $B \in \mathbb{R}^{m \times n}$: $C = A B$ can be written in the form (1) if $r \geq nmk$.
3. If $A \in \mathbb{R}^{k \times m}$ is fixed and $B \in \mathbb{R}^{m \times n}$ concentrates on low-dimensional subspace of $\mathbb{R}^{k \times m}$: Can find $W_A, W_B$ s.t. (1) holds approximately even when $r \ll nmk$

*Idea:* Leverage item 3 to learn low-cost (in terms of multiplications) approximate linear operations in neural network layers by associating $A$ with (pretrained) weights/filters and $B$ with corresponding activations/feature maps, and by learning $W_A$, $W_B$, $W_C$ end-to-end.

In this notebook, we demonstrate this idea by compressing the 2D convolution layers of ResNet.

#### Outline of the notebook:
1. Implement the ternary quantization operator to learn ternary weights $W_B$ and $W_C$
2. Implement the SP 2D convolution layer (i.e. a Strassen network for 2D convolutions)
3. Implement residual blocks with SP layers and use these to construct a SP ResNet
4. Define iterators and the training loop
5. Train the network

The code is base on Apache MXNet 0.12 and the gluon package.

## 1. Implement the ternary quantization operator

Following the tutorial from [1] to define custom gluon operators, we define an operator implementing the ternarization method proposed in [2]. Using the notation from [2], this method quantizes the entries of a given input tensor to values $\{-\alpha, 0, \alpha\}$ ($\alpha$ corresponds to `scale` below) based on a threshold $\Delta$ (corresponding to `thresh` below) on the full precision entries. $\alpha$ and $\Delta$ are determinined by approximately solving an optimization problem, see [2]. The operator will be used to quanzie $W_B$ and $W_C$ (the factor $\alpha$ can be absorbed into $\tilde a = W_A \vec(A)$ after training to ensure that the entries of $W_B$ and $W_C$ are in $\{-1,0,1\}$). 

[1] https://github.com/apache/incubator-mxnet/blob/master/docs/tutorials/gluon/customop.md

[2] F. Li, B. Zhang, B. Liu, Ternary Weight Networks, 2016, https://arxiv.org/abs/1605.04711v2


In [None]:
!pip install mxnet

In [1]:
import mxnet as mx
import numpy as np

# define the quantization operator
class TerQuant(mx.operator.CustomOp):
    # define forward pass as described in [1]
    def forward(self, is_train, req, in_data, out_data, aux):
        full_data = in_data[0]

        # compute $\Delta$ as in [1], Eq. (6)
        thresh = mx.nd.mean(mx.nd.abs(full_data), axis=()).asnumpy()[0] * 0.7

        # ternarize weights
        quant_data = mx.nd.greater(full_data, thresh*mx.nd.ones(full_data.shape, full_data.context))\
            - mx.nd.lesser(full_data, (-1.0)*thresh*mx.nd.ones(full_data.shape, full_data.context))

        # compute $\alpha$ as in [1], Eq. (5)
        scale = mx.nd.sum(mx.nd.multiply(full_data, quant_data), axis=()).asnumpy()[0]\
            / mx.nd.sum(mx.nd.abs(quant_data), axis=()).asnumpy()[0]

        # rescale ternary weights
        quant_data = scale * quant_data
        self.assign(out_data[0], req[0], quant_data)

    # define backward pass: simply pass on gradient from previous operation (derivative of the identity function)
    def backward(self, req, out_grad, in_data, out_data, in_grad, aux):
        self.assign(in_grad[0], req[0], out_grad[0])

# wrap the operator and provide functions to get information about input, output etc.
@mx.operator.register('ter_quant')
class TerQuantProp(mx.operator.CustomOpProp):
    def __init__(self):
        super(TerQuantProp, self).__init__(True)

    def list_arguments(self):
        return ['data']

    def list_outputs(self):
        return ['output']

    def infer_shape(self, in_shapes):
        data_shape = in_shapes[0]
        output_shape = data_shape
        return (data_shape,), (output_shape,), ()

    def create_operator(self, ctx, in_shapes, in_dtypes):
        return TerQuant()

## 2. Implement the SP 2DConvLayer
$\newcommand{\cin}{c_\mathrm{in}}\newcommand{\cout}{c_\mathrm{out}}$We are now ready to implement the SP 2D convolution layer (as a gluon `HybridBlock` object). Specifically, we compress the computation of $\cout \times p \times p$ output elements of the convolution from $\cin \times (p + 2\lfloor k/2\rfloor) \times (p + 2\lfloor k/2\rfloor)$ input elements by applying (see figure below)

1. a $p$-strided 2D convolution with a *ternary* filter of size $r \times \cin \times (p + 2\lfloor k/2\rfloor) \times (p + 2\lfloor k/2\rfloor)$ (corresponding to multiplication with $W_B$),
2. a channel-wise multiplication with an $r \times 1 \times 1$ *full-precision* filter $\tilde a = W_A \vec(A)$ (corresponding to the filters of the original convolution),
3. a $1/p$-strided transposed 2D convolution with a *ternary* filter of size $\cout \times r \times p \times p$ (corresponding to multiplication with $W_C$).

<img src="files/sp2dconv.png" alt="spnetwork" style="width: 600px;"/>

As we will pretrain the convolutions described in items 1 and 3 above with full-precision weights, we include a variable `mode` that determines whether quantization is active or not. Furthermore, we include the original convolution to selectively replace the SP convolutions with ordinary ones (not used in this notebook).

In [2]:
from mxnet import gluon, autograd
from mxnet.gluon.block import HybridBlock
from mxnet.base import numeric_types
from math import ceil

class SumProd2DConv(mx.gluon.HybridBlock):
    r"""
    nbr_mul : int
        Width of the sum-product network, corresponding to the number of multiplications r used to compute an
        $p_x$ x $p_y$ x channels activation element
    target_layer_shape:
        Shape of the filter to be replaced by the sum-product network
    target_layer_key:
        Name of the target layer in the original ResNet (to enable loading from a pretrained network)
    out_patch : int
        size ($p_x$, $p_y$) of the output patch for spatial compression
    kernel, stride, pad : specifics of the 2D convolution
    """
    
    def __init__(self, nbr_mul, target_layer_shape, target_layer_key, out_patch=(1,1), kernel=(3,3), stride=(1,1), pad=(1,1), prefix=None, **kwargs):
        # Use same prefix as parent layer for loading the shared paramters to be compressed
        super(SumProd2DConv, self).__init__(prefix=prefix, **kwargs)

        if isinstance(stride, numeric_types):
            stride = (stride,)*len(kernel)
        if isinstance(pad, numeric_types):
            pad = (pad,)*len(kernel)

        self._nbr_mul = nbr_mul
        # corresponds to $p$ in the description above
        self._out_patch = out_patch
        self._kernel = kernel
        self._stride = stride
        self._pad = pad

        # compute kernel shapes and stride for $W_B$
        self._sp_data_weight_kernel = ((self._out_patch[0]-1)*self._stride[0]+self._kernel[0], (self._out_patch[1]-1)*self._stride[1]+self._kernel[1])
        self._sp_data_weight_stride = (self._out_patch[0]*self._stride[0], self._out_patch[1]*self._stride[1])

        # filters for the original convolution
        self.filter_weights = self.params.get(target_layer_key, shape=target_layer_shape)

        # Assuming the default NCHW layout, determine shapes of $W_B$ and $W_C$
        filter_weights_shape = list(target_layer_shape)
        sp_data_weights_shape = filter_weights_shape.copy()
        sp_data_weights_shape[0] = nbr_mul
        sp_data_weights_shape[2] = self._sp_data_weight_kernel[0]
        sp_data_weights_shape[3] = self._sp_data_weight_kernel[1]
        sp_out_weights_shape = filter_weights_shape.copy()
        sp_out_weights_shape[0] = nbr_mul
        sp_out_weights_shape[1] = filter_weights_shape[0]
        sp_out_weights_shape[2] = self._out_patch[0]
        sp_out_weights_shape[3] = self._out_patch[1]

        ## SP network weights
        # $\tilde a = W_A \vec(A)$
        self.sp_in_weights = self.params.get('sp_in_weights', shape=(1,self._nbr_mul,1,1))
        # $W_B$
        self.sp_data_weights = self.params.get('sp_data_weights', shape=sp_data_weights_shape)
        # $W_C$
        self.sp_out_weights =  self.params.get('sp_out_weights', shape=sp_out_weights_shape)

        # SP network batchnorm parameters
        self.sp_batchnorm_gamma = self.params.get('sp_batchnorm_gamma', shape=(1,self._nbr_mul,1,1))
        self.sp_batchnorm_beta = self.params.get('sp_batchnorm_beta', shape=(1,self._nbr_mul,1,1))
        self.sp_batchnorm_running_mean = self.params.get('sp_batchnorm_running_mean', shape=(1,self._nbr_mul,1,1))
        self.sp_batchnorm_running_var = self.params.get('sp_batchnorm_running_var', shape=(1,self._nbr_mul,1,1))

        # mode: 1: original convolution; 2: SP network without ternary quantization; 3: SP network with ternary quantization
        self.sp_mode = self.params.get('sp_mode', shape=(1,), dtype=np.uint8)


    def forward(self, data):
        with data.context:
            mode = self.sp_mode.data().asnumpy()[0]
            filter_weights = self.filter_weights.data()

            if mode == 0:
                # perform original convolution
                conv_out = mx.nd.Convolution(data=data,
                                  weight=filter_weights,
                                  bias=None,
                                  no_bias=True,
                                  kernel=self._kernel,
                                  stride=self._stride,
                                  pad=self._pad,
                                  num_filter=filter_weights.shape[0])

                return conv_out
            else:
                sp_data_weights = self.sp_data_weights.data()
                sp_out_weights = self.sp_out_weights.data()

                # quantzie $W_B$ and $W_C$ if desired
                if mode == 2:
                    sp_data_weights = mx.nd.Custom(sp_data_weights, op_type='ter_quant')
                    sp_out_weights = mx.nd.Custom(sp_out_weights, op_type='ter_quant')

                pad_x = self._pad[0] + ceil((data.shape[2]%self._sp_data_weight_stride[0])/2)
                pad_y = self._pad[1] + ceil((data.shape[3]%self._sp_data_weight_stride[1])/2)

                # compute $W_B \vec(B)$
                sp_data_mul = mx.nd.Convolution(data=data,
                                  weight=sp_data_weights,
                                  bias=None,
                                  no_bias=True,
                                  kernel=self._sp_data_weight_kernel,
                                  stride=self._sp_data_weight_stride,
                                  pad=(pad_x, pad_y),
                                  num_filter = self._nbr_mul)

                # apply batchnorm to $W_B \vec(B)$
                sp_data_mul_norm = mx.nd.BatchNorm(data=sp_data_mul,
                                              gamma=self.sp_batchnorm_gamma.data(),
                                              beta=self.sp_batchnorm_beta.data(),
                                              moving_mean=self.sp_batchnorm_running_mean.data(),
                                              moving_var=self.sp_batchnorm_running_var.data(),
                                              fix_gamma=True)

                # apply non-negativity constraint to $\tilde a = W_A \vec(A)$
                with mx.autograd.pause():
                    self.sp_in_weights.set_data(mx.nd.clip(data=self.sp_in_weights.data(), a_min=0.0, a_max=100.0))
                
                # compute $(W_A \vec(A)) \odot (W_B \vec(B))$
                sp_mul = mx.nd.multiply(self.sp_in_weights.data(), sp_data_mul_norm)
                
                # compute $W_C[(W_A \vec(A)) \odot (W_B \vec(B))]$
                sp_out = mx.nd.Deconvolution(data = sp_mul,
                                  weight = sp_out_weights,
                                  bias = None,
                                  no_bias = True,
                                  kernel = self._out_patch,
                                  stride = self._out_patch,
                                  target_shape = (data.shape[2]//self._stride[0], data.shape[3]//self._stride[1]),
                                  num_filter = filter_weights.shape[0])

                return sp_out

## 3. Implement the sum-product residual unit

We continue by defining the residual unit as proposed in [3] using sum-product 2D convolutions. The code block below is adapted from [4] and looks lengthy, but there are only small modifications compared to [4]. Specifically, we replace ordinary 2D convolutions by sum-product 2D convolutions and add parameters for the sum-product convolutions.

[3] K. He, X. Zhang, S. Ren, J. Sun, Deep Residual Learning for Image Recognition, 2016, http://arxiv.org/abs/1512.03385

[4] https://github.com/apache/incubator-mxnet/blob/master/python/mxnet/gluon/model_zoo/vision/resnet.py

In [3]:
from mxnet.gluon import nn

class BasicBlockV1SumProd(HybridBlock):
    r"""
    Parameters
    ----------
    channels : int
        Number of output channels.
    stride : int
        Stride size.
    nbr_mul : int
        Width of the sum-product network, corresponding to the number of multiplications r used to compute an
        $p_x$ x $p_y$ x channels activation element
    out_patch : int
        size ($p_x$, $p_y$) of the output patch for spatial compression
    conv_idx : int
        index of the convolution, used for parameter names
    downsample : bool, default False
        Whether to downsample the input.
    in_channels : int, default 0
        Number of input channels. Default is 0, to infer from the graph.
    quant_down : bool
        Whether to quantize ResNet 1x1 convolution projection layers
    """
    def __init__(self, channels, stride, nbr_mul, out_patch=(1,1), conv_idx=0, downsample=False, in_channels=0, quant_down=True, **kwargs):
        super(BasicBlockV1SumProd, self).__init__(**kwargs)
        self.body = nn.HybridSequential(prefix='')

        prefix1 = 'conv%i_'%conv_idx
        # shape of the original convolution
        orig_shape1 = (channels, in_channels, 3, 3)
        # kernel of the original convolution
        kernel = (3,3)

        self.body.add(SumProd2DConv(nbr_mul=nbr_mul,
                                      target_layer_shape=orig_shape1,
                                      target_layer_key='weight',
                                      out_patch=out_patch,
                                      kernel=kernel,
                                      stride=stride,
                                      prefix=prefix1))

        self.body.add(nn.BatchNorm())
        self.body.add(nn.Activation('relu'))
        
        prefix2 = 'conv%i_'%(conv_idx+1)
        orig_shape2 = (channels, channels, 3, 3)

        self.body.add(SumProd2DConv(nbr_mul=nbr_mul,
                                      target_layer_shape=orig_shape2,
                                      target_layer_key='weight',
                                      out_patch=out_patch,
                                      kernel=kernel,
                                      stride=1,
                                      prefix=prefix2))

        self.body.add(nn.BatchNorm())
        
        if downsample:
            self.downsample = nn.HybridSequential(prefix='')

            # Quantize projection layer if required
            if quant_down:
                prefix_ds = 'conv%i_'%(conv_idx+2)
                orig_shape_ds = (channels, in_channels, 1, 1)

                self.downsample.add(SumProd2DConv(nbr_mul=nbr_mul,
                                              target_layer_shape=orig_shape_ds,
                                              target_layer_key='weight',
                                              out_patch=(1,1),
                                              kernel=(1,1),
                                              stride=stride,
                                              pad=(0,0),
                                              prefix=prefix_ds))
            else:
                self.downsample.add(nn.Conv2D(channels, kernel_size=1, strides=stride,
                                              use_bias=False, in_channels=in_channels))

            self.downsample.add(nn.BatchNorm())
        else:
            self.downsample = None

    def hybrid_forward(self, F, x):
        residual = x

        x = self.body(x)

        if self.downsample:
            residual = self.downsample(residual)

        x = F.Activation(residual+x, act_type='relu')

        return x

## 4. Implement the sum-product ResNet

Again, we modify the code to construct ResNet (as a gluon `HybridBlock` object) for various dephts from [4], replacing ordinary 2D convolutions by sum-product 2D convolutions, and the fully connected output layer by a sum-product network if desired. As before, we add in some parameters for the sum-product networks.

In [4]:
from mxnet.gluon.model_zoo.vision.resnet import _conv3x3

class ResNetV1SumProd(HybridBlock):
    r"""ResNet V1 model from
    `"Deep Residual Learning for Image Recognition"
    <http://arxiv.org/abs/1512.03385>`_ paper.

    Parameters
    ----------
    block : HybridBlock
        Class for the residual block. Options are BasicBlockV1, BottleneckV1.
    layers : list of int
        Numbers of layers in each block
    channels : list of int
        Numbers of channels in each block. Length should be one larger than layers list.
    nbr_mul : int
        Width of the sum-product network, corresponding to the number of multiplications r used to compute an
        $p_x$ x $p_y$ x channels activation element
    out_patch : int
        size ($p_x$, $p_y$) of the output patch for spatial compression
    classes : int, default 1000
        Number of classification classes.
    thumbnail : bool, default False
        Enable thumbnail.
    quant_down : bool
        Whether to quantize ResNet 1x1 convolution projection layers
    """
    def __init__(self, layers, channels, nbr_mul, out_patch, classes=1000, thumbnail=False, quant_down=True, prefix='resnetv10_', **kwargs):
        super(ResNetV1SumProd, self).__init__(prefix, **kwargs)
        
        with self.name_scope():
            self.features = nn.HybridSequential(prefix='')
            # Thumbnail determines the characteristics of the first layer. If thumbnail==True, 3x3 convolutions are used
            # and no pooling is performed (this setting is used for CIFAR). Otherwise, 7x7 filters with max-pooling are used.
            if thumbnail:
                # replace first layer with a sum-product layer if desired
                if nbr_mul[0] > 0:
                    orig_shape = (channels[0], 3, 3, 3)

                    self.features.add(SumProd2DConv(nbr_mul = nbr_mul[0],
                                                  target_layer_shape = orig_shape,
                                                  target_layer_key = 'weight',
                                                  kernel = (3, 3),
                                                  stride = (1, 1),
                                                  out_patch = out_patch[0],
                                                  prefix = 'conv0_'))
                else:
                    self.features.add(_conv3x3(channels[0], 1, 3))

            else:
                # replace first layer with a sum-product layer if desired
                if nbr_mul[0] > 0:
                    orig_shape = (channels[0], 3, 7, 7)

                    self.features.add(SumProd2DConv(nbr_mul = nbr_mul[0],
                                                  target_layer_shape = orig_shape,
                                                  target_layer_key = 'weight',
                                                  kernel = (7, 7),
                                                  stride = (2, 2),
                                                  pad = (3,3),
                                                  out_patch = out_patch[0],
                                                  prefix = 'conv0_'))
                else:
                    self.features.add(nn.Conv2D(channels[0], 7, 2, 3, use_bias=False,
                                                in_channels=3))
                self.features.add(nn.BatchNorm())
                self.features.add(nn.Activation('relu'))
                self.features.add(nn.MaxPool2D(3, 2, 1))

            for i, num_layer in enumerate(layers):
                stride = 1 if i == 0 else 2
                self.features.add(self._make_sumprod_layer(num_layer, channels[i+1], stride,
                                                   nbr_mul[i+1], out_patch[i+1], i+1, in_channels=channels[i], quant_down=quant_down))

            self.classifier = nn.HybridSequential(prefix='')
            self.classifier.add(nn.GlobalAvgPool2D())
            self.classifier.add(nn.Flatten())
            self.classifier.add(nn.Dense(classes, in_units=channels[-1]))

    # stack residual blocks as described in [2]
    def _make_sumprod_layer(self, layers, channels, stride, nbr_mul, out_patch, stage_index, in_channels=0, quant_down=True):
        layer = nn.HybridSequential(prefix='stage%d_'%stage_index)
        with layer.name_scope():
            layer.add(BasicBlockV1SumProd(channels, stride, nbr_mul, out_patch, 0, channels != in_channels, in_channels=in_channels, quant_down=quant_down, prefix=''))
            nbr_conv = 2 if channels == in_channels else 3
            for i in range(layers-1):
                layer.add(BasicBlockV1SumProd(channels, 1, nbr_mul, out_patch, nbr_conv, False, in_channels=channels, quant_down=quant_down, prefix=''))
                nbr_conv += 2
        return layer

    def hybrid_forward(self, F, x):
        x = self.features(x)
        x = self.classifier(x)

        return x

# 5. Training the network

We are now ready to train and evaluate the SP-ResNet-20 on CIFAR-10 as described in the paper.

We start by defining the SP-Resnet-20. The parameters below realize the configuration for $p=1$ and $r=c_\mathrm{out}$ in Table 1 (first row, first column) in the paper, and can be modified for other configurations. Furthermore, the context can be modified below (use `ctx = [mx.cpu(0)]` if you don't have a GPU available, or extend to a list of devices for parallel training).

In [5]:
# CIFAR-10 has 10 classes
classes = 10
# structure of ResNet-20
res_units = [3, 3, 3]
res_channels = [16, 16, 32, 64]
# choose the width of the sum-product convolution networks (i.e., $r$) to be equal to the number of output channels
# don't compress the last layer (encoded by nbr_mul=0)
nbr_mul = [16, 16, 32, 64, 0]
# don't do spatial compression (out. patch=1x1)
out_patch = [(1,1)]*4
# quantize ResNet 1x1 convolution projection layers
quant_down = True
# use 3x3 kernels for the first convolution layer, don't do pooling after the first convolution layer
thumbnail = True

# construct ResNet HybridBlock object
sp_resnet = ResNetV1SumProd(res_units, res_channels, nbr_mul, out_patch, classes, thumbnail, quant_down)

# define context
ctx = [mx.gpu(0)]

Next, we initialize the network parameters. As we first train all weights at full precision first, `sp_mode` is set to  1. `sp_in_weights` (corresponding to $\tilde a = W_A \mathrm{vec}(A)$) is initialized to 1s and Xavier initialization is used for `sp_data_weights` and `sp_out_weights` (corresponding to $W_B$ and $W_C$, respectively), and for the fully connected output layer. Furthermore, optimization is deactivated for the auxiliary batchnorm variables.

In [6]:
# helper function to filter and intialize paramters whose name contains a given substring
def initialize_by_key(net, key, init, context, force_reinit=False):
    for k, p in net.collect_params().items():
        if key in k:
            p.initialize(init=init, ctx=context, force_reinit=force_reinit)

# initialize SP networks
initialize_by_key(sp_resnet, 'sp_mode', mx.init.One(), ctx)
initialize_by_key(sp_resnet, 'sp_in_weights', mx.init.One(), ctx)
initialize_by_key(sp_resnet, 'sp_data_weights', mx.init.Xavier(magnitude=2), ctx)
initialize_by_key(sp_resnet, 'sp_out_weights', mx.init.Xavier(magnitude=2), ctx)

# intialize batchnorm parameters
initialize_by_key(sp_resnet, 'gamma', mx.init.One(), ctx)
initialize_by_key(sp_resnet, 'beta', mx.init.Zero(), ctx)
initialize_by_key(sp_resnet, 'running_mean', mx.init.Zero(), ctx)
initialize_by_key(sp_resnet, 'running_var', mx.init.Zero(), ctx)

# initialize fully connected layer
initialize_by_key(sp_resnet, 'dense0_weight', mx.init.Xavier(magnitude=2), ctx) #mx.init.Normal(1)
initialize_by_key(sp_resnet, 'dense0_bias', mx.init.Zero(), ctx)

# deactivate optimization for original convolutions as they are not used here
for k, p in sp_resnet.collect_params().items():
    if 'conv' in k and 'sp' not in k:
        p.initialize(init=mx.init.Xavier(magnitude=2), ctx=ctx)
        p.grad_req = 'null'

# deactivate optimization for batchnorm auxiliary variables
for k, p in sp_resnet.collect_params().items():
    if 'running_mean' in k or 'running_var' in k or 'sp_mode' in k:
        p.grad_req = 'null'

We continue by defining a testing routine and a training loop. This again looks lengthy but is standard and adapted from [5] with slight modifications.

[5] https://github.com/apache/incubator-mxnet/blob/master/example/gluon/image_classification.py

In [7]:
import time

def test(net, val_data, context):
    metric = mx.metric.Accuracy()
    val_data.reset()
    for batch in val_data:
        data = gluon.utils.split_and_load(batch.data[0], ctx_list=context, batch_axis=0)
        label = gluon.utils.split_and_load(batch.label[0], ctx_list=context, batch_axis=0)
        outputs = []
        for x in data:
            outputs.append(net(x))
        metric.update(label, outputs)
    return metric.get()


def train(net, lr, mom, wd, epochs, scheduler, context, train_data, val_data, log_interval):
    trainer = gluon.Trainer(net.collect_params(), 'sgd',
                            {'learning_rate': lr, 'wd': wd, 'momentum': mom, 'lr_scheduler': scheduler},
                            kvstore = 'device')
    metric = mx.metric.Accuracy()
    loss = gluon.loss.SoftmaxCrossEntropyLoss()

    for epoch in range(epochs):
        tic = time.time()
        train_data.reset()
        metric.reset()
        btic = time.time()
        for i, batch in enumerate(train_data):
            data = gluon.utils.split_and_load(batch.data[0], ctx_list=ctx, batch_axis=0)
            label = gluon.utils.split_and_load(batch.label[0], ctx_list=ctx, batch_axis=0)
            outputs = []
            Ls = []
            with autograd.record():
                for x, y in zip(data, label):
                    z = net(x)
                    L = loss(z, y)
                    Ls.append(L)
                    outputs.append(z)
                for L in Ls:
                    L.backward()
            trainer.step(batch.data[0].shape[0])
            metric.update(label, outputs)
            if log_interval and not (i+1)%log_interval:
                name, acc = metric.get()
                print('Epoch[%d] Batch [%d]\tSpeed: %f samples/sec\t%s=%f'%(
                               epoch, i, batch_size/(time.time()-btic), name, acc))
            btic = time.time()

        name, acc = metric.get()
        print('[Epoch %d] training: %s=%f'%(epoch, name, acc))
        print('[Epoch %d] time cost: %f'%(epoch, time.time()-tic))
        name, val_acc = test(net, val_data, context)
        print('[Epoch %d] validation: %s=%f'%(epoch, name, val_acc))

Now, we define the iterators, employing the same data augmentation procedures as described in [3], Sec. 4.2. The CIFAR-10 data set should be downloaded automatically if not available locally (otherwise, you can donwload it manually from http://data.mxnet.io/mxnet/data/cifar10.zip, create a subfolder 'data' in the folder containing this notebook, and extract it there).

In [8]:
mx.random.seed(321)

data_shape = (3, 32, 32)
batch_size = 128

# download CIFAR-10 if necessary
zip_file_path = mx.test_utils.download('http://data.mxnet.io/mxnet/data/cifar10.zip', dirname='data')
import zipfile
with zipfile.ZipFile(zip_file_path) as zf:
    zf.extractall('data')

# training set iterator
train_data = mx.io.ImageRecordIter(
    path_imgrec   = "data/cifar/train.rec",
    data_shape    = data_shape,
    batch_size    = batch_size,
    mean_r        = 125.3,
    mean_g        = 123.0,
    mean_b        = 113.9,
    std_r         = 63.0,
    std_g         = 62.1,
    std_b         = 66.7,
    dtype         = 'float32',
    rand_crop     = True,
    max_crop_size = 32,
    min_crop_size = 32,
    pad           = 4,
    fill_value    = 0,
    shuffle       = True,
    rand_mirror   = True,
    shuffle_chunk_seed  = 123)

# validation set iterator
val_data = mx.io.ImageRecordIter(
    mean_r      = 125.3,
    mean_g      = 123.0,
    mean_b      = 113.9,
    std_r       = 63.0,
    std_g       = 62.1,
    std_b       = 66.7,
    dtype       = 'float32',
    path_imgrec = "data/cifar/test.rec",
    rand_crop   = False,
    rand_mirror = False,
    data_shape  = data_shape,
    batch_size  = batch_size)

Specify optimizer paramters, regularization, and learning rate schedule (150-50-50 epochs x0.1)... 

In [9]:
epochs = 250
learning_rate = 0.1
momentum = 0.9
weight_decay = 0.0001

train_set_size = 50000
# introduce auxiliary variable to translate steps into epochs
schedule_factor = ceil(train_set_size/batch_size)
scheduler = mx.lr_scheduler.MultiFactorScheduler([150*schedule_factor, 200*schedule_factor, 250*schedule_factor], factor=0.1)

log_interval = 50

... and train the SP-ResNet-20 from scratch with full precision entries for $W_B$, $W_C$.

In [10]:
train(sp_resnet, learning_rate, momentum, weight_decay, epochs, scheduler, ctx, train_data, val_data, log_interval)

Epoch[0] Batch [49]	Speed: 3439.187414 samples/sec	accuracy=0.207344
Epoch[0] Batch [99]	Speed: 2777.957850 samples/sec	accuracy=0.246641
Epoch[0] Batch [149]	Speed: 3822.723345 samples/sec	accuracy=0.276719
Epoch[0] Batch [199]	Speed: 3925.585411 samples/sec	accuracy=0.295195
Epoch[0] Batch [249]	Speed: 3826.946987 samples/sec	accuracy=0.312875
Epoch[0] Batch [299]	Speed: 3414.449149 samples/sec	accuracy=0.328333
Epoch[0] Batch [349]	Speed: 3957.182222 samples/sec	accuracy=0.344062
[Epoch 0] training: accuracy=0.357197
[Epoch 0] time cost: 16.649981
[Epoch 0] validation: accuracy=0.449862
Epoch[1] Batch [49]	Speed: 3053.334804 samples/sec	accuracy=0.496250
Epoch[1] Batch [99]	Speed: 3770.902368 samples/sec	accuracy=0.506641
Epoch[1] Batch [149]	Speed: 3821.009302 samples/sec	accuracy=0.516354
Epoch[1] Batch [199]	Speed: 3916.021707 samples/sec	accuracy=0.525742
Epoch[1] Batch [249]	Speed: 3827.929298 samples/sec	accuracy=0.535312
Epoch[1] Batch [299]	Speed: 3848.454241 samples/sec	acc

The training should have led to a similar validation accuracy as reported in [3], Sec. 4.2. We next activate quantization for $W_B$, $W_C$ (by setting `sp_mode` to 2) and continue training for 40 epochs.

In [11]:
initialize_by_key(sp_resnet, 'sp_mode', mx.init.Constant(2), ctx, True)

epochs = 40
learning_rate = 0.01
scheduler = mx.lr_scheduler.MultiFactorScheduler([10*schedule_factor, 20*schedule_factor, 30*schedule_factor, 40*schedule_factor], factor=0.1)

train(sp_resnet, learning_rate, momentum, weight_decay, epochs, scheduler, ctx, train_data, val_data, log_interval)

Epoch[0] Batch [49]	Speed: 1069.645083 samples/sec	accuracy=0.729375
Epoch[0] Batch [99]	Speed: 1072.894238 samples/sec	accuracy=0.781250
Epoch[0] Batch [149]	Speed: 1066.555374 samples/sec	accuracy=0.808385
Epoch[0] Batch [199]	Speed: 1063.055985 samples/sec	accuracy=0.824844
Epoch[0] Batch [249]	Speed: 1017.921068 samples/sec	accuracy=0.834281
Epoch[0] Batch [299]	Speed: 1146.652895 samples/sec	accuracy=0.842630
Epoch[0] Batch [349]	Speed: 973.466941 samples/sec	accuracy=0.850022
[Epoch 0] training: accuracy=0.854667
[Epoch 0] time cost: 51.520275
[Epoch 0] validation: accuracy=0.853165
Epoch[1] Batch [49]	Speed: 782.601532 samples/sec	accuracy=0.905937
Epoch[1] Batch [99]	Speed: 1060.586311 samples/sec	accuracy=0.903750
Epoch[1] Batch [149]	Speed: 1118.213161 samples/sec	accuracy=0.901771
Epoch[1] Batch [199]	Speed: 1114.038837 samples/sec	accuracy=0.899258
Epoch[1] Batch [249]	Speed: 1104.597655 samples/sec	accuracy=0.900094
Epoch[1] Batch [299]	Speed: 1102.599046 samples/sec	accur

Finally, we fix $W_B$, $W_C$ and only train $\tilde a$ for another 10 epochs.

In [12]:
for k, p in sp_resnet.collect_params().items():
    if 'sp_data_weights' in k or 'sp_out_weights' in k:
        p.grad_req = 'null'

epochs = 10
learning_rate = 0.001
scheduler = None

train(sp_resnet, learning_rate, momentum, weight_decay, epochs, scheduler, ctx, train_data, val_data, log_interval)

Epoch[0] Batch [49]	Speed: 803.767555 samples/sec	accuracy=0.960000
Epoch[0] Batch [99]	Speed: 1187.712323 samples/sec	accuracy=0.957031
Epoch[0] Batch [149]	Speed: 999.350195 samples/sec	accuracy=0.957187
Epoch[0] Batch [199]	Speed: 1145.439462 samples/sec	accuracy=0.957500
Epoch[0] Batch [249]	Speed: 1179.003233 samples/sec	accuracy=0.958031
Epoch[0] Batch [299]	Speed: 1135.910189 samples/sec	accuracy=0.959167
Epoch[0] Batch [349]	Speed: 1168.841071 samples/sec	accuracy=0.959888
[Epoch 0] training: accuracy=0.960377
[Epoch 0] time cost: 47.362762
[Epoch 0] validation: accuracy=0.909255
Epoch[1] Batch [49]	Speed: 1119.558642 samples/sec	accuracy=0.960938
Epoch[1] Batch [99]	Speed: 1220.438536 samples/sec	accuracy=0.959375
Epoch[1] Batch [149]	Speed: 1152.179929 samples/sec	accuracy=0.958542
Epoch[1] Batch [199]	Speed: 1152.669731 samples/sec	accuracy=0.958477
Epoch[1] Batch [249]	Speed: 892.311663 samples/sec	accuracy=0.958812
Epoch[1] Batch [299]	Speed: 925.125899 samples/sec	accurac

The accuracy of the model with quantized $W_A$, $W_B$ should have a validation accuracy slighlty below that of the unquantized model.

We finally save the paramters of the quantized model.

In [13]:
sp_resnet.collect_params().save('sp_resnet.params')

Done!

### References

[1] https://github.com/apache/incubator-mxnet/blob/master/docs/tutorials/gluon/customop.md

[2] F. Li, B. Zhang, B. Liu, Ternary Weight Networks, 2016, https://arxiv.org/abs/1605.04711v2

[3] K. He, X. Zhang, S. Ren, J. Sun, Deep Residual Learning for Image Recognition, 2016, http://arxiv.org/abs/1512.03385

[4] https://github.com/apache/incubator-mxnet/blob/master/python/mxnet/gluon/model_zoo/vision/resnet.py

[5] https://github.com/apache/incubator-mxnet/blob/master/example/gluon/image_classification.py