diff --git a/rpyc/core/async_.py b/rpyc/core/async_.py index 0af147db..816626a6 100644 --- a/rpyc/core/async_.py +++ b/rpyc/core/async_.py @@ -44,16 +44,19 @@ def wait(self): """Waits for the result to arrive. If the AsyncResult object has an expiry set, and the result did not arrive within that timeout, an :class:`AsyncResultTimeout` exception is raised""" - while not (self._is_ready or self.expired): + while self._waiting(): # Serve the connection since we are not ready. Suppose # the reply for our seq is served. The callback is this class # so __call__ sets our obj and _is_ready to true. - self._conn.serve(self._ttl) + self._conn.serve(self._ttl, waiting=self._waiting) # Check if we timed out before result was ready if not self._is_ready: raise AsyncResultTimeout("result expired") + def _waiting(self): + return not (self._is_ready or self.expired) + def add_callback(self, func): """Adds a callback to be invoked when the result arrives. The callback function takes a single argument, which is the current AsyncResult diff --git a/rpyc/core/protocol.py b/rpyc/core/protocol.py index 69643c72..8b3ae94b 100644 --- a/rpyc/core/protocol.py +++ b/rpyc/core/protocol.py @@ -260,7 +260,7 @@ def _get_seq_id(self): # IO return next(self._seqcounter) def _send(self, msg, seq, args): # IO - data = brine.dump((msg, seq, args)) + data = brine.I1.pack(msg) + brine.dump((seq, args)) # see _dispatch if self._bind_threads: this_thread = self._get_thread() data = brine.I8I8.pack(this_thread.id, this_thread._remote_thread_id) + data @@ -392,10 +392,13 @@ def _seq_request_callback(self, msg, seq, is_exc, obj): self._config["logger"].debug(debug_msg.format(msg, seq)) def _dispatch(self, data): # serving---dispatch? - msg, seq, args = brine.load(data) + msg, = brine.I1.unpack(data[:1]) # unpack just msg to minimize time to release if msg == consts.MSG_REQUEST: if self._bind_threads: self._get_thread()._occupation_count += 1 + else: + self._recvlock.release() + seq, args = brine.load(data[1:]) self._dispatch_request(seq, args) else: if self._bind_threads: @@ -404,15 +407,21 @@ def _dispatch(self, data): # serving---dispatch? if this_thread._occupation_count == 0: this_thread._remote_thread_id = UNBOUND_THREAD_ID if msg == consts.MSG_REPLY: + seq, args = brine.load(data[1:]) obj = self._unbox(args) self._seq_request_callback(msg, seq, False, obj) + if not self._bind_threads: + self._recvlock.release() # releasing here fixes race condition with AsyncResult.wait elif msg == consts.MSG_EXCEPTION: + if not self._bind_threads: + self._recvlock.release() + seq, args = brine.load(data[1:]) obj = self._unbox_exc(args) self._seq_request_callback(msg, seq, True, obj) else: raise ValueError(f"invalid message type: {msg!r}") - def serve(self, timeout=1, wait_for_lock=True): # serving + def serve(self, timeout=1, wait_for_lock=True, waiting=lambda: True): # serving """Serves a single request or reply that arrives within the given time frame (default is 1 sec). Note that the dispatching of a request might trigger multiple (nested) requests, thus this function may be @@ -427,10 +436,17 @@ def serve(self, timeout=1, wait_for_lock=True): # serving # Exit early if we cannot acquire the recvlock if not self._recvlock.acquire(False): if wait_for_lock: + if not waiting(): # unlikely, but the result could've arrived and another thread could've won the race to acquire + return False # Wait condition for recvlock release; recvlock is not underlying lock for condition return self._recv_event.wait(timeout.timeleft()) else: return False + if not waiting(): # the result arrived and we won the race to acquire, unlucky + self._recvlock.release() + with self._recv_event: + self._recv_event.notify_all() + return False # Assume the receive rlock is acquired and incremented # We must release once BEFORE dispatch, dispatch any data, and THEN notify all (see issue #527 and #449) try: @@ -442,11 +458,11 @@ def serve(self, timeout=1, wait_for_lock=True): # serving self.close() # sends close async request raise else: - self._recvlock.release() if data: self._dispatch(data) # Dispatch will unbox, invoke callbacks, etc. return True else: + self._recvlock.release() return False finally: with self._recv_event: diff --git a/tests/test_race.py b/tests/test_race.py new file mode 100644 index 00000000..f71d5c6b --- /dev/null +++ b/tests/test_race.py @@ -0,0 +1,70 @@ +import rpyc +import rpyc.core.async_ as rc_async_ +import rpyc.core.protocol as rc_protocol +import contextlib +import signal +import threading +import time +import unittest + + +class TestRace(unittest.TestCase): + def setUp(self): + self.connection = rpyc.classic.connect_thread() + + self.a_str = rpyc.async_(self.connection.builtin.str) + + def tearDown(self): + self.connection.close() + + def test_asyncresult_race(self): + with _patch(): + def hook(): + time.sleep(0.2) # loose race + + _AsyncResult._HOOK = hook + + threading.Thread(target=self.connection.serve_all).start() + time.sleep(0.1) # wait for thread to serve + + # schedule KeyboardInterrupt + thread_id = threading.get_ident() + _ = lambda: signal.pthread_kill(thread_id, signal.SIGINT) + timer = threading.Timer(1, _) + timer.start() + + a_result = self.a_str("") # request + time.sleep(0.1) # wait for race to start + try: + a_result.wait() + except KeyboardInterrupt: + raise Exception("deadlock") + + timer.cancel() + + +class _AsyncResult(rc_async_.AsyncResult): + _HOOK = None + + def __call__(self, *args, **kwargs): + hook = type(self)._HOOK + if hook is not None: + hook() + return super().__call__(*args, **kwargs) + + +@contextlib.contextmanager +def _patch(): + AsyncResult = rc_async_.AsyncResult + try: + rc_async_.AsyncResult = _AsyncResult + rc_protocol.AsyncResult = _AsyncResult # from import + yield + + finally: + rc_async_.AsyncResult = AsyncResult + rc_protocol.AsyncResult = AsyncResult + + +if __name__ == "__main__": + unittest.main()