diff --git a/cassandra/policies.py b/cassandra/policies.py index cb83238e87..efef0bb2a5 100644 --- a/cassandra/policies.py +++ b/cassandra/policies.py @@ -14,7 +14,6 @@ import random from collections import namedtuple -from functools import lru_cache from itertools import islice, cycle, groupby, repeat import logging from random import randint, shuffle @@ -254,7 +253,7 @@ def _dc(self, host): def populate(self, cluster, hosts): for dc, dc_hosts in groupby(hosts, lambda h: self._dc(h)): - self._dc_live_hosts[dc] = tuple(set(dc_hosts)) + self._dc_live_hosts[dc] = tuple({*dc_hosts, *self._dc_live_hosts.get(dc, [])}) if not self.local_dc: self._endpoints = [ @@ -374,9 +373,9 @@ def _dc(self, host): def populate(self, cluster, hosts): for (dc, rack), rack_hosts in groupby(hosts, lambda host: (self._dc(host), self._rack(host))): - self._live_hosts[(dc, rack)] = tuple(set(rack_hosts)) + self._live_hosts[(dc, rack)] = tuple({*rack_hosts, *self._live_hosts.get((dc, rack), [])}) for dc, dc_hosts in groupby(hosts, lambda host: self._dc(host)): - self._dc_live_hosts[dc] = tuple(set(dc_hosts)) + self._dc_live_hosts[dc] = tuple({*dc_hosts, *self._dc_live_hosts.get(dc, [])}) self._position = randint(0, len(hosts) - 1) if hosts else 0 diff --git a/tests/unit/test_policies.py b/tests/unit/test_policies.py index c98511ab34..15970aadfa 100644 --- a/tests/unit/test_policies.py +++ b/tests/unit/test_policies.py @@ -11,7 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - +import random import unittest from itertools import islice, cycle @@ -199,6 +199,8 @@ def test_no_remote(self, policy_specialization, constructor_args): h.set_location_info("dc1", "rack1") hosts.append(h) + random.shuffle(hosts) + policy = policy_specialization(*constructor_args) policy.populate(None, hosts) qplan = list(policy.make_query_plan()) @@ -213,6 +215,8 @@ def test_with_remotes(self, policy_specialization, constructor_args): for h in hosts[4:]: h.set_location_info("dc2", "rack1") + random.shuffle(hosts) + local_rack_hosts = set(h for h in hosts if h.datacenter == "dc1" and h.rack == "rack1") local_hosts = set(h for h in hosts if h.datacenter == "dc1" and h.rack != "rack1") remote_hosts = set(h for h in hosts if h.datacenter != "dc1")