Skip to content

Commit

Permalink
Update pipeline parallel example to use RRef helpers
Browse files Browse the repository at this point in the history
  • Loading branch information
mrshenli committed Jul 1, 2020
1 parent 13acec6 commit 6cf0bdf
Show file tree
Hide file tree
Showing 3 changed files with 45 additions and 77 deletions.
8 changes: 7 additions & 1 deletion distributed/rpc/pipeline/README.md
Original file line number Diff line number Diff line change
@@ -1,7 +1,13 @@
Distributed Pipeline Parallel Example

This example shows how to distribute a ResNet50 model on two RPC workers and
then implement distributed pipeline parallelism using RPC.
then implement distributed pipeline parallelism using RPC. With pipeline
parallelism, every input batch is divided into micro-batches and thse
micro-batches are feed into the model in a pipelined fashion to increase the
amortized device utilization. Note that this example only parallelizes the
forward pass which can be viewed as the distributed counterpart of the
[single machine pipeline parallel](https://pytorch.org/tutorials/intermediate/model_parallel_tutorial.html#speed-up-by-pipelining-inputs)
example.

```
pip install -r requirements.txt
Expand Down
110 changes: 36 additions & 74 deletions distributed/rpc/pipeline/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,59 +15,15 @@
from torchvision.models.resnet import Bottleneck


#########################################################
# helper functions #
#########################################################


def _call_method(method, rref, *args, **kwargs):
r"""
a helper function to call a method on the given RRef
"""
return method(rref.local_value(), *args, **kwargs)


def _remote_on_rref(method, rref, *args, **kwargs):
r"""
a helper function to run method on the owner of rref and return an RRef
of the result.
"""
return rpc.remote(
rref.owner(),
_call_method,
args=[method, rref] + list(args),
kwargs=kwargs
)


def _async_on_rref(method, rref, *args, **kwargs):
r"""
a helper function to run method on the owner of rref and fetch back the
result using RPC
"""
return rpc.rpc_async(
rref.owner(),
_call_method,
args=[method, rref] + list(args),
kwargs=kwargs
)


def _parameter_rrefs(module):
r"""
Create one RRef for each parameter in the given local module, and return a
list of RRefs.
"""
param_rrefs = []
for param in module.parameters():
param_rrefs.append(RRef(param))
return param_rrefs


#########################################################
# Define Model Parallel ResNet50 #
#########################################################

# In order to split the ResNet50 and place it on two different workers, we
# implement it in two model shards. The ResNetBase class defines common
# attributes and methods shared by two shards. ResNetShard1 and ResNetShard2
# contain two partitions of the model layers respectively.


num_classes = 1000

Expand All @@ -76,9 +32,8 @@ def conv1x1(in_planes, out_planes, stride=1):
"""1x1 convolution"""
return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False)


class ResNetBase(nn.Module):
def __init__(self, block, inplanes, num_classes=1000,
def __init__(self, block, inplanes, num_classes=1000,
groups=1, width_per_group=64, norm_layer=None):
super(ResNetBase, self).__init__()

Expand Down Expand Up @@ -111,13 +66,20 @@ def _make_layer(self, planes, blocks, stride=1):

return nn.Sequential(*layers)

def parameter_rrefs(self):
r"""
Create one RRef for each parameter in the given local module, and return a
list of RRefs.
"""
return [RRef(p) for p in self.parameters()]


class ResNetPart1(ResNetBase):
class ResNetShard1(ResNetBase):
"""
The first part of ResNet.
"""
def __init__(self, device, *args, **kwargs):
super(ResNetPart1, self).__init__(
super(ResNetShard1, self).__init__(
Bottleneck, 64, num_classes=num_classes, *args, **kwargs)

self.device = device
Expand All @@ -144,12 +106,12 @@ def forward(self, x_rref):
return out.cpu()


class ResNetPart2(ResNetBase):
class ResNetShard2(ResNetBase):
"""
The second part of ResNet.
"""
def __init__(self, device, *args, **kwargs):
super(ResNetPart2, self).__init__(
super(ResNetShard2, self).__init__(
Bottleneck, 512, num_classes=num_classes, *args, **kwargs)

self.device = device
Expand Down Expand Up @@ -180,15 +142,15 @@ def __init__(self, split_size, workers, *args, **kwargs):
# Put the first part of the ResNet50 on workers[0]
self.p1_rref = rpc.remote(
workers[0],
ResNetPart1,
ResNetShard1,
args = ("cuda:0",) + args,
kwargs = kwargs
)

# Put the second part of the ResNet50 on workers[1]
self.p2_rref = rpc.remote(
workers[1],
ResNetPart2,
ResNetShard2,
args = ("cuda:1",) + args,
kwargs = kwargs
)
Expand All @@ -199,22 +161,19 @@ def forward(self, xs):
out_futures = []
for x in iter(xs.split(self.split_size, dim=0)):
x_rref = RRef(x)
y_rref = _remote_on_rref(ResNetPart1.forward, self.p1_rref, x_rref)
z_fut = _async_on_rref(ResNetPart2.forward, self.p2_rref, y_rref)
y_rref = self.p1_rref.remote().forward(x_rref)
z_fut = self.p2_rref.rpc_async().forward(y_rref)
out_futures.append(z_fut)

# wait for all RPC to finish
outs = [fut.wait() for fut in out_futures]
# cat all tensors into one tensor.
out = torch.cat(outs)
return out

# collect and cat all output tensors into one tensor.
return torch.cat(torch.futures.wait_all(out_futures))

def parameter_rrefs(self):
remote_params = []
remote_params.extend(_remote_on_rref(_parameter_rrefs, self.p1_rref).to_here())
remote_params.extend(_remote_on_rref(_parameter_rrefs, self.p2_rref).to_here())
remote_params.extend(self.p1_rref.remote().parameter_rrefs().to_here())
remote_params.extend(self.p2_rref.remote().parameter_rrefs().to_here())
return remote_params


#########################################################
# Run RPC Processes #
Expand Down Expand Up @@ -248,6 +207,9 @@ def run_master(split_size):
labels = torch.zeros(batch_size, num_classes) \
.scatter_(1, one_hot_indices, 1)

# The distributed autograd context is the dedicated scope for the
# distributed backward pass to store gradients, which can later be
# retrieved using the context_id by the distributed optimizer.
with dist_autograd.context() as context_id:
outputs = model(inputs)
dist_autograd.backward(context_id, [loss_fn(outputs, labels)])
Expand All @@ -261,17 +223,17 @@ def run_worker(rank, world_size, num_split):

if rank == 0:
rpc.init_rpc(
"master",
rank=rank,
world_size=world_size,
"master",
rank=rank,
world_size=world_size,
rpc_backend_options=options
)
run_master(num_split)
else:
rpc.init_rpc(
f"worker{rank}",
rank=rank,
world_size=world_size,
f"worker{rank}",
rank=rank,
world_size=world_size,
rpc_backend_options=options
)
pass
Expand Down
4 changes: 2 additions & 2 deletions distributed/rpc/pipeline/requirements.txt
Original file line number Diff line number Diff line change
@@ -1,2 +1,2 @@
torch==1.5.0
torchvision==0.6.0
torch==1.6.0
torchvision==0.7.0

0 comments on commit 6cf0bdf

Please sign in to comment.