Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 15 additions & 0 deletions distributed/rpc/batch/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
Batch RPC Server Example

This example shows how to create a batch RPC server using `torch.distributed.rpc`
package, where multiple clients uses RPC to run functions on the server and the
server processes multiple RPC requests together in a batch.

To try an example with three workers, try running the following three commands
to create three processes.


```
python server.py --name="s" --rank=0 --world_size=3
python client.py --name="c1" --rank=1 --world_size=3 --server_name="s"
python client.py --name="c2" --rank=2 --world_size=3 --server_name="s"
```
74 changes: 74 additions & 0 deletions distributed/rpc/batch/batch.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,74 @@
import argparse
import threading
import concurrent.futures as futures

import torch.distributed.rpc as rpc
from torch.distributed.rpc import rpc_sync, rpc_async, remote

_server = None

def _run_on_server(name, *args, **kwargs):
return _server.call(name, *args, **kwargs)


def parse_args():
parser = argparse.ArgumentParser(
description="Batch RPC example",
formatter_class=argparse.ArgumentDefaultsHelpFormatter,
)

parser.add_argument('--name')
parser.add_argument('--rank', type=int)
parser.add_argument('--world_size', type=int)
parser.add_argument('--server_name')
return parser.parse_args()


class BatchServer:
def __init__(self, batch_size=2):
self._batch_size = batch_size
self._fns = {}
self._inps = {}
self._outs = {}
self.lock = threading.Lock()
global _server
_server = self

def bind(self, fn):
name = fn.__name__
self._fns[name] = fn
self._inps[name] = []
self._outs[name] = futures.Future()

def call(self, fn_name, *args, **kwargs):
with self.lock:
inps = self._inps[fn_name]
fut = self._outs[fn_name]
idx = len(inps)
inps.append((args, kwargs))
if idx + 1 >= self._batch_size:
self._inps[fn_name] = []
self._outs[fn_name] = futures.Future()

if idx + 1 >= self._batch_size:
rets = []
for arg, kwargs in inps:
rets.append(self._fns[fn_name](*arg, **kwargs))
fut.set_result(rets)

return fut.result()[idx]


class BatchClient:
def __init__(self, server_name):
self._server_info = rpc.get_worker_info(worker_name=server_name)

def __getattr__(self, name):
def fn(*args, **kwargs):
return rpc_sync(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

would this be clearer if we change this to rpc_async? then a single client can eventually trigger a batch, and we could increase the batch size. I suspect this example was created before async was available?

self._server_info,
_run_on_server,
args=(name, *args),
kwargs=kwargs
)
return fn
16 changes: 16 additions & 0 deletions distributed/rpc/batch/client.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
import os

import torch.distributed.rpc as rpc
from batch import BatchClient, parse_args


if __name__ == "__main__":
os.environ['MASTER_ADDR'] = 'localhost'
os.environ['MASTER_PORT'] = '29500'
args = parse_args()
rpc.init_rpc(args.name, rank=args.rank, world_size=args.world_size)
client = BatchClient(args.server_name)
y = args.rank * 100
for x in range(5):
print("Client {} got result {}".format(args.name, client.foo(x, y)))
rpc.shutdown()
19 changes: 19 additions & 0 deletions distributed/rpc/batch/server.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
import os

import torch.distributed.rpc as rpc
from batch import BatchServer, parse_args



def foo(x, y):
return x + y


if __name__ == "__main__":
os.environ['MASTER_ADDR'] = 'localhost'
os.environ['MASTER_PORT'] = '29500'
args = parse_args()
rpc.init_rpc(args.name, rank=args.rank, world_size=args.world_size)
server = BatchServer()
server.bind(foo)
rpc.shutdown()