# Distributed training of ATLAS RPV CNN Classifier

In this notebook, we extend the Train_rpv example to train distributed across nodes
using ipyparallel and Horovod.

* TODO: improve documentation.
* TODO: run on a full Haswell batch node.
* TODO: tinker with the TF thread settings.

In [1]:
# System imports
from __future__ import print_function
from __future__ import division
from __future__ import absolute_import

# External imports
import ipyparallel as ipp

## Connect to ipyparallel cluster

In [2]:
# Use default profile for now
c = ipp.Client()

In [3]:
c.ids

[0, 1, 2, 3]

## Setup the workers

In [4]:
%%px

from __future__ import print_function
from __future__ import division
import os
import socket

import h5py
import numpy as np

import keras
from keras import layers, models, callbacks
from keras import backend as K
import horovod.keras as hvd

#from sklearn import metrics

# Initialize horovod
hvd.init()
print('MPI rank %i, local rank %i, host %s' %
      (hvd.rank(), hvd.local_rank(), socket.gethostname()))

[stdout:0] MPI rank 3, local rank 0, host nid00027
[stdout:1] MPI rank 1, local rank 0, host nid00019
[stdout:2] MPI rank 2, local rank 0, host nid00020
[stdout:3] MPI rank 0, local rank 0, host nid00018


## Load the data

In [5]:
%%px

def load_file(file, n_samples):
    with h5py.File(file) as f:
        data_group = f['all_events']
        data = data_group['hist'][:n_samples][:,:,:,None]
        labels = data_group['y'][:n_samples]
        weights = data_group['weight'][:n_samples]
    return data, labels, weights

In [6]:
%%px

# Data config
n_train = 1024 #64000 #412416
n_valid = 1024 #32000 #137471
n_test = 1024 #32000 #137471
input_dir = '/global/cscratch1/sd/sfarrell/atlas-rpv-images'

# Load the data files
train_file = os.path.join(input_dir, 'train.h5')
valid_file = os.path.join(input_dir, 'val.h5')
test_file = os.path.join(input_dir, 'test.h5')
train_input, train_labels, train_weights = load_file(train_file, n_train)
valid_input, valid_labels, valid_weights = load_file(valid_file, n_valid)
test_input, test_labels, test_weights = load_file(test_file, n_test)
print('train shape:', train_input.shape)
print('valid shape:', valid_input.shape)
print('test shape: ', test_input.shape)

[stdout:0] 
train shape: (1024, 64, 64, 1)
valid shape: (1024, 64, 64, 1)
test shape:  (1024, 64, 64, 1)
[stdout:1] 
train shape: (1024, 64, 64, 1)
valid shape: (1024, 64, 64, 1)
test shape:  (1024, 64, 64, 1)
[stdout:2] 
train shape: (1024, 64, 64, 1)
valid shape: (1024, 64, 64, 1)
test shape:  (1024, 64, 64, 1)
[stdout:3] 
train shape: (1024, 64, 64, 1)
valid shape: (1024, 64, 64, 1)
test shape:  (1024, 64, 64, 1)


## Build and train the model

In [7]:
%%px

def build_model(input_shape,
                h1=64, h2=128, h3=256, h4=256, h5=512,
                optimizer=keras.optimizers.Adam, lr=0.001):
    # Define the NN layers
    inputs = layers.Input(shape=input_shape)
    h = layers.Conv2D(h1, kernel_size=(3, 3), activation='relu', strides=1, padding='same')(inputs)
    h = layers.Conv2D(h2, kernel_size=(3, 3), activation='relu', strides=2, padding='same')(h)
    h = layers.Conv2D(h3, kernel_size=(3, 3), activation='relu', strides=1, padding='same')(h)
    h = layers.Conv2D(h4, kernel_size=(3, 3), activation='relu', strides=2, padding='same')(h)
    h = layers.Flatten()(h)
    h = layers.Dense(h5, activation='relu')(h)
    outputs = layers.Dense(1, activation='sigmoid')(h)
    # Construct the distributed optimizer
    opt = optimizer(lr=lr)
    opt = hvd.DistributedOptimizer(opt)
    # Compile the model
    model = models.Model(inputs, outputs, 'RPVClassifier')
    model.compile(optimizer=opt,
                  loss='binary_crossentropy',
                  metrics=['accuracy'])
    return model

In [8]:
%%px

# Model config
h1, h2, h3, h4, h5 = 64, 128, 256, 256, 512
optimizer = keras.optimizers.Adam
lr = 0.001 * hvd.size()

# Training config
batch_size = 64
n_epochs = 4
use_weights = False

# Build the model
model = build_model(train_input.shape[1:],
                    h1=h1, h2=h2, h3=h3, h4=h4, h5=h5,
                    optimizer=optimizer, lr=lr)
if hvd.rank() == 0:
    model.summary()

[stdout:3] 
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
input_2 (InputLayer)         (None, 64, 64, 1)         0         
_________________________________________________________________
conv2d_7 (Conv2D)            (None, 64, 64, 64)        640       
_________________________________________________________________
conv2d_8 (Conv2D)            (None, 32, 32, 128)       73856     
_________________________________________________________________
conv2d_9 (Conv2D)            (None, 32, 32, 256)       295168    
_________________________________________________________________
conv2d_10 (Conv2D)           (None, 16, 16, 256)       590080    
_________________________________________________________________
flatten_3 (Flatten)          (None, 65536)             0         
_________________________________________________________________
dense_5 (Dense)              (None, 512)               33554944 

In [9]:
%%px

# Train the model
callbacks = [
    # Horovod: broadcast initial variable states from rank 0 to all other processes.
    # This is necessary to ensure consistent initialization of all workers when
    # training is started with random weights or restored from a checkpoint.
    hvd.callbacks.BroadcastGlobalVariablesCallback(0),
]
fit_args = dict(x=train_input, y=train_labels,
                batch_size=batch_size, epochs=n_epochs,
                validation_data=(valid_input, valid_labels),
                callbacks=callbacks,
                verbose=2)
if use_weights:
    fit_args.update(sample_weight=train_weights,
                    validation_data=(valid_input, valid_labels, valid_weights))
history = model.fit(**fit_args)

[stdout:0] 
Train on 1024 samples, validate on 1024 samples
Epoch 1/4
 - 13s - loss: 0.8862 - acc: 0.5684 - val_loss: 0.5119 - val_acc: 0.8359
Epoch 2/4
 - 13s - loss: 0.5145 - acc: 0.7910 - val_loss: 0.3830 - val_acc: 0.8984
Epoch 3/4
 - 12s - loss: 0.3330 - acc: 0.9062 - val_loss: 0.4009 - val_acc: 0.8838
Epoch 4/4
 - 12s - loss: 0.1937 - acc: 0.9248 - val_loss: 0.3835 - val_acc: 0.8936
[stdout:1] 
Train on 1024 samples, validate on 1024 samples
Epoch 1/4
 - 13s - loss: 0.9627 - acc: 0.5811 - val_loss: 0.5126 - val_acc: 0.8291
Epoch 2/4
 - 13s - loss: 0.5226 - acc: 0.8076 - val_loss: 0.3836 - val_acc: 0.8975
Epoch 3/4
 - 13s - loss: 0.3216 - acc: 0.9150 - val_loss: 0.3964 - val_acc: 0.8857
Epoch 4/4
 - 12s - loss: 0.1981 - acc: 0.9238 - val_loss: 0.3828 - val_acc: 0.8936
[stdout:2] 
Train on 1024 samples, validate on 1024 samples
Epoch 1/4
 - 13s - loss: 0.9700 - acc: 0.5869 - val_loss: 0.5123 - val_acc: 0.8291
Epoch 2/4
 - 13s - loss: 0.5172 - acc: 0.8232 - val_loss: 0.3836 - val_ac

In [10]:
# Can I get worker-local variables out?
histories = c[:].get('history.history')

## Evaluate on the test set

In [11]:
%%px

score = model.evaluate(test_input, test_labels, verbose=2)
print('Test loss:', score[0])
print('Test accuracy:', score[1])

[stdout:0] 
Test loss: 0.48565021017566323
Test accuracy: 0.8720703125
[stdout:1] 
Test loss: 0.48467844212427735
Test accuracy: 0.873046875
[stdout:2] 
Test loss: 0.48526582261547446
Test accuracy: 0.873046875
[stdout:3] 
Test loss: 0.48574708541855216
Test accuracy: 0.873046875


In [12]:
%%px

# Inspect the model weights. Are they the same?
model.get_weights()[0][0,0,0]

[0;31mOut[0:28]: [0m
array([ 0.0359285 ,  0.14070866, -0.03147585, -0.11712615,  0.13835126,
        0.04000098, -0.0848937 , -0.04932515,  0.19469474,  0.06563387,
       -0.08081567, -0.07108123,  0.00721872,  0.13346985,  0.09964314,
       -0.06333202,  0.01039529,  0.04151731,  0.19975036,  0.0468534 ,
        0.20597549,  0.05762145, -0.1233936 , -0.00195075, -0.0102873 ,
        0.16713418,  0.20150925, -0.07612296, -0.05693804,  0.10351744,
        0.13675688,  0.0919622 ,  0.05040056, -0.05369506,  0.15337764,
        0.00970458,  0.05666427, -0.06198144,  0.0622775 ,  0.06317747,
        0.06043039, -0.01114823, -0.07662099, -0.01320019, -0.10590444,
       -0.06452797,  0.12556276,  0.14275475,  0.09189197,  0.0967575 ,
       -0.03720548, -0.00921574,  0.080351  , -0.02739344,  0.00197626,
        0.12395329,  0.09318297, -0.07751956,  0.20624001, -0.04554509,
        0.17160976, -0.04968931,  0.08846218,  0.03056484], dtype=float32)

[0;31mOut[1:28]: [0m
array([ 0.03594369,  0.14102964, -0.03147377, -0.1175801 ,  0.13863148,
        0.03990909, -0.08479712, -0.04956196,  0.19500998,  0.06632189,
       -0.08059914, -0.07135448,  0.00669504,  0.13377933,  0.0999418 ,
       -0.06358718,  0.01047185,  0.04175984,  0.20004192,  0.04701871,
        0.20626186,  0.05782318, -0.12334074, -0.00207194, -0.01032731,
        0.16741832,  0.20178348, -0.07604172, -0.05747492,  0.10368998,
        0.1370194 ,  0.09211288,  0.05079716, -0.05365899,  0.15360837,
        0.00913636,  0.05697636, -0.06250375,  0.06246425,  0.06321386,
        0.06073523, -0.01094039, -0.07657057, -0.01327501, -0.1058809 ,
       -0.06461395,  0.12592152,  0.14299646,  0.09217875,  0.09654881,
       -0.03697995, -0.0090028 ,  0.08081379, -0.02792548,  0.00210617,
        0.12427051,  0.09343056, -0.07765009,  0.20651025, -0.045669  ,
        0.17190306, -0.05015682,  0.08824555,  0.0307135 ], dtype=float32)

[0;31mOut[2:28]: [0m
array([ 0.03595275,  0.14103054, -0.03144496, -0.11758562,  0.13865495,
        0.03993119, -0.08479556, -0.04956302,  0.19506116,  0.06634358,
       -0.08059914, -0.07135468,  0.00670963,  0.13381214,  0.09994429,
       -0.06357279,  0.01047078,  0.04178128,  0.20009483,  0.0470452 ,
        0.20631383,  0.05783153, -0.12335784, -0.00208892, -0.01031001,
        0.16744767,  0.20183524, -0.07604139, -0.05749213,  0.10373257,
        0.13704225,  0.0921422 ,  0.05081956, -0.05365925,  0.1536581 ,
        0.0091392 ,  0.05696901, -0.06249473,  0.06246135,  0.06322864,
        0.06072824, -0.01093547, -0.07657141, -0.01327415, -0.1058815 ,
       -0.06461382,  0.12594908,  0.14303385,  0.0921834 ,  0.09659387,
       -0.03698721, -0.00899186,  0.08083027, -0.0279142 ,  0.00210904,
        0.12429023,  0.09344826, -0.0776515 ,  0.20656236, -0.04568081,
        0.1719402 , -0.05016229,  0.08825307,  0.0307228 ], dtype=float32)

[0;31mOut[3:28]: [0m
array([ 0.03599238,  0.14103489, -0.03146977, -0.11759035,  0.13866036,
        0.03998362, -0.08479521, -0.0495626 ,  0.19506593,  0.06634719,
       -0.08060362, -0.07135422,  0.00675255,  0.1338333 ,  0.09994949,
       -0.06356803,  0.01047233,  0.04174615,  0.20010321,  0.04703918,
        0.20632108,  0.05789309, -0.12335753, -0.00208309, -0.01031724,
        0.16746059,  0.20184411, -0.07604092, -0.05749106,  0.10374303,
        0.13704681,  0.09214494,  0.05081026, -0.05365882,  0.15366276,
        0.00915343,  0.05697469, -0.06251232,  0.06246971,  0.06323055,
        0.06073349, -0.01093503, -0.076571  , -0.01327649, -0.10588112,
       -0.06461336,  0.12591305,  0.14304489,  0.09218711,  0.09658511,
       -0.03699038, -0.00901619,  0.0808327 , -0.02791135,  0.00210935,
        0.12430234,  0.09345107, -0.07765109,  0.2065683 , -0.04567585,
        0.17194937, -0.05014611,  0.08831865,  0.03070764], dtype=float32)