Parle: parallelizing stochastic gradient descent
This is the code for Parle: parallelizing stochastic gradient descent. We demonstrate an algorithm for parallel training of deep neural networks which trains multiple copies of the same network in parallel, called as "replicas", with special coupling upon their weights to obtain significantly improved generalization performance over a single network as well as 2-5x faster convergence over a data-parallel implementation of SGD for a single network.
High-performance multi-GPU version coming soon.
We have two versions, both of which are written using PyTorch:
- A parallel version that uses MPI (mpi4py) for synchronizing weights.
- A more efficient version that can be executed on a single computer with multiple GPUs. The synchronization of weights is done explicitly here using inter-GPU messages.
In both cases, we construct an optimizer class that initializes the requisite buffers on different GPUs and handles all the updates after each mini-batch. As an example, we have provided code for MNIST and CIFAR-10 datasets with two prototypical networks, LeNet and All-CNN, respectively. The MNIST and CIFAR-10/100 datasets will be downloaded and pre-processed (stored in the
proc folder) the first time
parle is run.
Instructions for running the code
The MPI version works great for small experiments and prototyping while the second version is a good alternative for larger networks, e.g., wide-residual networks used in the paper.
Parle is very insensitive to hyper-parameters. A description for some of the parameters and their intuition follows.
- the learning rate
lris set to be the same as SGD, along with the same drop schedule. It is advisable to train with SGD for a few epochs and then use the same
gammacontrols how far successive gradient updates on each replica are allowed to go from the previous checkpoint, i.e., the last instant when weights were synchronized with the master. This is the same as the step-size in proximal point iteration.
rhocontrols how far each replica moves from the master. The weights of the master are the average of the weights of all the replicas while each replica gets pulled towards this average with a force that is proportional to
Lis the number of gradient updates performed on each replica (worker) before synchronizing the weights with the master. You can safely fix this to 25. Alternatively, you set this to
L = gamma x lrwhich has the advantage of being slightly faster towards the end of training.
- Proximal point iteration is insensitive to both
rhoand the above code uses a default decaying schedules for these, which should typically work. In particular, we set
gamma = rho = 100*(1-/(2 nb)^(k/L)where
nbis the number of mini-batches per epoch and
kis the current iteration number.
Lis the number of weight updates per synchronization, as above.
nis the number of replicas. The code distributes these replicas on all available GPUs. For the MPI version, this is controlled by
MPI.RANK. In general, larger the
n, the better Parle works. Each replica can itself be data-parallel using multiple GPUs.
The number of epochs
B for Parle is typically much smaller than SGD and 5-10 epochs are sufficient to train on MNIST or CIFAR-10/100.
python parle_mpi.py -hto get a list of all arguments and defaults. You can train LeNet on MNIST with 3 replicas using
python parle_mpi.py -n 3
You can train All-CNN on CIFAR-10 with 3 replicas using
python parle_mpi.py -n 3 -m allcnn
You can run the MPI version with 12 replicas as
mpirun -n 12 python parle_mpi.py
n=1, L=1, gamma=0, rho=0makes Parle equivalent to SGD; the implementation here uses Nesterov's momentum.
n=1, rho=0decouples the replicas from the master. In this case, Parle becomes equivalent to executing Entropy-SGD: biasing gradient descent into wide valleys; see the code for the latter here.
L=1, gamma=0makes Parle equivalent to Elastic-SGD; the code for the latter by the original authors is here. Parle uses an annealing schedule on
rhohowever, which makes it faster and generalize better than vanilla Elastic-SGD.