diff --git a/cassandra/cluster.py b/cassandra/cluster.py index 14c5cb4bd9..5822a23aa9 100644 --- a/cassandra/cluster.py +++ b/cassandra/cluster.py @@ -1734,14 +1734,6 @@ def connect(self, keyspace=None, wait_for_all_pools=False): self.shutdown() raise - # Update the information about tablet support after connection handshake. - self.load_balancing_policy._tablets_routing_v1 = self.control_connection._tablets_routing_v1 - child_policy = self.load_balancing_policy.child_policy if hasattr(self.load_balancing_policy, 'child_policy') else None - while child_policy is not None: - if hasattr(child_policy, '_tablet_routing_v1'): - child_policy._tablet_routing_v1 = self.control_connection._tablets_routing_v1 - child_policy = child_policy.child_policy if hasattr(child_policy, 'child_policy') else None - self.profile_manager.check_supported() # todo: rename this method if self.idle_heartbeat_interval: diff --git a/cassandra/policies.py b/cassandra/policies.py index efef0bb2a5..a679bff877 100644 --- a/cassandra/policies.py +++ b/cassandra/policies.py @@ -476,7 +476,6 @@ class TokenAwarePolicy(LoadBalancingPolicy): _child_policy = None _cluster_metadata = None - _tablets_routing_v1 = False shuffle_replicas = False """ Yield local replicas in a random order. @@ -488,7 +487,6 @@ def __init__(self, child_policy, shuffle_replicas=False): def populate(self, cluster, hosts): self._cluster_metadata = cluster.metadata - self._tablets_routing_v1 = cluster.control_connection._tablets_routing_v1 self._child_policy.populate(cluster, hosts) def check_supported(self): @@ -513,17 +511,16 @@ def make_query_plan(self, working_keyspace=None, query=None): return replicas = [] - if self._tablets_routing_v1: + if self._cluster_metadata._tablets.table_has_tablets(keyspace, query.table): tablet = self._cluster_metadata._tablets.get_tablet_for_key( - keyspace, query.table, self._cluster_metadata.token_map.token_class.from_key(query.routing_key)) + keyspace, query.table, self._cluster_metadata.token_map.token_class.from_key(query.routing_key)) if tablet is not None: replicas_mapped = set(map(lambda r: r[0], tablet.replicas)) child_plan = child.make_query_plan(keyspace, query) replicas = [host for host in child_plan if host.host_id in replicas_mapped] - - if not replicas: + else: replicas = self._cluster_metadata.get_replicas(keyspace, query.routing_key) if self.shuffle_replicas: diff --git a/cassandra/tablets.py b/cassandra/tablets.py index 457ee93ca4..dca26ab0df 100644 --- a/cassandra/tablets.py +++ b/cassandra/tablets.py @@ -49,6 +49,9 @@ def __init__(self, tablets): self._tablets = tablets self._lock = Lock() + def table_has_tablets(self, keyspace, table) -> bool: + return bool(self._tablets.get((keyspace, table), [])) + def get_tablet_for_key(self, keyspace, table, t): tablet = self._tablets.get((keyspace, table), []) if not tablet: diff --git a/tests/integration/standard/test_tablets.py b/tests/integration/standard/test_tablets.py index 0216f7843a..d9439e5c2c 100644 --- a/tests/integration/standard/test_tablets.py +++ b/tests/integration/standard/test_tablets.py @@ -2,7 +2,7 @@ import pytest -from cassandra.cluster import Cluster +from cassandra.cluster import Cluster, EXEC_PROFILE_DEFAULT, ExecutionProfile from cassandra.policies import ConstantReconnectionPolicy, RoundRobinPolicy, TokenAwarePolicy from tests.integration import PROTOCOL_VERSION, use_cluster, get_cluster @@ -163,6 +163,36 @@ def test_tablets_shard_awareness(self): self.query_data_shard_select(self.session) self.query_data_shard_insert(self.session) + def test_tablets_lbp_in_profile(self): + cluster = Cluster(contact_points=["127.0.0.1", "127.0.0.2", "127.0.0.3"], protocol_version=PROTOCOL_VERSION, + execution_profiles={ + EXEC_PROFILE_DEFAULT: ExecutionProfile( + load_balancing_policy=TokenAwarePolicy(RoundRobinPolicy()), + )}, + reconnection_policy=ConstantReconnectionPolicy(1)) + session = cluster.connect() + try: + self.query_data_host_select(self.session) + self.query_data_host_insert(self.session) + finally: + session.shutdown() + cluster.shutdown() + + def test_tablets_shard_awareness_lbp_in_profile(self): + cluster = Cluster(contact_points=["127.0.0.1", "127.0.0.2", "127.0.0.3"], protocol_version=PROTOCOL_VERSION, + execution_profiles={ + EXEC_PROFILE_DEFAULT: ExecutionProfile( + load_balancing_policy=TokenAwarePolicy(RoundRobinPolicy()), + )}, + reconnection_policy=ConstantReconnectionPolicy(1)) + session = cluster.connect() + try: + self.query_data_shard_select(self.session) + self.query_data_shard_insert(self.session) + finally: + session.shutdown() + cluster.shutdown() + def test_tablets_invalidation_drop_ks_while_reconnecting(self): def recreate_while_reconnecting(_): # Kill control connection diff --git a/tests/unit/test_policies.py b/tests/unit/test_policies.py index 15970aadfa..e65a89bca7 100644 --- a/tests/unit/test_policies.py +++ b/tests/unit/test_policies.py @@ -36,6 +36,7 @@ from cassandra.connection import DefaultEndPoint, UnixSocketEndPoint from cassandra.pool import Host from cassandra.query import Statement +from cassandra.tablets import Tablets, Tablet class LoadBalancingPolicyTest(unittest.TestCase): @@ -582,7 +583,8 @@ class TokenAwarePolicyTest(unittest.TestCase): def test_wrap_round_robin(self): cluster = Mock(spec=Cluster) cluster.metadata = Mock(spec=Metadata) - cluster.control_connection._tablets_routing_v1 = False + cluster.metadata._tablets = Mock(spec=Tablets) + cluster.metadata._tablets.table_has_tablets.return_value = [] hosts = [Host(DefaultEndPoint(str(i)), SimpleConvictionPolicy) for i in range(4)] for host in hosts: host.set_up() @@ -614,7 +616,8 @@ def get_replicas(keyspace, packed_key): def test_wrap_dc_aware(self): cluster = Mock(spec=Cluster) cluster.metadata = Mock(spec=Metadata) - cluster.control_connection._tablets_routing_v1 = False + cluster.metadata._tablets = Mock(spec=Tablets) + cluster.metadata._tablets.table_has_tablets.return_value = [] hosts = [Host(DefaultEndPoint(str(i)), SimpleConvictionPolicy) for i in range(4)] for host in hosts: host.set_up() @@ -744,9 +747,10 @@ def test_statement_keyspace(self): cluster = Mock(spec=Cluster) cluster.metadata = Mock(spec=Metadata) - cluster.control_connection._tablets_routing_v1 = False + cluster.metadata._tablets = Mock(spec=Tablets) replicas = hosts[2:] cluster.metadata.get_replicas.return_value = replicas + cluster.metadata._tablets.table_has_tablets.return_value = [] child_policy = Mock() child_policy.make_query_plan.return_value = hosts @@ -803,7 +807,8 @@ def test_shuffles_if_given_keyspace_and_routing_key(self): @test_category policy """ - self._assert_shuffle(keyspace='keyspace', routing_key='routing_key') + self._assert_shuffle(cluster=self._prepare_cluster_with_vnodes(), keyspace='keyspace', routing_key='routing_key') + self._assert_shuffle(cluster=self._prepare_cluster_with_tablets(), keyspace='keyspace', routing_key='routing_key') def test_no_shuffle_if_given_no_keyspace(self): """ @@ -814,7 +819,8 @@ def test_no_shuffle_if_given_no_keyspace(self): @test_category policy """ - self._assert_shuffle(keyspace=None, routing_key='routing_key') + self._assert_shuffle(cluster=self._prepare_cluster_with_vnodes(), keyspace=None, routing_key='routing_key') + self._assert_shuffle(cluster=self._prepare_cluster_with_tablets(), keyspace=None, routing_key='routing_key') def test_no_shuffle_if_given_no_routing_key(self): """ @@ -825,20 +831,38 @@ def test_no_shuffle_if_given_no_routing_key(self): @test_category policy """ - self._assert_shuffle(keyspace='keyspace', routing_key=None) + self._assert_shuffle(cluster=self._prepare_cluster_with_vnodes(), keyspace='keyspace', routing_key=None) + self._assert_shuffle(cluster=self._prepare_cluster_with_tablets(), keyspace='keyspace', routing_key=None) - @patch('cassandra.policies.shuffle') - def _assert_shuffle(self, patched_shuffle, keyspace, routing_key): + def _prepare_cluster_with_vnodes(self): hosts = [Host(DefaultEndPoint(str(i)), SimpleConvictionPolicy) for i in range(4)] for host in hosts: host.set_up() + cluster = Mock(spec=Cluster) + cluster.metadata = Mock(spec=Metadata) + cluster.metadata._tablets = Mock(spec=Tablets) + cluster.metadata.all_hosts.return_value = hosts + cluster.metadata.get_replicas.return_value = hosts[2:] + cluster.metadata._tablets.table_has_tablets.return_value = False + return cluster + def _prepare_cluster_with_tablets(self): + hosts = [Host(DefaultEndPoint(str(i)), SimpleConvictionPolicy) for i in range(4)] + for host in hosts: + host.set_up() cluster = Mock(spec=Cluster) cluster.metadata = Mock(spec=Metadata) - cluster.control_connection._tablets_routing_v1 = False - replicas = hosts[2:] - cluster.metadata.get_replicas.return_value = replicas + cluster.metadata._tablets = Mock(spec=Tablets) + cluster.metadata.all_hosts.return_value = hosts + cluster.metadata.get_replicas.return_value = hosts[2:] + cluster.metadata._tablets.table_has_tablets.return_value = True + cluster.metadata._tablets.get_tablet_for_key.return_value = Tablet(replicas=[(h.host_id, 0) for h in hosts[2:]]) + return cluster + @patch('cassandra.policies.shuffle') + def _assert_shuffle(self, patched_shuffle, cluster, keyspace, routing_key): + hosts = cluster.metadata.all_hosts() + replicas = cluster.metadata.get_replicas() child_policy = Mock() child_policy.make_query_plan.return_value = hosts child_policy.distance.return_value = HostDistance.LOCAL @@ -846,6 +870,8 @@ def _assert_shuffle(self, patched_shuffle, keyspace, routing_key): policy = TokenAwarePolicy(child_policy, shuffle_replicas=True) policy.populate(cluster, hosts) + is_tablets = cluster.metadata._tablets.table_has_tablets() + cluster.metadata.get_replicas.reset_mock() child_policy.make_query_plan.reset_mock() query = Statement(routing_key=routing_key) @@ -858,7 +884,11 @@ def _assert_shuffle(self, patched_shuffle, keyspace, routing_key): else: assert set(replicas) == set(qplan[:2]) assert hosts[:2] == qplan[2:] - child_policy.make_query_plan.assert_called_once_with(keyspace, query) + if is_tablets: + child_policy.make_query_plan.assert_called_with(keyspace, query) + assert child_policy.make_query_plan.call_count == 2 + else: + child_policy.make_query_plan.assert_called_once_with(keyspace, query) assert patched_shuffle.call_count == 1 @@ -1538,7 +1568,6 @@ def test_query_plan_deferred_to_child(self): def test_wrap_token_aware(self): cluster = Mock(spec=Cluster) - cluster.control_connection._tablets_routing_v1 = False hosts = [Host(DefaultEndPoint("127.0.0.{}".format(i)), SimpleConvictionPolicy) for i in range(1, 6)] for host in hosts: host.set_up() @@ -1547,6 +1576,8 @@ def get_replicas(keyspace, packed_key): return hosts[:2] cluster.metadata.get_replicas.side_effect = get_replicas + cluster.metadata._tablets = Mock(spec=Tablets) + cluster.metadata._tablets.table_has_tablets.return_value = [] child_policy = TokenAwarePolicy(RoundRobinPolicy())