Skip to content

Commit

Permalink
HostConnection: implement replacing overloaded connections
Browse files Browse the repository at this point in the history
In a situation of very high overload or poor networking conditions, it
might happen that there is a large number of outstanding requests on a
single connection. Each request reserves a stream ID which cannot be
reused until a response for it arrives, even if the request already
timed out on the client side. Because the pool of available stream IDs
for a single connection is limited, such situation might cause the set
of free stream IDs to shrink a very small size (including zero), which
will drastically reduce the available concurrency on the connection, or
even render it unusable for some time.

In order to prevent this, the following strategy is adopted: when the
number of orphaned stream IDs reaches a certain threshold (e.g. 75% of
all available stream IDs), the connection becomes marked as overloaded.
Meanwhile, a new connection is opened - when it becomes available, it
replaces the old one, and the old connection is moved to "trash" where
it waits until all its outstanding requests either respond or time out.

Because there is no guarantee that the new connection will have the same
shard assigned as the old connection, this strategy uses the excess
connection pool to increase the chances of getting the right shard after
several attempts.

This fix is heavily inspired by the fix for JAVA-1519.
  • Loading branch information
piodul committed Oct 11, 2021
1 parent e2b1feb commit 3210a9e
Show file tree
Hide file tree
Showing 4 changed files with 83 additions and 8 deletions.
2 changes: 2 additions & 0 deletions cassandra/cluster.py
Original file line number Diff line number Diff line change
Expand Up @@ -4360,6 +4360,8 @@ def _on_timeout(self, _attempts=0):
# query could get a response from the old query
with self._connection.lock:
self._connection.orphaned_request_ids.add(self._req_id)
if len(self._connection.orphaned_request_ids) >= self._connection.orphaned_threshold:
self._connection.orphaned_threshold_reached = True

pool.return_connection(self._connection, stream_was_orphaned=True)

Expand Down
9 changes: 9 additions & 0 deletions cassandra/connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -648,6 +648,15 @@ class Connection(object):
# on this stream
orphaned_request_ids = None

# Set to true if the orphaned stream ID count cross configured threshold
# and the connection will be replaced
orphaned_threshold_reached = False

# If the number of orphaned streams reaches this threshold, this connection
# will become marked and will be replaced with a new connection by the
# owning pool (currently, only HostConnection supports this)
orphaned_threshold = 3 * max_in_flight // 4

is_defunct = False
is_closed = False
lock = None
Expand Down
77 changes: 70 additions & 7 deletions cassandra/pool.py
Original file line number Diff line number Diff line change
Expand Up @@ -407,6 +407,10 @@ def __init__(self, host, host_distance, session):
# After we get at least one connection for each shard, we can close
# the additional connections.
self._excess_connections = set()
# Contains connections which shouldn't be used anymore
# and are waiting until all requests time out or complete
# so that we can dispose of them.
self._trash = set()

if host_distance == HostDistance.IGNORED:
log.debug("Not opening connection to ignored host %s", self.host)
Expand All @@ -425,7 +429,7 @@ def __init__(self, host, host_distance, session):
first_connection.set_keyspace_blocking(self._keyspace)

if first_connection.sharding_info:
self.host.sharding_info = weakref.proxy(first_connection.sharding_info)
self.host.sharding_info = first_connection.sharding_info
self._open_connections_for_all_shards()

log.debug("Finished initializing connection for host %s", self.host)
Expand Down Expand Up @@ -455,6 +459,19 @@ def borrow_connection(self, timeout, routing_key=None):
self.host,
routing_key
)
if conn.orphaned_threshold_reached and shard_id not in self._connecting:
# The connection has met its orphaned stream ID limit
# and needs to be replaced. Start opening a connection
# to the same shard and replace when it is opened.
self._connecting.add(shard_id)
self._session.submit(self._open_connection_to_missing_shard, shard_id)
log.debug(
"Connection to shard_id=%i reached orphaned stream limit, replacing on host %s (%s/%i)",
shard_id,
self.host,
len(self._connections.keys()),
self.host.sharding_info.shards_count
)
elif shard_id not in self._connecting:
# rate controlled optimistic attempt to connect to a missing shard
self._connecting.add(shard_id)
Expand Down Expand Up @@ -521,6 +538,16 @@ def return_connection(self, connection, stream_was_orphaned=False):
return
self._is_replacing = True
self._session.submit(self._replace, connection)
else:
if connection in self._trash:
with connection.lock:
if connection.in_flight == len(connection.orphaned_request_ids):
with self._lock:
if connection in self._trash:
self._trash.remove(connection)
log.debug("Closing trashed connection (%s) to %s", id(connection), self.host)
connection.close()
return

def on_orphaned_stream_released(self):
"""
Expand Down Expand Up @@ -572,6 +599,9 @@ def shutdown(self):

self._close_excess_connections()

for conn in self._trash:
conn.close()

def _close_excess_connections(self):
if self._excess_connections:
for c in self._excess_connections:
Expand Down Expand Up @@ -606,31 +636,59 @@ def _open_connection_to_missing_shard(self, shard_id):
if self.is_shutdown:
log.debug("Pool for host %s is in shutdown, closing the new connection", self.host)
conn.close()
elif conn.shard_id not in self._connections.keys():
elif conn.shard_id not in self._connections.keys() or self._connections[conn.shard_id].orphaned_threshold_reached:
log.debug(
"New connection created to shard_id=%i on host %s",
conn.shard_id,
self.host
)
old_conn = None
if conn.shard_id in self._connections.keys():
# Move the current connection to the trash and use the new one from now on
log.debug(
"Replacing overloaded connection for shard %i for host %s",
conn.shard_id,
self.host
)
old_conn = self._connections[conn.shard_id]
self._connections[conn.shard_id] = conn
if old_conn is not None:
with old_conn.lock:
remaining = len(old_conn.orphaned_request_ids) - old_conn.in_flight
if remaining == 0:
log.debug(
"Immediately closing the old connection for shard %i on host %s",
conn.shard_id,
self.host
)
old_conn.close()
else:
log.debug(
"Moving the connection for shard %i to trash on host %s, %i requests remaining",
conn.shard_id,
self.host,
remaining,
)
with self._lock:
self._trash.add(old_conn)
if self._keyspace:
self._connections[conn.shard_id].set_keyspace_blocking(self._keyspace)
num_missing = self.host.sharding_info.shards_count - len(self._connections.keys())
needing = self.num_missing_or_needing_replacement
log.debug(
"Connected to %s/%i shards on host %s (%i missing)",
"Connected to %s/%i shards on host %s (%i missing or needs replacement)",
len(self._connections.keys()),
self.host.sharding_info.shards_count,
self.host,
num_missing
needing
)
if num_missing == 0:
if needing == 0:
log.debug(
"All shards of host %s have at least one connection, closing %i excess connections",
self.host,
len(self._excess_connections)
)
self._close_excess_connections()
elif self.host.sharding_info.shards_count == len(self._connections.keys()):
elif self.host.sharding_info.shards_count == len(self._connections.keys()) and self.num_missing_or_needing_replacement == 0:
log.debug(
"All shards are already covered, closing newly opened excess connection for host %s",
self.host
Expand Down Expand Up @@ -702,6 +760,11 @@ def get_state(self):
in_flights = [c.in_flight for c in self._connections.values()]
return {'shutdown': self.is_shutdown, 'open_count': self.open_count, 'in_flights': in_flights}

@property
def num_missing_or_needing_replacement(self):
return self.host.sharding_info.shards_count \
- sum(1 for c in self._connections.values() if not c.orphaned_threshold_reached)

@property
def open_count(self):
return sum([1 if c and not (c.is_closed or c.is_defunct) else 0 for c in self._connections.values()])
Expand Down
3 changes: 2 additions & 1 deletion tests/unit/test_response_future.py
Original file line number Diff line number Diff line change
Expand Up @@ -617,7 +617,8 @@ def test_timeout_does_not_release_stream_id(self):
session.cluster._default_load_balancing_policy.make_query_plan.return_value = [Mock(endpoint='ip1'), Mock(endpoint='ip2')]
pool = self.make_pool()
session._pools.get.return_value = pool
connection = Mock(spec=Connection, lock=RLock(), _requests={}, request_ids=deque())
connection = Mock(spec=Connection, lock=RLock(), _requests={}, request_ids=deque(),
orphaned_request_ids=set(), orphaned_threshold=256)
pool.borrow_connection.return_value = (connection, 1)

rf = self.make_response_future(session)
Expand Down

0 comments on commit 3210a9e

Please sign in to comment.