Skip to content

Commit

Permalink
Merge pull request #507 from notEvil/thread_bind
Browse files Browse the repository at this point in the history
Optional thread binding
  • Loading branch information
comrumino committed Oct 24, 2022
2 parents fc3ec13 + bb40f03 commit 0a3c654
Showing 1 changed file with 222 additions and 7 deletions.
229 changes: 222 additions & 7 deletions rpyc/core/protocol.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,12 @@
import time # noqa: F401
import gc # noqa: F401

import collections
import concurrent.futures as c_futures
import os
import struct
import threading

from threading import Lock, Condition, RLock
from rpyc.lib import spawn, Timeout, get_methods, get_id_pack, hasattr_static
from rpyc.lib.compat import pickle, next, maxint, select_error, acquire_lock # noqa: F401
Expand Down Expand Up @@ -62,6 +68,7 @@ class PingError(Exception):
sync_request_timeout=30,
before_closed=None,
close_catchall=False,
bind_threads=os.environ.get('RPYC_BIND_THREADS', '0') != '0',
)
"""
The default configuration dictionary of the protocol. You can override these parameters
Expand Down Expand Up @@ -157,6 +164,13 @@ def __init__(self, root, channel, config={}):
self._send_queue = []
self._local_root = root
self._closed = False
self._bind_threads = self._config['bind_threads']
if self._bind_threads:
self._lock = threading.Lock()
self._threads = {}
self._receiving = False
self._thread_pool = []
self._thread_pool_executor = c_futures.ThreadPoolExecutor()

def __del__(self):
self.close()
Expand Down Expand Up @@ -187,6 +201,8 @@ def _cleanup(self, _anyway=True): # IO
# self._seqcounter = None
# self._config.clear()
del self._HANDLERS
if self._bind_threads:
self._thread_pool_executor.shutdown(wait=False) # TODO where?

def close(self): # IO
"""closes the connection, releasing all held resources"""
Expand Down Expand Up @@ -235,6 +251,15 @@ def _get_seq_id(self): # IO

def _send(self, msg, seq, args): # IO
data = brine.dump((msg, seq, args))
if self._bind_threads:
this_thread = self._get_thread()
data = struct.pack('<QQ', this_thread.id, this_thread._remote_thread_id) + data
if msg == consts.MSG_REQUEST:
this_thread._occupation_count += 1
else:
this_thread._occupation_count -= 1
if this_thread._occupation_count == 0:
this_thread._remote_thread_id = 0
# GC might run while sending data
# if so, a BaseNetref.__del__ might be called
# BaseNetref.__del__ must call asyncreq,
Expand Down Expand Up @@ -359,15 +384,23 @@ def _seq_request_callback(self, msg, seq, is_exc, obj):
def _dispatch(self, data): # serving---dispatch?
msg, seq, args = brine.load(data)
if msg == consts.MSG_REQUEST:
if self._bind_threads:
self._get_thread()._occupation_count += 1
self._dispatch_request(seq, args)
elif msg == consts.MSG_REPLY:
obj = self._unbox(args)
self._seq_request_callback(msg, seq, False, obj)
elif msg == consts.MSG_EXCEPTION:
obj = self._unbox_exc(args)
self._seq_request_callback(msg, seq, True, obj)
else:
raise ValueError(f"invalid message type: {msg!r}")
if self._bind_threads:
this_thread = self._get_thread()
this_thread._occupation_count -= 1
if this_thread._occupation_count == 0:
this_thread._remote_thread_id = 0
if msg == consts.MSG_REPLY:
obj = self._unbox(args)
self._seq_request_callback(msg, seq, False, obj)
elif msg == consts.MSG_EXCEPTION:
obj = self._unbox_exc(args)
self._seq_request_callback(msg, seq, True, obj)
else:
raise ValueError(f"invalid message type: {msg!r}")

def serve(self, timeout=1, wait_for_lock=True): # serving
"""Serves a single request or reply that arrives within the given
Expand All @@ -379,6 +412,8 @@ def serve(self, timeout=1, wait_for_lock=True): # serving
otherwise.
"""
timeout = Timeout(timeout)
if self._bind_threads:
return self._serve_bound(timeout, wait_for_lock)
with self._recv_event:
# Exit early if we cannot acquire the recvlock
if not self._recvlock.acquire(False):
Expand Down Expand Up @@ -410,6 +445,174 @@ def serve(self, timeout=1, wait_for_lock=True): # serving
self._recvlock.release()
return False

def _serve_bound(self, timeout, wait_for_lock):
this_thread = self._get_thread()
wait = False

with self._lock:
message_available = this_thread._event.is_set() and len(this_thread._deque) != 0

if message_available:
remote_thread_id, message = this_thread._deque.popleft()
if len(this_thread._deque) == 0:
this_thread._event.clear()

else:
if self._receiving: # enter pool
self._thread_pool.append(this_thread)
wait = True

else:
self._receiving = True

if message_available: # just process
this_thread._remote_thread_id = remote_thread_id
self._dispatch(message)
return True

if wait:
while True:
if wait_for_lock:
this_thread._event.wait(timeout.timeleft())

with self._lock:
if this_thread._event.is_set():
message_available = len(this_thread._deque) != 0

if message_available:
remote_thread_id, message = this_thread._deque.popleft()
if len(this_thread._deque) == 0:
this_thread._event.clear()

else:
this_thread._event.clear()

if self._receiving: # another thread was faster
continue

self._receiving = True

self._thread_pool.remove(this_thread) # leave pool
break

else: # timeout
return False

if message_available:
this_thread._remote_thread_id = remote_thread_id
self._dispatch(message)
return True

while True:
# from upstream
try:
message = self._channel.poll(timeout) and self._channel.recv()

except Exception as exception:
if isinstance(exception, EOFError):
self.close() # sends close async request

with self._lock:
self._receiving = False

for thread in self._thread_pool:
thread._event.set()
break

raise

if not message: # timeout; from upstream
with self._lock:
for thread in self._thread_pool:
if not thread._event.is_set():
self._receiving = False
thread._event.set()
break

else: # stop receiving
self._receiving = False

return False

remote_thread_id, local_thread_id = struct.unpack('<QQ', message[:16])
message = message[16:]

this = False

if local_thread_id == 0: # root request
if this_thread._occupation_count == 0: # this
this = True

else: # other
new = False

with self._lock:
for thread in self._thread_pool:
if thread._occupation_count == 0 and not thread._event.is_set():
thread._deque.append((remote_thread_id, message))
thread._event.set()
break

else:
new = True

if new:
self._thread_pool_executor.submit(self._serve_temporary, remote_thread_id, message)

elif local_thread_id == this_thread.id:
this = True

else: # sub request
thread = self._get_thread(id=local_thread_id)
with self._lock:
thread._deque.append((remote_thread_id, message))
thread._event.set()

if this:
with self._lock:
for thread in self._thread_pool:
if not thread._event.is_set():
self._receiving = False
thread._event.set()
break

else: # stop receiving
self._receiving = False

this_thread._remote_thread_id = remote_thread_id
self._dispatch(message)
return True

def _serve_temporary(self, remote_thread_id, message):
thread = self._get_thread()
thread._deque.append((remote_thread_id, message))
thread._event.set()

# from upstream
try:
while not self.closed:
self.serve(None)

if thread._occupation_count == 0:
break

except (socket.error, select_error, IOError):
if not self.closed:
raise
except EOFError:
pass

def _get_thread(self, id=None):
if id is None:
id = threading.get_ident()

thread = self._threads.get(id)
if thread is None:
thread = _Thread(id)
self._threads[id] = thread

return thread

def poll(self, timeout=0): # serving
"""Serves a single transaction, should one arrives in the given
interval. Note that handling a request/reply may trigger nested
Expand Down Expand Up @@ -686,3 +889,15 @@ def _handle_oldslicing(self, obj, attempt, fallback, start, stop, args): # requ
stop = maxint
getslice = self._handle_getattr(obj, fallback)
return getslice(start, stop, *args)


class _Thread:
def __init__(self, id):
super().__init__()

self.id = id

self._remote_thread_id = 0
self._occupation_count = 0
self._event = threading.Event()
self._deque = collections.deque()

0 comments on commit 0a3c654

Please sign in to comment.