Skip to content

Commit

Permalink
Fix/dataparallel (#618)
Browse files Browse the repository at this point in the history
* Fix some examples

* Add output to amsgrad

* Stash

* Add tests for exceptions

* Fix test pass forward

* Add initial ddp note

* Add apex closure and example

* Tests

* Update changelog
  • Loading branch information
MattPainter01 committed Aug 5, 2019
1 parent 499ea87 commit 1d30401
Show file tree
Hide file tree
Showing 9 changed files with 873 additions and 14 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Expand Up @@ -10,6 +10,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
### Removed
### Fixed
- Fixed bug where aggregate predictions couldn't handle empty list
- Fixed a bug where Runtime Errors on forward weren't handled properly

## [0.4.0] - 2019-07-05
### Added
Expand Down
124 changes: 124 additions & 0 deletions docs/_static/examples/distributed_data_parallel.py
@@ -0,0 +1,124 @@
import os
import tempfile
import torch
import torch.distributed as dist
import torch.nn as nn
import torch.optim as optim
import torch.multiprocessing as mp
import sys
from torch.nn.parallel import DistributedDataParallel as DDP
import torchbearer
import platform
from torchvision import datasets, transforms
import argparse


parser = argparse.ArgumentParser(description='Torchbearer Distributed Data Parallel MNIST')
parser.add_argument('--master-addr', '--master', '--host', '-m', dest='master', help='Address of master node')
parser.add_argument('--rank', '-r', dest='rank', help='Rank of this process')
parser.add_argument('--world-size', dest='world_size', default=2, help='World size')
args = parser.parse_args()


def setup():
os.environ['MASTER_ADDR'] = args.master
os.environ['MASTER_PORT'] = '29500'

# initialize the process group
dist.init_process_group("gloo", rank=args.rank, world_size=args.world_size)

# Explicitly setting seed makes sure that models created in two processes
# start from same random weights and biases. Alternatively, sync models
# on start with the callback below.
#torch.manual_seed(42)


def cleanup():
dist.destroy_process_group()


class ToyModel(nn.Module):
def __init__(self):
super(ToyModel, self).__init__()
self.net1 = nn.Linear(784, 100)
self.relu = nn.ReLU()
self.net2 = nn.Linear(100, 10)

def forward(self, x):
return self.net2(self.relu(self.net1(x)))


def sync_model(model):
size = float(dist.get_world_size())
for param in model.parameters():
dist.all_reduce(param.data, op=dist.ReduceOp.SUM)
param.data /= size


def average_gradients(model):
size = float(dist.get_world_size())
for param in model.parameters():
dist.all_reduce(param.grad.data, op=dist.ReduceOp.SUM)
param.grad.data /= size


@torchbearer.callbacks.on_init
def sync(state):
sync_model(state[torchbearer.MODEL])


@torchbearer.callbacks.on_backward
def grad(state):
average_gradients(state[torchbearer.MODEL])


@torchbearer.callbacks.on_sample
def flatten(state):
state[torchbearer.X] = state[torchbearer.X].view(state[torchbearer.X].shape[0], -1)


def worker():
setup()
print("Rank and node: {}-{}".format(args.rank, platform.node()))

model = ToyModel().to('cpu')
ddp_model = DDP(model)

kwargs = {}

ds = datasets.MNIST('./data/mnist/', train=True, download=True,
transform=transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.1307,), (0.3081,))
]))

train_sampler = torch.utils.data.distributed.DistributedSampler(ds)
train_loader = torch.utils.data.DataLoader(ds,
batch_size=128, sampler=train_sampler, **kwargs)

test_ds = datasets.MNIST('./data/mnist', train=False,
transform=transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.1307,), (0.3081,))
]))
test_sampler = torch.utils.data.distributed.DistributedSampler(test_ds)
test_loader = torch.utils.data.DataLoader(test_ds,
batch_size=128, sampler=test_sampler, **kwargs)

loss_fn = nn.CrossEntropyLoss()
optimizer = optim.SGD(ddp_model.parameters(), lr=0.001)

trial = torchbearer.Trial(ddp_model, optimizer, loss_fn, metrics=['loss', 'acc'],
callbacks=[sync, grad, flatten])
trial.with_train_generator(train_loader)
trial.run(10, verbose=2)

print("Model hash: {}".format(hash(model)))
print('First parameter: {}'.format(next(model.parameters())))

cleanup()


if __name__ == "__main__":
worker()
print('done')
323 changes: 323 additions & 0 deletions docs/_static/notebooks/apex_torchbearer.ipynb

Large diffs are not rendered by default.

7 changes: 7 additions & 0 deletions docs/examples/notebooks.rst
Expand Up @@ -68,6 +68,13 @@ General

|nbviewer| `Preview <https://nbviewer.jupyter.org/github/pytorchbearer/torchbearer/blob/master/docs/_static/notebooks/pycm.ipynb>`__   :download:`Download Notebook </_static/notebooks/pycm.ipynb>`   |colab| `Run on Colab <https://colab.research.google.com/github/pytorchbearer/torchbearer/blob/master/docs/_static/notebooks/pycm.ipynb>`__


- **Nvidia Apex with Torchbearer**:

This guide shows how we can do half and mixed precision training in torchbearer.

|nbviewer| `Preview <https://nbviewer.jupyter.org/github/pytorchbearer/torchbearer/blob/master/docs/_static/notebooks/apex_torchbearer.ipynb>`__   :download:`Download Notebook </_static/notebooks/apex_torchbearer.ipynb>`   |colab| `Run on Colab <https://colab.research.google.com/github/pytorchbearer/torchbearer/blob/master/docs/_static/notebooks/apex_torchbearer.ipynb>`__

Deep Learning
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^

Expand Down
79 changes: 79 additions & 0 deletions docs/notes/distributed.rst
@@ -0,0 +1,79 @@
Using DistributedDataParallel with Torchbearer on CPU
=====================================================

This note will quickly cover how we can use torchbearer to train over multiple nodes.
We shall do this by training a simple model to classify and for a massive amount of overkill we will be doing this on MNIST.
Most of the code for this example is based off the
`Distributed Data Parallel (DDP) tutorial <https://pytorch.org/tutorials/intermediate/ddp_tutorial.html>`__ and the
`imagenet example <https://github.com/pytorch/examples/blob/master/imagenet/main.py>`__
from the PyTorch docs.
We recommend you read at least the DDP tutorial before continuing with this note.

Setup, Cleanup and Model
------------------------------------
We keep similar setup, cleanup and model from the DDP tutorial. All that is changed is taking rank, world size and master
address from terminal arguments and changing the model to apply to MNIST.
Note that we are keeping to the GLOO backend since this part of the note will be purely on the CPU.

.. literalinclude:: /_static/examples/distributed_data_parallel.py
:lines: 23-48



Sync Methods
------------------------------------
Since we are working across multiple machines we need a way to synchronise the model itself and its gradients. To do this
we utilise methods similar to that of the `distributed applications tutorial <https://pytorch.org/tutorials/intermediate/dist_tuto.html>`__
from PyTorch.

.. literalinclude:: /_static/examples/distributed_data_parallel.py
:lines: 51-62

Since we require the gradients to be synced every step we implement both of these methods as Torchbearer callbacks.
We sync the model itself on init and sync the gradients every step after the backward call.

.. literalinclude:: /_static/examples/distributed_data_parallel.py
:lines: 65-72


Worker Function
------------------------------------
Now we need to define the main worker function that each process will be running. We need this to setup the environment,
actually run the training process and cleanup the environment after we finish.
This function outside of calling `setup` and `cleanup` is exactly the same as any Torchbearer training function.

.. literalinclude:: /_static/examples/distributed_data_parallel.py
:lines: 80-119

You might have noticed that we had an extra flatten callback in the Trial, the only purpose of this was to flatten each image.

.. literalinclude:: /_static/examples/distributed_data_parallel.py
:lines: 75-77


Running
------------------------------------
All we need to do now is write a `__main__` function to run the worker function.

.. literalinclude:: /_static/examples/distributed_data_parallel.py
:lines: 122-124

We can then ssh into each node on which we want to run the training and run the following code replacing i with the rank of each process.

.. highlight:: bash

.. code:: bash
python distributed_data_parallel.py --world-size 2 --rank i --host (host address)
Running on machines with GPUs
------------------------------------
Coming soon.


Source Code
------------------------------------

The source code for this example is given below:

:download:`Download Python source code: distributed_data_parallel.py </_static/examples/distributed_data_parallel.py>`

0 comments on commit 1d30401

Please sign in to comment.