Let's see how we can build and train a hand-written digit classifier on the MNIST dataset using MXNet Gluon API with Horovod. We made necessary modifications on this tutorial (https://mxnet.incubator.apache.org/versions/master/tutorials/gluon/mnist.html) to fit into Horovod:

### 1. Import all required packages and set up training

In [None]:
import logging
import time

import horovod.mxnet as hvd
import mxnet as mx
from mxnet import autograd, gluon, nd


data_dir = '/home/ubuntu/mnist/data'
batch_size = 64
dtype = 'float32'
epochs = 5

logging.basicConfig(level=logging.INFO)

### 2. Initialize Horovod

In [None]:
hvd.init()

### 3. Set context by pinning GPU to local rank

In [None]:
context = mx.gpu(hvd.local_rank())

### 4. Load MNIST dataset

In [None]:
def get_mnist_iterator():
    def batch_fn(batch, ctx):
        data = gluon.utils.split_and_load(batch.data[0], ctx_list=ctx,
                                          batch_axis=0)
        label = gluon.utils.split_and_load(batch.label[0], ctx_list=ctx,
                                           batch_axis=0)
        return data, label
    
    input_shape = (1, 28, 28)

    # Split the training data for each worker
    train_iter = mx.io.MNISTIter(
        image="%s/train-images-idx3-ubyte" % data_dir,
        label="%s/train-labels-idx1-ubyte" % data_dir,
        input_shape=input_shape,
        batch_size=batch_size,
        shuffle=True,
        flat=False,
        num_parts=hvd.size(),
        part_index=hvd.rank()
    )

    # Keep each worker to use full validation data to make it easy to monitor results
    val_iter = mx.io.MNISTIter(
        image="%s/t10k-images-idx3-ubyte" % data_dir,
        label="%s/t10k-labels-idx1-ubyte" % data_dir,
        input_shape=input_shape,
        batch_size=batch_size,
        flat=False
    )

    return train_iter, val_iter, batch_fn

train_iter, val_iter, batch_fn = get_mnist_iterator()

### 5. Define a Convolutional Nerual Network

In [None]:
# Use hybrid blocks for better performance
net = gluon.nn.HybridSequential()
with net.name_scope():
    net.add(gluon.nn.Conv2D(channels=20, kernel_size=5, activation='relu'))
    net.add(gluon.nn.MaxPool2D(pool_size=2, strides=2))
    net.add(gluon.nn.Conv2D(channels=50, kernel_size=5, activation='relu'))
    net.add(gluon.nn.MaxPool2D(pool_size=2, strides=2))
    # The Flatten layer collapses all axis, except the first one, into one axis.
    net.add(gluon.nn.Flatten())
    net.add(gluon.nn.Dense(512, activation="relu"))
    net.add(gluon.nn.Dense(10))
net.cast(dtype)
net.hybridize()

### 6. Initializer parameters

In [None]:
initializer = mx.init.Xavier(rnd_type='gaussian', factor_type="in", magnitude=2)
net.initialize(initializer, ctx=context)

### 7. Horovod: fetch and broadcast parameters

In [None]:
params = net.collect_params()
if params is not None:
    hvd.broadcast_parameters(params, root_rank=0)

### 8. Create SGD optimizer

In [None]:
optimizer_params = {'momentum': 0.9,
                    'learning_rate': 0.025 * hvd.size(),
                    'rescale_grad': 1.0 / batch_size}
if dtype == 'float16':
    optimizer_params['multi_precision'] = True
opt = mx.optimizer.create('sgd', **optimizer_params)

### 9. Horovod: wrap optimizer with DistributedOptimizer

In [None]:
opt = hvd.DistributedOptimizer(opt)

### 10. Create trainer and loss function

In [None]:
trainer = gluon.Trainer(params, opt, kvstore=None)
loss_fn = gluon.loss.SoftmaxCrossEntropyLoss()

### 11. Define a function to evaluate accuracy

In [None]:
def evaluate_accuracy(net, data_iter, context):
    data_iter.reset()

    acc = mx.metric.Accuracy()
    for _, batch in enumerate(data_iter):
        data, label = batch_fn(batch, [context])
        outputs = [net(x.astype(dtype, copy=False)) for x in data]
        preds = [nd.argmax(output, axis=1) for output in outputs]
        acc.update(label, preds)
    return acc.get()[1]

### 12. Train the CNN

In [None]:
for epoch in range(epochs):
    tic = time.time()

    train_iter.reset()
    for nbatch, batch in enumerate(train_iter, start=1):
        data, label = batch_fn(batch, [context])
        with autograd.record(): # Start recording the derivatives
            outputs = [net(x.astype(dtype, copy=False)) for x in data] # the forward iteration
            loss = [loss_fn(yhat, y) for yhat, y in zip(outputs, label)] # compute the loss
        for l in loss:
            l.backward() # backpropgation
        trainer.step(batch_size)

    train_acc = evaluate_accuracy(net, train_iter, context)
    if hvd.rank() == 0:
        elapsed = time.time() - tic
        speed = nbatch * batch_size * hvd.size() / elapsed
        logging.info('Epoch[%d]\tSpeed=%.2f samples/s\tTime cost=%f', epoch, speed, elapsed)

    # Evaluate model accuracy
    val_acc = evaluate_accuracy(net, val_iter, context)
    logging.info('Epoch[%d]\tTrain-accuracy=%f\tValidation-accuracy=%f', epoch, train_acc, val_acc)

    if hvd.rank() == 0 and epoch == epochs - 1:
        assert val_acc > 0.96, "Achieved accuracy (%f) is lower than expected (0.96)" % val_acc
        logging.info("Done training with top-1 accuracy %f", val_acc)