# Keras + horovod + ipyparallel MNIST example

In this notebook we will use ipyparallel to deploy a Keras + Horovod distributed training example.

In [1]:
# System imports
from __future__ import print_function
from __future__ import division
from __future__ import absolute_import

# External imports
import ipyparallel as ipp

## Connect to ipyparallel cluster

In [2]:
%%bash
squeue -u sfarrell

             JOBID PARTITION     NAME     USER ST       TIME  NODES NODELIST(REASON)
          14772843 interacti       sh sfarrell  R      22:18      4 nid000[11-12,22-23]


In [3]:
# Cluster ID taken from job ID above
job_id = 14772843
cluster_id = 'cori_{}'.format(job_id)

# Use default profile
c = ipp.Client(timeout=60, cluster_id=cluster_id)
print('Worker IDs:', c.ids)

Worker IDs: [0, 1, 2, 3]


## Initialize environment on the workers

In [4]:
%%px

from __future__ import print_function
from __future__ import division

import socket
import math

import keras
from keras.datasets import mnist
from keras.models import Sequential
from keras.layers import Dense, Dropout, Flatten
from keras.layers import Conv2D, MaxPooling2D
from keras import backend as K
import tensorflow as tf

# Horovod for MPI synchronization routines
import horovod.keras as hvd

In [5]:
%%px

# Initialize horovod
hvd.init()
print('MPI rank %i, local rank %i, host %s' %
      (hvd.rank(), hvd.local_rank(), socket.gethostname()))

[stdout:0] MPI rank 0, local rank 0, host nid00011
[stdout:1] MPI rank 2, local rank 0, host nid00022
[stdout:2] MPI rank 3, local rank 0, host nid00023
[stdout:3] MPI rank 1, local rank 0, host nid00012


In [6]:
%%px

# Data config
n_classes = 10
img_rows, img_cols = 28, 28

# Training config
batch_size = 128
n_epochs = 8

## Load the data on each worker

In [7]:
%%px

(x_train, y_train), (x_test, y_test) = mnist.load_data()

if K.image_data_format() == 'channels_first':
    x_train = x_train.reshape(x_train.shape[0], 1, img_rows, img_cols)
    x_test = x_test.reshape(x_test.shape[0], 1, img_rows, img_cols)
    input_shape = (1, img_rows, img_cols)
else:
    x_train = x_train.reshape(x_train.shape[0], img_rows, img_cols, 1)
    x_test = x_test.reshape(x_test.shape[0], img_rows, img_cols, 1)
    input_shape = (img_rows, img_cols, 1)

# Scale pixels to [0, 1]
x_train = x_train.astype('float32') / 255
x_test = x_test.astype('float32') / 255

print('x_train shape:', x_train.shape)
print(x_train.shape[0], 'train samples')
print(x_test.shape[0], 'test samples')

# Convert class vectors to binary class matrices
y_train = keras.utils.to_categorical(y_train, n_classes)
y_test = keras.utils.to_categorical(y_test, n_classes)

[stdout:0] 
x_train shape: (60000, 28, 28, 1)
60000 train samples
10000 test samples
[stdout:1] 
x_train shape: (60000, 28, 28, 1)
60000 train samples
10000 test samples
[stdout:2] 
x_train shape: (60000, 28, 28, 1)
60000 train samples
10000 test samples
[stdout:3] 
x_train shape: (60000, 28, 28, 1)
60000 train samples
10000 test samples


## Define the model

In [8]:
%%px

model = Sequential()
model.add(Conv2D(32, kernel_size=(3, 3),
                 activation='relu',
                 input_shape=input_shape))
model.add(Conv2D(64, (3, 3), activation='relu'))
model.add(MaxPooling2D(pool_size=(2, 2)))
model.add(Dropout(0.25))
model.add(Flatten())
model.add(Dense(128, activation='relu'))
model.add(Dropout(0.5))
model.add(Dense(n_classes, activation='softmax'))

# Adjust learning rate based on number of workers.
opt = keras.optimizers.Adadelta(1.0 * hvd.size())

# Add Horovod Distributed Optimizer.
opt = hvd.DistributedOptimizer(opt)

model.compile(loss=keras.losses.categorical_crossentropy,
              optimizer=opt,
              metrics=['accuracy'])

if hvd.rank() == 0:
    model.summary()

[stdout:0] 
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
conv2d_19 (Conv2D)           (None, 26, 26, 32)        320       
_________________________________________________________________
conv2d_20 (Conv2D)           (None, 24, 24, 64)        18496     
_________________________________________________________________
max_pooling2d_10 (MaxPooling (None, 12, 12, 64)        0         
_________________________________________________________________
dropout_19 (Dropout)         (None, 12, 12, 64)        0         
_________________________________________________________________
flatten_10 (Flatten)         (None, 9216)              0         
_________________________________________________________________
dense_19 (Dense)             (None, 128)               1179776   
_________________________________________________________________
dropout_20 (Dropout)         (None, 128)               0        

## Distributed training

Training with horovod + MPI allows for synchronous distributed batch updates.

We need to register the model synchronization callback and restrict checkpoint writing to a single worker.

In [None]:
%%px

callbacks = [
    # Horovod: broadcast initial variable states from rank 0 to all other processes.
    # This is necessary to ensure consistent initialization of all workers when
    # training is started with random weights or restored from a checkpoint.
    hvd.callbacks.BroadcastGlobalVariablesCallback(0),
]

# Horovod: save checkpoints only on worker 0 to prevent other workers from corrupting them.
#if hvd.rank() == 0:
#    callbacks.append(keras.callbacks.ModelCheckpoint('./checkpoint-{epoch}.h5'))

history = model.fit(x_train, y_train,
                    batch_size=batch_size,
                    callbacks=callbacks,
                    epochs=n_epochs,
                    verbose=2,
                    validation_data=(x_test, y_test))

## Evaluate the model

In [10]:
%%px

score = model.evaluate(x_test, y_test, verbose=0)
print('Test loss:', score[0])
print('Test accuracy:', score[1])

[stdout:0] 
Test loss: 0.03895845667719013
Test accuracy: 0.9929
[stdout:1] 
Test loss: 0.03895845667719013
Test accuracy: 0.9929
[stdout:2] 
Test loss: 0.03895845667719013
Test accuracy: 0.9929
[stdout:3] 
Test loss: 0.03895845667719013
Test accuracy: 0.9929
[stdout:4] 
Test loss: 0.03895845667719013
Test accuracy: 0.9929
[stdout:5] 
Test loss: 0.03895845667719013
Test accuracy: 0.9929
[stdout:6] 
Test loss: 0.03895845667719013
Test accuracy: 0.9929
[stdout:7] 
Test loss: 0.03895845667719013
Test accuracy: 0.9929
