diff --git a/src/Cassandra.IntegrationTests/Core/ClusterTests.cs b/src/Cassandra.IntegrationTests/Core/ClusterTests.cs index ea92e6401..2cc623038 100644 --- a/src/Cassandra.IntegrationTests/Core/ClusterTests.cs +++ b/src/Cassandra.IntegrationTests/Core/ClusterTests.cs @@ -245,9 +245,9 @@ public HostDistance Distance(Host host) return HostDistance.Local; } - public IEnumerable NewQueryPlan(string keyspace, IStatement query) + public IEnumerable NewQueryPlan(string keyspace, IStatement query) { - return _cluster.AllHosts(); + return _cluster.AllHosts().Select(h => new HostShard(h, -1)); } } } diff --git a/src/Cassandra.IntegrationTests/Core/ConnectionSimulacronTests.cs b/src/Cassandra.IntegrationTests/Core/ConnectionSimulacronTests.cs index 0b569dbe1..0894bea5b 100644 --- a/src/Cassandra.IntegrationTests/Core/ConnectionSimulacronTests.cs +++ b/src/Cassandra.IntegrationTests/Core/ConnectionSimulacronTests.cs @@ -386,10 +386,10 @@ public HostDistance Distance(Host host) return _parent.Distance(host); } - public IEnumerable NewQueryPlan(string keyspace, IStatement query) + public IEnumerable NewQueryPlan(string keyspace, IStatement query) { var plan = _parent.NewQueryPlan(keyspace, query); - return plan.Where(h => !_disallowed.Contains(h.Address)); + return plan.Where(h => !_disallowed.Contains(h.Host.Address)); } } } diff --git a/src/Cassandra.IntegrationTests/Core/SessionTests.cs b/src/Cassandra.IntegrationTests/Core/SessionTests.cs index 85e8b3afe..0b3e44eed 100644 --- a/src/Cassandra.IntegrationTests/Core/SessionTests.cs +++ b/src/Cassandra.IntegrationTests/Core/SessionTests.cs @@ -331,7 +331,7 @@ public HostDistance Distance(Host host) return HostDistance.Local; } - public IEnumerable NewQueryPlan(string keyspace, IStatement query) + public IEnumerable NewQueryPlan(string keyspace, IStatement query) { return _childPolicy.NewQueryPlan(keyspace, query); } diff --git a/src/Cassandra.IntegrationTests/Core/SpeculativeExecutionShortTests.cs b/src/Cassandra.IntegrationTests/Core/SpeculativeExecutionShortTests.cs index 9b383e384..269f0152c 100644 --- a/src/Cassandra.IntegrationTests/Core/SpeculativeExecutionShortTests.cs +++ b/src/Cassandra.IntegrationTests/Core/SpeculativeExecutionShortTests.cs @@ -276,14 +276,14 @@ public HostDistance Distance(Host host) return HostDistance.Local; } - public IEnumerable NewQueryPlan(string keyspace, IStatement query) + public IEnumerable NewQueryPlan(string keyspace, IStatement query) { var hosts = _cluster.AllHosts().ToArray(); foreach (var addr in _addresses) { var host = hosts.Single(h => h.Address.Address.ToString() == addr); Interlocked.Increment(ref _hostYielded); - yield return host; + yield return new HostShard(host, -1); } } } diff --git a/src/Cassandra.IntegrationTests/MetadataTests/TokenMapTopologyChangeTests.cs b/src/Cassandra.IntegrationTests/MetadataTests/TokenMapTopologyChangeTests.cs index db22f287e..8df93ec31 100644 --- a/src/Cassandra.IntegrationTests/MetadataTests/TokenMapTopologyChangeTests.cs +++ b/src/Cassandra.IntegrationTests/MetadataTests/TokenMapTopologyChangeTests.cs @@ -69,8 +69,8 @@ public void TokenMap_Should_RebuildTokenMap_When_NodeIsDecommissioned() sessionNotSync.ChangeKeyspace(keyspaceName); sessionSync.ChangeKeyspace(keyspaceName); - ICollection replicasSync = null; - ICollection replicasNotSync = null; + ICollection replicasSync = null; + ICollection replicasNotSync = null; TestHelper.RetryAssert(() => { diff --git a/src/Cassandra.IntegrationTests/Policies/Tests/LoadBalancingPolicyShortTests.cs b/src/Cassandra.IntegrationTests/Policies/Tests/LoadBalancingPolicyShortTests.cs index 71d9e0e01..c0fce0374 100644 --- a/src/Cassandra.IntegrationTests/Policies/Tests/LoadBalancingPolicyShortTests.cs +++ b/src/Cassandra.IntegrationTests/Policies/Tests/LoadBalancingPolicyShortTests.cs @@ -382,7 +382,7 @@ public void Token_Aware_Uses_Keyspace_From_Statement_To_Determine_Replication(bo // Get the replicas var replicas = cluster.GetReplicas(ks, routingKey); Assert.AreEqual(metadataSync ? 2 : 1, replicas.Count); - CollectionAssert.AreEquivalent(replicas.Select(h => h.Address), coordinators); + CollectionAssert.AreEquivalent(replicas.Select(h => h.Host.Address), coordinators); } finally { diff --git a/src/Cassandra.IntegrationTests/Policies/Util/CustomLoadBalancingPolicy.cs b/src/Cassandra.IntegrationTests/Policies/Util/CustomLoadBalancingPolicy.cs index 2d98b588c..cdd4c1bbb 100644 --- a/src/Cassandra.IntegrationTests/Policies/Util/CustomLoadBalancingPolicy.cs +++ b/src/Cassandra.IntegrationTests/Policies/Util/CustomLoadBalancingPolicy.cs @@ -39,13 +39,13 @@ public HostDistance Distance(Host host) return HostDistance.Local; } - public IEnumerable NewQueryPlan(string keyspace, IStatement query) + public IEnumerable NewQueryPlan(string keyspace, IStatement query) { - var queryPlan = new List(); + var queryPlan = new List(); var allHosts = _cluster.AllHosts(); foreach (var host in _hosts) { - queryPlan.Add(allHosts.Single(h => h.Address.ToString() == host)); + queryPlan.Add(new HostShard(allHosts.Single(h => h.Address.ToString() == host), -1)); } return queryPlan; } diff --git a/src/Cassandra.IntegrationTests/ScyllaTabletTests.cs b/src/Cassandra.IntegrationTests/ScyllaTabletTests.cs index 5a70c19da..4edd35c2c 100644 --- a/src/Cassandra.IntegrationTests/ScyllaTabletTests.cs +++ b/src/Cassandra.IntegrationTests/ScyllaTabletTests.cs @@ -1,6 +1,7 @@ using System; using System.Diagnostics; using System.Linq; +using System.Collections.Generic; using Cassandra.IntegrationTests.TestBase; using Cassandra.IntegrationTests.TestClusterManagement; using NUnit.Framework; @@ -66,5 +67,65 @@ public void CorrectTabletMapTest() Assert.IsTrue(tablets.Count > 0, "Make sure tablets are present in the tablet set"); } } + + [Test] + public void CorrectShardUsingTabletsInTracingTest() + { + _realCluster = TestClusterManager.CreateNew(); + var cluster = ClusterBuilder() + .WithSocketOptions(new SocketOptions().SetReadTimeoutMillis(22000).SetConnectTimeoutMillis(60000)) + .AddContactPoint(_realCluster.InitialContactPoint) + .Build(); + var _session = cluster.Connect(); + + var rf = 1; + _session.Execute("DROP KEYSPACE IF EXISTS tablettest"); + _session.Execute($"CREATE KEYSPACE tablettest WITH replication = {{'class': 'NetworkTopologyStrategy', 'replication_factor': '{rf}'}}"); + _session.Execute("CREATE TABLE tablettest.t (pk text, ck text, v text, PRIMARY KEY (pk, ck))"); + + var populateStatement = _session.Prepare("INSERT INTO tablettest.t (pk, ck, v) VALUES (?, ?, ?)"); + //Insert 50 rows to ensure that the tablet map is populated correctly + for (int i = 0; i < 50; i++) + { + _session.Execute(populateStatement.Bind(i.ToString(), "ck" + i, "v" + i)); + } + + for (int i = 0; i < 50; i++) + { + _session.Execute(populateStatement.Bind(i.ToString(), "ck" + i, "v" + i)); + } + } + + private void VerifyCorrectShardSingleRow(ISession _session, string pk, string ck, string v) + { + var prepared = _session.Prepare("SELECT pk, ck, v FROM tablettest.t WHERE pk=? AND ck=?"); + var result = _session.Execute(prepared.Bind(pk, ck).EnableTracing()); + + var row = result.First(); + Assert.IsNotNull(row); + Assert.AreEqual(pk, row.GetValue("pk")); + Assert.AreEqual(ck, row.GetValue("ck")); + Assert.AreEqual(v, row.GetValue("v")); + + var executionInfo = result.Info; + var trace = executionInfo.QueryTrace; + bool anyLocal = false; + var shardSet = new HashSet(); + foreach (var eventItem in trace.Events) + { + Trace.TraceInformation(" {0} - {1} - [{2}] - {3}", + eventItem.SourceElapsedMicros, + eventItem.Source, + eventItem.ThreadName, + eventItem.Description); + shardSet.Add(eventItem.ThreadName); + if (eventItem.Description.Contains("querying locally")) + { + anyLocal = true; + } + } + Assert.IsTrue(shardSet.Count == 1); + Assert.IsTrue(anyLocal); + } } } diff --git a/src/Cassandra.Tests/BaseUnitTest.cs b/src/Cassandra.Tests/BaseUnitTest.cs index fed7c9b55..91d239a9a 100644 --- a/src/Cassandra.Tests/BaseUnitTest.cs +++ b/src/Cassandra.Tests/BaseUnitTest.cs @@ -79,12 +79,12 @@ public HostDistance Distance(Host host) return _distance; } - public IEnumerable NewQueryPlan(string keyspace, IStatement query) + public IEnumerable NewQueryPlan(string keyspace, IStatement query) { return new[] { - new Host(new IPEndPoint(101L, 9042), ReconnectionPolicy), - new Host(new IPEndPoint(102L, 9042), ReconnectionPolicy) + new HostShard(new Host(new IPEndPoint(101L, 9042), ReconnectionPolicy), -1), + new HostShard(new Host(new IPEndPoint(102L, 9042), ReconnectionPolicy), -1) }; } } diff --git a/src/Cassandra.Tests/ClusterTests.cs b/src/Cassandra.Tests/ClusterTests.cs index ef4e33f94..0af8d9bd4 100644 --- a/src/Cassandra.Tests/ClusterTests.cs +++ b/src/Cassandra.Tests/ClusterTests.cs @@ -291,9 +291,12 @@ public HostDistance Distance(Host host) return _distances[host.Address.Address.ToString()]; } - public IEnumerable NewQueryPlan(string keyspace, IStatement query) + public IEnumerable NewQueryPlan(string keyspace, IStatement query) { - return _cluster.AllHosts().OrderBy(h => Guid.NewGuid().GetHashCode()).Take(_distances.Count); + return _cluster.AllHosts() + .OrderBy(h => Guid.NewGuid().GetHashCode()) + .Take(_distances.Count) + .Select(h => new HostShard(h, -1)); } } } diff --git a/src/Cassandra.Tests/ExecutionProfiles/ClusterTests.cs b/src/Cassandra.Tests/ExecutionProfiles/ClusterTests.cs index 1bd61c669..5baf9d292 100644 --- a/src/Cassandra.Tests/ExecutionProfiles/ClusterTests.cs +++ b/src/Cassandra.Tests/ExecutionProfiles/ClusterTests.cs @@ -362,9 +362,9 @@ public HostDistance Distance(Host host) return HostDistance.Local; } - public IEnumerable NewQueryPlan(string keyspace, IStatement query) + public IEnumerable NewQueryPlan(string keyspace, IStatement query) { - return _cluster.AllHosts(); + return _cluster.AllHosts().Select(h => new HostShard(h, -1)); } } } diff --git a/src/Cassandra.Tests/ExecutionProfiles/RequestHandlerTests.cs b/src/Cassandra.Tests/ExecutionProfiles/RequestHandlerTests.cs index 14096c86b..252eb597e 100644 --- a/src/Cassandra.Tests/ExecutionProfiles/RequestHandlerTests.cs +++ b/src/Cassandra.Tests/ExecutionProfiles/RequestHandlerTests.cs @@ -379,13 +379,13 @@ public HostDistance Distance(Host host) return HostDistance.Local; } - public IEnumerable NewQueryPlan(string keyspace, IStatement query) + public IEnumerable NewQueryPlan(string keyspace, IStatement query) { Interlocked.Increment(ref Count); - return new List + return new List { - new Host(new IPEndPoint(IPAddress.Parse("127.0.0.1"), 9042), contactPoint: null), - new Host(new IPEndPoint(IPAddress.Parse("127.0.0.2"), 9042), contactPoint: null) // 2 hosts for speculative execution policy + new HostShard(new Host(new IPEndPoint(IPAddress.Parse("127.0.0.1"), 9042), contactPoint: null), -1), + new HostShard(new Host(new IPEndPoint(IPAddress.Parse("127.0.0.2"), 9042), contactPoint: null), -1) // 2 hosts for speculative execution policy }; } } diff --git a/src/Cassandra.Tests/Policies/DefaultLoadBalancingTests.cs b/src/Cassandra.Tests/Policies/DefaultLoadBalancingTests.cs index 42c739837..b6ce89c9f 100644 --- a/src/Cassandra.Tests/Policies/DefaultLoadBalancingTests.cs +++ b/src/Cassandra.Tests/Policies/DefaultLoadBalancingTests.cs @@ -32,7 +32,7 @@ public void Should_Yield_Preferred_Host_First() var hosts = lbp.NewQueryPlan(null, statement); CollectionAssert.AreEqual( new[] { "201.0.0.0:9042", "101.0.0.0:9042", "102.0.0.0:9042" }, - hosts.Select(h => h.Address.ToString())); + hosts.Select(h => h.Host.Address.ToString())); } [Test] @@ -45,7 +45,7 @@ public void Should_Yield_Child_Hosts_When_No_Preferred_Host_Defined() var hosts = lbp.NewQueryPlan(null, statement); CollectionAssert.AreEqual( new[] { "101.0.0.0:9042", "102.0.0.0:9042" }, - hosts.Select(h => h.Address.ToString())); + hosts.Select(h => h.Host.Address.ToString())); } [Test] diff --git a/src/Cassandra.Tests/PoliciesUnitTests.cs b/src/Cassandra.Tests/PoliciesUnitTests.cs index 928e6e574..0f1b03b07 100644 --- a/src/Cassandra.Tests/PoliciesUnitTests.cs +++ b/src/Cassandra.Tests/PoliciesUnitTests.cs @@ -52,11 +52,11 @@ public void RoundRobinIsCyclicTest() foreach (var host in hostList) { //Check that each host appears only once - Assert.AreEqual(1, firstRound.Where(h => h.Equals(host)).Count()); + Assert.AreEqual(1, firstRound.Where(h => h.Host.Equals(host)).Count()); } //test the same but in the following times - var followingRounds = new List(); + var followingRounds = new List(); for (var i = 0; i < 10; i++) { followingRounds.AddRange(policy.NewQueryPlan(null, new SimpleStatement())); @@ -98,7 +98,7 @@ public void RoundRobinIsCyclicTestInParallel() { //Slow down to try to execute it at the same time Thread.Sleep(50); - resultingHosts.Add(h); + resultingHosts.Add(h.Host); } Assert.AreEqual(hostLength, resultingHosts.Count); Assert.AreEqual(hostLength, resultingHosts.Distinct().Count()); @@ -143,11 +143,11 @@ public void DCAwareRoundRobinPolicyNeverHitsRemoteWhenSet() var firstRound = balancedHosts.ToList(); //Returns only local hosts - Assert.AreEqual(hostLength - 2, firstRound.Count(h => h.Datacenter == "local")); - Assert.AreEqual(0, firstRound.Count(h => h.Datacenter != "local")); + Assert.AreEqual(hostLength - 2, firstRound.Count(h => h.Host.Datacenter == "local")); + Assert.AreEqual(0, firstRound.Count(h => h.Host.Datacenter != "local")); //following rounds: test it multiple times - var followingRounds = new List(); + var followingRounds = new List(); for (var i = 0; i < 10; i++) { followingRounds.AddRange(policy.NewQueryPlan(null, new SimpleStatement()).ToList()); @@ -155,7 +155,7 @@ public void DCAwareRoundRobinPolicyNeverHitsRemoteWhenSet() Assert.AreEqual(10 * (hostLength - 2), followingRounds.Count); //Check that there aren't remote nodes. - Assert.AreEqual(0, followingRounds.Count(h => h.Datacenter != "local")); + Assert.AreEqual(0, followingRounds.Count(h => h.Host.Datacenter != "local")); } [Test] @@ -194,11 +194,11 @@ public void DCAwareRoundRobinYieldsRemoteNodesAtTheEnd() var h = hosts[i]; if (i < localHostsLength) { - Assert.AreEqual(localDc, h.Datacenter); + Assert.AreEqual(localDc, h.Host.Datacenter); } else { - Assert.AreNotEqual(localDc, h.Datacenter); + Assert.AreNotEqual(localDc, h.Host.Datacenter); } } }; @@ -278,11 +278,11 @@ public void DCAwareRoundRobinPolicyTestInParallel() .Where(g => g.Count() > 1) .Select(y => y.Key) .Count()); - firstHosts.Add(hosts[0]); + firstHosts.Add(hosts[0].Host); //Add to the general list foreach (var h in hosts) { - allHosts.Add(h); + allHosts.Add(h.Host); } }; @@ -373,7 +373,7 @@ public void DCAwareRoundRobinPolicyWithNodesChanging() policy.Initialize(clusterMock.Object); var hostYielded = new ConcurrentBag>(); - Action action = () => hostYielded.Add(policy.NewQueryPlan(null, null).ToList()); + Action action = () => hostYielded.Add(policy.NewQueryPlan(null, null).ToList().Select(h => h.Host)); //Invoke without nodes changing TestHelper.ParallelInvoke(action, 100); @@ -496,7 +496,7 @@ public void TokenAwarePolicyReturnsLocalReplicasFirst() //The host at with address == k || address == k + n var address = TestHelper.GetLastAddressByte(h); return address == i || address == i + n; - }).ToList(); + }).Select(h => new HostShard(h, -1)).ToList(); }) .Verifiable(); @@ -509,7 +509,7 @@ public void TokenAwarePolicyReturnsLocalReplicasFirst() //5 local hosts + 2 remote hosts Assert.AreEqual(7, hosts.Count); //local replica first - Assert.AreEqual(1, TestHelper.GetLastAddressByte(hosts[0])); + Assert.AreEqual(1, TestHelper.GetLastAddressByte(hosts[0].Host)); clusterMock.Verify(); //key for host :::2 and :::5 @@ -518,11 +518,11 @@ public void TokenAwarePolicyReturnsLocalReplicasFirst() hosts = policy.NewQueryPlan(null, new SimpleStatement().SetRoutingKey(k)).ToList(); Assert.AreEqual(7, hosts.Count); //local replicas first - CollectionAssert.AreEquivalent(new[] { 2, 5 }, hosts.Take(2).Select(TestHelper.GetLastAddressByte)); + CollectionAssert.AreEquivalent(new[] { 2, 5 }, hosts.Take(2).Select(h => TestHelper.GetLastAddressByte(h.Host))); //next should be local nodes - Assert.AreEqual("dc1", hosts[2].Datacenter); - Assert.AreEqual("dc1", hosts[3].Datacenter); - Assert.AreEqual("dc1", hosts[4].Datacenter); + Assert.AreEqual("dc1", hosts[2].Host.Datacenter); + Assert.AreEqual("dc1", hosts[3].Host.Datacenter); + Assert.AreEqual("dc1", hosts[4].Host.Datacenter); clusterMock.Verify(); } @@ -557,7 +557,7 @@ public void TokenAwarePolicyRoundRobinsOnLocalReplicas() //The host at with address == k and the next one var address = TestHelper.GetLastAddressByte(h); return address == i || address == i + 1; - }).ToList(); + }).Select(h => new HostShard(h, -1)).ToList(); }) .Verifiable(); @@ -571,7 +571,7 @@ public void TokenAwarePolicyRoundRobinsOnLocalReplicas() Action action = () => { var h = policy.NewQueryPlan(null, new SimpleStatement().SetRoutingKey(k)).First(); - firstHosts.Add(h); + firstHosts.Add(h.Host); }; TestHelper.ParallelInvoke(action, times); Assert.AreEqual(times, firstHosts.Count); @@ -605,16 +605,16 @@ public void TokenAwarePolicyReturnsChildHostsIfNoRoutingKey() //No routing key var hosts = policy.NewQueryPlan(null, new SimpleStatement()).ToList(); //2 localhosts - Assert.AreEqual(2, hosts.Count(h => policy.Distance(h) == HostDistance.Local)); - Assert.AreEqual("dc1", hosts[0].Datacenter); - Assert.AreEqual("dc1", hosts[1].Datacenter); + Assert.AreEqual(2, hosts.Count(h => policy.Distance(h.Host) == HostDistance.Local)); + Assert.AreEqual("dc1", hosts[0].Host.Datacenter); + Assert.AreEqual("dc1", hosts[1].Host.Datacenter); clusterMock.Verify(); //No statement hosts = policy.NewQueryPlan(null, null).ToList(); //2 localhosts - Assert.AreEqual(2, hosts.Count(h => policy.Distance(h) == HostDistance.Local)); - Assert.AreEqual("dc1", hosts[0].Datacenter); - Assert.AreEqual("dc1", hosts[1].Datacenter); + Assert.AreEqual(2, hosts.Count(h => policy.Distance(h.Host) == HostDistance.Local)); + Assert.AreEqual("dc1", hosts[0].Host.Datacenter); + Assert.AreEqual("dc1", hosts[1].Host.Datacenter); clusterMock.Verify(); } diff --git a/src/Cassandra.Tests/RequestExecutionTests.cs b/src/Cassandra.Tests/RequestExecutionTests.cs index 34e754499..64d83f582 100644 --- a/src/Cassandra.Tests/RequestExecutionTests.cs +++ b/src/Cassandra.Tests/RequestExecutionTests.cs @@ -130,7 +130,7 @@ public void Should_SendRequest_When_AConnectionIsObtained(bool currentHostRetry) host, HostDistance.Local); Mock.Get(mockParent) - .Setup(m => m.GetConnectionToValidHostAsync(validHost, It.IsAny>())) + .Setup(m => m.GetConnectionToValidHostAsync(validHost, It.IsAny>(), It.IsAny())) .ReturnsAsync(connection); Mock.Get(mockParent) .Setup(m => m.GetNextValidHost(It.IsAny>())) @@ -208,12 +208,12 @@ public void Should_RetryRequestToSameHost_When_ConnectionFailsAndRetryDecisionIs // Setup connection failure Mock.Get(mockParent) - .Setup(m => m.GetConnectionToValidHostAsync(validHost, It.IsAny>())) + .Setup(m => m.GetConnectionToValidHostAsync(validHost, It.IsAny>(), It.IsAny())) .ThrowsAsync(exception); // Setup successful second connection on the same host retry (different method call - ValidateHostAndGetConnectionAsync) Mock.Get(mockParent) - .Setup(m => m.ValidateHostAndGetConnectionAsync(validHost.Host, It.IsAny>())) + .Setup(m => m.ValidateHostAndGetConnectionAsync(new HostShard(validHost.Host, -1), It.IsAny>())) .ReturnsAsync(connection); Mock.Get(mockParent) @@ -234,10 +234,10 @@ public void Should_RetryRequestToSameHost_When_ConnectionFailsAndRetryDecisionIs // Validate that there were 2 connection attempts (1 with each method) Mock.Get(mockParent).Verify( - m => m.GetConnectionToValidHostAsync(validHost, It.IsAny>()), + m => m.GetConnectionToValidHostAsync(validHost, It.IsAny>(), It.IsAny()), Times.Once); Mock.Get(mockParent).Verify( - m => m.ValidateHostAndGetConnectionAsync(validHost.Host, It.IsAny>()), + m => m.ValidateHostAndGetConnectionAsync(new HostShard(validHost.Host, -1), It.IsAny>()), Times.Once); } } diff --git a/src/Cassandra.Tests/RequestHandlerMockTests.cs b/src/Cassandra.Tests/RequestHandlerMockTests.cs index f5a4b32c4..03c5e5190 100644 --- a/src/Cassandra.Tests/RequestHandlerMockTests.cs +++ b/src/Cassandra.Tests/RequestHandlerMockTests.cs @@ -70,8 +70,8 @@ public void Should_ThrowNoHostAvailableException_When_QueryPlanMoveNextReturnsFa var sessionMock = GetMockInternalSession(); var lbpMock = Mock.Of(); Mock.Get(sessionMock).SetupGet(m => m.Cluster.Configuration).Returns(RequestHandlerMockTests.GetConfig(lbpMock)); - var enumerable = Mock.Of>(); - var enumerator = Mock.Of>(); + var enumerable = Mock.Of>(); + var enumerator = Mock.Of>(); Mock.Get(enumerator).Setup(m => m.MoveNext()).Returns(false); Mock.Get(enumerable).Setup(m => m.GetEnumerator()).Returns(enumerator); @@ -93,11 +93,11 @@ public void Should_ThrowNoHostAvailableException_When_QueryPlanMoveNextReturnsTr var sessionMock = GetMockInternalSession(); var lbpMock = Mock.Of(); Mock.Get(sessionMock).SetupGet(m => m.Cluster.Configuration).Returns(RequestHandlerMockTests.GetConfig(lbpMock)); - var enumerable = Mock.Of>(); - var enumerator = Mock.Of>(); + var enumerable = Mock.Of>(); + var enumerator = Mock.Of>(); Mock.Get(enumerator).Setup(m => m.MoveNext()).Returns(true); - Mock.Get(enumerator).SetupGet(m => m.Current).Returns((Host)null); + Mock.Get(enumerator).SetupGet(m => m.Current).Returns((HostShard)null); Mock.Get(enumerable).Setup(m => m.GetEnumerator()).Returns(enumerator); Mock.Get(lbpMock) .Setup(m => m.NewQueryPlan(It.IsAny(), It.IsAny())) @@ -117,23 +117,23 @@ public void Should_ReturnHost_When_QueryPlanMoveNextReturnsTrueAndCurrentReturns var sessionMock = GetMockInternalSession(); var lbpMock = Mock.Of(); Mock.Get(sessionMock).SetupGet(m => m.Cluster.Configuration).Returns(RequestHandlerMockTests.GetConfig(lbpMock)); - var enumerable = Mock.Of>(); - var enumerator = Mock.Of>(); - var host = new Host(new IPEndPoint(IPAddress.Parse("127.0.0.1"), 9047), contactPoint: null); + var enumerable = Mock.Of>(); + var enumerator = Mock.Of>(); + var hostShard = new HostShard(new Host(new IPEndPoint(IPAddress.Parse("127.0.0.1"), 9047), contactPoint: null), -1); Mock.Get(enumerator).Setup(m => m.MoveNext()).Returns(true); - Mock.Get(enumerator).SetupGet(m => m.Current).Returns(host); + Mock.Get(enumerator).SetupGet(m => m.Current).Returns(hostShard); Mock.Get(enumerable).Setup(m => m.GetEnumerator()).Returns(enumerator); Mock.Get(lbpMock) .Setup(m => m.NewQueryPlan(It.IsAny(), It.IsAny())) .Returns(enumerable); - Mock.Get(lbpMock).Setup(m => m.Distance(host)).Returns(HostDistance.Local); + Mock.Get(lbpMock).Setup(m => m.Distance(hostShard.Host)).Returns(HostDistance.Local); var triedHosts = new Dictionary(); var requestTrackingInfoAndObserver = RequestHandler.CreateRequestObserver(sessionMock, null).GetAwaiter().GetResult(); var sut = new RequestHandler(sessionMock, new SerializerManager(ProtocolVersion.V4).GetCurrentSerializer(), requestTrackingInfoAndObserver.Item1, requestTrackingInfoAndObserver.Item2); var validHost = sut.GetNextValidHost(triedHosts); Assert.NotNull(validHost); - Assert.AreEqual(host, validHost.Host); + Assert.AreEqual(hostShard.Host, validHost.Host); } } } diff --git a/src/Cassandra.Tests/Requests/PrepareHandlerTests.cs b/src/Cassandra.Tests/Requests/PrepareHandlerTests.cs index f2968fbb2..b480845b1 100644 --- a/src/Cassandra.Tests/Requests/PrepareHandlerTests.cs +++ b/src/Cassandra.Tests/Requests/PrepareHandlerTests.cs @@ -77,10 +77,10 @@ public async Task Should_NotSendRequestToSecondHost_When_SecondHostDoesntHavePoo }; var queryPlan = mockResult.Session.InternalCluster .GetResolvedEndpoints() - .Select(x => new Host(x.Value.First().GetHostIpEndPointWithFallback(), contactPoint: null)) + .Select(x => new HostShard(new Host(x.Value.First().GetHostIpEndPointWithFallback(), contactPoint: null), -1)) .ToList(); - await mockResult.Session.GetOrCreateConnectionPool(queryPlan[0], HostDistance.Local).Warmup().ConfigureAwait(false); - await mockResult.Session.GetOrCreateConnectionPool(queryPlan[2], HostDistance.Local).Warmup().ConfigureAwait(false); + await mockResult.Session.GetOrCreateConnectionPool(queryPlan[0].Host, HostDistance.Local).Warmup().ConfigureAwait(false); + await mockResult.Session.GetOrCreateConnectionPool(queryPlan[2].Host, HostDistance.Local).Warmup().ConfigureAwait(false); var pools = mockResult.Session.GetPools().ToList(); Assert.AreEqual(2, pools.Count); var distanceCount = Interlocked.Read(ref lbpCluster.DistanceCount); @@ -99,8 +99,8 @@ await mockResult.PrepareHandler.Prepare( Assert.AreEqual(distanceCount + 1, Interlocked.Read(ref lbpCluster.DistanceCount), 1); Assert.AreEqual(Interlocked.Read(ref lbpCluster.NewQueryPlanCount), 0); Assert.AreEqual(2, mockResult.ConnectionFactory.CreatedConnections.Count); - Assert.LessOrEqual(1, mockResult.ConnectionFactory.CreatedConnections[queryPlan[0].Address].Count); - Assert.LessOrEqual(1, mockResult.ConnectionFactory.CreatedConnections[queryPlan[2].Address].Count); + Assert.LessOrEqual(1, mockResult.ConnectionFactory.CreatedConnections[queryPlan[0].Host.Address].Count); + Assert.LessOrEqual(1, mockResult.ConnectionFactory.CreatedConnections[queryPlan[2].Host.Address].Count); // Assert that each pool contains only one connection that was called send var poolConnections = pools.Select(p => p.Value.ConnectionsSnapshot.Intersect(results.Select(r => r.Connection))).ToList(); Assert.AreEqual(2, poolConnections.Count); @@ -148,11 +148,11 @@ public async Task Should_NotSendRequestToSecondHost_When_SecondHostPoolDoesNotHa }; var queryPlan = mockResult.Session.InternalCluster .GetResolvedEndpoints() - .Select(x => new Host(x.Value.First().GetHostIpEndPointWithFallback(), contactPoint: null)) + .Select(x => new HostShard(new Host(x.Value.First().GetHostIpEndPointWithFallback(), contactPoint: null), -1)) .ToList(); - await mockResult.Session.GetOrCreateConnectionPool(queryPlan[0], HostDistance.Local).Warmup().ConfigureAwait(false); - mockResult.Session.GetOrCreateConnectionPool(queryPlan[1], HostDistance.Local); - await mockResult.Session.GetOrCreateConnectionPool(queryPlan[2], HostDistance.Local).Warmup().ConfigureAwait(false); + await mockResult.Session.GetOrCreateConnectionPool(queryPlan[0].Host, HostDistance.Local).Warmup().ConfigureAwait(false); + mockResult.Session.GetOrCreateConnectionPool(queryPlan[1].Host, HostDistance.Local); + await mockResult.Session.GetOrCreateConnectionPool(queryPlan[2].Host, HostDistance.Local).Warmup().ConfigureAwait(false); var pools = mockResult.Session.GetPools().ToList(); Assert.AreEqual(3, pools.Count); var distanceCount = Interlocked.Read(ref lbpCluster.DistanceCount); @@ -171,8 +171,8 @@ await mockResult.PrepareHandler.Prepare( Assert.AreEqual(distanceCount + 1, Interlocked.Read(ref lbpCluster.DistanceCount), 1); Assert.AreEqual(Interlocked.Read(ref lbpCluster.NewQueryPlanCount), 0); Assert.AreEqual(2, mockResult.ConnectionFactory.CreatedConnections.Count); - Assert.LessOrEqual(1, mockResult.ConnectionFactory.CreatedConnections[queryPlan[0].Address].Count); - Assert.LessOrEqual(1, mockResult.ConnectionFactory.CreatedConnections[queryPlan[2].Address].Count); + Assert.LessOrEqual(1, mockResult.ConnectionFactory.CreatedConnections[queryPlan[0].Host.Address].Count); + Assert.LessOrEqual(1, mockResult.ConnectionFactory.CreatedConnections[queryPlan[2].Host.Address].Count); // Assert that each pool that contains connections contains only one connection that was called send var poolConnections = pools.Select(p => p.Value.ConnectionsSnapshot.Intersect(results.Select(r => r.Connection))).Where(p => p.Any()).ToList(); Assert.AreEqual(2, poolConnections.Count); @@ -220,11 +220,11 @@ public async Task Should_SendRequestToAllHosts_When_AllHostsHaveConnections() }; var queryPlan = mockResult.Session.InternalCluster .GetResolvedEndpoints() - .Select(x => new Host(x.Value.First().GetHostIpEndPointWithFallback(), contactPoint: null)) + .Select(x => new HostShard(new Host(x.Value.First().GetHostIpEndPointWithFallback(), contactPoint: null), -1)) .ToList(); - await mockResult.Session.GetOrCreateConnectionPool(queryPlan[0], HostDistance.Local).Warmup().ConfigureAwait(false); - await mockResult.Session.GetOrCreateConnectionPool(queryPlan[1], HostDistance.Local).Warmup().ConfigureAwait(false); - await mockResult.Session.GetOrCreateConnectionPool(queryPlan[2], HostDistance.Local).Warmup().ConfigureAwait(false); + await mockResult.Session.GetOrCreateConnectionPool(queryPlan[0].Host, HostDistance.Local).Warmup().ConfigureAwait(false); + await mockResult.Session.GetOrCreateConnectionPool(queryPlan[1].Host, HostDistance.Local).Warmup().ConfigureAwait(false); + await mockResult.Session.GetOrCreateConnectionPool(queryPlan[2].Host, HostDistance.Local).Warmup().ConfigureAwait(false); var pools = mockResult.Session.GetPools().ToList(); Assert.AreEqual(3, pools.Count); @@ -244,9 +244,9 @@ await mockResult.PrepareHandler.Prepare( Assert.AreEqual(distanceCount + 1, Interlocked.Read(ref lbpCluster.DistanceCount), 1); Assert.AreEqual(Interlocked.Read(ref lbpCluster.NewQueryPlanCount), 0); Assert.AreEqual(3, mockResult.ConnectionFactory.CreatedConnections.Count); - Assert.LessOrEqual(1, mockResult.ConnectionFactory.CreatedConnections[queryPlan[0].Address].Count); - Assert.LessOrEqual(1, mockResult.ConnectionFactory.CreatedConnections[queryPlan[1].Address].Count); - Assert.LessOrEqual(1, mockResult.ConnectionFactory.CreatedConnections[queryPlan[2].Address].Count); + Assert.LessOrEqual(1, mockResult.ConnectionFactory.CreatedConnections[queryPlan[0].Host.Address].Count); + Assert.LessOrEqual(1, mockResult.ConnectionFactory.CreatedConnections[queryPlan[1].Host.Address].Count); + Assert.LessOrEqual(1, mockResult.ConnectionFactory.CreatedConnections[queryPlan[2].Host.Address].Count); // Assert that each pool contains only one connection that was called send var poolConnections = pools.Select(p => p.Value.ConnectionsSnapshot.Intersect(results.Select(r => r.Connection))).ToList(); Assert.AreEqual(3, poolConnections.Count); @@ -294,10 +294,10 @@ public async Task Should_SendRequestToAllHosts_When_AllHostsHaveConnectionsButFi }; var queryPlan = mockResult.Session.InternalCluster .GetResolvedEndpoints() - .Select(x => new Host(x.Value.First().GetHostIpEndPointWithFallback(), contactPoint: null)) + .Select(x => new HostShard(new Host(x.Value.First().GetHostIpEndPointWithFallback(), contactPoint: null), -1)) .ToList(); - await mockResult.Session.GetOrCreateConnectionPool(queryPlan[1], HostDistance.Local).Warmup().ConfigureAwait(false); - await mockResult.Session.GetOrCreateConnectionPool(queryPlan[2], HostDistance.Local).Warmup().ConfigureAwait(false); + await mockResult.Session.GetOrCreateConnectionPool(queryPlan[1].Host, HostDistance.Local).Warmup().ConfigureAwait(false); + await mockResult.Session.GetOrCreateConnectionPool(queryPlan[2].Host, HostDistance.Local).Warmup().ConfigureAwait(false); var pools = mockResult.Session.GetPools().ToList(); Assert.AreEqual(2, pools.Count); var distanceCount = Interlocked.Read(ref lbpCluster.DistanceCount); @@ -316,9 +316,9 @@ await mockResult.PrepareHandler.Prepare( Assert.AreEqual(distanceCount + 1, Interlocked.Read(ref lbpCluster.DistanceCount), 1); Assert.AreEqual(Interlocked.Read(ref lbpCluster.NewQueryPlanCount), 0); Assert.AreEqual(3, mockResult.ConnectionFactory.CreatedConnections.Count); - Assert.LessOrEqual(1, mockResult.ConnectionFactory.CreatedConnections[queryPlan[0].Address].Count); - Assert.LessOrEqual(1, mockResult.ConnectionFactory.CreatedConnections[queryPlan[1].Address].Count); - Assert.LessOrEqual(1, mockResult.ConnectionFactory.CreatedConnections[queryPlan[2].Address].Count); + Assert.LessOrEqual(1, mockResult.ConnectionFactory.CreatedConnections[queryPlan[0].Host.Address].Count); + Assert.LessOrEqual(1, mockResult.ConnectionFactory.CreatedConnections[queryPlan[1].Host.Address].Count); + Assert.LessOrEqual(1, mockResult.ConnectionFactory.CreatedConnections[queryPlan[2].Host.Address].Count); // Assert that each pool contains only one connection that was called send var poolConnections = pools.Select(p => p.Value.ConnectionsSnapshot.Intersect(results.Select(r => r.Connection))).ToList(); Assert.AreEqual(3, poolConnections.Count); @@ -366,10 +366,10 @@ public async Task Should_SendRequestToAllHosts_When_AllHostsHaveConnectionsButFi }; var queryPlan = mockResult.Session.InternalCluster .GetResolvedEndpoints() - .Select(x => new Host(x.Value.First().GetHostIpEndPointWithFallback(), contactPoint: null)) + .Select(x => new HostShard(new Host(x.Value.First().GetHostIpEndPointWithFallback(), contactPoint: null), -1)) .ToList(); - await mockResult.Session.GetOrCreateConnectionPool(queryPlan[1], HostDistance.Local).Warmup().ConfigureAwait(false); - await mockResult.Session.GetOrCreateConnectionPool(queryPlan[2], HostDistance.Local).Warmup().ConfigureAwait(false); + await mockResult.Session.GetOrCreateConnectionPool(queryPlan[1].Host, HostDistance.Local).Warmup().ConfigureAwait(false); + await mockResult.Session.GetOrCreateConnectionPool(queryPlan[2].Host, HostDistance.Local).Warmup().ConfigureAwait(false); var pools = mockResult.Session.GetPools().ToList(); Assert.AreEqual(2, pools.Count); var distanceCount = Interlocked.Read(ref lbpCluster.DistanceCount); @@ -388,9 +388,9 @@ await mockResult.PrepareHandler.Prepare( Assert.AreEqual(distanceCount + 1, Interlocked.Read(ref lbpCluster.DistanceCount), 1); Assert.AreEqual(Interlocked.Read(ref lbpCluster.NewQueryPlanCount), 0); Assert.AreEqual(3, mockResult.ConnectionFactory.CreatedConnections.Count); - Assert.LessOrEqual(1, mockResult.ConnectionFactory.CreatedConnections[queryPlan[0].Address].Count); - Assert.LessOrEqual(1, mockResult.ConnectionFactory.CreatedConnections[queryPlan[1].Address].Count); - Assert.LessOrEqual(1, mockResult.ConnectionFactory.CreatedConnections[queryPlan[2].Address].Count); + Assert.LessOrEqual(1, mockResult.ConnectionFactory.CreatedConnections[queryPlan[0].Host.Address].Count); + Assert.LessOrEqual(1, mockResult.ConnectionFactory.CreatedConnections[queryPlan[1].Host.Address].Count); + Assert.LessOrEqual(1, mockResult.ConnectionFactory.CreatedConnections[queryPlan[2].Host.Address].Count); // Assert that each pool contains only one connection that was called send var poolConnections = pools.Select(p => p.Value.ConnectionsSnapshot.Intersect(results.Select(r => r.Connection))).ToList(); Assert.AreEqual(3, poolConnections.Count); @@ -439,10 +439,10 @@ public async Task Should_SendRequestToFirstHostOnly_When_PrepareOnAllHostsIsFals }; var queryPlan = mockResult.Session.InternalCluster .GetResolvedEndpoints() - .Select(x => new Host(x.Value.First().GetHostIpEndPointWithFallback(), contactPoint: null)) + .Select(x => new HostShard(new Host(x.Value.First().GetHostIpEndPointWithFallback(), contactPoint: null), -1)) .ToList(); - await mockResult.Session.GetOrCreateConnectionPool(queryPlan[1], HostDistance.Local).Warmup().ConfigureAwait(false); - await mockResult.Session.GetOrCreateConnectionPool(queryPlan[2], HostDistance.Local).Warmup().ConfigureAwait(false); + await mockResult.Session.GetOrCreateConnectionPool(queryPlan[1].Host, HostDistance.Local).Warmup().ConfigureAwait(false); + await mockResult.Session.GetOrCreateConnectionPool(queryPlan[2].Host, HostDistance.Local).Warmup().ConfigureAwait(false); var pools = mockResult.Session.GetPools().ToList(); Assert.AreEqual(2, pools.Count); var distanceCount = Interlocked.Read(ref lbpCluster.DistanceCount); @@ -461,14 +461,14 @@ await mockResult.PrepareHandler.Prepare( Assert.AreEqual(distanceCount + 1, Interlocked.Read(ref lbpCluster.DistanceCount), 1); Assert.AreEqual(Interlocked.Read(ref lbpCluster.NewQueryPlanCount), 0); Assert.AreEqual(3, mockResult.ConnectionFactory.CreatedConnections.Count); - Assert.LessOrEqual(1, mockResult.ConnectionFactory.CreatedConnections[queryPlan[0].Address].Count); - Assert.LessOrEqual(1, mockResult.ConnectionFactory.CreatedConnections[queryPlan[1].Address].Count); - Assert.LessOrEqual(1, mockResult.ConnectionFactory.CreatedConnections[queryPlan[2].Address].Count); + Assert.LessOrEqual(1, mockResult.ConnectionFactory.CreatedConnections[queryPlan[0].Host.Address].Count); + Assert.LessOrEqual(1, mockResult.ConnectionFactory.CreatedConnections[queryPlan[1].Host.Address].Count); + Assert.LessOrEqual(1, mockResult.ConnectionFactory.CreatedConnections[queryPlan[2].Host.Address].Count); // Assert that pool of first host contains only one connection that was called send var poolConnections = pools .Select(p => p.Value.ConnectionsSnapshot.Intersect(results.Select(r => r.Connection))) - .Where(p => mockResult.ConnectionFactory.CreatedConnections[queryPlan[0].Address].Contains(p.SingleOrDefault())) + .Where(p => mockResult.ConnectionFactory.CreatedConnections[queryPlan[0].Host.Address].Contains(p.SingleOrDefault())) .ToList(); Assert.AreEqual(1, poolConnections.Count); foreach (var pool in poolConnections) @@ -568,7 +568,7 @@ public HostDistance Distance(Host host) return HostDistance.Local; } - public IEnumerable NewQueryPlan(string keyspace, IStatement query) + public IEnumerable NewQueryPlan(string keyspace, IStatement query) { Interlocked.Increment(ref NewQueryPlanCount); throw new NotImplementedException(); diff --git a/src/Cassandra.Tests/StatementTests.cs b/src/Cassandra.Tests/StatementTests.cs index 1591e9e0c..07a125de4 100644 --- a/src/Cassandra.Tests/StatementTests.cs +++ b/src/Cassandra.Tests/StatementTests.cs @@ -331,7 +331,7 @@ public void BatchStatement_Should_UseRoutingKeyAndKeyspaceOfFirstStatement_When_ var lbp = new TokenAwarePolicy(new ClusterTests.FakeLoadBalancingPolicy()); var clusterMock = Mock.Of(); Mock.Get(clusterMock).Setup(c => c.GetReplicas(It.IsAny(), It.IsAny())) - .Returns(new List()); + .Returns(new List()); Mock.Get(clusterMock).Setup(c => c.AllHosts()) .Returns(new List()); lbp.Initialize(clusterMock); diff --git a/src/Cassandra.Tests/TestHelper.cs b/src/Cassandra.Tests/TestHelper.cs index 9de99fedd..d83a7f824 100644 --- a/src/Cassandra.Tests/TestHelper.cs +++ b/src/Cassandra.Tests/TestHelper.cs @@ -739,9 +739,9 @@ public HostDistance Distance(Host host) return _childPolicy.Distance(host); } - public IEnumerable NewQueryPlan(string keyspace, IStatement query) + public IEnumerable NewQueryPlan(string keyspace, IStatement query) { - return !_useRoundRobin ? _hosts : _childPolicy.NewQueryPlan(keyspace, query); + return !_useRoundRobin ? _hosts.Select(h => new HostShard(h, -1)) : _childPolicy.NewQueryPlan(keyspace, query); } } @@ -773,9 +773,9 @@ public HostDistance Distance(Host host) return _distanceHandler(_cluster, host); } - public IEnumerable NewQueryPlan(string keyspace, IStatement query) + public IEnumerable NewQueryPlan(string keyspace, IStatement query) { - return _queryPlanHandler(_cluster, keyspace, query); + return _queryPlanHandler(_cluster, keyspace, query).Select(h => new HostShard(h, -1)); } } } diff --git a/src/Cassandra.Tests/TokenTests.cs b/src/Cassandra.Tests/TokenTests.cs index 92ad7869a..5f0d58ee1 100644 --- a/src/Cassandra.Tests/TokenTests.cs +++ b/src/Cassandra.Tests/TokenTests.cs @@ -111,34 +111,34 @@ public void TokenMap_SimpleStrategy_With_Keyspace_Test() //the primary replica and the next var replicas = tokenMap.GetReplicas("ks1", new M3PToken(0)); - Assert.AreEqual("0,1", String.Join(",", replicas.Select(TestHelper.GetLastAddressByte))); + Assert.AreEqual("0,1", String.Join(",", replicas.Select(h => TestHelper.GetLastAddressByte(h.Host)))); replicas = tokenMap.GetReplicas("ks1", new M3PToken(-100)); - Assert.AreEqual("0,1", String.Join(",", replicas.Select(TestHelper.GetLastAddressByte))); + Assert.AreEqual("0,1", String.Join(",", replicas.Select(h => TestHelper.GetLastAddressByte(h.Host)))); //Greater than the greatest token replicas = tokenMap.GetReplicas("ks1", new M3PToken(500000)); - Assert.AreEqual("0,1", String.Join(",", replicas.Select(TestHelper.GetLastAddressByte))); + Assert.AreEqual("0,1", String.Join(",", replicas.Select(h => TestHelper.GetLastAddressByte(h.Host)))); //The next replica should be the first replicas = tokenMap.GetReplicas("ks1", new M3PToken(20)); - Assert.AreEqual("2,0", String.Join(",", replicas.Select(TestHelper.GetLastAddressByte))); + Assert.AreEqual("2,0", String.Join(",", replicas.Select(h => TestHelper.GetLastAddressByte(h.Host)))); //The closest replica and the next replicas = tokenMap.GetReplicas("ks1", new M3PToken(19)); - Assert.AreEqual("2,0", String.Join(",", replicas.Select(TestHelper.GetLastAddressByte))); + Assert.AreEqual("2,0", String.Join(",", replicas.Select(h => TestHelper.GetLastAddressByte(h.Host)))); //Even if the replication factor is greater than the ring, it should return only ring size replicas = tokenMap.GetReplicas("ks2", new M3PToken(5)); - Assert.AreEqual("1,2,0", String.Join(",", replicas.Select(TestHelper.GetLastAddressByte))); + Assert.AreEqual("1,2,0", String.Join(",", replicas.Select(h => TestHelper.GetLastAddressByte(h.Host)))); //The primary replica only as the keyspace was not found replicas = tokenMap.GetReplicas(null, new M3PToken(0)); - Assert.AreEqual("0", String.Join(",", replicas.Select(TestHelper.GetLastAddressByte))); + Assert.AreEqual("0", String.Join(",", replicas.Select(h => TestHelper.GetLastAddressByte(h.Host)))); replicas = tokenMap.GetReplicas(null, new M3PToken(10)); - Assert.AreEqual("1", String.Join(",", replicas.Select(TestHelper.GetLastAddressByte))); + Assert.AreEqual("1", String.Join(",", replicas.Select(h => TestHelper.GetLastAddressByte(h.Host)))); replicas = tokenMap.GetReplicas("ks_does_not_exist", new M3PToken(20)); - Assert.AreEqual("2", String.Join(",", replicas.Select(TestHelper.GetLastAddressByte))); + Assert.AreEqual("2", String.Join(",", replicas.Select(h => TestHelper.GetLastAddressByte(h.Host)))); replicas = tokenMap.GetReplicas(null, new M3PToken(19)); - Assert.AreEqual("2", String.Join(",", replicas.Select(TestHelper.GetLastAddressByte))); + Assert.AreEqual("2", String.Join(",", replicas.Select(h => TestHelper.GetLastAddressByte(h.Host)))); } [Test] @@ -160,16 +160,16 @@ public void TokenMap_SimpleStrategy_With_Hosts_Without_Tokens() //the primary replica and the next var replicas = tokenMap.GetReplicas("ks1", new M3PToken(0)); //The node without tokens should not be considered - CollectionAssert.AreEqual(new byte[] { 0, 2 }, replicas.Select(TestHelper.GetLastAddressByte)); + CollectionAssert.AreEqual(new byte[] { 0, 2 }, replicas.Select(h => TestHelper.GetLastAddressByte(h.Host))); replicas = tokenMap.GetReplicas("ks1", new M3PToken(-100)); - CollectionAssert.AreEqual(new byte[] { 0, 2 }, replicas.Select(TestHelper.GetLastAddressByte)); + CollectionAssert.AreEqual(new byte[] { 0, 2 }, replicas.Select(h => TestHelper.GetLastAddressByte(h.Host))); //Greater than the greatest token replicas = tokenMap.GetReplicas("ks1", new M3PToken(500000)); - CollectionAssert.AreEqual(new byte[] { 0, 2 }, replicas.Select(TestHelper.GetLastAddressByte)); + CollectionAssert.AreEqual(new byte[] { 0, 2 }, replicas.Select(h => TestHelper.GetLastAddressByte(h.Host))); //The next replica should be the first replicas = tokenMap.GetReplicas("ks1", new M3PToken(20)); - CollectionAssert.AreEqual(new byte[] { 2, 0 }, replicas.Select(TestHelper.GetLastAddressByte)); + CollectionAssert.AreEqual(new byte[] { 2, 0 }, replicas.Select(h => TestHelper.GetLastAddressByte(h.Host))); } [Test] @@ -200,28 +200,28 @@ public void TokenMap_NetworkTopologyStrategy_With_Keyspace_Test() //KS1 //the primary replica and the next var replicas = tokenMap.GetReplicas("ks1", new M3PToken(0)); - Assert.AreEqual("0,100,1,101", String.Join(",", replicas.Select(TestHelper.GetLastAddressByte))); + Assert.AreEqual("0,100,1,101", String.Join(",", replicas.Select(h => TestHelper.GetLastAddressByte(h.Host)))); //The next replica should be the first replicas = tokenMap.GetReplicas("ks1", new M3PToken(200)); - Assert.AreEqual("2,102,0,100", String.Join(",", replicas.Select(TestHelper.GetLastAddressByte))); + Assert.AreEqual("2,102,0,100", String.Join(",", replicas.Select(h => TestHelper.GetLastAddressByte(h.Host)))); //The closest replica and the next replicas = tokenMap.GetReplicas("ks1", new M3PToken(190)); - Assert.AreEqual("2,102,0,100", String.Join(",", replicas.Select(TestHelper.GetLastAddressByte))); + Assert.AreEqual("2,102,0,100", String.Join(",", replicas.Select(h => TestHelper.GetLastAddressByte(h.Host)))); //KS2 //Simple strategy: 3 tokens no matter which dc replicas = tokenMap.GetReplicas("ks2", new M3PToken(5000)); - Assert.AreEqual("0,100,1", String.Join(",", replicas.Select(TestHelper.GetLastAddressByte))); + Assert.AreEqual("0,100,1", String.Join(",", replicas.Select(h => TestHelper.GetLastAddressByte(h.Host)))); //KS3 replicas = tokenMap.GetReplicas("ks3", new M3PToken(0)); - Assert.AreEqual("0,100,1,2", String.Join(",", replicas.Select(TestHelper.GetLastAddressByte))); + Assert.AreEqual("0,100,1,2", String.Join(",", replicas.Select(h => TestHelper.GetLastAddressByte(h.Host)))); replicas = tokenMap.GetReplicas("ks3", new M3PToken(201)); - Assert.AreEqual("102,0,1,2", String.Join(",", replicas.Select(TestHelper.GetLastAddressByte))); + Assert.AreEqual("102,0,1,2", String.Join(",", replicas.Select(h => TestHelper.GetLastAddressByte(h.Host)))); //KS4 replicas = tokenMap.GetReplicas("ks4", new M3PToken(0)); - Assert.AreEqual("0,1,2", String.Join(",", replicas.Select(TestHelper.GetLastAddressByte))); + Assert.AreEqual("0,1,2", String.Join(",", replicas.Select(h => TestHelper.GetLastAddressByte(h.Host)))); } [Test] @@ -240,7 +240,7 @@ public void TokenMap_Build_NetworkTopology_Adjacent_Ranges_Test() var replicas = map.GetReplicas("ks1", new M3PToken(0)); Assert.AreEqual(2, replicas.Count); //It should contain the first host and the second, even though the first host contains adjacent - CollectionAssert.AreEqual(new byte[] { 1, 2 }, replicas.Select(TestHelper.GetLastAddressByte)); + CollectionAssert.AreEqual(new byte[] { 1, 2 }, replicas.Select(h => TestHelper.GetLastAddressByte(h.Host))); } [Test] @@ -291,7 +291,7 @@ public void TokenMap_Build_NetworkTopology_Multiple_Racks_Test() }); var map = TokenMap.Build("Murmur3Partitioner", hosts, new[] { ks }); var replicas = map.GetReplicas("ks1", new M3PToken(0)); - CollectionAssert.AreEqual(new byte[] { 0, 1, 2, 3, 4 }, replicas.Select(TestHelper.GetLastAddressByte)); + CollectionAssert.AreEqual(new byte[] { 0, 1, 2, 3, 4 }, replicas.Select(h => TestHelper.GetLastAddressByte(h.Host))); } [Test] @@ -326,7 +326,7 @@ public void TokenMap_Build_NetworkTopology_Multiple_Racks_Skipping_Hosts_Test() foreach (var v in values) { var replicas = map.GetReplicas("ks1", new M3PToken(v.Item1)); - CollectionAssert.AreEqual(v.Item2, replicas.Select(TestHelper.GetLastAddressByte)); + CollectionAssert.AreEqual(v.Item2, replicas.Select(h => TestHelper.GetLastAddressByte(h.Host))); } } @@ -368,7 +368,7 @@ public void TokenMap_Build_SimpleStrategy_Adjacent_Ranges_Test() var replicas = map.GetReplicas("ks1", new M3PToken(0)); Assert.AreEqual(2, replicas.Count); //It should contain the first host and the second, even though the first host contains adjacent - CollectionAssert.AreEqual(new byte[] { 1, 2 }, replicas.Select(TestHelper.GetLastAddressByte)); + CollectionAssert.AreEqual(new byte[] { 1, 2 }, replicas.Select(h => TestHelper.GetLastAddressByte(h.Host))); } [Test] diff --git a/src/Cassandra/Cluster.cs b/src/Cassandra/Cluster.cs index 07bd4e08f..0c588a7bb 100644 --- a/src/Cassandra/Cluster.cs +++ b/src/Cassandra/Cluster.cs @@ -490,13 +490,13 @@ public Host GetHost(IPEndPoint address) } /// - public ICollection GetReplicas(byte[] partitionKey) + public ICollection GetReplicas(byte[] partitionKey) { return Metadata.GetReplicas(partitionKey); } /// - public ICollection GetReplicas(string keyspace, byte[] partitionKey) + public ICollection GetReplicas(string keyspace, byte[] partitionKey) { return Metadata.GetReplicas(keyspace, partitionKey); } diff --git a/src/Cassandra/Connections/Control/ControlConnection.cs b/src/Cassandra/Connections/Control/ControlConnection.cs index 1083706d3..80451e87e 100644 --- a/src/Cassandra/Connections/Control/ControlConnection.cs +++ b/src/Cassandra/Connections/Control/ControlConnection.cs @@ -216,16 +216,16 @@ private IEnumerable>> DefaultLbpHostsEnume bool refreshContactPoints, bool refreshEndpoints) { - foreach (var host in _config.DefaultRequestOptions.LoadBalancingPolicy.NewQueryPlan(null, null)) + foreach (var hostShard in _config.DefaultRequestOptions.LoadBalancingPolicy.NewQueryPlan(null, null)) { - if (attemptedHosts.TryAdd(host, null)) + if (attemptedHosts.TryAdd(hostShard.Host, null)) { - if (!IsHostValid(host, isInitializing)) + if (!IsHostValid(hostShard.Host, isInitializing)) { continue; } - yield return ResolveHostContactPointOrConnectionEndpointAsync(attemptedContactPoints, host, refreshContactPoints, refreshEndpoints); + yield return ResolveHostContactPointOrConnectionEndpointAsync(attemptedContactPoints, hostShard.Host, refreshContactPoints, refreshEndpoints); } } } diff --git a/src/Cassandra/Connections/HostConnectionPool.cs b/src/Cassandra/Connections/HostConnectionPool.cs index 6a8e53546..ce9abea8f 100644 --- a/src/Cassandra/Connections/HostConnectionPool.cs +++ b/src/Cassandra/Connections/HostConnectionPool.cs @@ -135,7 +135,7 @@ public HostConnectionPool(Host host, Configuration config, ISerializerManager se } /// - public async Task BorrowConnectionAsync(RoutingKey routingKey = null) + public async Task BorrowConnectionAsync(RoutingKey routingKey = null, int shardID = -1) { var connections = await EnsureCreate().ConfigureAwait(false); if (connections.Length == 0) @@ -143,11 +143,11 @@ public async Task BorrowConnectionAsync(RoutingKey routingKey = nul throw new DriverInternalError("No connection could be borrowed"); } - return BorrowLeastBusyConnection(connections, routingKey); + return BorrowLeastBusyConnection(connections, routingKey, shardID); } /// - public IConnection BorrowExistingConnection(RoutingKey routingKey) + public IConnection BorrowExistingConnection(RoutingKey routingKey, int shardID = -1) { var connections = GetExistingConnections(); if (connections.Length == 0) @@ -155,18 +155,20 @@ public IConnection BorrowExistingConnection(RoutingKey routingKey) return null; } - return BorrowLeastBusyConnection(connections, routingKey); + return BorrowLeastBusyConnection(connections, routingKey, shardID); } - private IConnection BorrowLeastBusyConnection(ShardedList connections, RoutingKey routingKey = null) + private IConnection BorrowLeastBusyConnection(ShardedList connections, RoutingKey routingKey = null, int shardID = -1) { - int shardID = -1; if (shardingInfo != null) { if (routingKey != null) { IToken token = _tokenFactory.Hash(routingKey.RawRoutingKey); - shardID = shardingInfo.ShardID(token); + if (shardID == -1) + { + shardID = shardingInfo.ShardID(token); + } } else { @@ -178,10 +180,9 @@ private IConnection BorrowLeastBusyConnection(ShardedList connectio if (shardID != -1) { var minInFlight = int.MaxValue; - var localInFlight = 0; foreach (var connection in _connections.GetItemsForShard(shardID)) { - localInFlight = connection.InFlight; + int localInFlight = connection.InFlight; if (localInFlight < minInFlight) { minInFlight = localInFlight; @@ -189,7 +190,7 @@ private IConnection BorrowLeastBusyConnection(ShardedList connectio } } } - var inFlight = 0; + int inFlight; if (c != null) { // if we have a connection for the shard, use it if it is not too busy @@ -952,31 +953,31 @@ public void MarkAsDownAndScheduleReconnection() /// public Task GetConnectionFromHostAsync( - IDictionary triedHosts, Func getKeyspaceFunc, RoutingKey routingKey) + IDictionary triedHosts, Func getKeyspaceFunc, RoutingKey routingKey, int shardID = -1) { - return GetConnectionFromHostAsync(triedHosts, getKeyspaceFunc, true, routingKey); + return GetConnectionFromHostAsync(triedHosts, getKeyspaceFunc, true, routingKey, shardID); } /// public Task GetExistingConnectionFromHostAsync( - IDictionary triedHosts, Func getKeyspaceFunc, RoutingKey routingKey) + IDictionary triedHosts, Func getKeyspaceFunc, RoutingKey routingKey, int shardID = -1) { - return GetConnectionFromHostAsync(triedHosts, getKeyspaceFunc, false, routingKey); + return GetConnectionFromHostAsync(triedHosts, getKeyspaceFunc, false, routingKey, shardID); } private async Task GetConnectionFromHostAsync( - IDictionary triedHosts, Func getKeyspaceFunc, bool createIfNeeded, RoutingKey routingKey) + IDictionary triedHosts, Func getKeyspaceFunc, bool createIfNeeded, RoutingKey routingKey, int shardID = -1) { IConnection c = null; try { if (createIfNeeded) { - c = await BorrowConnectionAsync(routingKey).ConfigureAwait(false); + c = await BorrowConnectionAsync(routingKey, shardID).ConfigureAwait(false); } else { - c = BorrowExistingConnection(routingKey); + c = BorrowExistingConnection(routingKey, shardID); } } catch (UnsupportedProtocolVersionException ex) diff --git a/src/Cassandra/Connections/IHostConnectionPool.cs b/src/Cassandra/Connections/IHostConnectionPool.cs index 21f4c9a36..0b56799a5 100644 --- a/src/Cassandra/Connections/IHostConnectionPool.cs +++ b/src/Cassandra/Connections/IHostConnectionPool.cs @@ -55,7 +55,7 @@ internal interface IHostConnectionPool : IDisposable /// /// /// - Task BorrowConnectionAsync(RoutingKey routingKey = null); + Task BorrowConnectionAsync(RoutingKey routingKey = null, int shardID = -1); /// /// Gets an open connection from the host pool. It does NOT create one if necessary (for that use . @@ -63,7 +63,7 @@ internal interface IHostConnectionPool : IDisposable /// /// /// Not connected. - IConnection BorrowExistingConnection(RoutingKey routingKey); + IConnection BorrowExistingConnection(RoutingKey routingKey, int shardID = -1); void SetDistance(HostDistance distance); @@ -92,9 +92,9 @@ internal interface IHostConnectionPool : IDisposable void MarkAsDownAndScheduleReconnection(); Task GetConnectionFromHostAsync( - IDictionary triedHosts, Func getKeyspaceFunc, RoutingKey routingKey); + IDictionary triedHosts, Func getKeyspaceFunc, RoutingKey routingKey, int shardID); Task GetExistingConnectionFromHostAsync( - IDictionary triedHosts, Func getKeyspaceFunc, RoutingKey routingKey); + IDictionary triedHosts, Func getKeyspaceFunc, RoutingKey routingKey, int shardID); } } diff --git a/src/Cassandra/HostShard.cs b/src/Cassandra/HostShard.cs new file mode 100644 index 000000000..2a77198ae --- /dev/null +++ b/src/Cassandra/HostShard.cs @@ -0,0 +1,36 @@ +namespace Cassandra +{ + public class HostShard + { + public Host Host { get; } + public int Shard { get; } + + public HostShard(Host host, int shard) + { + Host = host; + Shard = shard; + } + + public override string ToString() => $"HostShard {{host={Host.Address}, shard={Shard}}}"; + + public override bool Equals(object obj) + { + if (obj is HostShard other) + { + return Host.Equals(other.Host) && Shard == other.Shard; + } + return false; + } + + public override int GetHashCode() + { + unchecked + { + int hash = 17; + hash = hash * 23 + (Host != null ? Host.GetHashCode() : 0); + hash = hash * 23 + Shard.GetHashCode(); + return hash; + } + } + } +} \ No newline at end of file diff --git a/src/Cassandra/ICluster.cs b/src/Cassandra/ICluster.cs index 0a52e407e..872ffcb2a 100644 --- a/src/Cassandra/ICluster.cs +++ b/src/Cassandra/ICluster.cs @@ -112,7 +112,7 @@ public interface ICluster : IDisposable /// /// Byte array representing the partition key /// - ICollection GetReplicas(byte[] partitionKey); + ICollection GetReplicas(byte[] partitionKey); /// /// Gets a collection of replicas for a given partitionKey on a given keyspace @@ -120,7 +120,7 @@ public interface ICluster : IDisposable /// Byte array representing the partition key /// Byte array representing the partition key /// - ICollection GetReplicas(string keyspace, byte[] partitionKey); + ICollection GetReplicas(string keyspace, byte[] partitionKey); /// /// Shutdown this cluster instance. This closes all connections from all the diff --git a/src/Cassandra/Metadata.cs b/src/Cassandra/Metadata.cs index 723df1cc6..ed152f514 100644 --- a/src/Cassandra/Metadata.cs +++ b/src/Cassandra/Metadata.cs @@ -276,17 +276,17 @@ internal async Task UpdateTokenMapForKeyspace(string name) /// /// Get the replicas for a given partition key and keyspace /// - public ICollection GetReplicas(string keyspaceName, byte[] partitionKey) + public ICollection GetReplicas(string keyspaceName, byte[] partitionKey) { if (_tokenMap == null) { Metadata.Logger.Warning("Metadata.GetReplicas was called but there was no token map."); - return new Host[0]; + return new HostShard[0]; } return _tokenMap.GetReplicas(keyspaceName, _tokenMap.Factory.Hash(partitionKey)); } - public ICollection GetReplicas(byte[] partitionKey) + public ICollection GetReplicas(byte[] partitionKey) { return GetReplicas(null, partitionKey); } diff --git a/src/Cassandra/Policies/DCAwareRoundRobinPolicy.cs b/src/Cassandra/Policies/DCAwareRoundRobinPolicy.cs index d59d38698..86134880f 100644 --- a/src/Cassandra/Policies/DCAwareRoundRobinPolicy.cs +++ b/src/Cassandra/Policies/DCAwareRoundRobinPolicy.cs @@ -202,7 +202,7 @@ public HostDistance Distance(Host host) /// the query for which to build the plan. /// a new query plan, i.e. an iterator indicating which host to try /// first for querying, which one to use as failover, etc... - public IEnumerable NewQueryPlan(string keyspace, IStatement query) + public IEnumerable NewQueryPlan(string keyspace, IStatement query) { var startIndex = Interlocked.Increment(ref _index); //Simplified overflow protection @@ -216,7 +216,7 @@ public IEnumerable NewQueryPlan(string keyspace, IStatement query) //Round-robin through local nodes for (var i = 0; i < localHosts.Count; i++) { - yield return localHosts[(startIndex + i) % localHosts.Count]; + yield return new HostShard(localHosts[(startIndex + i) % localHosts.Count], -1); } if (_usedHostsPerRemoteDc == 0) @@ -234,7 +234,7 @@ public IEnumerable NewQueryPlan(string keyspace, IStatement query) continue; } dcHosts[dc] = hostYieldedByDc + 1; - yield return h; + yield return new HostShard(h, -1); } } diff --git a/src/Cassandra/Policies/DefaultLoadBalancingPolicy.cs b/src/Cassandra/Policies/DefaultLoadBalancingPolicy.cs index c37fb3f0a..c02bebac2 100644 --- a/src/Cassandra/Policies/DefaultLoadBalancingPolicy.cs +++ b/src/Cassandra/Policies/DefaultLoadBalancingPolicy.cs @@ -87,7 +87,7 @@ public void Initialize(ICluster cluster) /// /// Returns the hosts to used for a query. /// - public IEnumerable NewQueryPlan(string keyspace, IStatement statement) + public IEnumerable NewQueryPlan(string keyspace, IStatement statement) { if (statement is TargettedSimpleStatement targetedStatement && targetedStatement.PreferredHost != null) { @@ -98,9 +98,9 @@ public IEnumerable NewQueryPlan(string keyspace, IStatement statement) return ChildPolicy.NewQueryPlan(keyspace, statement); } - private IEnumerable YieldPreferred(string keyspace, TargettedSimpleStatement statement) + private IEnumerable YieldPreferred(string keyspace, TargettedSimpleStatement statement) { - yield return statement.PreferredHost; + yield return new HostShard(statement.PreferredHost, -1); foreach (var h in ChildPolicy.NewQueryPlan(keyspace, statement)) { yield return h; diff --git a/src/Cassandra/Policies/ILoadBalancingPolicy.cs b/src/Cassandra/Policies/ILoadBalancingPolicy.cs index 00a48f12f..b11ca9ebe 100644 --- a/src/Cassandra/Policies/ILoadBalancingPolicy.cs +++ b/src/Cassandra/Policies/ILoadBalancingPolicy.cs @@ -63,6 +63,6 @@ public interface ILoadBalancingPolicy /// An iterator of Host. The query is tried against the hosts returned /// by this iterator in order, until the query has been sent successfully to one /// of the host. - IEnumerable NewQueryPlan(string keyspace, IStatement query); + IEnumerable NewQueryPlan(string keyspace, IStatement query); } } diff --git a/src/Cassandra/Policies/RetryLoadBalancingPolicy.cs b/src/Cassandra/Policies/RetryLoadBalancingPolicy.cs index c7eaa8ee1..6636fa22e 100644 --- a/src/Cassandra/Policies/RetryLoadBalancingPolicy.cs +++ b/src/Cassandra/Policies/RetryLoadBalancingPolicy.cs @@ -44,13 +44,13 @@ public HostDistance Distance(Host host) return LoadBalancingPolicy.Distance(host); } - public IEnumerable NewQueryPlan(string keyspace, IStatement query) + public IEnumerable NewQueryPlan(string keyspace, IStatement query) { IReconnectionSchedule schedule = ReconnectionPolicy.NewSchedule(); while (true) { - IEnumerable childQueryPlan = LoadBalancingPolicy.NewQueryPlan(keyspace, query); - foreach (Host host in childQueryPlan) + IEnumerable childQueryPlan = LoadBalancingPolicy.NewQueryPlan(keyspace, query); + foreach (HostShard host in childQueryPlan) yield return host; if (ReconnectionEvent != null) diff --git a/src/Cassandra/Policies/RoundRobinPolicy.cs b/src/Cassandra/Policies/RoundRobinPolicy.cs index 85e2199fc..a428a5061 100644 --- a/src/Cassandra/Policies/RoundRobinPolicy.cs +++ b/src/Cassandra/Policies/RoundRobinPolicy.cs @@ -22,7 +22,7 @@ namespace Cassandra { /// - /// A Round-robin load balancing policy. + /// A Round-robin load balancing policy. /// This policy queries nodes in a /// round-robin fashion. For a given query, if an host fail, the next one /// (following the round-robin order) is tried, until all hosts have been tried. @@ -66,7 +66,7 @@ public HostDistance Distance(Host host) /// the query for which to build the plan. /// a new query plan, i.e. an iterator indicating which host to try /// first for querying, which one to use as failover, etc... - public IEnumerable NewQueryPlan(string keyspace, IStatement query) + public IEnumerable NewQueryPlan(string keyspace, IStatement query) { //shallow copy the all hosts var hosts = (from h in _cluster.AllHosts() select h).ToArray(); @@ -80,7 +80,7 @@ public IEnumerable NewQueryPlan(string keyspace, IStatement query) for (var i = 0; i < hosts.Length; i++) { - yield return hosts[(startIndex + i) % hosts.Length]; + yield return new HostShard(hosts[(startIndex + i) % hosts.Length], -1); } } } diff --git a/src/Cassandra/Policies/TokenAwarePolicy.cs b/src/Cassandra/Policies/TokenAwarePolicy.cs index ddd1ccfe5..4a39b379a 100644 --- a/src/Cassandra/Policies/TokenAwarePolicy.cs +++ b/src/Cassandra/Policies/TokenAwarePolicy.cs @@ -62,7 +62,7 @@ public void Initialize(ICluster cluster) /// Return the HostDistance for the provided host. /// /// the host of which to return the distance of. - /// + /// /// the HostDistance to host as returned by the wrapped /// policy. public HostDistance Distance(Host host) @@ -80,10 +80,10 @@ public HostDistance Distance(Host host) /// Keyspace on which the query is going to be executed /// the query for which to build the plan. /// the new query plan. - public IEnumerable NewQueryPlan(string loggedKeyspace, IStatement query) + public IEnumerable NewQueryPlan(string loggedKeyspace, IStatement query) { var routingKey = query?.RoutingKey; - IEnumerable childIterator; + IEnumerable childIterator; if (routingKey == null) { childIterator = ChildPolicy.NewQueryPlan(loggedKeyspace, query); @@ -95,12 +95,22 @@ public IEnumerable NewQueryPlan(string loggedKeyspace, IStatement query) } var keyspace = query.Keyspace ?? loggedKeyspace; - var replicas = _cluster.GetReplicas(keyspace, routingKey.RawRoutingKey); + var table = query.TableName; + IEnumerable replicas = null; + if (table != null) + { + var token = _cluster.Metadata.GetTokenFactory().Hash(routingKey.RawRoutingKey); + replicas = _cluster.Metadata.TabletMap.GetReplicas(keyspace, table, token); + } + if (replicas == null || !replicas.Any()) + { + replicas = _cluster.GetReplicas(keyspace, routingKey.RawRoutingKey); + } - var localReplicaSet = new HashSet(); - var localReplicaList = new List(replicas.Count); + var localReplicaSet = new HashSet(); + var localReplicaList = new List(replicas.Count()); // We can't do it lazily as we need to balance the load between local replicas - foreach (var localReplica in replicas.Where(h => ChildPolicy.Distance(h) == HostDistance.Local)) + foreach (var localReplica in replicas.Where(hs => ChildPolicy.Distance(hs.Host) == HostDistance.Local)) { localReplicaSet.Add(localReplica); localReplicaList.Add(localReplica); diff --git a/src/Cassandra/Requests/IPrepareHandler.cs b/src/Cassandra/Requests/IPrepareHandler.cs index 4b450de11..9e8bd8e9d 100644 --- a/src/Cassandra/Requests/IPrepareHandler.cs +++ b/src/Cassandra/Requests/IPrepareHandler.cs @@ -24,6 +24,6 @@ namespace Cassandra.Requests internal interface IPrepareHandler { Task Prepare( - InternalPrepareRequest request, IInternalSession session, IEnumerator queryPlan); + InternalPrepareRequest request, IInternalSession session, IEnumerator queryPlan); } } \ No newline at end of file diff --git a/src/Cassandra/Requests/IRequestHandler.cs b/src/Cassandra/Requests/IRequestHandler.cs index 7921ca8e2..0ecf6c7b1 100644 --- a/src/Cassandra/Requests/IRequestHandler.cs +++ b/src/Cassandra/Requests/IRequestHandler.cs @@ -74,9 +74,10 @@ internal interface IRequestHandler /// /// Host to which a connection will be obtained. /// Hosts for which there were attempts to connect and send the request. + /// Shard to use. /// When the keyspace is not valid /// If every host from the query plan is unavailable. - Task GetConnectionToValidHostAsync(ValidHost validHost, IDictionary triedHosts); + Task GetConnectionToValidHostAsync(ValidHost validHost, IDictionary triedHosts, int shardID = -1); /// /// Obtain a connection to the provided . @@ -86,7 +87,7 @@ internal interface IRequestHandler /// Hosts for which there were attempts to connect and send the request. /// When the keyspace is not valid /// If every host from the query plan is unavailable. - Task ValidateHostAndGetConnectionAsync(Host host, Dictionary triedHosts); + Task ValidateHostAndGetConnectionAsync(HostShard host, Dictionary triedHosts); Task SendAsync(); diff --git a/src/Cassandra/Requests/PrepareHandler.cs b/src/Cassandra/Requests/PrepareHandler.cs index 400c32ea0..8e7c3beff 100644 --- a/src/Cassandra/Requests/PrepareHandler.cs +++ b/src/Cassandra/Requests/PrepareHandler.cs @@ -44,7 +44,7 @@ public PrepareHandler(ISerializerManager serializerManager, IInternalCluster clu } public async Task Prepare( - InternalPrepareRequest request, IInternalSession session, IEnumerator queryPlan) + InternalPrepareRequest request, IInternalSession session, IEnumerator queryPlan) { var infoAndObs = await CreateRequestObserverAsync(session, request).ConfigureAwait(false); var observer = infoAndObs.Item2; @@ -77,7 +77,7 @@ public static async Task> CreateRequ } private async Task SendRequestToOneNode( - IInternalSession session, IEnumerator queryPlan, InternalPrepareRequest request, IRequestObserver observer, SessionRequestInfo info) + IInternalSession session, IEnumerator queryPlan, InternalPrepareRequest request, IRequestObserver observer, SessionRequestInfo info) { var triedHosts = new Dictionary(); @@ -132,7 +132,7 @@ private static bool CanBeRetried(Exception ex) } private async Task> GetNextConnection( - IInternalSession session, IEnumerator queryPlan, Dictionary triedHosts) + IInternalSession session, IEnumerator queryPlan, Dictionary triedHosts) { Host host; while ((host = GetNextHost(queryPlan, out HostDistance distance)) != null) @@ -146,12 +146,12 @@ private async Task> GetNextConnection( throw new NoHostAvailableException(triedHosts); } - private Host GetNextHost(IEnumerator queryPlan, out HostDistance distance) + private Host GetNextHost(IEnumerator queryPlan, out HostDistance distance) { distance = HostDistance.Ignored; while (queryPlan.MoveNext()) { - var host = queryPlan.Current; + var host = queryPlan.Current.Host; if (!host.IsUp) { continue; diff --git a/src/Cassandra/Requests/ReprepareHandler.cs b/src/Cassandra/Requests/ReprepareHandler.cs index 3e963d67d..d2b1b718d 100644 --- a/src/Cassandra/Requests/ReprepareHandler.cs +++ b/src/Cassandra/Requests/ReprepareHandler.cs @@ -83,7 +83,7 @@ private static async Task GetConnectionFromHostInternalAsync( { try { - return await pool.GetExistingConnectionFromHostAsync(triedHosts, () => ps.Keyspace, ps.RoutingKey).ConfigureAwait(false); + return await pool.GetExistingConnectionFromHostAsync(triedHosts, () => ps.Keyspace, ps.RoutingKey, -1).ConfigureAwait(false); } catch (SocketException) { diff --git a/src/Cassandra/Requests/RequestExecution.cs b/src/Cassandra/Requests/RequestExecution.cs index 4308d58d0..1655fd2ce 100644 --- a/src/Cassandra/Requests/RequestExecution.cs +++ b/src/Cassandra/Requests/RequestExecution.cs @@ -89,7 +89,7 @@ private async Task SendToCurrentHostAsync() try { // host needs to be re-validated using the load balancing policy - _connection = await _parent.ValidateHostAndGetConnectionAsync(host, _triedHosts).ConfigureAwait(false); + _connection = await _parent.ValidateHostAndGetConnectionAsync(new HostShard(host, -1), _triedHosts).ConfigureAwait(false); if (_connection != null) { await SendAsync(_request, host, HandleResponseAsync).ConfigureAwait(false); diff --git a/src/Cassandra/Requests/RequestHandler.cs b/src/Cassandra/Requests/RequestHandler.cs index 4407a0709..4768ca9c1 100644 --- a/src/Cassandra/Requests/RequestHandler.cs +++ b/src/Cassandra/Requests/RequestHandler.cs @@ -44,7 +44,7 @@ internal class RequestHandler : IRequestHandler private readonly IInternalSession _session; private readonly IRequestResultHandler _requestResultHandler; private long _state; - private readonly IEnumerator _queryPlan; + private readonly IEnumerator _queryPlan; private readonly object _queryPlanLock = new object(); private readonly ICollection _running = new CopyOnWriteList(); private ISpeculativeExecutionPlan _executionPlan; @@ -107,14 +107,14 @@ public RequestHandler(IInternalSession session, ISerializer serializer, SessionR /// In the special case when a Host is provided at Statement level, it will return a query plan with a single /// host. /// - private static IEnumerable GetQueryPlan(ISession session, IStatement statement, ILoadBalancingPolicy lbp) + private static IEnumerable GetQueryPlan(ISession session, IStatement statement, ILoadBalancingPolicy lbp) { // Single host iteration var host = (statement as Statement)?.Host; return host == null ? lbp.NewQueryPlan(session.Keyspace, statement) - : Enumerable.Repeat(host, 1); + : Enumerable.Repeat(new HostShard(host, -1), 1); } /// @@ -317,7 +317,7 @@ public bool HasCompleted() return Interlocked.Read(ref _state) == RequestHandler.StateCompleted; } - private Host GetNextHost() + private HostShard GetNextHost() { // Lock to handle multiple threads from multiple executions to get a new host lock (_queryPlanLock) @@ -333,11 +333,11 @@ private Host GetNextHost() /// public ValidHost GetNextValidHost(Dictionary triedHosts) { - Host host; - while ((host = GetNextHost()) != null && !_session.IsDisposed) + HostShard hostShard; + while ((hostShard = GetNextHost()) != null && !_session.IsDisposed) { - triedHosts[host.Address] = null; - if (!TryValidateHost(host, out var validHost)) + triedHosts[hostShard.Host.Address] = null; + if (!TryValidateHost(hostShard.Host, out var validHost)) { continue; } @@ -366,11 +366,11 @@ private bool TryValidateHost(Host host, out ValidHost validHost) /// public async Task GetNextConnectionAsync(Dictionary triedHosts) { - Host host; + HostShard hostShard; // While there is an available host - while ((host = GetNextHost()) != null) + while ((hostShard = GetNextHost()) != null) { - var c = await ValidateHostAndGetConnectionAsync(host, triedHosts).ConfigureAwait(false); + var c = await ValidateHostAndGetConnectionAsync(hostShard, triedHosts).ConfigureAwait(false); if (c == null) { continue; @@ -382,27 +382,27 @@ public async Task GetNextConnectionAsync(Dictionary - public async Task ValidateHostAndGetConnectionAsync(Host host, Dictionary triedHosts) + public async Task ValidateHostAndGetConnectionAsync(HostShard hostShard, Dictionary triedHosts) { if (_session.IsDisposed) { throw new NoHostAvailableException(triedHosts); } - triedHosts[host.Address] = null; - if (!TryValidateHost(host, out var validHost)) + triedHosts[hostShard.Host.Address] = null; + if (!TryValidateHost(hostShard.Host, out var validHost)) { return null; } - var c = await GetConnectionToValidHostAsync(validHost, triedHosts).ConfigureAwait(false); + var c = await GetConnectionToValidHostAsync(validHost, triedHosts, hostShard.Shard).ConfigureAwait(false); return c; } /// - public Task GetConnectionToValidHostAsync(ValidHost validHost, IDictionary triedHosts) + public Task GetConnectionToValidHostAsync(ValidHost validHost, IDictionary triedHosts, int shardID = -1) { - return RequestHandler.GetConnectionFromHostAsync(validHost.Host, validHost.Distance, _session, triedHosts, Statement != null ? Statement.RoutingKey : null); + return RequestHandler.GetConnectionFromHostAsync(validHost.Host, validHost.Distance, _session, triedHosts, Statement != null ? Statement.RoutingKey : null, shardID); } /// @@ -414,21 +414,22 @@ public Task GetConnectionToValidHostAsync(ValidHost validHost, IDic /// Session from where a connection will be obtained (or created). /// Hosts for which there were attempts to connect and send the request. /// Routing key to use for the next host. + /// Shard to use. /// When the keyspace is not valid internal static Task GetConnectionFromHostAsync( - Host host, HostDistance distance, IInternalSession session, IDictionary triedHosts, RoutingKey routingKey = null) + Host host, HostDistance distance, IInternalSession session, IDictionary triedHosts, RoutingKey routingKey = null, int shardID = -1) { - return GetConnectionFromHostInternalAsync(host, distance, session, triedHosts, true, routingKey); + return GetConnectionFromHostInternalAsync(host, distance, session, triedHosts, true, routingKey, shardID); } private static async Task GetConnectionFromHostInternalAsync( - Host host, HostDistance distance, IInternalSession session, IDictionary triedHosts, bool retry, RoutingKey routingKey) + Host host, HostDistance distance, IInternalSession session, IDictionary triedHosts, bool retry, RoutingKey routingKey, int shardID = -1) { var hostPool = session.GetOrCreateConnectionPool(host, distance); try { - return await hostPool.GetConnectionFromHostAsync(triedHosts, () => session.Keyspace, routingKey).ConfigureAwait(false); + return await hostPool.GetConnectionFromHostAsync(triedHosts, () => session.Keyspace, routingKey, shardID).ConfigureAwait(false); } catch (SocketException) { diff --git a/src/Cassandra/TabletMap.cs b/src/Cassandra/TabletMap.cs index 544a09181..a3fed6faa 100644 --- a/src/Cassandra/TabletMap.cs +++ b/src/Cassandra/TabletMap.cs @@ -11,7 +11,7 @@ namespace Cassandra internal class TabletMap { private static readonly Logger Logger = new Logger(typeof(TabletMap)); - private static readonly IReadOnlyList EMPTY_LIST = new List(); + private static readonly IReadOnlyList EMPTY_LIST = new List(); private readonly ConcurrentDictionary _mapping; private readonly Metadata _metadata; @@ -103,31 +103,36 @@ public static TabletMap EmptyMap(Metadata metadata, Hosts hosts) public IDictionary GetMapping() => _mapping; - public IReadOnlyList GetReplicas(string keyspace, string table, long token) + public IReadOnlyList GetReplicas(string keyspace, string table, IToken token) { + if (token == null) + { + return EMPTY_LIST; + } + var key = new KeyspaceTableNamePair(keyspace, table); if (!_mapping.TryGetValue(key, out var tabletSet)) { - Logger.Info("No tablets for {keyspace}.{table} in mapping.", keyspace, table); + Logger.Info($"No tablets for {keyspace}.{table} in mapping.", keyspace, table); return EMPTY_LIST; } - var row = tabletSet.Tablets.FirstOrDefault(t => t.LastToken >= token); - if (row == null || row.FirstToken >= token) + var row = tabletSet.Tablets.FirstOrDefault(t => token.CompareTo(new M3PToken(t.LastToken)) <= 0); + if (row == null || token.CompareTo(new M3PToken(row.FirstToken)) <= 0) { - Logger.Info("Could not find tablet for {keyspace}.{table} owning token {token}.", keyspace, table, token); + Logger.Info($"Could not find tablet for {keyspace}.{table} owning token {token}.", keyspace, table, token); return EMPTY_LIST; } - var replicas = new List(); + var replicas = new List(); foreach (var hostShardPair in row.Replicas) { Host replica = _metadata.Hosts.ToCollection().FirstOrDefault(h => h.HostId == hostShardPair.HostID); if (replica == null) return EMPTY_LIST; - replicas.Add(replica); + replicas.Add(new HostShard(replica, hostShardPair.Shard)); } return replicas.ToList(); // Return as List, which implements IReadOnlyList } diff --git a/src/Cassandra/TokenMap.cs b/src/Cassandra/TokenMap.cs index 3872627e7..569aa2df1 100644 --- a/src/Cassandra/TokenMap.cs +++ b/src/Cassandra/TokenMap.cs @@ -77,7 +77,7 @@ public void UpdateKeyspace(KeyspaceMetadata ks) sw.Elapsed.TotalMilliseconds); } - public ICollection GetReplicas(string keyspaceName, IToken token) + public ICollection GetReplicas(string keyspaceName, IToken token) { IReadOnlyList readOnlyRing = _ring; @@ -96,12 +96,14 @@ public ICollection GetReplicas(string keyspaceName, IToken token) var closestToken = readOnlyRing[i]; if (keyspaceName != null && _tokenToHostsByKeyspace.ContainsKey(keyspaceName)) { - return _tokenToHostsByKeyspace[keyspaceName][closestToken]; + return _tokenToHostsByKeyspace[keyspaceName][closestToken] + .Select(h => new HostShard(h, -1)) + .ToList(); } TokenMap.Logger.Warning("An attempt to obtain the replicas for a specific token was made but the replicas collection " + "wasn't computed for this keyspace: {0}. Returning the primary replica for the provided token.", keyspaceName); - return new[] { _primaryReplicas[closestToken] }; + return new[] { new HostShard(_primaryReplicas[closestToken], -1) }; } public static TokenMap Build(string partitioner, ICollection hosts, ICollection keyspaces)