From e3db20421bd409f4b5c996b5c44ec0016579df95 Mon Sep 17 00:00:00 2001 From: Dmitry Kropachev Date: Tue, 4 Nov 2025 08:13:09 -0400 Subject: [PATCH] token-aware-policy: drop _tablets_routing_v1 flag This flag was introduced to check if server supports tablets. As result driver needs to sync it to policy when control connection is established. Which is an unwanted problem. Let's relay on presence of tablets for given table instead, which will not require any syncing. --- cassandra/cluster.py | 8 --- cassandra/policies.py | 9 ++-- cassandra/tablets.py | 3 ++ tests/integration/standard/test_tablets.py | 32 +++++++++++- tests/unit/test_policies.py | 57 +++++++++++++++++----- 5 files changed, 81 insertions(+), 28 deletions(-) diff --git a/cassandra/cluster.py b/cassandra/cluster.py index 14c5cb4bd9..5822a23aa9 100644 --- a/cassandra/cluster.py +++ b/cassandra/cluster.py @@ -1734,14 +1734,6 @@ def connect(self, keyspace=None, wait_for_all_pools=False): self.shutdown() raise - # Update the information about tablet support after connection handshake. - self.load_balancing_policy._tablets_routing_v1 = self.control_connection._tablets_routing_v1 - child_policy = self.load_balancing_policy.child_policy if hasattr(self.load_balancing_policy, 'child_policy') else None - while child_policy is not None: - if hasattr(child_policy, '_tablet_routing_v1'): - child_policy._tablet_routing_v1 = self.control_connection._tablets_routing_v1 - child_policy = child_policy.child_policy if hasattr(child_policy, 'child_policy') else None - self.profile_manager.check_supported() # todo: rename this method if self.idle_heartbeat_interval: diff --git a/cassandra/policies.py b/cassandra/policies.py index efef0bb2a5..a679bff877 100644 --- a/cassandra/policies.py +++ b/cassandra/policies.py @@ -476,7 +476,6 @@ class TokenAwarePolicy(LoadBalancingPolicy): _child_policy = None _cluster_metadata = None - _tablets_routing_v1 = False shuffle_replicas = False """ Yield local replicas in a random order. @@ -488,7 +487,6 @@ def __init__(self, child_policy, shuffle_replicas=False): def populate(self, cluster, hosts): self._cluster_metadata = cluster.metadata - self._tablets_routing_v1 = cluster.control_connection._tablets_routing_v1 self._child_policy.populate(cluster, hosts) def check_supported(self): @@ -513,17 +511,16 @@ def make_query_plan(self, working_keyspace=None, query=None): return replicas = [] - if self._tablets_routing_v1: + if self._cluster_metadata._tablets.table_has_tablets(keyspace, query.table): tablet = self._cluster_metadata._tablets.get_tablet_for_key( - keyspace, query.table, self._cluster_metadata.token_map.token_class.from_key(query.routing_key)) + keyspace, query.table, self._cluster_metadata.token_map.token_class.from_key(query.routing_key)) if tablet is not None: replicas_mapped = set(map(lambda r: r[0], tablet.replicas)) child_plan = child.make_query_plan(keyspace, query) replicas = [host for host in child_plan if host.host_id in replicas_mapped] - - if not replicas: + else: replicas = self._cluster_metadata.get_replicas(keyspace, query.routing_key) if self.shuffle_replicas: diff --git a/cassandra/tablets.py b/cassandra/tablets.py index 457ee93ca4..dca26ab0df 100644 --- a/cassandra/tablets.py +++ b/cassandra/tablets.py @@ -49,6 +49,9 @@ def __init__(self, tablets): self._tablets = tablets self._lock = Lock() + def table_has_tablets(self, keyspace, table) -> bool: + return bool(self._tablets.get((keyspace, table), [])) + def get_tablet_for_key(self, keyspace, table, t): tablet = self._tablets.get((keyspace, table), []) if not tablet: diff --git a/tests/integration/standard/test_tablets.py b/tests/integration/standard/test_tablets.py index 0216f7843a..d9439e5c2c 100644 --- a/tests/integration/standard/test_tablets.py +++ b/tests/integration/standard/test_tablets.py @@ -2,7 +2,7 @@ import pytest -from cassandra.cluster import Cluster +from cassandra.cluster import Cluster, EXEC_PROFILE_DEFAULT, ExecutionProfile from cassandra.policies import ConstantReconnectionPolicy, RoundRobinPolicy, TokenAwarePolicy from tests.integration import PROTOCOL_VERSION, use_cluster, get_cluster @@ -163,6 +163,36 @@ def test_tablets_shard_awareness(self): self.query_data_shard_select(self.session) self.query_data_shard_insert(self.session) + def test_tablets_lbp_in_profile(self): + cluster = Cluster(contact_points=["127.0.0.1", "127.0.0.2", "127.0.0.3"], protocol_version=PROTOCOL_VERSION, + execution_profiles={ + EXEC_PROFILE_DEFAULT: ExecutionProfile( + load_balancing_policy=TokenAwarePolicy(RoundRobinPolicy()), + )}, + reconnection_policy=ConstantReconnectionPolicy(1)) + session = cluster.connect() + try: + self.query_data_host_select(self.session) + self.query_data_host_insert(self.session) + finally: + session.shutdown() + cluster.shutdown() + + def test_tablets_shard_awareness_lbp_in_profile(self): + cluster = Cluster(contact_points=["127.0.0.1", "127.0.0.2", "127.0.0.3"], protocol_version=PROTOCOL_VERSION, + execution_profiles={ + EXEC_PROFILE_DEFAULT: ExecutionProfile( + load_balancing_policy=TokenAwarePolicy(RoundRobinPolicy()), + )}, + reconnection_policy=ConstantReconnectionPolicy(1)) + session = cluster.connect() + try: + self.query_data_shard_select(self.session) + self.query_data_shard_insert(self.session) + finally: + session.shutdown() + cluster.shutdown() + def test_tablets_invalidation_drop_ks_while_reconnecting(self): def recreate_while_reconnecting(_): # Kill control connection diff --git a/tests/unit/test_policies.py b/tests/unit/test_policies.py index 15970aadfa..e65a89bca7 100644 --- a/tests/unit/test_policies.py +++ b/tests/unit/test_policies.py @@ -36,6 +36,7 @@ from cassandra.connection import DefaultEndPoint, UnixSocketEndPoint from cassandra.pool import Host from cassandra.query import Statement +from cassandra.tablets import Tablets, Tablet class LoadBalancingPolicyTest(unittest.TestCase): @@ -582,7 +583,8 @@ class TokenAwarePolicyTest(unittest.TestCase): def test_wrap_round_robin(self): cluster = Mock(spec=Cluster) cluster.metadata = Mock(spec=Metadata) - cluster.control_connection._tablets_routing_v1 = False + cluster.metadata._tablets = Mock(spec=Tablets) + cluster.metadata._tablets.table_has_tablets.return_value = [] hosts = [Host(DefaultEndPoint(str(i)), SimpleConvictionPolicy) for i in range(4)] for host in hosts: host.set_up() @@ -614,7 +616,8 @@ def get_replicas(keyspace, packed_key): def test_wrap_dc_aware(self): cluster = Mock(spec=Cluster) cluster.metadata = Mock(spec=Metadata) - cluster.control_connection._tablets_routing_v1 = False + cluster.metadata._tablets = Mock(spec=Tablets) + cluster.metadata._tablets.table_has_tablets.return_value = [] hosts = [Host(DefaultEndPoint(str(i)), SimpleConvictionPolicy) for i in range(4)] for host in hosts: host.set_up() @@ -744,9 +747,10 @@ def test_statement_keyspace(self): cluster = Mock(spec=Cluster) cluster.metadata = Mock(spec=Metadata) - cluster.control_connection._tablets_routing_v1 = False + cluster.metadata._tablets = Mock(spec=Tablets) replicas = hosts[2:] cluster.metadata.get_replicas.return_value = replicas + cluster.metadata._tablets.table_has_tablets.return_value = [] child_policy = Mock() child_policy.make_query_plan.return_value = hosts @@ -803,7 +807,8 @@ def test_shuffles_if_given_keyspace_and_routing_key(self): @test_category policy """ - self._assert_shuffle(keyspace='keyspace', routing_key='routing_key') + self._assert_shuffle(cluster=self._prepare_cluster_with_vnodes(), keyspace='keyspace', routing_key='routing_key') + self._assert_shuffle(cluster=self._prepare_cluster_with_tablets(), keyspace='keyspace', routing_key='routing_key') def test_no_shuffle_if_given_no_keyspace(self): """ @@ -814,7 +819,8 @@ def test_no_shuffle_if_given_no_keyspace(self): @test_category policy """ - self._assert_shuffle(keyspace=None, routing_key='routing_key') + self._assert_shuffle(cluster=self._prepare_cluster_with_vnodes(), keyspace=None, routing_key='routing_key') + self._assert_shuffle(cluster=self._prepare_cluster_with_tablets(), keyspace=None, routing_key='routing_key') def test_no_shuffle_if_given_no_routing_key(self): """ @@ -825,20 +831,38 @@ def test_no_shuffle_if_given_no_routing_key(self): @test_category policy """ - self._assert_shuffle(keyspace='keyspace', routing_key=None) + self._assert_shuffle(cluster=self._prepare_cluster_with_vnodes(), keyspace='keyspace', routing_key=None) + self._assert_shuffle(cluster=self._prepare_cluster_with_tablets(), keyspace='keyspace', routing_key=None) - @patch('cassandra.policies.shuffle') - def _assert_shuffle(self, patched_shuffle, keyspace, routing_key): + def _prepare_cluster_with_vnodes(self): hosts = [Host(DefaultEndPoint(str(i)), SimpleConvictionPolicy) for i in range(4)] for host in hosts: host.set_up() + cluster = Mock(spec=Cluster) + cluster.metadata = Mock(spec=Metadata) + cluster.metadata._tablets = Mock(spec=Tablets) + cluster.metadata.all_hosts.return_value = hosts + cluster.metadata.get_replicas.return_value = hosts[2:] + cluster.metadata._tablets.table_has_tablets.return_value = False + return cluster + def _prepare_cluster_with_tablets(self): + hosts = [Host(DefaultEndPoint(str(i)), SimpleConvictionPolicy) for i in range(4)] + for host in hosts: + host.set_up() cluster = Mock(spec=Cluster) cluster.metadata = Mock(spec=Metadata) - cluster.control_connection._tablets_routing_v1 = False - replicas = hosts[2:] - cluster.metadata.get_replicas.return_value = replicas + cluster.metadata._tablets = Mock(spec=Tablets) + cluster.metadata.all_hosts.return_value = hosts + cluster.metadata.get_replicas.return_value = hosts[2:] + cluster.metadata._tablets.table_has_tablets.return_value = True + cluster.metadata._tablets.get_tablet_for_key.return_value = Tablet(replicas=[(h.host_id, 0) for h in hosts[2:]]) + return cluster + @patch('cassandra.policies.shuffle') + def _assert_shuffle(self, patched_shuffle, cluster, keyspace, routing_key): + hosts = cluster.metadata.all_hosts() + replicas = cluster.metadata.get_replicas() child_policy = Mock() child_policy.make_query_plan.return_value = hosts child_policy.distance.return_value = HostDistance.LOCAL @@ -846,6 +870,8 @@ def _assert_shuffle(self, patched_shuffle, keyspace, routing_key): policy = TokenAwarePolicy(child_policy, shuffle_replicas=True) policy.populate(cluster, hosts) + is_tablets = cluster.metadata._tablets.table_has_tablets() + cluster.metadata.get_replicas.reset_mock() child_policy.make_query_plan.reset_mock() query = Statement(routing_key=routing_key) @@ -858,7 +884,11 @@ def _assert_shuffle(self, patched_shuffle, keyspace, routing_key): else: assert set(replicas) == set(qplan[:2]) assert hosts[:2] == qplan[2:] - child_policy.make_query_plan.assert_called_once_with(keyspace, query) + if is_tablets: + child_policy.make_query_plan.assert_called_with(keyspace, query) + assert child_policy.make_query_plan.call_count == 2 + else: + child_policy.make_query_plan.assert_called_once_with(keyspace, query) assert patched_shuffle.call_count == 1 @@ -1538,7 +1568,6 @@ def test_query_plan_deferred_to_child(self): def test_wrap_token_aware(self): cluster = Mock(spec=Cluster) - cluster.control_connection._tablets_routing_v1 = False hosts = [Host(DefaultEndPoint("127.0.0.{}".format(i)), SimpleConvictionPolicy) for i in range(1, 6)] for host in hosts: host.set_up() @@ -1547,6 +1576,8 @@ def get_replicas(keyspace, packed_key): return hosts[:2] cluster.metadata.get_replicas.side_effect = get_replicas + cluster.metadata._tablets = Mock(spec=Tablets) + cluster.metadata._tablets.table_has_tablets.return_value = [] child_policy = TokenAwarePolicy(RoundRobinPolicy())