# MXNet Distributed Training (MNIST Sample) with Azure ML service

See "[MXNet Distributed Training Example for Azure ML service](https://tsmatz.wordpress.com/2019/01/17/azure-machine-learning-service-custom-amlcompute-and-runconfig-for-mxnet-distributed-training/)" for details.

## 1. Preparation

Before starting, you must create Azure Machine Learning service workspace and prepare config settings.

See [here](https://github.com/tsmatz/azure-ml-tensorflow-complete-sample/blob/master/Readme.md) for details.

## 2. Save Script for MXNet Distributed Training (mnist_distributed.py)

Note : Use commented lines for your debugging in local (with 1 CPU device).

In [1]:
import os
script_folder = './script'
os.makedirs(script_folder, exist_ok=True)

In [2]:
%%writefile script/mnist_distributed.py
import os, random
import mxnet as mx
from mxnet import kv, gluon, autograd, nd
from mxnet.gluon import nn

store = kv.create('dist')

gpus_per_machine = 1
batch_size_per_gpu = 64
batch_size = batch_size_per_gpu * gpus_per_machine
num_epochs = 5

ctx = [mx.gpu(i) for i in range(gpus_per_machine)]
# ctx = mx.cpu(0)

class SplitSampler(gluon.data.sampler.Sampler):
    """
    length: Number of examples in the dataset
    num_parts: Partition the data into multiple parts
    part_index: The index of the part to read from
    """
    def __init__(self, length, num_parts, part_index):
        self.part_len = length // num_parts
        self.start = self.part_len * part_index
        self.end = self.start + self.part_len
    def __iter__(self):
        indices = list(range(self.start, self.end))
        random.shuffle(indices)
        return iter(indices)
    def __len__(self):
        return self.part_len

mx.random.seed(42)
def data_xform(data):
    """Move channel axis to the beginning, cast to float32, and normalize to [0, 1]"""
    return nd.moveaxis(data, 2, 0).astype('float32') / 255
train_data = gluon.data.DataLoader(
    gluon.data.vision.MNIST(train=True).transform_first(data_xform),
    batch_size=batch_size,
    sampler=SplitSampler(50000, store.num_workers, store.rank))
test_data = gluon.data.DataLoader(
    gluon.data.vision.MNIST(train=False).transform_first(data_xform),
    batch_size=batch_size,
    shuffle=False)
# train_data = gluon.data.DataLoader(
#     gluon.data.vision.MNIST(train=True, root='./data').transform_first(data_xform),
#     batch_size=batch_size)
# test_data = gluon.data.DataLoader(
#     gluon.data.vision.MNIST(train=False, root='./data').transform_first(data_xform),
#     batch_size=batch_size,
#     shuffle=False)

net = nn.HybridSequential(prefix='MLP_')
with net.name_scope():
    net.add(
        nn.Flatten(),
        nn.Dense(128, activation='relu'),
        nn.Dense(64, activation='relu'),
        nn.Dense(10, activation=None)
    )

net.hybridize()

net.initialize(mx.init.Xavier(), ctx=ctx)

loss_function = gluon.loss.SoftmaxCrossEntropyLoss()

trainer = gluon.Trainer(
    params=net.collect_params(),
    optimizer='sgd',
    optimizer_params={'learning_rate': 0.07},
    kvstore=store)
# trainer = gluon.Trainer(
#     params=net.collect_params(),
#     optimizer='sgd',
#     optimizer_params={'learning_rate': 0.07},
# )

for epoch in range(num_epochs):
    """ Train ! """
    for batch in train_data:
        inputs = gluon.utils.split_and_load(batch[0], ctx)
        labels = gluon.utils.split_and_load(batch[1], ctx)
        # inputs = batch[0].as_in_context(ctx)
        # labels = batch[1].as_in_context(ctx)
        with autograd.record():
            loss = [loss_function(net(X), Y) for X, Y in zip(inputs, labels)]
            # loss = loss_function(net(inputs), labels)
        for l in loss:
            l.backward()
        # loss.backward()
        trainer.step(batch_size=batch[0].shape[0])
    """ Evaluate and Output ! """
    metric = mx.metric.Accuracy()
    for i, (test_input, test_label) in enumerate(test_data):
        test_input = test_input.as_in_context(ctx[0])
        test_label = test_label.as_in_context(ctx[0])
        # test_input = test_input.as_in_context(ctx)
        # test_label = test_label.as_in_context(ctx)
        test_output = net(test_input)
        test_pred = nd.argmax(test_output, axis=1)
        metric.update(preds=test_pred, labels=test_label)
    print('Epoch %d: Accuracy %f' % (epoch, metric.get()[1]))

""" Save Model (both architecture and parameters) """
if store.rank == 0:
    os.makedirs('./outputs', exist_ok=True)
    net.export('./outputs/test', epoch=1)
# os.makedirs('./outputs', exist_ok=True)
# net.export('./outputs/test', epoch=1)

Writing script/mnist_distributed.py


## 3. Save Script to Start Each Roles (start_mx_role.py)

In [3]:
%%writefile script/start_mx_role.py
import argparse
import os
from mpi4py import MPI

parser = argparse.ArgumentParser()
parser.add_argument(
    '--num_workers',
    type=int,
    default=0,
    help='Specifies how many worker roles')
parser.add_argument(
    '--num_servers',
    type=int,
    default=0,
    help='Specifies how many server roles')
parser.add_argument(
    '--scheduler_host',
    type=str,
    default='10.0.0.4',
    help='Specifies the IP of the scheduler')
FLAGS, unparsed = parser.parse_known_args()

#
# See https://mxnet.incubator.apache.org/faq/distributed_training.html
#

mpi_comm = MPI.COMM_WORLD
mpi_rank = mpi_comm.Get_rank()
if mpi_rank == 0 :
    # Rank 0 is scheduler
    os.environ['DMLC_ROLE'] = 'scheduler'
elif mpi_rank <= FLAGS.num_servers :
    # Rank 1, ..., FLAGS.num_servers is server
    os.environ['DMLC_ROLE'] = 'server'
else :
    # Others are all workers (The count of workers must equal to FLAGS.num_workers.)
    os.environ['DMLC_ROLE'] = 'worker'
os.environ['DMLC_PS_ROOT_URI'] = FLAGS.scheduler_host
os.environ['DMLC_PS_ROOT_PORT'] = '9092'
os.environ['DMLC_NUM_WORKER'] = str(FLAGS.num_workers)
os.environ['DMLC_NUM_SERVER'] = str(FLAGS.num_servers)

#
# Run previous script !
#
import mnist_distributed

Writing script/start_mx_role.py


## 4. Read Workspace Config

In [4]:
from azureml.core import Workspace
import azureml.core

ws = Workspace.from_config()

Found the config file in: /data/home/tsmatsuz/mxnet/aml_config/config.json


## 5. Create Cluster (Nodes)

Total 4 nodes : scheduler, parameter server, worker0, worker1

In [5]:
from azureml.core import Workspace
import azureml.core
from azureml.core.compute import AmlCompute, ComputeTarget
from azureml.core.compute_target import ComputeTargetException
 
# Create AML compute (or Get existing one)
# (Total 4 : scheduler, server, worker1, worker2)
try:
    compute_target = ComputeTarget(workspace=ws, name='cluster01')
    print('found existing:', compute_target.name)
except ComputeTargetException:
    print('creating new.')
    compute_config = AmlCompute.provisioning_configuration(
        vm_size='Standard_NC6',
        min_nodes=4,
        max_nodes=4)
    compute_target = ComputeTarget.create(ws, 'cluster01', compute_config)
    compute_target.wait_for_completion(show_output=True)

creating new.
Creating
Succeeded..................
AmlCompute wait for completion finished
Minimum number of nodes requested have been provisioned


## 6. Generate Config

In [6]:
from azureml.core import ScriptRunConfig, Experiment, Run
from azureml.core.runconfig import RunConfiguration
from azureml.core.conda_dependencies import CondaDependencies
 
conda_dep = CondaDependencies.create()
conda_dep.add_pip_package('mxnet-cu90');
conda_dep.add_pip_package('mpi4py');
run_config = RunConfiguration(
    framework='python',
    conda_dependencies=conda_dep)
run_config.target = compute_target.name
run_config.environment.docker.enabled = True
run_config.environment.docker.gpu_support = True
run_config.environment.docker.base_image = 'tsmatz/azureml-openmpi:0.1.0-gpu'
run_config.communicator = 'OpenMpi'
run_config.node_count = 4
run_config.mpi.process_count_per_node = 1

src = ScriptRunConfig(
    source_directory='./script',
    script='start_mx_role.py',
    run_config=run_config,
    arguments=[
        '--num_workers', 2,
        '--num_servers', 1,
        '--scheduler_host', '`cut -d ":" -f 1 <<< $AZ_BATCH_MASTER_NODE`']) # getting master node's ip like "10.0.0.4" (or use $AZ_BATCHAI_MPI_MASTER_NODE)

## 7. Run !

In [7]:
exp = Experiment(workspace=ws, name='mnist_mxnet_distributed')
run = exp.submit(config=src)
run.wait_for_completion(show_output=True)

RunId: mnist_mxnet_distributed_1547784855230

Streaming azureml-logs/60_control_log_rank_0.txt

This is an MPI job. Rank:0
Streaming log file azureml-logs/60_control_log_rank_0.txt
Streaming log file azureml-logs/80_driver_log_rank_0.txt

Streaming azureml-logs/80_driver_log_rank_2.txt

Downloading /root/.mxnet/datasets/mnist/train-images-idx3-ubyte.gz from https://apache-mxnet.s3-accelerate.dualstack.amazonaws.com/gluon/dataset/mnist/train-images-idx3-ubyte.gz...
Downloading /root/.mxnet/datasets/mnist/train-labels-idx1-ubyte.gz from https://apache-mxnet.s3-accelerate.dualstack.amazonaws.com/gluon/dataset/mnist/train-labels-idx1-ubyte.gz...
Downloading /root/.mxnet/datasets/mnist/t10k-images-idx3-ubyte.gz from https://apache-mxnet.s3-accelerate.dualstack.amazonaws.com/gluon/dataset/mnist/t10k-images-idx3-ubyte.gz...
Downloading /root/.mxnet/datasets/mnist/t10k-labels-idx1-ubyte.gz from https://apache-mxnet.s3-accelerate.dualstack.amazonaws.com/gluon/dataset/mnist/t10k-labels-idx1-ubyt

{'runId': 'mnist_mxnet_distributed_1547784855230',
 'target': 'cluster01',
 'status': 'Finalizing',
 'startTimeUtc': '2019-01-18T04:14:32.472192Z',
 'properties': {'azureml.runsource': 'experiment',
  'ContentSnapshotId': 'a56c6fac-3a7e-4788-840a-fd38eef9810b'},
 'runDefinition': {'Script': 'start_mx_role.py',
  'Arguments': ['--num_workers',
   '2',
   '--num_servers',
   '1',
   '--scheduler_host',
   '`cut -d ":" -f 1 <<< $AZ_BATCH_MASTER_NODE`'],
  'SourceDirectoryDataStore': None,
  'Framework': 0,
  'Communicator': 5,
  'Target': 'cluster01',
  'DataReferences': {},
  'JobName': None,
  'AutoPrepareEnvironment': True,
  'MaxRunDurationSeconds': None,
  'NodeCount': 4,
  'Environment': {'Python': {'InterpreterPath': 'python',
    'UserManagedDependencies': False,
    'CondaDependencies': {'name': 'project_environment',
     'dependencies': ['python=3.6.2',
      {'pip': ['azureml-defaults==1.0.6', 'mxnet-cu90', 'mpi4py']}]}},
   'EnvironmentVariables': {'EXAMPLE_ENV_VAR': 'EXAMPLE

## 8. See Generated Model

In [8]:
# You can see and download results (test-symbol.json, test-0001.params).
run.get_file_names()

['azureml-logs/60_control_log_rank_0.txt',
 'azureml-logs/60_control_log_rank_1.txt',
 'azureml-logs/60_control_log_rank_3.txt',
 'azureml-logs/60_control_log_rank_2.txt',
 'azureml-logs/80_driver_log_rank_3.txt',
 'azureml-logs/80_driver_log_rank_2.txt',
 'outputs/test-0001.params',
 'outputs/test-symbol.json',
 'driver_log',
 'azureml-logs/azureml.log',
 'azureml-logs/80_driver_log_rank_1.txt',
 'azureml-logs/80_driver_log_rank_0.txt',
 'azureml-logs/56_batchai_stderr.txt',
 'azureml-logs/55_batchai_execution-tvm-676767296_4-20190118t040928z.txt']

## 9. Remove Cluster (Clean-up)

In [9]:
# Delete cluster (nodes) and remove from AML workspace
mycompute = AmlCompute(workspace=ws, name='cluster01')
mycompute.delete()