Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 0 additions & 8 deletions cassandra/cluster.py
Original file line number Diff line number Diff line change
Expand Up @@ -1734,14 +1734,6 @@
self.shutdown()
raise

# Update the information about tablet support after connection handshake.
self.load_balancing_policy._tablets_routing_v1 = self.control_connection._tablets_routing_v1
child_policy = self.load_balancing_policy.child_policy if hasattr(self.load_balancing_policy, 'child_policy') else None
while child_policy is not None:
if hasattr(child_policy, '_tablet_routing_v1'):
child_policy._tablet_routing_v1 = self.control_connection._tablets_routing_v1
child_policy = child_policy.child_policy if hasattr(child_policy, 'child_policy') else None

self.profile_manager.check_supported() # todo: rename this method

if self.idle_heartbeat_interval:
Expand Down Expand Up @@ -4323,7 +4315,7 @@
self._scheduled_tasks.discard(task)
fn, args, kwargs = task
kwargs = dict(kwargs)
future = self._executor.submit(fn, *args, **kwargs)

Check failure on line 4318 in cassandra/cluster.py

View workflow job for this annotation

GitHub Actions / test libev (3.12)

cannot schedule new futures after shutdown
future.add_done_callback(self._log_if_failed)
else:
self._queue.put_nowait((run_at, i, task))
Expand Down
9 changes: 3 additions & 6 deletions cassandra/policies.py
Original file line number Diff line number Diff line change
Expand Up @@ -476,7 +476,6 @@ class TokenAwarePolicy(LoadBalancingPolicy):

_child_policy = None
_cluster_metadata = None
_tablets_routing_v1 = False
shuffle_replicas = False
"""
Yield local replicas in a random order.
Expand All @@ -488,7 +487,6 @@ def __init__(self, child_policy, shuffle_replicas=False):

def populate(self, cluster, hosts):
self._cluster_metadata = cluster.metadata
self._tablets_routing_v1 = cluster.control_connection._tablets_routing_v1
self._child_policy.populate(cluster, hosts)

def check_supported(self):
Expand All @@ -513,17 +511,16 @@ def make_query_plan(self, working_keyspace=None, query=None):
return

replicas = []
if self._tablets_routing_v1:
if self._cluster_metadata._tablets.table_has_tablets(keyspace, query.table):
tablet = self._cluster_metadata._tablets.get_tablet_for_key(
keyspace, query.table, self._cluster_metadata.token_map.token_class.from_key(query.routing_key))
keyspace, query.table, self._cluster_metadata.token_map.token_class.from_key(query.routing_key))

if tablet is not None:
replicas_mapped = set(map(lambda r: r[0], tablet.replicas))
child_plan = child.make_query_plan(keyspace, query)

replicas = [host for host in child_plan if host.host_id in replicas_mapped]

if not replicas:
else:
replicas = self._cluster_metadata.get_replicas(keyspace, query.routing_key)

if self.shuffle_replicas:
Expand Down
3 changes: 3 additions & 0 deletions cassandra/tablets.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,9 @@ def __init__(self, tablets):
self._tablets = tablets
self._lock = Lock()

def table_has_tablets(self, keyspace, table) -> bool:
return bool(self._tablets.get((keyspace, table), []))

def get_tablet_for_key(self, keyspace, table, t):
tablet = self._tablets.get((keyspace, table), [])
if not tablet:
Expand Down
32 changes: 31 additions & 1 deletion tests/integration/standard/test_tablets.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

import pytest

from cassandra.cluster import Cluster
from cassandra.cluster import Cluster, EXEC_PROFILE_DEFAULT, ExecutionProfile
from cassandra.policies import ConstantReconnectionPolicy, RoundRobinPolicy, TokenAwarePolicy

from tests.integration import PROTOCOL_VERSION, use_cluster, get_cluster
Expand Down Expand Up @@ -163,6 +163,36 @@ def test_tablets_shard_awareness(self):
self.query_data_shard_select(self.session)
self.query_data_shard_insert(self.session)

def test_tablets_lbp_in_profile(self):
cluster = Cluster(contact_points=["127.0.0.1", "127.0.0.2", "127.0.0.3"], protocol_version=PROTOCOL_VERSION,
execution_profiles={
EXEC_PROFILE_DEFAULT: ExecutionProfile(
load_balancing_policy=TokenAwarePolicy(RoundRobinPolicy()),
)},
reconnection_policy=ConstantReconnectionPolicy(1))
session = cluster.connect()
try:
self.query_data_host_select(self.session)
self.query_data_host_insert(self.session)
finally:
session.shutdown()
cluster.shutdown()

def test_tablets_shard_awareness_lbp_in_profile(self):
cluster = Cluster(contact_points=["127.0.0.1", "127.0.0.2", "127.0.0.3"], protocol_version=PROTOCOL_VERSION,
execution_profiles={
EXEC_PROFILE_DEFAULT: ExecutionProfile(
load_balancing_policy=TokenAwarePolicy(RoundRobinPolicy()),
)},
reconnection_policy=ConstantReconnectionPolicy(1))
session = cluster.connect()
try:
self.query_data_shard_select(self.session)
self.query_data_shard_insert(self.session)
finally:
session.shutdown()
cluster.shutdown()

def test_tablets_invalidation_drop_ks_while_reconnecting(self):
def recreate_while_reconnecting(_):
# Kill control connection
Expand Down
57 changes: 44 additions & 13 deletions tests/unit/test_policies.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@
from cassandra.connection import DefaultEndPoint, UnixSocketEndPoint
from cassandra.pool import Host
from cassandra.query import Statement
from cassandra.tablets import Tablets, Tablet


class LoadBalancingPolicyTest(unittest.TestCase):
Expand Down Expand Up @@ -582,7 +583,8 @@ class TokenAwarePolicyTest(unittest.TestCase):
def test_wrap_round_robin(self):
cluster = Mock(spec=Cluster)
cluster.metadata = Mock(spec=Metadata)
cluster.control_connection._tablets_routing_v1 = False
cluster.metadata._tablets = Mock(spec=Tablets)
cluster.metadata._tablets.table_has_tablets.return_value = []
hosts = [Host(DefaultEndPoint(str(i)), SimpleConvictionPolicy) for i in range(4)]
for host in hosts:
host.set_up()
Expand Down Expand Up @@ -614,7 +616,8 @@ def get_replicas(keyspace, packed_key):
def test_wrap_dc_aware(self):
cluster = Mock(spec=Cluster)
cluster.metadata = Mock(spec=Metadata)
cluster.control_connection._tablets_routing_v1 = False
cluster.metadata._tablets = Mock(spec=Tablets)
cluster.metadata._tablets.table_has_tablets.return_value = []
hosts = [Host(DefaultEndPoint(str(i)), SimpleConvictionPolicy) for i in range(4)]
for host in hosts:
host.set_up()
Expand Down Expand Up @@ -744,9 +747,10 @@ def test_statement_keyspace(self):

cluster = Mock(spec=Cluster)
cluster.metadata = Mock(spec=Metadata)
cluster.control_connection._tablets_routing_v1 = False
cluster.metadata._tablets = Mock(spec=Tablets)
replicas = hosts[2:]
cluster.metadata.get_replicas.return_value = replicas
cluster.metadata._tablets.table_has_tablets.return_value = []

child_policy = Mock()
child_policy.make_query_plan.return_value = hosts
Expand Down Expand Up @@ -803,7 +807,8 @@ def test_shuffles_if_given_keyspace_and_routing_key(self):

@test_category policy
"""
self._assert_shuffle(keyspace='keyspace', routing_key='routing_key')
self._assert_shuffle(cluster=self._prepare_cluster_with_vnodes(), keyspace='keyspace', routing_key='routing_key')
self._assert_shuffle(cluster=self._prepare_cluster_with_tablets(), keyspace='keyspace', routing_key='routing_key')

def test_no_shuffle_if_given_no_keyspace(self):
"""
Expand All @@ -814,7 +819,8 @@ def test_no_shuffle_if_given_no_keyspace(self):

@test_category policy
"""
self._assert_shuffle(keyspace=None, routing_key='routing_key')
self._assert_shuffle(cluster=self._prepare_cluster_with_vnodes(), keyspace=None, routing_key='routing_key')
self._assert_shuffle(cluster=self._prepare_cluster_with_tablets(), keyspace=None, routing_key='routing_key')

def test_no_shuffle_if_given_no_routing_key(self):
"""
Expand All @@ -825,27 +831,47 @@ def test_no_shuffle_if_given_no_routing_key(self):

@test_category policy
"""
self._assert_shuffle(keyspace='keyspace', routing_key=None)
self._assert_shuffle(cluster=self._prepare_cluster_with_vnodes(), keyspace='keyspace', routing_key=None)
self._assert_shuffle(cluster=self._prepare_cluster_with_tablets(), keyspace='keyspace', routing_key=None)

@patch('cassandra.policies.shuffle')
def _assert_shuffle(self, patched_shuffle, keyspace, routing_key):
def _prepare_cluster_with_vnodes(self):
hosts = [Host(DefaultEndPoint(str(i)), SimpleConvictionPolicy) for i in range(4)]
for host in hosts:
host.set_up()
cluster = Mock(spec=Cluster)
cluster.metadata = Mock(spec=Metadata)
cluster.metadata._tablets = Mock(spec=Tablets)
cluster.metadata.all_hosts.return_value = hosts
cluster.metadata.get_replicas.return_value = hosts[2:]
cluster.metadata._tablets.table_has_tablets.return_value = False
return cluster

def _prepare_cluster_with_tablets(self):
hosts = [Host(DefaultEndPoint(str(i)), SimpleConvictionPolicy) for i in range(4)]
for host in hosts:
host.set_up()
cluster = Mock(spec=Cluster)
cluster.metadata = Mock(spec=Metadata)
cluster.control_connection._tablets_routing_v1 = False
replicas = hosts[2:]
cluster.metadata.get_replicas.return_value = replicas
cluster.metadata._tablets = Mock(spec=Tablets)
cluster.metadata.all_hosts.return_value = hosts
cluster.metadata.get_replicas.return_value = hosts[2:]
cluster.metadata._tablets.table_has_tablets.return_value = True
cluster.metadata._tablets.get_tablet_for_key.return_value = Tablet(replicas=[(h.host_id, 0) for h in hosts[2:]])
return cluster

@patch('cassandra.policies.shuffle')
def _assert_shuffle(self, patched_shuffle, cluster, keyspace, routing_key):
hosts = cluster.metadata.all_hosts()
replicas = cluster.metadata.get_replicas()
child_policy = Mock()
child_policy.make_query_plan.return_value = hosts
child_policy.distance.return_value = HostDistance.LOCAL

policy = TokenAwarePolicy(child_policy, shuffle_replicas=True)
policy.populate(cluster, hosts)

is_tablets = cluster.metadata._tablets.table_has_tablets()

cluster.metadata.get_replicas.reset_mock()
child_policy.make_query_plan.reset_mock()
query = Statement(routing_key=routing_key)
Expand All @@ -858,7 +884,11 @@ def _assert_shuffle(self, patched_shuffle, keyspace, routing_key):
else:
assert set(replicas) == set(qplan[:2])
assert hosts[:2] == qplan[2:]
child_policy.make_query_plan.assert_called_once_with(keyspace, query)
if is_tablets:
child_policy.make_query_plan.assert_called_with(keyspace, query)
assert child_policy.make_query_plan.call_count == 2
else:
child_policy.make_query_plan.assert_called_once_with(keyspace, query)
assert patched_shuffle.call_count == 1


Expand Down Expand Up @@ -1538,7 +1568,6 @@ def test_query_plan_deferred_to_child(self):

def test_wrap_token_aware(self):
cluster = Mock(spec=Cluster)
cluster.control_connection._tablets_routing_v1 = False
hosts = [Host(DefaultEndPoint("127.0.0.{}".format(i)), SimpleConvictionPolicy) for i in range(1, 6)]
for host in hosts:
host.set_up()
Expand All @@ -1547,6 +1576,8 @@ def get_replicas(keyspace, packed_key):
return hosts[:2]

cluster.metadata.get_replicas.side_effect = get_replicas
cluster.metadata._tablets = Mock(spec=Tablets)
cluster.metadata._tablets.table_has_tablets.return_value = []

child_policy = TokenAwarePolicy(RoundRobinPolicy())

Expand Down
Loading