Skip to content

Commit

Permalink
Simplify send buffer logic
Browse files Browse the repository at this point in the history
  • Loading branch information
parantapa committed Oct 25, 2019
1 parent 2a7b65f commit bf88475
Show file tree
Hide file tree
Showing 3 changed files with 78 additions and 69 deletions.
3 changes: 2 additions & 1 deletion test_xactor/test_simple.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,10 +20,11 @@ def main(self):
for node in xa.get_nodes():
for rank in xa.get_node_ranks(node):
greeter_id = "greeter-%d" % rank
msg = xa.Message(greeter_id, "greet", "World")
msg = xa.Message(greeter_id, "greet", "world")
xa.send(rank, msg, flush=False)
xa.flush()


xa.stop()

def test_greeter():
Expand Down
134 changes: 71 additions & 63 deletions xactor/mpi_acomm.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,108 +7,115 @@
from unittest.mock import sentinel

from mpi4py import MPI
from more_itertools import chunked

COMM_WORLD = MPI.COMM_WORLD
WORLD_RANK = COMM_WORLD.Get_rank()
WORLD_SIZE = COMM_WORLD.Get_size()

RECV_BUFFER_SIZE = 67108864 # 64MB
BUFFER_SIZE = 4194304 # 4MB
HEADER_SIZE = 4096

SEND_BUFFER_SIZE = 100
PENDING_SEND_CLEANUP_AFTER = WORLD_SIZE
MAX_MESSAGE_SIZE = BUFFER_SIZE - HEADER_SIZE

log = logging.getLogger("%s.%d" % (__name__, WORLD_RANK))
LOG = logging.getLogger("%s.%d" % (__name__, WORLD_RANK))

# Message used to stop AsyncReceiver
StopAsyncReceiver = sentinel.StopAsyncReceiver


def pickle_dumps_adaptive(objs):
"""Make sure that the dumped pickles are within BUFFER_SIZE."""
n = len(objs)
assert n >= 1
class MessageBuffer:
"""Buffer for pickling objects."""

while True:
pkls = [pickle.dumps(group) for group in chunked(objs, n)]
max_size = max(map(len, pkls))
if max_size < RECV_BUFFER_SIZE:
return pkls
def __init__(self):
self.buffer = []
self.total_size = 0

def append(self, msg):
"""Add an object to the buffer."""
pkl = pickle.dumps(msg, pickle.HIGHEST_PROTOCOL)
pkl_len = len(pkl)
if pkl_len > MAX_MESSAGE_SIZE:
raise ValueError("Message too large %d > %d" % (pkl_len, MAX_MESSAGE_SIZE))

if self.total_size + pkl_len < MAX_MESSAGE_SIZE:
self.buffer.append(pkl)
self.total_size += pkl_len
return None

imsgs = self.buffer
self.buffer = [pkl]
self.total_size = pkl_len
return imsgs

log.warning("%d objects don't fit into %d bytes", n, RECV_BUFFER_SIZE)
if n == 1:
raise RuntimeError("Cant fit data into buffer")
def flush(self):
"""Flush out the current buffer."""
if not self.buffer:
return None

n = int(n / 2)
imsgs = self.buffer
self.buffer = []
self.total_size = 0
return imsgs


class AsyncSender:
"""Manager for sending messages."""

def __init__(self):
"""Initialize."""
self._buffer = {rank: [] for rank in range(WORLD_SIZE)}
self._pending_sends = []
self.buffer = {rank: MessageBuffer() for rank in range(WORLD_SIZE)}
self.pending_sends = []

def send(self, to, msg):
"""Send a messge."""
self._buffer[to].append(msg)
imsgs = self.buffer[to].append(msg)
if imsgs is not None:
self.do_send(to, imsgs)
self.cleanup_finished_sends()

if len(self._buffer[to]) < SEND_BUFFER_SIZE:
return
def do_send(self, to, msgs):
"""Send all messages that have been cached."""
pkl = pickle.dumps(msgs, pickle.HIGHEST_PROTOCOL)
assert len(pkl) <= BUFFER_SIZE
if __debug__:
LOG.debug("Sending %d messages to %d", len(msgs), to)

self._do_send(to)
self._cleanup_finished_sends()
req = COMM_WORLD.Isend([pkl, MPI.CHAR], dest=to)
self.pending_sends.append(req)

def flush(self, wait=True):
"""Flush out message buffers."""
for rank in range(WORLD_SIZE):
self._do_send(rank)
msgs = self.buffer[rank].flush()
if msgs is not None:
self.do_send(rank, msgs)

if wait:
self._wait_pending_sends()
self.wait_pending_sends()
else:
self._cleanup_finished_sends()

def _do_send(self, to):
"""Send all messages that have been cached."""
if not self._buffer[to]:
return

if __debug__:
log.debug("Sending %d messages to %d", len(self._buffer[to]), to)
self.cleanup_finished_sends()

pkls = pickle_dumps_adaptive(self._buffer[to])
for pkl in pkls:
req = COMM_WORLD.Isend([pkl, MPI.CHAR], dest=to)
self._pending_sends.append(req)

self._buffer[to].clear()

def _cleanup_finished_sends(self):
def cleanup_finished_sends(self):
"""Cleanup send requests that have already completed."""
if not self._pending_sends:
return

if len(self._pending_sends) < PENDING_SEND_CLEANUP_AFTER:
if not self.pending_sends:
return

indices = MPI.Request.Waitsome(self._pending_sends)
indices = MPI.Request.Waitsome(self.pending_sends)
if indices is None:
return

indices = set(indices)
self._pending_sends = [
r for i, r in enumerate(self._pending_sends) if i not in indices
self.pending_sends = [
r for i, r in enumerate(self.pending_sends) if i not in indices
]

def _wait_pending_sends(self):
def wait_pending_sends(self):
"""Wait for all pending send requests to finish."""
if not self._pending_sends:
if not self.pending_sends:
return

MPI.Request.Waitall(self._pending_sends)
self._pending_sends.clear()
MPI.Request.Waitall(self.pending_sends)
self.pending_sends.clear()


class AsyncReceiver:
Expand All @@ -117,8 +124,8 @@ class AsyncReceiver:
def __init__(self):
"""Initialize."""
self.msgq = queue.Queue()
self._buf = bytearray(RECV_BUFFER_SIZE)

self._buf = bytearray(BUFFER_SIZE)
self._receiver_thread = threading.Thread(target=self._keep_receiving)
self._receiver_thread.start()

Expand All @@ -134,16 +141,17 @@ def _keep_receiving(self):
"""Code for the receiver thread."""
stop_receiver = False
while not stop_receiver:
COMM_WORLD.Irecv([self._buf, MPI.CHAR]).wait()
messages = pickle.loads(self._buf)
COMM_WORLD.Irecv([self._buf, MPI.CHAR]).Wait()
msgs = pickle.loads(self._buf)
if __debug__:
log.debug("Received %d messages", len(messages))
for message in messages:
if message is StopAsyncReceiver:
LOG.debug("Received %d messages", len(msgs))
for msg in msgs:
msg = pickle.loads(msg)
if msg is StopAsyncReceiver:
stop_receiver = True
continue

self.msgq.put(message)
self.msgq.put(msg)


class AsyncCommunicator:
Expand Down Expand Up @@ -174,6 +182,6 @@ def finish(self):

qsize = self.receiver.msgq.qsize()
if qsize:
log.warning(
LOG.warning(
"Communicator finished with %d messages still in receiver queue", qsize
)
10 changes: 5 additions & 5 deletions xactor/mpi_actor.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@

RANK_AID = "rank"

log = logging.getLogger("%s.%d" % (__name__, WORLD_RANK))
LOG = logging.getLogger("%s.%d" % (__name__, WORLD_RANK))


@dataclass(init=False)
Expand Down Expand Up @@ -111,7 +111,7 @@ def __init__(self):

def _loop(self):
"""Loop through messages."""
log.info("Starting rank loop with %d actors", len(self.local_actors))
LOG.info("Starting rank loop with %d actors", len(self.local_actors))

while not self.stopping:
message = self.acomm.recv()
Expand All @@ -122,22 +122,22 @@ def _loop(self):
try:
method = getattr(actor, message.method)
except AttributeError:
log.exception(
LOG.exception(
"Target actor doesn't have requested method: %r, %r", actor, message
)
raise

try:
method(*message.args, **message.kwargs)
except Exception: # pylint: disable=broad-except
log.exception(
LOG.exception(
"Exception occured while processing message: %r, %r", actor, message
)
raise

def _stop(self):
"""Stop the event loop after processing the current message."""
log.info("Received stop message")
LOG.info("Received stop message.")

self.acomm.finish()
self.stopping = True
Expand Down

0 comments on commit bf88475

Please sign in to comment.