# Keras + horovod + ipyparallel MNIST example

In this notebook we will use ipyparallel to deploy a Keras + Horovod distributed training example.

In [1]:
# System imports
from __future__ import print_function
from __future__ import division
from __future__ import absolute_import

# External imports
import ipyparallel as ipp

## Connect to ipyparallel cluster

In [2]:
%%bash
squeue -u sfarrell

             JOBID PARTITION     NAME     USER ST       TIME  NODES NODELIST(REASON)
          24175753 interacti       sh sfarrell  R       1:35      8 nid00[163-170]


In [3]:
# Cluster ID taken from job ID above
job_id = 24175753
cluster_id = 'cori_{}'.format(job_id)

# Use default profile
c = ipp.Client(timeout=60, cluster_id=cluster_id)
print('Worker IDs:', c.ids)

Worker IDs: [0, 1, 2, 3, 4, 5, 6, 7]


## Initialize environment on the workers

In [4]:
%%px

from __future__ import print_function
from __future__ import division

import socket
import math

import keras
from keras.datasets import mnist
from keras.models import Sequential
from keras.layers import Dense, Dropout, Flatten
from keras.layers import Conv2D, MaxPooling2D
from keras import backend as K
import tensorflow as tf

# Horovod for MPI synchronization routines
import horovod.keras as hvd

[stderr:0] Using TensorFlow backend.
[stderr:1] Using TensorFlow backend.
[stderr:2] Using TensorFlow backend.
[stderr:3] Using TensorFlow backend.
[stderr:4] Using TensorFlow backend.
[stderr:5] Using TensorFlow backend.
[stderr:6] Using TensorFlow backend.
[stderr:7] Using TensorFlow backend.


In [5]:
%%px

# Initialize horovod
hvd.init()
print('MPI rank %i, local rank %i, host %s' %
      (hvd.rank(), hvd.local_rank(), socket.gethostname()))

[stdout:0] MPI rank 0, local rank 0, host nid00163
[stdout:1] MPI rank 6, local rank 0, host nid00169
[stdout:2] MPI rank 4, local rank 0, host nid00167
[stdout:3] MPI rank 1, local rank 0, host nid00164
[stdout:4] MPI rank 7, local rank 0, host nid00170
[stdout:5] MPI rank 5, local rank 0, host nid00168
[stdout:6] MPI rank 3, local rank 0, host nid00166
[stdout:7] MPI rank 2, local rank 0, host nid00165


In [6]:
%%px

# Data config
n_classes = 10
img_rows, img_cols = 28, 28

# Training config
batch_size = 128
n_epochs = 8

## Load the data on each worker

In [7]:
%%px

(x_train, y_train), (x_test, y_test) = mnist.load_data()

if K.image_data_format() == 'channels_first':
    x_train = x_train.reshape(x_train.shape[0], 1, img_rows, img_cols)
    x_test = x_test.reshape(x_test.shape[0], 1, img_rows, img_cols)
    input_shape = (1, img_rows, img_cols)
else:
    x_train = x_train.reshape(x_train.shape[0], img_rows, img_cols, 1)
    x_test = x_test.reshape(x_test.shape[0], img_rows, img_cols, 1)
    input_shape = (img_rows, img_cols, 1)

# Scale pixels to [0, 1]
x_train = x_train.astype('float32') / 255
x_test = x_test.astype('float32') / 255

print('x_train shape:', x_train.shape)
print(x_train.shape[0], 'train samples')
print(x_test.shape[0], 'test samples')

# Convert class vectors to binary class matrices
y_train = keras.utils.to_categorical(y_train, n_classes)
y_test = keras.utils.to_categorical(y_test, n_classes)

[stdout:0] 
x_train shape: (60000, 28, 28, 1)
60000 train samples
10000 test samples
[stdout:1] 
x_train shape: (60000, 28, 28, 1)
60000 train samples
10000 test samples
[stdout:2] 
x_train shape: (60000, 28, 28, 1)
60000 train samples
10000 test samples
[stdout:3] 
x_train shape: (60000, 28, 28, 1)
60000 train samples
10000 test samples
[stdout:4] 
x_train shape: (60000, 28, 28, 1)
60000 train samples
10000 test samples
[stdout:5] 
x_train shape: (60000, 28, 28, 1)
60000 train samples
10000 test samples
[stdout:6] 
x_train shape: (60000, 28, 28, 1)
60000 train samples
10000 test samples
[stdout:7] 
x_train shape: (60000, 28, 28, 1)
60000 train samples
10000 test samples


## Define the model

In [8]:
%%px

model = Sequential()
model.add(Conv2D(32, kernel_size=(3, 3),
                 activation='relu',
                 input_shape=input_shape))
model.add(Conv2D(64, (3, 3), activation='relu'))
model.add(MaxPooling2D(pool_size=(2, 2)))
model.add(Dropout(0.25))
model.add(Flatten())
model.add(Dense(128, activation='relu'))
model.add(Dropout(0.5))
model.add(Dense(n_classes, activation='softmax'))

# Adjust learning rate based on number of workers.
opt = keras.optimizers.Adadelta(1.0 * hvd.size())

# Add Horovod Distributed Optimizer.
opt = hvd.DistributedOptimizer(opt)

model.compile(loss=keras.losses.categorical_crossentropy,
              optimizer=opt,
              metrics=['accuracy'])

if hvd.rank() == 0:
    model.summary()

[stdout:0] 
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
conv2d_1 (Conv2D)            (None, 26, 26, 32)        320       
_________________________________________________________________
conv2d_2 (Conv2D)            (None, 24, 24, 64)        18496     
_________________________________________________________________
max_pooling2d_1 (MaxPooling2 (None, 12, 12, 64)        0         
_________________________________________________________________
dropout_1 (Dropout)          (None, 12, 12, 64)        0         
_________________________________________________________________
flatten_1 (Flatten)          (None, 9216)              0         
_________________________________________________________________
dense_1 (Dense)              (None, 128)               1179776   
_________________________________________________________________
dropout_2 (Dropout)          (None, 128)               0        

## Distributed training

Training with horovod + MPI allows for synchronous distributed batch updates.

We need to register the model synchronization callback and restrict checkpoint writing to a single worker.

In [9]:
%%px

callbacks = [
    # Horovod: broadcast initial variable states from rank 0 to all other processes.
    # This is necessary to ensure consistent initialization of all workers when
    # training is started with random weights or restored from a checkpoint.
    hvd.callbacks.BroadcastGlobalVariablesCallback(0),
]

# Horovod: save checkpoints only on worker 0 to prevent other workers from corrupting them.
#if hvd.rank() == 0:
#    callbacks.append(keras.callbacks.ModelCheckpoint('./checkpoint-{epoch}.h5'))

history = model.fit(x_train, y_train,
                    batch_size=batch_size,
                    callbacks=callbacks,
                    epochs=n_epochs,
                    verbose=2,
                    validation_data=(x_test, y_test))

[stdout:0] 
Train on 60000 samples, validate on 10000 samples
Epoch 1/8
 - 12s - loss: 0.1705 - acc: 0.9485 - val_loss: 0.0311 - val_acc: 0.9891
Epoch 2/8
 - 11s - loss: 0.0374 - acc: 0.9884 - val_loss: 0.0306 - val_acc: 0.9911
Epoch 3/8
 - 11s - loss: 0.0261 - acc: 0.9912 - val_loss: 0.0501 - val_acc: 0.9877
Epoch 4/8
 - 11s - loss: 0.0191 - acc: 0.9935 - val_loss: 0.0358 - val_acc: 0.9911
Epoch 5/8
 - 11s - loss: 0.0157 - acc: 0.9949 - val_loss: 0.0322 - val_acc: 0.9921
Epoch 6/8
 - 11s - loss: 0.0128 - acc: 0.9956 - val_loss: 0.0301 - val_acc: 0.9929
Epoch 7/8
 - 11s - loss: 0.0103 - acc: 0.9965 - val_loss: 0.0328 - val_acc: 0.9935
Epoch 8/8
 - 11s - loss: 0.0091 - acc: 0.9969 - val_loss: 0.0336 - val_acc: 0.9932
[stdout:1] 
Train on 60000 samples, validate on 10000 samples
Epoch 1/8
 - 12s - loss: 0.1703 - acc: 0.9483 - val_loss: 0.0311 - val_acc: 0.9891
Epoch 2/8
 - 11s - loss: 0.0371 - acc: 0.9879 - val_loss: 0.0306 - val_acc: 0.9911
Epoch 3/8
 - 11s - loss: 0.0249 - acc: 0.9919 

## Evaluate the model

In [10]:
%%px

score = model.evaluate(x_test, y_test, verbose=0)
print('Test loss:', score[0])
print('Test accuracy:', score[1])

[stdout:0] 
Test loss: 0.033566167738372045
Test accuracy: 0.9932
[stdout:1] 
Test loss: 0.033566167738372045
Test accuracy: 0.9932
[stdout:2] 
Test loss: 0.033566167738372045
Test accuracy: 0.9932
[stdout:3] 
Test loss: 0.033566167738372045
Test accuracy: 0.9932
[stdout:4] 
Test loss: 0.033566167738372045
Test accuracy: 0.9932
[stdout:5] 
Test loss: 0.033566167738372045
Test accuracy: 0.9932
[stdout:6] 
Test loss: 0.033566167738372045
Test accuracy: 0.9932
[stdout:7] 
Test loss: 0.033566167738372045
Test accuracy: 0.9932
