From b3cd7abafee15f973d899dddae89d945a247b12e Mon Sep 17 00:00:00 2001 From: Sylwia Szunejko Date: Tue, 1 Apr 2025 08:38:27 +0200 Subject: [PATCH 01/14] Send OPTIONS message on every connections to have ShardingInfo --- .../FakeSupportedOptionsInitializer.cs | 7 +++- src/Cassandra/Connections/Connection.cs | 37 +++++++++++++++++++ .../Control/ISupportedOptionsInitializer.cs | 4 +- .../Control/SupportedOptionsInitializer.cs | 7 +++- src/Cassandra/Connections/IConnection.cs | 5 +++ 5 files changed, 57 insertions(+), 3 deletions(-) diff --git a/src/Cassandra.Tests/Connections/TestHelpers/FakeSupportedOptionsInitializer.cs b/src/Cassandra.Tests/Connections/TestHelpers/FakeSupportedOptionsInitializer.cs index 041fea31f..17b664b37 100644 --- a/src/Cassandra.Tests/Connections/TestHelpers/FakeSupportedOptionsInitializer.cs +++ b/src/Cassandra.Tests/Connections/TestHelpers/FakeSupportedOptionsInitializer.cs @@ -1,4 +1,4 @@ -// +// // Copyright (C) DataStax Inc. // // Licensed under the Apache License, Version 2.0 (the "License"); @@ -16,6 +16,7 @@ using System.Threading.Tasks; using Cassandra.Connections; using Cassandra.Connections.Control; +using Cassandra.Responses; using Cassandra.Tasks; namespace Cassandra.Tests.Connections.TestHelpers @@ -34,6 +35,10 @@ public Task ApplySupportedOptionsAsync(IConnection connection) return TaskHelper.Completed; } + public void ApplySupportedFromResponse(Response response) + { + } + public ShardingInfo GetShardingInfo() { return null; diff --git a/src/Cassandra/Connections/Connection.cs b/src/Cassandra/Connections/Connection.cs index f69f82887..fd4d45e37 100644 --- a/src/Cassandra/Connections/Connection.cs +++ b/src/Cassandra/Connections/Connection.cs @@ -24,6 +24,7 @@ using System.Net.Sockets; using System.Threading; using System.Threading.Tasks; +using Cassandra.Connections.Control; using Cassandra.Compression; using Cassandra.Metrics; using Cassandra.Observers.Abstractions; @@ -163,6 +164,8 @@ public string Keyspace public Configuration Configuration { get; set; } + private readonly ISupportedOptionsInitializer _supportedOptionsInitializer; + internal Connection( ISerializer serializer, IConnectionEndPoint endPoint, @@ -183,6 +186,7 @@ internal Connection( _freeOperations = new ConcurrentStack(Enumerable.Range(0, GetMaxConcurrentRequests(Serializer)).Select(s => (short)s).Reverse()); _pendingOperations = new ConcurrentDictionary(); _writeQueue = new ConcurrentQueue(); + _supportedOptionsInitializer = configuration.SupportedOptionsInitializerFactory.Create(null); if (Options.CustomCompressor != null) { @@ -484,6 +488,25 @@ public async Task DoOpen() _tcpSocket.WriteCompleted += WriteCompletedHandler; var protocolVersion = Serializer.ProtocolVersion; await _tcpSocket.Connect().ConfigureAwait(false); + + // Send the OPTIONS message + Response optionsResponse; + try + { + optionsResponse = await SendOptions().ConfigureAwait(false); + } + catch (ProtocolErrorException ex) + { + // As we are starting up, check for protocol version errors. + // There is no other way than checking the error message from Cassandra + if (ex.Message.Contains("Invalid or unsupported protocol version")) + { + throw new UnsupportedProtocolVersionException(protocolVersion, Serializer.ProtocolVersion, ex); + } + throw; + } + _supportedOptionsInitializer.ApplySupportedFromResponse(optionsResponse); + Response response; try { @@ -511,6 +534,11 @@ public async Task DoOpen() throw new DriverInternalError("Expected READY or AUTHENTICATE, obtained " + response.GetType().Name); } + public ShardingInfo ShardingInfo() + { + return _supportedOptionsInitializer.GetShardingInfo(); + } + private void ReadHandler(byte[] buffer, int bytesReceived) { if (_isClosed) @@ -773,6 +801,15 @@ private static bool InvokeReadCallbacks(MemoryStream stream, ICollection + /// Sends a protocol OPTIONS message + /// + private Task SendOptions() + { + var request = new OptionsRequest(); + return Send(request, Configuration.SocketOptions.ConnectTimeoutMillis); + } + /// /// Sends a protocol STARTUP message /// diff --git a/src/Cassandra/Connections/Control/ISupportedOptionsInitializer.cs b/src/Cassandra/Connections/Control/ISupportedOptionsInitializer.cs index f745e80e6..441983217 100644 --- a/src/Cassandra/Connections/Control/ISupportedOptionsInitializer.cs +++ b/src/Cassandra/Connections/Control/ISupportedOptionsInitializer.cs @@ -1,4 +1,4 @@ -// +// // Copyright (C) DataStax Inc. // // Licensed under the Apache License, Version 2.0 (the "License"); @@ -14,12 +14,14 @@ // limitations under the License. using System.Threading.Tasks; +using Cassandra.Responses; namespace Cassandra.Connections.Control { internal interface ISupportedOptionsInitializer { Task ApplySupportedOptionsAsync(IConnection connection); + void ApplySupportedFromResponse(Response response); ShardingInfo GetShardingInfo(); } } \ No newline at end of file diff --git a/src/Cassandra/Connections/Control/SupportedOptionsInitializer.cs b/src/Cassandra/Connections/Control/SupportedOptionsInitializer.cs index 644b637b9..08364fa68 100644 --- a/src/Cassandra/Connections/Control/SupportedOptionsInitializer.cs +++ b/src/Cassandra/Connections/Control/SupportedOptionsInitializer.cs @@ -1,4 +1,4 @@ -// +// // Copyright (C) DataStax Inc. // // Licensed under the Apache License, Version 2.0 (the "License"); @@ -48,6 +48,11 @@ public async Task ApplySupportedOptionsAsync(IConnection connection) var request = new OptionsRequest(); var response = await connection.Send(request).ConfigureAwait(false); + ApplySupportedFromResponse(response); + } + + public void ApplySupportedFromResponse(Response response) + { if (response == null) { throw new NullReferenceException("Response can not be null"); diff --git a/src/Cassandra/Connections/IConnection.cs b/src/Cassandra/Connections/IConnection.cs index 27f026520..579093626 100644 --- a/src/Cassandra/Connections/IConnection.cs +++ b/src/Cassandra/Connections/IConnection.cs @@ -137,5 +137,10 @@ internal interface IConnection : IDisposable /// Cancels current requests and invokes Closing event handlers. Doesn't guarantee disposal, the Closing event handlers should do that. /// void Close(); + + /// + /// Returns the current sharding information. + /// + ShardingInfo ShardingInfo(); } } \ No newline at end of file From 1350ff747a87e744e5c67fca727430ed262c8625 Mon Sep 17 00:00:00 2001 From: Sylwia Szunejko Date: Tue, 1 Apr 2025 08:40:43 +0200 Subject: [PATCH 02/14] Open as many connections as there are shards --- .../Core/SessionTests.cs | 5 +-- .../Connections/HostConnectionPool.cs | 35 +++++++++++++------ 2 files changed, 28 insertions(+), 12 deletions(-) diff --git a/src/Cassandra.IntegrationTests/Core/SessionTests.cs b/src/Cassandra.IntegrationTests/Core/SessionTests.cs index 5b7f60f49..98796236c 100644 --- a/src/Cassandra.IntegrationTests/Core/SessionTests.cs +++ b/src/Cassandra.IntegrationTests/Core/SessionTests.cs @@ -211,8 +211,9 @@ public void Should_Create_The_Right_Amount_Of_Connections() Thread.Sleep(2000); var pool21 = localSession2.GetOrCreateConnectionPool(hosts2[0], HostDistance.Local); var pool22 = localSession2.GetOrCreateConnectionPool(hosts2[1], HostDistance.Local); - Assert.That(pool21.OpenConnections, Is.EqualTo(1)); - Assert.That(pool22.OpenConnections, Is.EqualTo(1)); + // Should be 2 due to number of shards + Assert.That(pool21.OpenConnections, Is.EqualTo(2)); + Assert.That(pool22.OpenConnections, Is.EqualTo(2)); } } diff --git a/src/Cassandra/Connections/HostConnectionPool.cs b/src/Cassandra/Connections/HostConnectionPool.cs index d6ce4a36d..08f5011ae 100644 --- a/src/Cassandra/Connections/HostConnectionPool.cs +++ b/src/Cassandra/Connections/HostConnectionPool.cs @@ -642,7 +642,7 @@ private async Task CreateOrScheduleReconnectAsync(IReconnectionSchedule schedule } /// - /// Opens one connection. + /// Opens one connection. /// If a connection is being opened it yields the same task, preventing creation in parallel. /// /// @@ -943,7 +943,29 @@ private async Task GetConnectionFromHostAsync( public async Task Warmup() { var length = _expectedConnectionLength; - for (var i = 0; i < length; i++) + // Open first connection + try + { + await CreateOpenConnection(false, false).ConfigureAwait(false); + var shardingInfo = _connections.GetSnapshot()[0].ShardingInfo(); + if (shardingInfo != null) + { + var nrShards = shardingInfo.ScyllaNrShards; + if (nrShards > length) + { + // Create the rest of the connections + length = nrShards; + _expectedConnectionLength = nrShards; + } + } + } + catch + { + OnConnectionClosing(); + throw; + } + + for (var i = 1; i < length; i++) { try { @@ -951,14 +973,7 @@ public async Task Warmup() } catch { - if (i > 0) - { - // There is an opened connection, don't mind - break; - } - - OnConnectionClosing(); - throw; + break; } } } From a31ab7a570eeae39617a467b01c35bf9e288b11a Mon Sep 17 00:00:00 2001 From: Sylwia Szunejko Date: Tue, 1 Apr 2025 08:41:30 +0200 Subject: [PATCH 03/14] Add test to verify if correct number of connections are open --- .../ShardAwareOptionsTests.cs | 17 +++++++++++++++++ 1 file changed, 17 insertions(+) diff --git a/src/Cassandra.IntegrationTests/ShardAwareOptionsTests.cs b/src/Cassandra.IntegrationTests/ShardAwareOptionsTests.cs index 0d24217d6..a426e9565 100644 --- a/src/Cassandra.IntegrationTests/ShardAwareOptionsTests.cs +++ b/src/Cassandra.IntegrationTests/ShardAwareOptionsTests.cs @@ -31,5 +31,22 @@ public void Should_Connect_To_Shard_Aware_Cluster() var controlConnection = (ControlConnection)internalCluster.GetControlConnection(); Assert.IsTrue(controlConnection.IsShardAware()); } + + [Test] + public void Should_Have_NrShards_Connections() + { + _realCluster = TestClusterManager.CreateNew(); + var cluster = ClusterBuilder() + .WithSocketOptions(new SocketOptions().SetReadTimeoutMillis(22000).SetConnectTimeoutMillis(60000)) + .AddContactPoint(_realCluster.InitialContactPoint) + .Build(); + var session = cluster.Connect(); + IInternalSession internalSession = (IInternalSession)session; + var pools = internalSession.GetPools(); + foreach (var kvp in pools) + { + Assert.AreEqual(2, kvp.Value.OpenConnections); + } + } } } \ No newline at end of file From de70b9c0b12b3e3484770f5908f77b8059369597 Mon Sep 17 00:00:00 2001 From: Sylwia Szunejko Date: Tue, 1 Apr 2025 08:48:34 +0200 Subject: [PATCH 04/14] Add newlines at the end of files --- src/Cassandra.IntegrationTests/ShardAwareOptionsTests.cs | 2 +- .../Connections/Control/ISupportedOptionsInitializer.cs | 2 +- src/Cassandra/Connections/HostConnectionPoolFactory.cs | 2 +- src/Cassandra/Connections/IConnection.cs | 2 +- src/Cassandra/Connections/IEndPointResolver.cs | 2 +- src/Cassandra/Connections/IHostConnectionPool.cs | 2 +- src/Cassandra/Connections/IHostConnectionPoolFactory.cs | 2 +- 7 files changed, 7 insertions(+), 7 deletions(-) diff --git a/src/Cassandra.IntegrationTests/ShardAwareOptionsTests.cs b/src/Cassandra.IntegrationTests/ShardAwareOptionsTests.cs index a426e9565..626ee66eb 100644 --- a/src/Cassandra.IntegrationTests/ShardAwareOptionsTests.cs +++ b/src/Cassandra.IntegrationTests/ShardAwareOptionsTests.cs @@ -49,4 +49,4 @@ public void Should_Have_NrShards_Connections() } } } -} \ No newline at end of file +} diff --git a/src/Cassandra/Connections/Control/ISupportedOptionsInitializer.cs b/src/Cassandra/Connections/Control/ISupportedOptionsInitializer.cs index 441983217..4741c708e 100644 --- a/src/Cassandra/Connections/Control/ISupportedOptionsInitializer.cs +++ b/src/Cassandra/Connections/Control/ISupportedOptionsInitializer.cs @@ -24,4 +24,4 @@ internal interface ISupportedOptionsInitializer void ApplySupportedFromResponse(Response response); ShardingInfo GetShardingInfo(); } -} \ No newline at end of file +} diff --git a/src/Cassandra/Connections/HostConnectionPoolFactory.cs b/src/Cassandra/Connections/HostConnectionPoolFactory.cs index 172d83cf9..54166f7aa 100644 --- a/src/Cassandra/Connections/HostConnectionPoolFactory.cs +++ b/src/Cassandra/Connections/HostConnectionPoolFactory.cs @@ -26,4 +26,4 @@ public IHostConnectionPool Create(Host host, Configuration config, ISerializerMa return new HostConnectionPool(host, config, serializerManager, observerFactory); } } -} \ No newline at end of file +} diff --git a/src/Cassandra/Connections/IConnection.cs b/src/Cassandra/Connections/IConnection.cs index 579093626..1fea05f52 100644 --- a/src/Cassandra/Connections/IConnection.cs +++ b/src/Cassandra/Connections/IConnection.cs @@ -143,4 +143,4 @@ internal interface IConnection : IDisposable /// ShardingInfo ShardingInfo(); } -} \ No newline at end of file +} diff --git a/src/Cassandra/Connections/IEndPointResolver.cs b/src/Cassandra/Connections/IEndPointResolver.cs index 482b0875d..6c502e341 100644 --- a/src/Cassandra/Connections/IEndPointResolver.cs +++ b/src/Cassandra/Connections/IEndPointResolver.cs @@ -33,4 +33,4 @@ internal interface IEndPointResolver /// Endpoint. Task GetConnectionEndPointAsync(Host host, bool refreshCache); } -} \ No newline at end of file +} diff --git a/src/Cassandra/Connections/IHostConnectionPool.cs b/src/Cassandra/Connections/IHostConnectionPool.cs index f9acd0091..513a9fc18 100644 --- a/src/Cassandra/Connections/IHostConnectionPool.cs +++ b/src/Cassandra/Connections/IHostConnectionPool.cs @@ -97,4 +97,4 @@ Task GetConnectionFromHostAsync( Task GetExistingConnectionFromHostAsync( IDictionary triedHosts, Func getKeyspaceFunc); } -} \ No newline at end of file +} diff --git a/src/Cassandra/Connections/IHostConnectionPoolFactory.cs b/src/Cassandra/Connections/IHostConnectionPoolFactory.cs index bb671ee9b..9f0ac0089 100644 --- a/src/Cassandra/Connections/IHostConnectionPoolFactory.cs +++ b/src/Cassandra/Connections/IHostConnectionPoolFactory.cs @@ -23,4 +23,4 @@ internal interface IHostConnectionPoolFactory { IHostConnectionPool Create(Host host, Configuration config, ISerializerManager serializer, IObserverFactory observerFactory); } -} \ No newline at end of file +} From 3fd338b172072810daf096fd46c12af973590b0f Mon Sep 17 00:00:00 2001 From: Sylwia Szunejko Date: Wed, 2 Apr 2025 10:04:56 +0200 Subject: [PATCH 05/14] Add mechanism to choose which shard to connect to --- src/Cassandra/Connections/Connection.cs | 2 + .../Connections/HostConnectionPool.cs | 61 ++++++++++++++++--- src/Cassandra/Connections/IConnection.cs | 2 + 3 files changed, 55 insertions(+), 10 deletions(-) diff --git a/src/Cassandra/Connections/Connection.cs b/src/Cassandra/Connections/Connection.cs index fd4d45e37..f9b5c4369 100644 --- a/src/Cassandra/Connections/Connection.cs +++ b/src/Cassandra/Connections/Connection.cs @@ -166,6 +166,8 @@ public string Keyspace private readonly ISupportedOptionsInitializer _supportedOptionsInitializer; + public int ShardId { get; } + internal Connection( ISerializer serializer, IConnectionEndPoint endPoint, diff --git a/src/Cassandra/Connections/HostConnectionPool.cs b/src/Cassandra/Connections/HostConnectionPool.cs index 08f5011ae..1dd38caef 100644 --- a/src/Cassandra/Connections/HostConnectionPool.cs +++ b/src/Cassandra/Connections/HostConnectionPool.cs @@ -108,6 +108,9 @@ private static class PoolState /// public IConnection[] ConnectionsSnapshot => _connections.GetSnapshot(); + public ShardingInfo shardingInfo { get; private set; } + + private int lastAttemptedShard = 0; public HostConnectionPool(Host host, Configuration config, ISerializerManager serializerManager, IObserverFactory observerFactory) { @@ -649,10 +652,13 @@ private async Task CreateOrScheduleReconnectAsync(IReconnectionSchedule schedule /// Determines whether the Task should be marked as completed when there is a connection already opened. /// /// Determines whether this is a reconnection + /// + /// Determines whether the connection should be added to the pool. + /// /// Throws a SocketException when the connection could not be established with the host /// /// - private async Task CreateOpenConnection(bool satisfyWithAnOpenConnection, bool isReconnection) + private async Task CreateOpenConnection(bool satisfyWithAnOpenConnection, bool isReconnection, bool addToConnections = true) { var concurrentOpenTcs = Volatile.Read(ref _connectionOpenTcs); // Try to exit early (cheap) as there could be another thread creating / finishing creating @@ -705,6 +711,28 @@ private async Task CreateOpenConnection(bool satisfyWithAnOpenConne IConnection c; try { + // Find out to which shard should we connect to + // Console.WriteLine("Decide which shard to connect to"); + var shardID = 0; + // Console.WriteLine("ShardingInfo: {0}", shardingInfo); + if (shardingInfo != null) + { + // Find the shard without a connection + // It's important to start counting from 1 here because we want + // to consider the next shard after the previously attempted one + for (var i = 1; i <= shardingInfo.ScyllaNrShards; i++) + { + // Console.WriteLine("i: {0}", i); + var _shardID = (lastAttemptedShard + i) % shardingInfo.ScyllaNrShards; + if (connectionsSnapshot.Length <= shardID || connectionsSnapshot[shardID] == null) + { + lastAttemptedShard = _shardID; + shardID = _shardID; + break; + } + } + } + // Console.WriteLine("shardID: {0}", shardID); c = await DoCreateAndOpen(isReconnection).ConfigureAwait(false); } catch (Exception ex) @@ -721,9 +749,12 @@ private async Task CreateOpenConnection(bool satisfyWithAnOpenConne return await FinishOpen(tcs, false, HostConnectionPool.GetNotConnectedException()).ConfigureAwait(false); } - var newLength = _connections.AddNew(c); - HostConnectionPool.Logger.Info("Connection to {0} opened successfully, pool #{1} length: {2}", - _host.Address, GetHashCode(), newLength); + if (addToConnections) + { + var newLength = _connections.AddNew(c); + HostConnectionPool.Logger.Info("Connection to {0} opened successfully, pool #{1} length: {2}", + _host.Address, GetHashCode(), newLength); + } if (IsClosing) { @@ -946,11 +977,13 @@ public async Task Warmup() // Open first connection try { - await CreateOpenConnection(false, false).ConfigureAwait(false); - var shardingInfo = _connections.GetSnapshot()[0].ShardingInfo(); - if (shardingInfo != null) + var c = await CreateOpenConnection(false, false, false).ConfigureAwait(false); + var _shardingInfo = c.ShardingInfo(); + if (_shardingInfo != null) { - var nrShards = shardingInfo.ScyllaNrShards; + shardingInfo = _shardingInfo; + Console.WriteLine("TEST CONN shardingInfo: {0}", shardingInfo); + var nrShards = _shardingInfo.ScyllaNrShards; if (nrShards > length) { // Create the rest of the connections @@ -958,6 +991,7 @@ public async Task Warmup() _expectedConnectionLength = nrShards; } } + c.Dispose(); } catch { @@ -965,7 +999,7 @@ public async Task Warmup() throw; } - for (var i = 1; i < length; i++) + for (var i = 0; i < length; i++) { try { @@ -973,7 +1007,14 @@ public async Task Warmup() } catch { - break; + if (i > 0) + { + // There is an opened connection, don't mind + break; + } + + OnConnectionClosing(); + throw; } } } diff --git a/src/Cassandra/Connections/IConnection.cs b/src/Cassandra/Connections/IConnection.cs index 1fea05f52..41fd44097 100644 --- a/src/Cassandra/Connections/IConnection.cs +++ b/src/Cassandra/Connections/IConnection.cs @@ -142,5 +142,7 @@ internal interface IConnection : IDisposable /// Returns the current sharding information. /// ShardingInfo ShardingInfo(); + + int ShardId { get; } } } From abdf149c56b444bdb208481ba5b02ba3239efa98 Mon Sep 17 00:00:00 2001 From: Sylwia Szunejko Date: Tue, 8 Apr 2025 09:11:14 +0200 Subject: [PATCH 06/14] Connect to shard aware port --- .../HostConnectionPoolTests.cs | 38 +++++++++---------- src/Cassandra/Connections/EndPointResolver.cs | 7 ++++ .../Connections/HostConnectionPool.cs | 25 +++++++----- .../Connections/IEndPointResolver.cs | 12 ++++++ .../Connections/SniEndPointResolver.cs | 9 +++++ src/Cassandra/ShardingInfo.cs | 6 +-- 6 files changed, 65 insertions(+), 32 deletions(-) diff --git a/src/Cassandra.Tests/HostConnectionPoolTests.cs b/src/Cassandra.Tests/HostConnectionPoolTests.cs index 5fe87b6fa..8fe315afc 100644 --- a/src/Cassandra.Tests/HostConnectionPoolTests.cs +++ b/src/Cassandra.Tests/HostConnectionPoolTests.cs @@ -128,7 +128,7 @@ public void MaybeCreateFirstConnection_Should_Yield_The_First_Connection_Opened( var lastByte = 1; //use different addresses for same hosts to differentiate connections: for test only //different connections to same hosts should use the same address - _mock.Setup(p => p.DoCreateAndOpen(It.IsAny())).Returns(() => TestHelper.DelayedTask(CreateConnection((byte)lastByte++), 200 - lastByte * 50)); + _mock.Setup(p => p.DoCreateAndOpen(It.IsAny(), -1, 0)).Returns(() => TestHelper.DelayedTask(CreateConnection((byte)lastByte++), 200 - lastByte * 50)); var pool = _mock.Object; var creation = pool.EnsureCreate(); creation.Wait(); @@ -144,7 +144,7 @@ public async Task EnsureCreate_Should_Yield_A_Connection_If_Any_Fails() var counter = 0; //use different addresses for same hosts to differentiate connections: for test only //different connections to same hosts should use the same address - _mock.Setup(p => p.DoCreateAndOpen(It.IsAny())).Returns(() => + _mock.Setup(p => p.DoCreateAndOpen(It.IsAny(), -1, 0)).Returns(() => { if (++counter == 2) { @@ -162,7 +162,7 @@ public void EnsureCreate_Serial_Calls_Should_Yield_First() { _mock = GetPoolMock(); var lastByte = 1; - _mock.Setup(p => p.DoCreateAndOpen(It.IsAny())).Returns(() => + _mock.Setup(p => p.DoCreateAndOpen(It.IsAny(), -1, 0)).Returns(() => { var c = CreateConnection((byte)lastByte++); if (lastByte == 2) @@ -191,7 +191,7 @@ public void EnsureCreate_Parallel_Calls_Should_Yield_First() { _mock = GetPoolMock(); var lastByte = 0; - _mock.Setup(p => p.DoCreateAndOpen(It.IsAny())).Returns(() => TestHelper.DelayedTask(CreateConnection((byte)++lastByte), 100 + (lastByte > 1 ? 10000 : 0))); + _mock.Setup(p => p.DoCreateAndOpen(It.IsAny(), -1, 0)).Returns(() => TestHelper.DelayedTask(CreateConnection((byte)++lastByte), 100 + (lastByte > 1 ? 10000 : 0))); var pool = _mock.Object; var creationTasks = new Task[10]; var counter = -1; @@ -217,7 +217,7 @@ public void EnsureCreate_Parallel_Calls_Failing_Should_Only_Attempt_Creation_Onc // Use a reconnection policy that never attempts _mock = GetPoolMock(null, GetConfig(3, 3, new ConstantReconnectionPolicy(int.MaxValue))); var openConnectionAttempts = 0; - _mock.Setup(p => p.DoCreateAndOpen(It.IsAny())).Returns(() => + _mock.Setup(p => p.DoCreateAndOpen(It.IsAny(), -1, 0)).Returns(() => { Interlocked.Increment(ref openConnectionAttempts); return TaskHelper.FromException(new Exception("Test Exception")); @@ -254,7 +254,7 @@ public void EnsureCreate_Fail_To_Open_All_Connections_Should_Fault_Task() { _mock = GetPoolMock(); var testException = new Exception("Dummy exception"); - _mock.Setup(p => p.DoCreateAndOpen(It.IsAny())).Returns(() => TestHelper.DelayedTask(() => + _mock.Setup(p => p.DoCreateAndOpen(It.IsAny(), -1, 0)).Returns(() => TestHelper.DelayedTask(() => { throw testException; })); @@ -270,7 +270,7 @@ public async Task OnHostUp_Recreates_Pool_In_The_Background() _mock = GetPoolMock(null, GetConfig(2, 2)); var creationCounter = 0; var isCreating = 0; - _mock.Setup(p => p.DoCreateAndOpen(It.IsAny())).Returns(() => + _mock.Setup(p => p.DoCreateAndOpen(It.IsAny(), -1, 0)).Returns(() => { Interlocked.Increment(ref creationCounter); Interlocked.Exchange(ref isCreating, 1); @@ -293,7 +293,7 @@ public void OnHostUp_Does_Not_Recreates_Pool_For_Ignored_Hosts() { _mock = GetPoolMock(null, GetConfig(2, 2)); var creationCounter = 0; - _mock.Setup(p => p.DoCreateAndOpen(It.IsAny())).Returns(() => + _mock.Setup(p => p.DoCreateAndOpen(It.IsAny(), -1, 0)).Returns(() => { Interlocked.Increment(ref creationCounter); return TaskHelper.ToTask(CreateConnection()); @@ -314,7 +314,7 @@ public async Task EnsureCreate_After_Reconnection_Attempt_Waits_Existing() _mock = GetPoolMock(null, GetConfig(2, 2)); var creationCounter = 0; var isCreating = 0; - _mock.Setup(p => p.DoCreateAndOpen(It.IsAny())).Returns(() => + _mock.Setup(p => p.DoCreateAndOpen(It.IsAny(), -1, 0)).Returns(() => { Interlocked.Increment(ref creationCounter); Interlocked.Exchange(ref isCreating, 1); @@ -339,7 +339,7 @@ public async Task EnsureCreate_Can_Handle_Multiple_Concurrent_Calls() _mock = GetPoolMock(null, GetConfig(3, 3)); var creationCounter = 0; var isCreating = 0; - _mock.Setup(p => p.DoCreateAndOpen(It.IsAny())).Returns(() => + _mock.Setup(p => p.DoCreateAndOpen(It.IsAny(), -1, 0)).Returns(() => { Interlocked.Increment(ref creationCounter); Interlocked.Exchange(ref isCreating, 1); @@ -432,7 +432,7 @@ public void ScheduleReconnection_Should_Raise_AllConnectionClosed() { _mock = GetPoolMock(null, GetConfig(1, 1, new ConstantReconnectionPolicy(100))); var openConnectionsAttempts = 0; - _mock.Setup(p => p.DoCreateAndOpen(It.IsAny())).Returns(() => + _mock.Setup(p => p.DoCreateAndOpen(It.IsAny(), -1, 0)).Returns(() => { Interlocked.Increment(ref openConnectionsAttempts); return TaskHelper.FromException(new Exception("Test Exception")); @@ -455,7 +455,7 @@ public void ScheduleReconnection_Should_Not_Raise_AllConnectionClosed_When_Host_ host.SetDown(); _mock = GetPoolMock(host, GetConfig(1, 1, new ConstantReconnectionPolicy(100))); var openConnectionsAttempts = 0; - _mock.Setup(p => p.DoCreateAndOpen(It.IsAny())).Returns(() => + _mock.Setup(p => p.DoCreateAndOpen(It.IsAny(), -1, 0)).Returns(() => { Interlocked.Increment(ref openConnectionsAttempts); return TaskHelper.FromException(new Exception("Test Exception")); @@ -476,7 +476,7 @@ public void ScheduleReconnection_Should_Not_Raise_AllConnectionClosed_When_Host_ public async Task CheckHealth_Removes_Connection() { _mock = GetPoolMock(); - _mock.Setup(p => p.DoCreateAndOpen(It.IsAny())).Returns( + _mock.Setup(p => p.DoCreateAndOpen(It.IsAny(), -1, 0)).Returns( () => { var cc = HostConnectionPoolTests.GetConnectionMock(0, int.MaxValue); @@ -499,7 +499,7 @@ public async Task CheckHealth_Removes_Connection() public async Task Pool_Increasing_Size_And_Closing_Should_Not_Leave_Connections_Open([Range(0, 29)] int delay) { _mock = GetPoolMock(null, GetConfig(50, 50)); - _mock.Setup(p => p.DoCreateAndOpen(It.IsAny())).Returns(async () => + _mock.Setup(p => p.DoCreateAndOpen(It.IsAny(), -1, 0)).Returns(async () => { await Task.Yield(); var spinWait = new SpinWait(); @@ -537,7 +537,7 @@ await Task.Run(() => public async Task Dispose_Should_Not_Raise_AllConnections_Closed() { _mock = GetPoolMock(null, GetConfig(4, 4)); - _mock.Setup(p => p.DoCreateAndOpen(It.IsAny())).Returns(() => TaskHelper.ToTask(CreateConnection())); + _mock.Setup(p => p.DoCreateAndOpen(It.IsAny(), -1, 0)).Returns(() => TaskHelper.ToTask(CreateConnection())); var pool = _mock.Object; Assert.AreEqual(0, pool.OpenConnections); pool.SetDistance(HostDistance.Local); @@ -555,7 +555,7 @@ public async Task Dispose_Should_Cancel_Reconnection_Attempts() { _mock = GetPoolMock(null, GetConfig(4, 4, new ConstantReconnectionPolicy(200))); var openConnectionAttempts = 0; - _mock.Setup(p => p.DoCreateAndOpen(It.IsAny())).Returns(() => + _mock.Setup(p => p.DoCreateAndOpen(It.IsAny(), -1, 0)).Returns(() => { Interlocked.Increment(ref openConnectionAttempts); return TaskHelper.ToTask(CreateConnection()); @@ -575,7 +575,7 @@ public void Warmup_Should_Throw_When_The_First_Connection_Can_Not_Be_Opened() { _mock = GetPoolMock(null, GetConfig(4, 4, new ConstantReconnectionPolicy(200))); var openConnectionAttempts = 0; - _mock.Setup(p => p.DoCreateAndOpen(It.IsAny())).Returns(() => + _mock.Setup(p => p.DoCreateAndOpen(It.IsAny(), -1, 0)).Returns(() => { var index = Interlocked.Increment(ref openConnectionAttempts); if (index == 1) @@ -596,7 +596,7 @@ public void Warmup_Should_Succeed_When_The_Second_Connection_Can_Not_Be_Opened() { _mock = GetPoolMock(null, GetConfig(4, 4, new ConstantReconnectionPolicy(200))); var openConnectionAttempts = 0; - _mock.Setup(p => p.DoCreateAndOpen(It.IsAny())).Returns(() => + _mock.Setup(p => p.DoCreateAndOpen(It.IsAny(), -1, 0)).Returns(() => { var index = Interlocked.Increment(ref openConnectionAttempts); if (index == 2) @@ -617,7 +617,7 @@ public void Warmup_Should_Succeed_When_All_Connections_Can_Be_Opened() { _mock = GetPoolMock(null, GetConfig(4, 4, new ConstantReconnectionPolicy(200))); var openConnectionAttempts = 0; - _mock.Setup(p => p.DoCreateAndOpen(It.IsAny())).Returns(() => + _mock.Setup(p => p.DoCreateAndOpen(It.IsAny(), -1, 0)).Returns(() => { Interlocked.Increment(ref openConnectionAttempts); return TaskHelper.ToTask(CreateConnection()); diff --git a/src/Cassandra/Connections/EndPointResolver.cs b/src/Cassandra/Connections/EndPointResolver.cs index 33a76f043..449443b5d 100644 --- a/src/Cassandra/Connections/EndPointResolver.cs +++ b/src/Cassandra/Connections/EndPointResolver.cs @@ -15,6 +15,7 @@ // using System; +using System.Net; using System.Threading.Tasks; namespace Cassandra.Connections @@ -28,6 +29,12 @@ public EndPointResolver(IServerNameResolver serverNameResolver) _serverNameResolver = serverNameResolver ?? throw new ArgumentNullException(nameof(serverNameResolver)); } + /// + public Task GetConnectionShardAwareEndPointAsync(Host host, bool refreshCache, int shardID, int shardAwarePort) + { + return Task.FromResult((IConnectionEndPoint)new ConnectionEndPoint(new IPEndPoint(IPAddress.Parse(host.Address.ToString().Split(':')[0]), shardAwarePort), _serverNameResolver, host.ContactPoint)); + } + /// public Task GetConnectionEndPointAsync(Host host, bool refreshCache) { diff --git a/src/Cassandra/Connections/HostConnectionPool.cs b/src/Cassandra/Connections/HostConnectionPool.cs index 1dd38caef..7241f3ca7 100644 --- a/src/Cassandra/Connections/HostConnectionPool.cs +++ b/src/Cassandra/Connections/HostConnectionPool.cs @@ -277,9 +277,17 @@ public void Dispose() Interlocked.Exchange(ref _state, PoolState.Shutdown); } - public virtual async Task DoCreateAndOpen(bool isReconnection) + public virtual async Task DoCreateAndOpen(bool isReconnection, int shardID = -1, int shardAwarePort = 0) { - var endPoint = await _config.EndPointResolver.GetConnectionEndPointAsync(_host, isReconnection).ConfigureAwait(false); + IConnectionEndPoint endPoint; + if (shardID != -1 && shardAwarePort != 0) + { + endPoint = await _config.EndPointResolver.GetConnectionShardAwareEndPointAsync(_host, isReconnection, shardID, shardAwarePort).ConfigureAwait(false); + } + else + { + endPoint = await _config.EndPointResolver.GetConnectionEndPointAsync(_host, isReconnection).ConfigureAwait(false); + } var c = _config.ConnectionFactory.Create(_serializerManager.GetCurrentSerializer(), endPoint, _config, _observerFactory.CreateConnectionObserver(_host)); c.Closing += OnConnectionClosing; if (_poolingOptions.GetHeartBeatInterval() > 0) @@ -712,19 +720,18 @@ private async Task CreateOpenConnection(bool satisfyWithAnOpenConne try { // Find out to which shard should we connect to - // Console.WriteLine("Decide which shard to connect to"); - var shardID = 0; - // Console.WriteLine("ShardingInfo: {0}", shardingInfo); + var shardID = -1; + var shardAwarePort = 0; if (shardingInfo != null) { + shardAwarePort = shardingInfo.ScyllaShardAwarePort; // Find the shard without a connection // It's important to start counting from 1 here because we want // to consider the next shard after the previously attempted one for (var i = 1; i <= shardingInfo.ScyllaNrShards; i++) { - // Console.WriteLine("i: {0}", i); var _shardID = (lastAttemptedShard + i) % shardingInfo.ScyllaNrShards; - if (connectionsSnapshot.Length <= shardID || connectionsSnapshot[shardID] == null) + if (connectionsSnapshot.Length <= _shardID || connectionsSnapshot[_shardID] == null) { lastAttemptedShard = _shardID; shardID = _shardID; @@ -732,8 +739,7 @@ private async Task CreateOpenConnection(bool satisfyWithAnOpenConne } } } - // Console.WriteLine("shardID: {0}", shardID); - c = await DoCreateAndOpen(isReconnection).ConfigureAwait(false); + c = await DoCreateAndOpen(isReconnection, shardID, shardAwarePort).ConfigureAwait(false); } catch (Exception ex) { @@ -982,7 +988,6 @@ public async Task Warmup() if (_shardingInfo != null) { shardingInfo = _shardingInfo; - Console.WriteLine("TEST CONN shardingInfo: {0}", shardingInfo); var nrShards = _shardingInfo.ScyllaNrShards; if (nrShards > length) { diff --git a/src/Cassandra/Connections/IEndPointResolver.cs b/src/Cassandra/Connections/IEndPointResolver.cs index 6c502e341..2c983670c 100644 --- a/src/Cassandra/Connections/IEndPointResolver.cs +++ b/src/Cassandra/Connections/IEndPointResolver.cs @@ -32,5 +32,17 @@ internal interface IEndPointResolver /// no round trip will occur. /// Endpoint. Task GetConnectionEndPointAsync(Host host, bool refreshCache); + + /// + /// Gets an instance of to the provided host from the internal cache (if caching is supported by the implementation). + /// It uses provided shard aware port. + /// + /// Host related to the new endpoint. + /// Whether to refresh the internal cache. If it is false and the cache is populated then + /// no round trip will occur. + /// Shard ID. + /// Shard aware port. + /// Endpoint. + Task GetConnectionShardAwareEndPointAsync(Host host, bool refreshCache, int shardID, int shardAwarePort); } } diff --git a/src/Cassandra/Connections/SniEndPointResolver.cs b/src/Cassandra/Connections/SniEndPointResolver.cs index 36f629c9e..04f07a011 100644 --- a/src/Cassandra/Connections/SniEndPointResolver.cs +++ b/src/Cassandra/Connections/SniEndPointResolver.cs @@ -56,6 +56,15 @@ public SniEndPointResolver( { } + public async Task GetConnectionShardAwareEndPointAsync(Host host, bool refreshCache, int shardID, int shardAwarePort) + { + return new SniConnectionEndPoint( + await GetNextEndPointAsync(refreshCache).ConfigureAwait(false), + new IPEndPoint(IPAddress.Parse(host.Address.ToString().Split(':')[0]), shardAwarePort), + host.HostId.ToString("D"), + host.ContactPoint); + } + public async Task GetConnectionEndPointAsync(Host host, bool refreshCache) { return new SniConnectionEndPoint( diff --git a/src/Cassandra/ShardingInfo.cs b/src/Cassandra/ShardingInfo.cs index 646070b23..19771a6f1 100644 --- a/src/Cassandra/ShardingInfo.cs +++ b/src/Cassandra/ShardingInfo.cs @@ -11,12 +11,12 @@ public class ShardingInfo public string ScyllaPartitioner { get; } public string ScyllaShardingAlgorithm { get; } public ulong ScyllaShardingIgnoreMSB { get; } - public ulong ScyllaShardAwarePort { get; } + public int ScyllaShardAwarePort { get; } public ulong ScyllaShardAwarePortSSL { get; } private ShardingInfo(int scyllaShard, int scyllaNrShards, string scyllaPartitioner, string scyllaShardingAlgorithm, ulong scyllaShardingIgnoreMSB, - ulong scyllaShardAwarePort, ulong scyllaShardAwarePortSSL) + int scyllaShardAwarePort, ulong scyllaShardAwarePortSSL) { ScyllaShard = scyllaShard; ScyllaNrShards = scyllaNrShards; @@ -37,7 +37,7 @@ public static ShardingInfo Create(string scyllaShard, string scyllaNrShards, str scyllaPartitioner, scyllaShardingAlgorithm, ulong.Parse(scyllaShardingIgnoreMSB), - ulong.Parse(scyllaShardAwarePort), + int.Parse(scyllaShardAwarePort), ulong.Parse(scyllaShardAwarePortSSL) ); } From ddda012190c3454405ff85cb35fe2d899db6b624 Mon Sep 17 00:00:00 2001 From: Sylwia Szunejko Date: Wed, 9 Apr 2025 14:02:58 +0200 Subject: [PATCH 07/14] Connect to given shard --- .../HostConnectionPoolTests.cs | 38 +++++----- src/Cassandra/Connections/Connection.cs | 32 +++++++- src/Cassandra/Connections/EndPointResolver.cs | 2 +- .../Connections/HostConnectionPool.cs | 24 ++++-- src/Cassandra/Connections/IConnection.cs | 10 +++ .../Connections/IEndPointResolver.cs | 3 +- src/Cassandra/Connections/ITcpSocket.cs | 2 +- src/Cassandra/Connections/PortAllocator.cs | 74 +++++++++++++++++++ .../Connections/SniEndPointResolver.cs | 2 +- src/Cassandra/Connections/TcpSocket.cs | 8 +- src/Cassandra/ProtocolOptions.cs | 36 ++++++++- 11 files changed, 192 insertions(+), 39 deletions(-) create mode 100644 src/Cassandra/Connections/PortAllocator.cs diff --git a/src/Cassandra.Tests/HostConnectionPoolTests.cs b/src/Cassandra.Tests/HostConnectionPoolTests.cs index 8fe315afc..60fa94aa2 100644 --- a/src/Cassandra.Tests/HostConnectionPoolTests.cs +++ b/src/Cassandra.Tests/HostConnectionPoolTests.cs @@ -128,7 +128,7 @@ public void MaybeCreateFirstConnection_Should_Yield_The_First_Connection_Opened( var lastByte = 1; //use different addresses for same hosts to differentiate connections: for test only //different connections to same hosts should use the same address - _mock.Setup(p => p.DoCreateAndOpen(It.IsAny(), -1, 0)).Returns(() => TestHelper.DelayedTask(CreateConnection((byte)lastByte++), 200 - lastByte * 50)); + _mock.Setup(p => p.DoCreateAndOpen(It.IsAny(), -1, 0, 0)).Returns(() => TestHelper.DelayedTask(CreateConnection((byte)lastByte++), 200 - lastByte * 50)); var pool = _mock.Object; var creation = pool.EnsureCreate(); creation.Wait(); @@ -144,7 +144,7 @@ public async Task EnsureCreate_Should_Yield_A_Connection_If_Any_Fails() var counter = 0; //use different addresses for same hosts to differentiate connections: for test only //different connections to same hosts should use the same address - _mock.Setup(p => p.DoCreateAndOpen(It.IsAny(), -1, 0)).Returns(() => + _mock.Setup(p => p.DoCreateAndOpen(It.IsAny(), -1, 0, 0)).Returns(() => { if (++counter == 2) { @@ -162,7 +162,7 @@ public void EnsureCreate_Serial_Calls_Should_Yield_First() { _mock = GetPoolMock(); var lastByte = 1; - _mock.Setup(p => p.DoCreateAndOpen(It.IsAny(), -1, 0)).Returns(() => + _mock.Setup(p => p.DoCreateAndOpen(It.IsAny(), -1, 0, 0)).Returns(() => { var c = CreateConnection((byte)lastByte++); if (lastByte == 2) @@ -191,7 +191,7 @@ public void EnsureCreate_Parallel_Calls_Should_Yield_First() { _mock = GetPoolMock(); var lastByte = 0; - _mock.Setup(p => p.DoCreateAndOpen(It.IsAny(), -1, 0)).Returns(() => TestHelper.DelayedTask(CreateConnection((byte)++lastByte), 100 + (lastByte > 1 ? 10000 : 0))); + _mock.Setup(p => p.DoCreateAndOpen(It.IsAny(), -1, 0, 0)).Returns(() => TestHelper.DelayedTask(CreateConnection((byte)++lastByte), 100 + (lastByte > 1 ? 10000 : 0))); var pool = _mock.Object; var creationTasks = new Task[10]; var counter = -1; @@ -217,7 +217,7 @@ public void EnsureCreate_Parallel_Calls_Failing_Should_Only_Attempt_Creation_Onc // Use a reconnection policy that never attempts _mock = GetPoolMock(null, GetConfig(3, 3, new ConstantReconnectionPolicy(int.MaxValue))); var openConnectionAttempts = 0; - _mock.Setup(p => p.DoCreateAndOpen(It.IsAny(), -1, 0)).Returns(() => + _mock.Setup(p => p.DoCreateAndOpen(It.IsAny(), -1, 0, 0)).Returns(() => { Interlocked.Increment(ref openConnectionAttempts); return TaskHelper.FromException(new Exception("Test Exception")); @@ -254,7 +254,7 @@ public void EnsureCreate_Fail_To_Open_All_Connections_Should_Fault_Task() { _mock = GetPoolMock(); var testException = new Exception("Dummy exception"); - _mock.Setup(p => p.DoCreateAndOpen(It.IsAny(), -1, 0)).Returns(() => TestHelper.DelayedTask(() => + _mock.Setup(p => p.DoCreateAndOpen(It.IsAny(), -1, 0, 0)).Returns(() => TestHelper.DelayedTask(() => { throw testException; })); @@ -270,7 +270,7 @@ public async Task OnHostUp_Recreates_Pool_In_The_Background() _mock = GetPoolMock(null, GetConfig(2, 2)); var creationCounter = 0; var isCreating = 0; - _mock.Setup(p => p.DoCreateAndOpen(It.IsAny(), -1, 0)).Returns(() => + _mock.Setup(p => p.DoCreateAndOpen(It.IsAny(), -1, 0, 0)).Returns(() => { Interlocked.Increment(ref creationCounter); Interlocked.Exchange(ref isCreating, 1); @@ -293,7 +293,7 @@ public void OnHostUp_Does_Not_Recreates_Pool_For_Ignored_Hosts() { _mock = GetPoolMock(null, GetConfig(2, 2)); var creationCounter = 0; - _mock.Setup(p => p.DoCreateAndOpen(It.IsAny(), -1, 0)).Returns(() => + _mock.Setup(p => p.DoCreateAndOpen(It.IsAny(), -1, 0, 0)).Returns(() => { Interlocked.Increment(ref creationCounter); return TaskHelper.ToTask(CreateConnection()); @@ -314,7 +314,7 @@ public async Task EnsureCreate_After_Reconnection_Attempt_Waits_Existing() _mock = GetPoolMock(null, GetConfig(2, 2)); var creationCounter = 0; var isCreating = 0; - _mock.Setup(p => p.DoCreateAndOpen(It.IsAny(), -1, 0)).Returns(() => + _mock.Setup(p => p.DoCreateAndOpen(It.IsAny(), -1, 0, 0)).Returns(() => { Interlocked.Increment(ref creationCounter); Interlocked.Exchange(ref isCreating, 1); @@ -339,7 +339,7 @@ public async Task EnsureCreate_Can_Handle_Multiple_Concurrent_Calls() _mock = GetPoolMock(null, GetConfig(3, 3)); var creationCounter = 0; var isCreating = 0; - _mock.Setup(p => p.DoCreateAndOpen(It.IsAny(), -1, 0)).Returns(() => + _mock.Setup(p => p.DoCreateAndOpen(It.IsAny(), -1, 0, 0)).Returns(() => { Interlocked.Increment(ref creationCounter); Interlocked.Exchange(ref isCreating, 1); @@ -432,7 +432,7 @@ public void ScheduleReconnection_Should_Raise_AllConnectionClosed() { _mock = GetPoolMock(null, GetConfig(1, 1, new ConstantReconnectionPolicy(100))); var openConnectionsAttempts = 0; - _mock.Setup(p => p.DoCreateAndOpen(It.IsAny(), -1, 0)).Returns(() => + _mock.Setup(p => p.DoCreateAndOpen(It.IsAny(), -1, 0, 0)).Returns(() => { Interlocked.Increment(ref openConnectionsAttempts); return TaskHelper.FromException(new Exception("Test Exception")); @@ -455,7 +455,7 @@ public void ScheduleReconnection_Should_Not_Raise_AllConnectionClosed_When_Host_ host.SetDown(); _mock = GetPoolMock(host, GetConfig(1, 1, new ConstantReconnectionPolicy(100))); var openConnectionsAttempts = 0; - _mock.Setup(p => p.DoCreateAndOpen(It.IsAny(), -1, 0)).Returns(() => + _mock.Setup(p => p.DoCreateAndOpen(It.IsAny(), -1, 0, 0)).Returns(() => { Interlocked.Increment(ref openConnectionsAttempts); return TaskHelper.FromException(new Exception("Test Exception")); @@ -476,7 +476,7 @@ public void ScheduleReconnection_Should_Not_Raise_AllConnectionClosed_When_Host_ public async Task CheckHealth_Removes_Connection() { _mock = GetPoolMock(); - _mock.Setup(p => p.DoCreateAndOpen(It.IsAny(), -1, 0)).Returns( + _mock.Setup(p => p.DoCreateAndOpen(It.IsAny(), -1, 0, 0)).Returns( () => { var cc = HostConnectionPoolTests.GetConnectionMock(0, int.MaxValue); @@ -499,7 +499,7 @@ public async Task CheckHealth_Removes_Connection() public async Task Pool_Increasing_Size_And_Closing_Should_Not_Leave_Connections_Open([Range(0, 29)] int delay) { _mock = GetPoolMock(null, GetConfig(50, 50)); - _mock.Setup(p => p.DoCreateAndOpen(It.IsAny(), -1, 0)).Returns(async () => + _mock.Setup(p => p.DoCreateAndOpen(It.IsAny(), -1, 0, 0)).Returns(async () => { await Task.Yield(); var spinWait = new SpinWait(); @@ -537,7 +537,7 @@ await Task.Run(() => public async Task Dispose_Should_Not_Raise_AllConnections_Closed() { _mock = GetPoolMock(null, GetConfig(4, 4)); - _mock.Setup(p => p.DoCreateAndOpen(It.IsAny(), -1, 0)).Returns(() => TaskHelper.ToTask(CreateConnection())); + _mock.Setup(p => p.DoCreateAndOpen(It.IsAny(), -1, 0, 0)).Returns(() => TaskHelper.ToTask(CreateConnection())); var pool = _mock.Object; Assert.AreEqual(0, pool.OpenConnections); pool.SetDistance(HostDistance.Local); @@ -555,7 +555,7 @@ public async Task Dispose_Should_Cancel_Reconnection_Attempts() { _mock = GetPoolMock(null, GetConfig(4, 4, new ConstantReconnectionPolicy(200))); var openConnectionAttempts = 0; - _mock.Setup(p => p.DoCreateAndOpen(It.IsAny(), -1, 0)).Returns(() => + _mock.Setup(p => p.DoCreateAndOpen(It.IsAny(), -1, 0, 0)).Returns(() => { Interlocked.Increment(ref openConnectionAttempts); return TaskHelper.ToTask(CreateConnection()); @@ -575,7 +575,7 @@ public void Warmup_Should_Throw_When_The_First_Connection_Can_Not_Be_Opened() { _mock = GetPoolMock(null, GetConfig(4, 4, new ConstantReconnectionPolicy(200))); var openConnectionAttempts = 0; - _mock.Setup(p => p.DoCreateAndOpen(It.IsAny(), -1, 0)).Returns(() => + _mock.Setup(p => p.DoCreateAndOpen(It.IsAny(), -1, 0, 0)).Returns(() => { var index = Interlocked.Increment(ref openConnectionAttempts); if (index == 1) @@ -596,7 +596,7 @@ public void Warmup_Should_Succeed_When_The_Second_Connection_Can_Not_Be_Opened() { _mock = GetPoolMock(null, GetConfig(4, 4, new ConstantReconnectionPolicy(200))); var openConnectionAttempts = 0; - _mock.Setup(p => p.DoCreateAndOpen(It.IsAny(), -1, 0)).Returns(() => + _mock.Setup(p => p.DoCreateAndOpen(It.IsAny(), -1, 0, 0)).Returns(() => { var index = Interlocked.Increment(ref openConnectionAttempts); if (index == 2) @@ -617,7 +617,7 @@ public void Warmup_Should_Succeed_When_All_Connections_Can_Be_Opened() { _mock = GetPoolMock(null, GetConfig(4, 4, new ConstantReconnectionPolicy(200))); var openConnectionAttempts = 0; - _mock.Setup(p => p.DoCreateAndOpen(It.IsAny(), -1, 0)).Returns(() => + _mock.Setup(p => p.DoCreateAndOpen(It.IsAny(), -1, 0, 0)).Returns(() => { Interlocked.Increment(ref openConnectionAttempts); return TaskHelper.ToTask(CreateConnection()); diff --git a/src/Cassandra/Connections/Connection.cs b/src/Cassandra/Connections/Connection.cs index f9b5c4369..ba1a4a8a7 100644 --- a/src/Cassandra/Connections/Connection.cs +++ b/src/Cassandra/Connections/Connection.cs @@ -459,11 +459,24 @@ private void IdleTimeoutHandler(object state) /// /// public async Task Open() + { + return await Open(-1, 0); + } + + /// + /// Initializes the connection. + /// + /// Shard ID + /// Shard count + /// Throws a SocketException when the connection could not be established with the host + /// + /// + public async Task Open(int shardID = -1, int shardCount = 0) { try { Connection.Logger.Verbose("Attempting to open Connection #{0} to {1}", GetHashCode(), EndPoint.EndpointFriendlyName); - var response = await DoOpen().ConfigureAwait(false); + var response = await DoOpen(shardID, shardCount).ConfigureAwait(false); Connection.Logger.Verbose("Opened Connection #{0} to {1} with local endpoint {2}.", GetHashCode(), EndPoint.EndpointFriendlyName, _tcpSocket.GetLocalIpEndPoint()?.ToString()); return response; } @@ -480,7 +493,7 @@ public async Task Open() /// Throws a SocketException when the connection could not be established with the host /// /// - public async Task DoOpen() + public async Task DoOpen(int shardID = -1, int shardCount = 0) { //Init TcpSocket _tcpSocket.Error += OnSocketError; @@ -489,7 +502,20 @@ public async Task DoOpen() _tcpSocket.Read += ReadHandler; _tcpSocket.WriteCompleted += WriteCompletedHandler; var protocolVersion = Serializer.ProtocolVersion; - await _tcpSocket.Connect().ConfigureAwait(false); + if (shardID != -1) + { + var localPort = PortAllocator.GetNextAvailablePort(shardCount, shardID, Options.LocalPortLow, Options.LocalPortHigh); + if (localPort == -1) + { + throw new SocketException((int)SocketError.NoData); + } + await _tcpSocket.Connect(localPort).ConfigureAwait(false); + } + else + { + await _tcpSocket.Connect().ConfigureAwait(false); + } + // Send the OPTIONS message Response optionsResponse; diff --git a/src/Cassandra/Connections/EndPointResolver.cs b/src/Cassandra/Connections/EndPointResolver.cs index 449443b5d..1db15806a 100644 --- a/src/Cassandra/Connections/EndPointResolver.cs +++ b/src/Cassandra/Connections/EndPointResolver.cs @@ -30,7 +30,7 @@ public EndPointResolver(IServerNameResolver serverNameResolver) } /// - public Task GetConnectionShardAwareEndPointAsync(Host host, bool refreshCache, int shardID, int shardAwarePort) + public Task GetConnectionShardAwareEndPointAsync(Host host, bool refreshCache, int shardAwarePort) { return Task.FromResult((IConnectionEndPoint)new ConnectionEndPoint(new IPEndPoint(IPAddress.Parse(host.Address.ToString().Split(':')[0]), shardAwarePort), _serverNameResolver, host.ContactPoint)); } diff --git a/src/Cassandra/Connections/HostConnectionPool.cs b/src/Cassandra/Connections/HostConnectionPool.cs index 7241f3ca7..e966a94a0 100644 --- a/src/Cassandra/Connections/HostConnectionPool.cs +++ b/src/Cassandra/Connections/HostConnectionPool.cs @@ -277,12 +277,12 @@ public void Dispose() Interlocked.Exchange(ref _state, PoolState.Shutdown); } - public virtual async Task DoCreateAndOpen(bool isReconnection, int shardID = -1, int shardAwarePort = 0) + public virtual async Task DoCreateAndOpen(bool isReconnection, int shardID = -1, int shardAwarePort = 0, int shardCount = 0) { IConnectionEndPoint endPoint; - if (shardID != -1 && shardAwarePort != 0) + if (shardAwarePort != 0) { - endPoint = await _config.EndPointResolver.GetConnectionShardAwareEndPointAsync(_host, isReconnection, shardID, shardAwarePort).ConfigureAwait(false); + endPoint = await _config.EndPointResolver.GetConnectionShardAwareEndPointAsync(_host, isReconnection, shardAwarePort).ConfigureAwait(false); } else { @@ -296,7 +296,15 @@ public virtual async Task DoCreateAndOpen(bool isReconnection, int } try { - await c.Open().ConfigureAwait(false); + if (shardID != -1) + { + await c.Open(shardID, shardCount).ConfigureAwait(false); + } + else + { + await c.Open().ConfigureAwait(false); + } + } catch { @@ -722,15 +730,17 @@ private async Task CreateOpenConnection(bool satisfyWithAnOpenConne // Find out to which shard should we connect to var shardID = -1; var shardAwarePort = 0; + var shardCount = 0; if (shardingInfo != null) { shardAwarePort = shardingInfo.ScyllaShardAwarePort; + shardCount = shardingInfo.ScyllaNrShards; // Find the shard without a connection // It's important to start counting from 1 here because we want // to consider the next shard after the previously attempted one - for (var i = 1; i <= shardingInfo.ScyllaNrShards; i++) + for (var i = 1; i <= shardCount; i++) { - var _shardID = (lastAttemptedShard + i) % shardingInfo.ScyllaNrShards; + var _shardID = (lastAttemptedShard + i) % shardCount; if (connectionsSnapshot.Length <= _shardID || connectionsSnapshot[_shardID] == null) { lastAttemptedShard = _shardID; @@ -739,7 +749,7 @@ private async Task CreateOpenConnection(bool satisfyWithAnOpenConne } } } - c = await DoCreateAndOpen(isReconnection, shardID, shardAwarePort).ConfigureAwait(false); + c = await DoCreateAndOpen(isReconnection, shardID, shardAwarePort, shardCount).ConfigureAwait(false); } catch (Exception ex) { diff --git a/src/Cassandra/Connections/IConnection.cs b/src/Cassandra/Connections/IConnection.cs index 41fd44097..92eee2806 100644 --- a/src/Cassandra/Connections/IConnection.cs +++ b/src/Cassandra/Connections/IConnection.cs @@ -107,6 +107,16 @@ internal interface IConnection : IDisposable /// Task Open(); + /// + /// Initializes the connection. + /// + /// The shard ID + /// The shard count + /// Throws a SocketException when the connection could not be established with the host + /// + /// + Task Open(int shardID, int shardCount); + /// /// Sends a new request if possible. If it is not possible it queues it up. /// diff --git a/src/Cassandra/Connections/IEndPointResolver.cs b/src/Cassandra/Connections/IEndPointResolver.cs index 2c983670c..d9921e73e 100644 --- a/src/Cassandra/Connections/IEndPointResolver.cs +++ b/src/Cassandra/Connections/IEndPointResolver.cs @@ -40,9 +40,8 @@ internal interface IEndPointResolver /// Host related to the new endpoint. /// Whether to refresh the internal cache. If it is false and the cache is populated then /// no round trip will occur. - /// Shard ID. /// Shard aware port. /// Endpoint. - Task GetConnectionShardAwareEndPointAsync(Host host, bool refreshCache, int shardID, int shardAwarePort); + Task GetConnectionShardAwareEndPointAsync(Host host, bool refreshCache, int shardAwarePort); } } diff --git a/src/Cassandra/Connections/ITcpSocket.cs b/src/Cassandra/Connections/ITcpSocket.cs index a0e5acb9f..7e0ae2b7e 100644 --- a/src/Cassandra/Connections/ITcpSocket.cs +++ b/src/Cassandra/Connections/ITcpSocket.cs @@ -57,7 +57,7 @@ internal interface ITcpSocket : IDisposable /// Connects asynchronously to the host and starts reading /// /// Throws a SocketException when the connection could not be established with the host - Task Connect(); + Task Connect(int localPort = -1); /// /// Sends data asynchronously diff --git a/src/Cassandra/Connections/PortAllocator.cs b/src/Cassandra/Connections/PortAllocator.cs new file mode 100644 index 000000000..180323529 --- /dev/null +++ b/src/Cassandra/Connections/PortAllocator.cs @@ -0,0 +1,74 @@ +using System; +using System.Net; +using System.Net.Sockets; +using System.Threading; + +static class PortAllocator +{ + private static int lastPort = -1; + + public static int GetNextAvailablePort(int shardCount, int shardId, int lowPort, int highPort) + { + int foundPort = -1; + int lastPortValue; + + do + { + lastPortValue = Volatile.Read(ref lastPort); + + int scanStart = lastPortValue == -1 ? lowPort : lastPortValue; + if (scanStart < lowPort) + { + scanStart = lowPort; + } + + scanStart += (shardCount - scanStart % shardCount) + shardId; + + for (int port = scanStart; port <= highPort; port += shardCount) + { + if (IsTcpPortAvailable(port)) + { + foundPort = port; + break; + } + } + + if (foundPort == -1) + { + scanStart = lowPort + (shardCount - lowPort % shardCount) + shardId; + + for (int port = scanStart; port <= highPort; port += shardCount) + { + if (IsTcpPortAvailable(port)) + { + foundPort = port; + break; + } + } + } + + if (foundPort == -1) + { + return -1; + } + } + while (Interlocked.CompareExchange(ref lastPort, foundPort, lastPortValue) != lastPortValue); + + return foundPort; + } + + public static bool IsTcpPortAvailable(int port) + { + try + { + TcpListener listener = new TcpListener(IPAddress.Loopback, port); + listener.Start(); + listener.Stop(); + return true; + } + catch (SocketException) + { + return false; + } + } +} diff --git a/src/Cassandra/Connections/SniEndPointResolver.cs b/src/Cassandra/Connections/SniEndPointResolver.cs index 04f07a011..64d96d94e 100644 --- a/src/Cassandra/Connections/SniEndPointResolver.cs +++ b/src/Cassandra/Connections/SniEndPointResolver.cs @@ -56,7 +56,7 @@ public SniEndPointResolver( { } - public async Task GetConnectionShardAwareEndPointAsync(Host host, bool refreshCache, int shardID, int shardAwarePort) + public async Task GetConnectionShardAwareEndPointAsync(Host host, bool refreshCache, int shardAwarePort) { return new SniConnectionEndPoint( await GetNextEndPointAsync(refreshCache).ConfigureAwait(false), diff --git a/src/Cassandra/Connections/TcpSocket.cs b/src/Cassandra/Connections/TcpSocket.cs index 602d1c4e1..554fa877a 100644 --- a/src/Cassandra/Connections/TcpSocket.cs +++ b/src/Cassandra/Connections/TcpSocket.cs @@ -137,13 +137,19 @@ public IPEndPoint GetLocalIpEndPoint() /// Connects asynchronously to the host and starts reading /// /// Throws a SocketException when the connection could not be established with the host - public async Task Connect() + public async Task Connect(int localPort = -1) { var tcs = TaskHelper.TaskCompletionSourceWithTimeout( Options.ConnectTimeoutMillis, () => new SocketException((int)SocketError.TimedOut)); var socketConnectTask = tcs.Task; + if (localPort != -1) + { + var localEndPoint = new IPEndPoint(IPAddress.Any, localPort); + _socket.Bind(localEndPoint); + } + using (var eventArgs = new SocketAsyncEventArgs { RemoteEndPoint = EndPoint.SocketIpEndPoint }) { eventArgs.Completed += (sender, e) => { OnConnectComplete(tcs, e); }; diff --git a/src/Cassandra/ProtocolOptions.cs b/src/Cassandra/ProtocolOptions.cs index b20230b96..d69d74d3e 100644 --- a/src/Cassandra/ProtocolOptions.cs +++ b/src/Cassandra/ProtocolOptions.cs @@ -14,6 +14,8 @@ // limitations under the License. // +using System; + namespace Cassandra { /// @@ -36,17 +38,22 @@ public class ProtocolOptions /// public const int DefaultMaxSchemaAgreementWaitSeconds = 10; + public const int DefaultLocalPortLow = 10000; + public const int DefaultLocalPortHigh = 60000; + private readonly int _port; private readonly SSLOptions _sslOptions; private CompressionType _compression = CompressionType.NoCompression; private IFrameCompressor _compressor; private int _maxSchemaAgreementWaitSeconds = ProtocolOptions.DefaultMaxSchemaAgreementWaitSeconds; + private int _localPortLow = ProtocolOptions.DefaultLocalPortLow; + private int _localPortHigh = ProtocolOptions.DefaultLocalPortHigh; private ProtocolVersion? _maxProtocolVersion; /// /// The port used to connect to the Cassandra hosts. /// - /// + /// /// the port used to connect to the Cassandra hosts. public int Port { @@ -56,7 +63,7 @@ public int Port /// /// Specified SSL options used to connect to the Cassandra hosts. /// - /// + /// /// SSL options used to connect to the Cassandra hosts. public SSLOptions SslOptions { @@ -89,6 +96,16 @@ public int MaxSchemaAgreementWaitSeconds get { return _maxSchemaAgreementWaitSeconds; } } + public int LocalPortLow + { + get { return _localPortLow; } + } + + public int LocalPortHigh + { + get { return _localPortHigh; } + } + /// /// Determines whether NO_COMPACT is enabled as startup option. /// @@ -137,8 +154,8 @@ public ProtocolOptions(int port) } - /// - /// Creates a new ProtocolOptions instance using the provided port and SSL context. + /// + /// Creates a new ProtocolOptions instance using the provided port and SSL context. /// /// the port to use for the binary protocol. /// sslOptions the SSL options to use. Use null if SSL is not to be used. @@ -182,6 +199,17 @@ public ProtocolOptions SetMaxSchemaAgreementWaitSeconds(int value) return this; } + public ProtocolOptions SetLocalPortRange(int low, int high) + { + if (low < 1 || 65535 < low || high < 1 || 65535 < high) + { + throw new ArgumentOutOfRangeException("Port numbers must be between 1 and 65535"); + } + _localPortLow = low; + _localPortHigh = high; + return this; + } + /// /// Sets the maximum protocol version to be used. /// When set, it limits the maximum protocol version used to connect to the nodes. From 1dd22c928c07e6d868a8bc6f242a5b59bd3a0268 Mon Sep 17 00:00:00 2001 From: Sylwia Szunejko Date: Fri, 11 Apr 2025 10:03:03 +0200 Subject: [PATCH 08/14] Use shard information for query routing --- .../ShardAwarenessTests.cs | 79 ++++++++++++ .../Connections/HostConnectionPoolTests.cs | 3 +- src/Cassandra/Connections/Connection.cs | 12 +- .../Connections/HostConnectionPool.cs | 120 +++++++++++------- .../Connections/HostConnectionPoolFactory.cs | 4 +- src/Cassandra/Connections/IConnection.cs | 2 +- .../Connections/IHostConnectionPool.cs | 12 +- .../Connections/IHostConnectionPoolFactory.cs | 2 +- src/Cassandra/Connections/ShardingInfo.cs | 75 +++++++++++ src/Cassandra/Metadata.cs | 9 ++ src/Cassandra/Requests/IReprepareHandler.cs | 8 +- src/Cassandra/Requests/ReprepareHandler.cs | 10 +- src/Cassandra/Requests/RequestHandler.cs | 13 +- src/Cassandra/Session.cs | 2 +- src/Cassandra/ShardingInfo.cs | 45 ------- 15 files changed, 277 insertions(+), 119 deletions(-) create mode 100644 src/Cassandra.IntegrationTests/ShardAwarenessTests.cs create mode 100644 src/Cassandra/Connections/ShardingInfo.cs delete mode 100644 src/Cassandra/ShardingInfo.cs diff --git a/src/Cassandra.IntegrationTests/ShardAwarenessTests.cs b/src/Cassandra.IntegrationTests/ShardAwarenessTests.cs new file mode 100644 index 000000000..459b9301e --- /dev/null +++ b/src/Cassandra.IntegrationTests/ShardAwarenessTests.cs @@ -0,0 +1,79 @@ +using System; +using System.Collections.Generic; +using System.Linq; +using System.Diagnostics; +using Cassandra.Connections.Control; +using Cassandra.IntegrationTests.TestBase; +using Cassandra.IntegrationTests.TestClusterManagement; +using Cassandra.SessionManagement; +using NUnit.Framework; + +namespace Cassandra.IntegrationTests +{ + [TestFixture] + public class ShardAwarenessTest : TestGlobals + { + private ITestCluster _realCluster; + + [TearDown] + public void TestTearDown() + { + TestClusterManager.TryRemove(); + _realCluster = null; + } + + [Test] + public void CorrectShardInTracingTest() + { + _realCluster = TestClusterManager.CreateNew(); + var cluster = ClusterBuilder() + .WithSocketOptions(new SocketOptions().SetReadTimeoutMillis(22000).SetConnectTimeoutMillis(60000)) + .AddContactPoint(_realCluster.InitialContactPoint) + .Build(); + var _session = cluster.Connect(); + + _session.Execute("DROP KEYSPACE IF EXISTS shardawaretest"); + _session.Execute("CREATE KEYSPACE shardawaretest WITH replication = {'class': 'SimpleStrategy', 'replication_factor': '3'}"); + _session.Execute("CREATE TABLE shardawaretest.t (pk text, ck text, v text, PRIMARY KEY (pk, ck))"); + + var populateStatement = _session.Prepare("INSERT INTO shardawaretest.t (pk, ck, v) VALUES (?, ?, ?)"); + _session.Execute(populateStatement.Bind("a", "b", "c")); + _session.Execute(populateStatement.Bind("e", "f", "g")); + _session.Execute(populateStatement.Bind("100002", "f", "g")); + + VerifyCorrectShardSingleRow(_session, "a", "b", "c", "shard 0"); + VerifyCorrectShardSingleRow(_session, "e", "f", "g", "shard 0"); + VerifyCorrectShardSingleRow(_session, "100002", "f", "g", "shard 1"); + } + + private void VerifyCorrectShardSingleRow(ISession _session, string pk, string ck, string v, string shard) + { + var prepared = _session.Prepare("SELECT pk, ck, v FROM shardawaretest.t WHERE pk=? AND ck=?"); + var result = _session.Execute(prepared.Bind(pk, ck).EnableTracing()); + + var row = result.First(); + Assert.IsNotNull(row); + Assert.AreEqual(pk, row.GetValue("pk")); + Assert.AreEqual(ck, row.GetValue("ck")); + Assert.AreEqual(v, row.GetValue("v")); + + var executionInfo = result.Info; + var trace = executionInfo.QueryTrace; + bool anyLocal = false; + foreach (var eventItem in trace.Events) + { + Trace.TraceInformation(" {0} - {1} - [{2}] - {3}", + eventItem.SourceElapsedMicros, + eventItem.Source, + eventItem.ThreadName, + eventItem.Description); + Assert.IsTrue(eventItem.ThreadName.StartsWith(shard)); + if (eventItem.Description.Contains("querying locally")) + { + anyLocal = true; + } + } + Assert.IsTrue(anyLocal); + } + } +} diff --git a/src/Cassandra.Tests/Connections/HostConnectionPoolTests.cs b/src/Cassandra.Tests/Connections/HostConnectionPoolTests.cs index d34c42f23..e24cd370a 100644 --- a/src/Cassandra.Tests/Connections/HostConnectionPoolTests.cs +++ b/src/Cassandra.Tests/Connections/HostConnectionPoolTests.cs @@ -116,7 +116,8 @@ private IHostConnectionPool CreatePool(IEndPointResolver res = null) _host, config, SerializerManager.Default, - new MetricsObserverFactory(new MetricsManager(new NullDriverMetricsProvider(), new DriverMetricsOptions(), false, "s1")) + new MetricsObserverFactory(new MetricsManager(new NullDriverMetricsProvider(), new DriverMetricsOptions(), false, "s1")), + M3PToken.Factory ); pool.SetDistance(HostDistance.Local); // set expected connections length diff --git a/src/Cassandra/Connections/Connection.cs b/src/Cassandra/Connections/Connection.cs index ba1a4a8a7..2ba40a020 100644 --- a/src/Cassandra/Connections/Connection.cs +++ b/src/Cassandra/Connections/Connection.cs @@ -166,7 +166,8 @@ public string Keyspace private readonly ISupportedOptionsInitializer _supportedOptionsInitializer; - public int ShardId { get; } + public int ShardID { get; set; } + private int _requestedShardID { get; set; } internal Connection( ISerializer serializer, @@ -473,6 +474,7 @@ public async Task Open() /// public async Task Open(int shardID = -1, int shardCount = 0) { + _requestedShardID = shardID; try { Connection.Logger.Verbose("Attempting to open Connection #{0} to {1}", GetHashCode(), EndPoint.EndpointFriendlyName); @@ -534,6 +536,14 @@ public async Task DoOpen(int shardID = -1, int shardCount = 0) throw; } _supportedOptionsInitializer.ApplySupportedFromResponse(optionsResponse); + if (_supportedOptionsInitializer.GetShardingInfo() != null) + { + ShardID = _supportedOptionsInitializer.GetShardingInfo().ScyllaShard; + if (_requestedShardID != -1 && ShardID != _requestedShardID) + { + Connection.Logger.Warning("Requested connection to shard {1}, but connected to {2}. Is there a NAT between client and server?", _requestedShardID, ShardID); + } + } Response response; try diff --git a/src/Cassandra/Connections/HostConnectionPool.cs b/src/Cassandra/Connections/HostConnectionPool.cs index e966a94a0..55a73f4a7 100644 --- a/src/Cassandra/Connections/HostConnectionPool.cs +++ b/src/Cassandra/Connections/HostConnectionPool.cs @@ -34,6 +34,7 @@ namespace Cassandra.Connections internal class HostConnectionPool : IHostConnectionPool { private static readonly Logger Logger = new Logger(typeof(HostConnectionPool)); + private Random rand = new Random(); private const int ConnectionIndexOverflow = int.MaxValue - 1000000; private const long BetweenResizeDelay = 2000; @@ -112,7 +113,9 @@ private static class PoolState private int lastAttemptedShard = 0; - public HostConnectionPool(Host host, Configuration config, ISerializerManager serializerManager, IObserverFactory observerFactory) + private TokenFactory _tokenFactory; + + public HostConnectionPool(Host host, Configuration config, ISerializerManager serializerManager, IObserverFactory observerFactory, TokenFactory tokenFactory) { _host = host; _host.Down += OnHostDown; @@ -126,10 +129,11 @@ public HostConnectionPool(Host host, Configuration config, ISerializerManager se _timer = config.Timer; _reconnectionSchedule = config.Policies.ReconnectionPolicy.NewSchedule(); _expectedConnectionLength = 1; + _tokenFactory = tokenFactory; } /// - public async Task BorrowConnectionAsync() + public async Task BorrowConnectionAsync(RoutingKey routingKey = null) { var connections = await EnsureCreate().ConfigureAwait(false); if (connections.Length == 0) @@ -137,11 +141,11 @@ public async Task BorrowConnectionAsync() throw new DriverInternalError("No connection could be borrowed"); } - return BorrowLeastBusyConnection(connections); + return BorrowLeastBusyConnection(connections, routingKey); } /// - public IConnection BorrowExistingConnection() + public IConnection BorrowExistingConnection(RoutingKey routingKey) { var connections = GetExistingConnections(); if (connections.Length == 0) @@ -149,12 +153,52 @@ public IConnection BorrowExistingConnection() return null; } - return BorrowLeastBusyConnection(connections); + return BorrowLeastBusyConnection(connections, routingKey); + } + + private IConnection ConnectionForShard(IConnection[] connections, int shardID) + { + for (int i = 0; i < connections.Length; i++) + { + if (connections[i] != null && connections[i].ShardID == shardID) + { + return connections[i]; + } + } + return null; } - private IConnection BorrowLeastBusyConnection(IConnection[] connections) + private IConnection BorrowLeastBusyConnection(IConnection[] connections, RoutingKey routingKey = null) { - var c = HostConnectionPool.MinInFlight(connections, ref _connectionIndex, _maxRequestsPerConnection, out var inFlight); + int shardID = 0; + if (shardingInfo != null) + { + if (routingKey != null) + { + IToken token = _tokenFactory.Hash(routingKey.RawRoutingKey); + shardID = shardingInfo.ShardID(token); + } + else + { + shardID = rand.Next(shardingInfo.ScyllaNrShards); + } + } + + IConnection c = ConnectionForShard(connections, shardID); + var inFlight = 0; + if (c != null) + { + // if we have a connection for the shard, use it if it is not too busy + inFlight = c.InFlight; + if (inFlight >= _maxRequestsPerConnection) + { + c = HostConnectionPool.MinInFlight(connections, ref _connectionIndex, _maxRequestsPerConnection, out inFlight); + } + } + else + { + c = HostConnectionPool.MinInFlight(connections, ref _connectionIndex, _maxRequestsPerConnection, out inFlight); + } if (inFlight >= _maxRequestsPerConnection) { @@ -668,13 +712,10 @@ private async Task CreateOrScheduleReconnectAsync(IReconnectionSchedule schedule /// Determines whether the Task should be marked as completed when there is a connection already opened. /// /// Determines whether this is a reconnection - /// - /// Determines whether the connection should be added to the pool. - /// /// Throws a SocketException when the connection could not be established with the host /// /// - private async Task CreateOpenConnection(bool satisfyWithAnOpenConnection, bool isReconnection, bool addToConnections = true) + private async Task CreateOpenConnection(bool satisfyWithAnOpenConnection, bool isReconnection) { var concurrentOpenTcs = Volatile.Read(ref _connectionOpenTcs); // Try to exit early (cheap) as there could be another thread creating / finishing creating @@ -733,7 +774,7 @@ private async Task CreateOpenConnection(bool satisfyWithAnOpenConne var shardCount = 0; if (shardingInfo != null) { - shardAwarePort = shardingInfo.ScyllaShardAwarePort; + shardAwarePort = _config.ProtocolOptions.SslOptions != null ? shardingInfo.ScyllaShardAwarePortSSL : shardingInfo.ScyllaShardAwarePort; shardCount = shardingInfo.ScyllaNrShards; // Find the shard without a connection // It's important to start counting from 1 here because we want @@ -741,7 +782,7 @@ private async Task CreateOpenConnection(bool satisfyWithAnOpenConne for (var i = 1; i <= shardCount; i++) { var _shardID = (lastAttemptedShard + i) % shardCount; - if (connectionsSnapshot.Length <= _shardID || connectionsSnapshot[_shardID] == null) + if (ConnectionForShard(connectionsSnapshot, _shardID) == null) { lastAttemptedShard = _shardID; shardID = _shardID; @@ -750,6 +791,10 @@ private async Task CreateOpenConnection(bool satisfyWithAnOpenConne } } c = await DoCreateAndOpen(isReconnection, shardID, shardAwarePort, shardCount).ConfigureAwait(false); + if (c != null && c.ShardID != -1) + { + lastAttemptedShard = c.ShardID; + } } catch (Exception ex) { @@ -765,12 +810,9 @@ private async Task CreateOpenConnection(bool satisfyWithAnOpenConne return await FinishOpen(tcs, false, HostConnectionPool.GetNotConnectedException()).ConfigureAwait(false); } - if (addToConnections) - { - var newLength = _connections.AddNew(c); - HostConnectionPool.Logger.Info("Connection to {0} opened successfully, pool #{1} length: {2}", - _host.Address, GetHashCode(), newLength); - } + var newLength = _connections.AddNew(c); + HostConnectionPool.Logger.Info("Connection to {0} opened successfully, pool #{1} length: {2}", + _host.Address, GetHashCode(), newLength); if (IsClosing) { @@ -903,31 +945,31 @@ public void MarkAsDownAndScheduleReconnection() /// public Task GetConnectionFromHostAsync( - IDictionary triedHosts, Func getKeyspaceFunc) + IDictionary triedHosts, Func getKeyspaceFunc, RoutingKey routingKey) { - return GetConnectionFromHostAsync(triedHosts, getKeyspaceFunc, true); + return GetConnectionFromHostAsync(triedHosts, getKeyspaceFunc, true, routingKey); } /// public Task GetExistingConnectionFromHostAsync( - IDictionary triedHosts, Func getKeyspaceFunc) + IDictionary triedHosts, Func getKeyspaceFunc, RoutingKey routingKey) { - return GetConnectionFromHostAsync(triedHosts, getKeyspaceFunc, false); + return GetConnectionFromHostAsync(triedHosts, getKeyspaceFunc, false, routingKey); } private async Task GetConnectionFromHostAsync( - IDictionary triedHosts, Func getKeyspaceFunc, bool createIfNeeded) + IDictionary triedHosts, Func getKeyspaceFunc, bool createIfNeeded, RoutingKey routingKey) { IConnection c = null; try { if (createIfNeeded) { - c = await BorrowConnectionAsync().ConfigureAwait(false); + c = await BorrowConnectionAsync(routingKey).ConfigureAwait(false); } else { - c = BorrowExistingConnection(); + c = BorrowExistingConnection(routingKey); } } catch (UnsupportedProtocolVersionException ex) @@ -989,24 +1031,15 @@ private async Task GetConnectionFromHostAsync( /// public async Task Warmup() { - var length = _expectedConnectionLength; - // Open first connection try { - var c = await CreateOpenConnection(false, false, false).ConfigureAwait(false); + var c = await CreateOpenConnection(false, false).ConfigureAwait(false); var _shardingInfo = c.ShardingInfo(); if (_shardingInfo != null) { shardingInfo = _shardingInfo; - var nrShards = _shardingInfo.ScyllaNrShards; - if (nrShards > length) - { - // Create the rest of the connections - length = nrShards; - _expectedConnectionLength = nrShards; - } + _expectedConnectionLength = _shardingInfo.ScyllaNrShards; } - c.Dispose(); } catch { @@ -1014,7 +1047,9 @@ public async Task Warmup() throw; } - for (var i = 0; i < length; i++) + var length = _expectedConnectionLength; + + for (var i = 1; i < length; i++) { try { @@ -1022,14 +1057,7 @@ public async Task Warmup() } catch { - if (i > 0) - { - // There is an opened connection, don't mind - break; - } - - OnConnectionClosing(); - throw; + break; } } } diff --git a/src/Cassandra/Connections/HostConnectionPoolFactory.cs b/src/Cassandra/Connections/HostConnectionPoolFactory.cs index 54166f7aa..ae85d3bea 100644 --- a/src/Cassandra/Connections/HostConnectionPoolFactory.cs +++ b/src/Cassandra/Connections/HostConnectionPoolFactory.cs @@ -21,9 +21,9 @@ namespace Cassandra.Connections { internal class HostConnectionPoolFactory : IHostConnectionPoolFactory { - public IHostConnectionPool Create(Host host, Configuration config, ISerializerManager serializerManager, IObserverFactory observerFactory) + public IHostConnectionPool Create(Host host, Configuration config, ISerializerManager serializerManager, IObserverFactory observerFactory, TokenFactory tokenFactory) { - return new HostConnectionPool(host, config, serializerManager, observerFactory); + return new HostConnectionPool(host, config, serializerManager, observerFactory, tokenFactory); } } } diff --git a/src/Cassandra/Connections/IConnection.cs b/src/Cassandra/Connections/IConnection.cs index 92eee2806..698ed5996 100644 --- a/src/Cassandra/Connections/IConnection.cs +++ b/src/Cassandra/Connections/IConnection.cs @@ -153,6 +153,6 @@ internal interface IConnection : IDisposable /// ShardingInfo ShardingInfo(); - int ShardId { get; } + int ShardID { get; set; } } } diff --git a/src/Cassandra/Connections/IHostConnectionPool.cs b/src/Cassandra/Connections/IHostConnectionPool.cs index 513a9fc18..21f4c9a36 100644 --- a/src/Cassandra/Connections/IHostConnectionPool.cs +++ b/src/Cassandra/Connections/IHostConnectionPool.cs @@ -25,12 +25,12 @@ namespace Cassandra.Connections internal interface IHostConnectionPool : IDisposable { /// - /// Gets the total amount of open connections. + /// Gets the total amount of open connections. /// int OpenConnections { get; } /// - /// Gets the total of in-flight requests on all connections. + /// Gets the total of in-flight requests on all connections. /// int InFlight { get; } @@ -55,7 +55,7 @@ internal interface IHostConnectionPool : IDisposable /// /// /// - Task BorrowConnectionAsync(); + Task BorrowConnectionAsync(RoutingKey routingKey = null); /// /// Gets an open connection from the host pool. It does NOT create one if necessary (for that use . @@ -63,7 +63,7 @@ internal interface IHostConnectionPool : IDisposable /// /// /// Not connected. - IConnection BorrowExistingConnection(); + IConnection BorrowExistingConnection(RoutingKey routingKey); void SetDistance(HostDistance distance); @@ -92,9 +92,9 @@ internal interface IHostConnectionPool : IDisposable void MarkAsDownAndScheduleReconnection(); Task GetConnectionFromHostAsync( - IDictionary triedHosts, Func getKeyspaceFunc); + IDictionary triedHosts, Func getKeyspaceFunc, RoutingKey routingKey); Task GetExistingConnectionFromHostAsync( - IDictionary triedHosts, Func getKeyspaceFunc); + IDictionary triedHosts, Func getKeyspaceFunc, RoutingKey routingKey); } } diff --git a/src/Cassandra/Connections/IHostConnectionPoolFactory.cs b/src/Cassandra/Connections/IHostConnectionPoolFactory.cs index 9f0ac0089..e40607445 100644 --- a/src/Cassandra/Connections/IHostConnectionPoolFactory.cs +++ b/src/Cassandra/Connections/IHostConnectionPoolFactory.cs @@ -21,6 +21,6 @@ namespace Cassandra.Connections { internal interface IHostConnectionPoolFactory { - IHostConnectionPool Create(Host host, Configuration config, ISerializerManager serializer, IObserverFactory observerFactory); + IHostConnectionPool Create(Host host, Configuration config, ISerializerManager serializer, IObserverFactory observerFactory, TokenFactory tokenFactory); } } diff --git a/src/Cassandra/Connections/ShardingInfo.cs b/src/Cassandra/Connections/ShardingInfo.cs new file mode 100644 index 000000000..cd91beb20 --- /dev/null +++ b/src/Cassandra/Connections/ShardingInfo.cs @@ -0,0 +1,75 @@ +namespace Cassandra.Connections +{ + /// + /// Represents Scylla connection options as sent in SUPPORTED + /// frame. + /// + public class ShardingInfo + { + public int ScyllaShard { get; } + public int ScyllaNrShards { get; } + public string ScyllaPartitioner { get; } + public string ScyllaShardingAlgorithm { get; } + public long ScyllaShardingIgnoreMSB { get; } + public int ScyllaShardAwarePort { get; } + public int ScyllaShardAwarePortSSL { get; } + + private ShardingInfo(int scyllaShard, int scyllaNrShards, string scyllaPartitioner, + string scyllaShardingAlgorithm, long scyllaShardingIgnoreMSB, + int scyllaShardAwarePort, int scyllaShardAwarePortSSL) + { + ScyllaShard = scyllaShard; + ScyllaNrShards = scyllaNrShards; + ScyllaPartitioner = scyllaPartitioner; + ScyllaShardingAlgorithm = scyllaShardingAlgorithm; + ScyllaShardingIgnoreMSB = scyllaShardingIgnoreMSB; + ScyllaShardAwarePort = scyllaShardAwarePort; + ScyllaShardAwarePortSSL = scyllaShardAwarePortSSL; + } + + public static ShardingInfo Create(string scyllaShard, string scyllaNrShards, string scyllaPartitioner, + string scyllaShardingAlgorithm, string scyllaShardingIgnoreMSB, + string scyllaShardAwarePort, string scyllaShardAwarePortSSL) + { + return new ShardingInfo( + int.Parse(scyllaShard), + int.Parse(scyllaNrShards), + scyllaPartitioner, + scyllaShardingAlgorithm, + long.Parse(scyllaShardingIgnoreMSB), + int.Parse(scyllaShardAwarePort), + int.Parse(scyllaShardAwarePortSSL) + ); + } + + internal int ShardID(IToken t) + { + long token = long.Parse(t.ToString()); + token += long.MinValue; + token <<= (int)ScyllaShardingIgnoreMSB; + + ulong tokLo = (ulong)(token & 0xFFFFFFFFL); + ulong tokHi = (ulong)((token >> 32) & 0xFFFFFFFFL); + + ulong mul1 = tokLo * (ulong)ScyllaNrShards; + ulong mul2 = tokHi * (ulong)ScyllaNrShards; // logically shifted 32 bits + + ulong sum = (mul1 >> 32) + mul2; + + return (int)(sum >> 32); + } + + + public override string ToString() + { + return $"ShardingInfo: " + + $"ScyllaShard={ScyllaShard}, " + + $"ScyllaNrShards={ScyllaNrShards}, " + + $"ScyllaPartitioner={ScyllaPartitioner}, " + + $"ScyllaShardingAlgorithm={ScyllaShardingAlgorithm}, " + + $"ScyllaShardingIgnoreMSB={ScyllaShardingIgnoreMSB}, " + + $"ScyllaShardAwarePort={ScyllaShardAwarePort}, " + + $"ScyllaShardAwarePortSSL={ScyllaShardAwarePortSSL}"; + } + } +} diff --git a/src/Cassandra/Metadata.cs b/src/Cassandra/Metadata.cs index cc44745a1..90d3d4ed2 100644 --- a/src/Cassandra/Metadata.cs +++ b/src/Cassandra/Metadata.cs @@ -100,6 +100,15 @@ public void Dispose() ShutDown(); } + internal TokenFactory GetTokenFactory() + { + if (_tokenMap == null) + { + throw new DriverInternalError("Token map is not initialized"); + } + return _tokenMap.Factory; + } + internal KeyspaceMetadata GetKeyspaceFromCache(string keyspace) { _keyspaces.TryGetValue(keyspace, out var ks); diff --git a/src/Cassandra/Requests/IReprepareHandler.cs b/src/Cassandra/Requests/IReprepareHandler.cs index ad6f20f02..46e95f992 100644 --- a/src/Cassandra/Requests/IReprepareHandler.cs +++ b/src/Cassandra/Requests/IReprepareHandler.cs @@ -1,12 +1,12 @@ -// +// // Copyright (C) DataStax Inc. -// +// // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at -// +// // http://www.apache.org/licenses/LICENSE-2.0 -// +// // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. diff --git a/src/Cassandra/Requests/ReprepareHandler.cs b/src/Cassandra/Requests/ReprepareHandler.cs index 326634817..3e963d67d 100644 --- a/src/Cassandra/Requests/ReprepareHandler.cs +++ b/src/Cassandra/Requests/ReprepareHandler.cs @@ -1,12 +1,12 @@ -// +// // Copyright (C) DataStax Inc. -// +// // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at -// +// // http://www.apache.org/licenses/LICENSE-2.0 -// +// // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. @@ -83,7 +83,7 @@ private static async Task GetConnectionFromHostInternalAsync( { try { - return await pool.GetExistingConnectionFromHostAsync(triedHosts, () => ps.Keyspace).ConfigureAwait(false); + return await pool.GetExistingConnectionFromHostAsync(triedHosts, () => ps.Keyspace, ps.RoutingKey).ConfigureAwait(false); } catch (SocketException) { diff --git a/src/Cassandra/Requests/RequestHandler.cs b/src/Cassandra/Requests/RequestHandler.cs index be1044f5e..4407a0709 100644 --- a/src/Cassandra/Requests/RequestHandler.cs +++ b/src/Cassandra/Requests/RequestHandler.cs @@ -402,7 +402,7 @@ public async Task ValidateHostAndGetConnectionAsync(Host host, Dict /// public Task GetConnectionToValidHostAsync(ValidHost validHost, IDictionary triedHosts) { - return RequestHandler.GetConnectionFromHostAsync(validHost.Host, validHost.Distance, _session, triedHosts); + return RequestHandler.GetConnectionFromHostAsync(validHost.Host, validHost.Distance, _session, triedHosts, Statement != null ? Statement.RoutingKey : null); } /// @@ -413,21 +413,22 @@ public Task GetConnectionToValidHostAsync(ValidHost validHost, IDic /// . It is retrieved from the current . /// Session from where a connection will be obtained (or created). /// Hosts for which there were attempts to connect and send the request. + /// Routing key to use for the next host. /// When the keyspace is not valid internal static Task GetConnectionFromHostAsync( - Host host, HostDistance distance, IInternalSession session, IDictionary triedHosts) + Host host, HostDistance distance, IInternalSession session, IDictionary triedHosts, RoutingKey routingKey = null) { - return GetConnectionFromHostInternalAsync(host, distance, session, triedHosts, true); + return GetConnectionFromHostInternalAsync(host, distance, session, triedHosts, true, routingKey); } private static async Task GetConnectionFromHostInternalAsync( - Host host, HostDistance distance, IInternalSession session, IDictionary triedHosts, bool retry) + Host host, HostDistance distance, IInternalSession session, IDictionary triedHosts, bool retry, RoutingKey routingKey) { var hostPool = session.GetOrCreateConnectionPool(host, distance); try { - return await hostPool.GetConnectionFromHostAsync(triedHosts, () => session.Keyspace).ConfigureAwait(false); + return await hostPool.GetConnectionFromHostAsync(triedHosts, () => session.Keyspace, routingKey).ConfigureAwait(false); } catch (SocketException) { @@ -435,7 +436,7 @@ private static async Task GetConnectionFromHostInternalAsync( { // A socket exception on the current connection does not mean that all the pool is closed: // Retry on the same host - return await RequestHandler.GetConnectionFromHostInternalAsync(host, distance, session, triedHosts, false).ConfigureAwait(false); + return await RequestHandler.GetConnectionFromHostInternalAsync(host, distance, session, triedHosts, false, routingKey).ConfigureAwait(false); } throw; diff --git a/src/Cassandra/Session.cs b/src/Cassandra/Session.cs index be732eaca..57ac46428 100644 --- a/src/Cassandra/Session.cs +++ b/src/Cassandra/Session.cs @@ -359,7 +359,7 @@ IHostConnectionPool IInternalSession.GetOrCreateConnectionPool(Host host, HostDi var hostPool = _connectionPool.GetOrAdd(host.Address, address => { var newPool = Configuration.HostConnectionPoolFactory.Create( - host, Configuration, _serializerManager, _observerFactory); + host, Configuration, _serializerManager, _observerFactory, Cluster.Metadata.GetTokenFactory()); newPool.AllConnectionClosed += InternalRef.OnAllConnectionClosed; newPool.SetDistance(distance); _metricsManager.GetOrCreateNodeMetrics(host).InitializePoolGauges(newPool); diff --git a/src/Cassandra/ShardingInfo.cs b/src/Cassandra/ShardingInfo.cs deleted file mode 100644 index 19771a6f1..000000000 --- a/src/Cassandra/ShardingInfo.cs +++ /dev/null @@ -1,45 +0,0 @@ -namespace Cassandra -{ - /// - /// Represents Scylla connection options as sent in SUPPORTED - /// frame. - /// - public class ShardingInfo - { - public int ScyllaShard { get; } - public int ScyllaNrShards { get; } - public string ScyllaPartitioner { get; } - public string ScyllaShardingAlgorithm { get; } - public ulong ScyllaShardingIgnoreMSB { get; } - public int ScyllaShardAwarePort { get; } - public ulong ScyllaShardAwarePortSSL { get; } - - private ShardingInfo(int scyllaShard, int scyllaNrShards, string scyllaPartitioner, - string scyllaShardingAlgorithm, ulong scyllaShardingIgnoreMSB, - int scyllaShardAwarePort, ulong scyllaShardAwarePortSSL) - { - ScyllaShard = scyllaShard; - ScyllaNrShards = scyllaNrShards; - ScyllaPartitioner = scyllaPartitioner; - ScyllaShardingAlgorithm = scyllaShardingAlgorithm; - ScyllaShardingIgnoreMSB = scyllaShardingIgnoreMSB; - ScyllaShardAwarePort = scyllaShardAwarePort; - ScyllaShardAwarePortSSL = scyllaShardAwarePortSSL; - } - - public static ShardingInfo Create(string scyllaShard, string scyllaNrShards, string scyllaPartitioner, - string scyllaShardingAlgorithm, string scyllaShardingIgnoreMSB, - string scyllaShardAwarePort, string scyllaShardAwarePortSSL) - { - return new ShardingInfo( - int.Parse(scyllaShard), - int.Parse(scyllaNrShards), - scyllaPartitioner, - scyllaShardingAlgorithm, - ulong.Parse(scyllaShardingIgnoreMSB), - int.Parse(scyllaShardAwarePort), - ulong.Parse(scyllaShardAwarePortSSL) - ); - } - } -} \ No newline at end of file From 6cf096cfc100aaca6da19b0c541acfcf8969d1c6 Mon Sep 17 00:00:00 2001 From: Sylwia Szunejko Date: Fri, 11 Apr 2025 10:28:04 +0200 Subject: [PATCH 09/14] Add option to disable shard awareness --- .../Core/SessionTests.cs | 34 ++++++++++++------- .../Connections/HostConnectionPool.cs | 11 +++--- src/Cassandra/PoolingOptions.cs | 29 +++++++++++----- 3 files changed, 50 insertions(+), 24 deletions(-) diff --git a/src/Cassandra.IntegrationTests/Core/SessionTests.cs b/src/Cassandra.IntegrationTests/Core/SessionTests.cs index 98796236c..cf97f8097 100644 --- a/src/Cassandra.IntegrationTests/Core/SessionTests.cs +++ b/src/Cassandra.IntegrationTests/Core/SessionTests.cs @@ -170,14 +170,18 @@ public void Session_Keyspace_Create_Case_Sensitive() }); } - [Test] - public void Should_Create_The_Right_Amount_Of_Connections() + [TestCase(true)] + [TestCase(false)] + public void Should_Create_The_Right_Amount_Of_Connections(bool useShardAwareness) { + var poolingOptions1 = new PoolingOptions().SetCoreConnectionsPerHost(HostDistance.Local, 3); + if (!useShardAwareness) + { + poolingOptions1.DisableShardAwareness(); + } var localCluster1 = GetNewTemporaryCluster( builder => builder - .WithPoolingOptions( - new PoolingOptions() - .SetCoreConnectionsPerHost(HostDistance.Local, 3))); + .WithPoolingOptions(poolingOptions1)); var localSession1 = (IInternalSession)localCluster1.Connect(); var hosts1 = localCluster1.AllHosts().ToList(); @@ -191,12 +195,18 @@ public void Should_Create_The_Right_Amount_Of_Connections() Thread.Sleep(2000); var pool11 = localSession1.GetOrCreateConnectionPool(hosts1[0], HostDistance.Local); var pool12 = localSession1.GetOrCreateConnectionPool(hosts1[1], HostDistance.Local); - Assert.That(pool11.OpenConnections, Is.EqualTo(3)); - Assert.That(pool12.OpenConnections, Is.EqualTo(3)); + var expectedConnections1 = useShardAwareness ? 2 : 3; + Assert.That(pool11.OpenConnections, Is.EqualTo(expectedConnections1)); + Assert.That(pool12.OpenConnections, Is.EqualTo(expectedConnections1)); + var poolingOptions2 = new PoolingOptions().SetCoreConnectionsPerHost(HostDistance.Local, 1); + if (!useShardAwareness) + { + poolingOptions2.DisableShardAwareness(); + } using (var localCluster2 = ClusterBuilder() .AddContactPoint(TestCluster.InitialContactPoint) - .WithPoolingOptions(new PoolingOptions().SetCoreConnectionsPerHost(HostDistance.Local, 1)) + .WithPoolingOptions(poolingOptions2) .Build()) { var localSession2 = (IInternalSession)localCluster2.Connect(); @@ -211,9 +221,9 @@ public void Should_Create_The_Right_Amount_Of_Connections() Thread.Sleep(2000); var pool21 = localSession2.GetOrCreateConnectionPool(hosts2[0], HostDistance.Local); var pool22 = localSession2.GetOrCreateConnectionPool(hosts2[1], HostDistance.Local); - // Should be 2 due to number of shards - Assert.That(pool21.OpenConnections, Is.EqualTo(2)); - Assert.That(pool22.OpenConnections, Is.EqualTo(2)); + var expectedConnections2 = useShardAwareness ? 2 : 1; + Assert.That(pool21.OpenConnections, Is.EqualTo(expectedConnections2)); + Assert.That(pool22.OpenConnections, Is.EqualTo(expectedConnections2)); } } @@ -233,7 +243,7 @@ public async Task Session_With_Host_Changing_Distance() var builder = ClusterBuilder() .AddContactPoint(TestCluster.InitialContactPoint) .WithLoadBalancingPolicy(lbp) - .WithPoolingOptions(new PoolingOptions().SetCoreConnectionsPerHost(HostDistance.Local, 3)) + .WithPoolingOptions(new PoolingOptions().SetCoreConnectionsPerHost(HostDistance.Local, 3).DisableShardAwareness()) .WithReconnectionPolicy(new ConstantReconnectionPolicy(1000)); var counter = 0; using (var localCluster = builder.Build()) diff --git a/src/Cassandra/Connections/HostConnectionPool.cs b/src/Cassandra/Connections/HostConnectionPool.cs index 55a73f4a7..7508ee99a 100644 --- a/src/Cassandra/Connections/HostConnectionPool.cs +++ b/src/Cassandra/Connections/HostConnectionPool.cs @@ -1034,11 +1034,14 @@ public async Task Warmup() try { var c = await CreateOpenConnection(false, false).ConfigureAwait(false); - var _shardingInfo = c.ShardingInfo(); - if (_shardingInfo != null) + if (!_poolingOptions.GetDisableShardAwareness()) { - shardingInfo = _shardingInfo; - _expectedConnectionLength = _shardingInfo.ScyllaNrShards; + var _shardingInfo = c.ShardingInfo(); + if (_shardingInfo != null) + { + shardingInfo = _shardingInfo; + _expectedConnectionLength = _shardingInfo.ScyllaNrShards; + } } } catch diff --git a/src/Cassandra/PoolingOptions.cs b/src/Cassandra/PoolingOptions.cs index 7e5ab42a9..18d585109 100644 --- a/src/Cassandra/PoolingOptions.cs +++ b/src/Cassandra/PoolingOptions.cs @@ -21,12 +21,12 @@ namespace Cassandra /// /// Represents the options related to connection pooling. /// - /// For each host selected by the load balancing policy, the driver keeps a core amount of - /// connections open at all times + /// For each host selected by the load balancing policy, the driver keeps a core amount of + /// connections open at all times /// (). - /// If the use of those connections reaches a configurable threshold - /// (), - /// more connections are created up to the configurable maximum number of connections + /// If the use of those connections reaches a configurable threshold + /// (), + /// more connections are created up to the configurable maximum number of connections /// (). /// /// @@ -92,6 +92,8 @@ public class PoolingOptions private int _maxRequestsPerConnection = DefaultMaxRequestsPerConnection; private bool _warmup = true; + private bool _disableShardAwareness = false; + /// /// DEPRECATED: It will be removed in future versions. Use instead. /// @@ -135,7 +137,7 @@ public int GetMinSimultaneousRequestsPerConnectionTreshold(HostDistance distance /// the for which to configure this /// threshold. /// the value to set. - /// + /// /// this PoolingOptions. public PoolingOptions SetMinSimultaneousRequestsPerConnectionTreshold(HostDistance distance, int minSimultaneousRequests) { @@ -236,6 +238,11 @@ public bool GetWarmup() return _warmup; } + public bool GetDisableShardAwareness() + { + return _disableShardAwareness; + } + /// /// Sets the core number of connections per host. /// @@ -266,7 +273,7 @@ public PoolingOptions SetCoreConnectionsPerHost(HostDistance distance, int coreC /// /// the HostDistance for which to return this threshold. /// - /// + /// /// the maximum number of connections per host at distance /// distance. public int GetMaxConnectionPerHost(HostDistance distance) @@ -288,7 +295,7 @@ public int GetMaxConnectionPerHost(HostDistance distance) /// the HostDistance for which to set this threshold. /// /// the value to set - /// + /// /// this PoolingOptions. public PoolingOptions SetMaxConnectionsPerHost(HostDistance distance, int maxConnections) { @@ -375,6 +382,12 @@ public PoolingOptions SetWarmup(bool doWarmup) return this; } + public PoolingOptions DisableShardAwareness() + { + _disableShardAwareness = true; + return this; + } + /// /// Creates a new instance of using the default amount of connections /// and settings based on the protocol version. From 8d52ba80478bda51b05dcaa0aacef2cb7dc5aceb Mon Sep 17 00:00:00 2001 From: Sylwia Szunejko Date: Wed, 23 Apr 2025 09:38:59 +0200 Subject: [PATCH 10/14] Store connections as a list with shardID as an index to improve access --- .../Connections/HostConnectionPool.cs | 77 +++++++++++++++---- 1 file changed, 61 insertions(+), 16 deletions(-) diff --git a/src/Cassandra/Connections/HostConnectionPool.cs b/src/Cassandra/Connections/HostConnectionPool.cs index 7508ee99a..7da0dc749 100644 --- a/src/Cassandra/Connections/HostConnectionPool.cs +++ b/src/Cassandra/Connections/HostConnectionPool.cs @@ -73,6 +73,9 @@ private static class PoolState private readonly ISerializerManager _serializerManager; private readonly IObserverFactory _observerFactory; private readonly CopyOnWriteList _connections = new CopyOnWriteList(); + private volatile int _connectionsPerShard; + private volatile CopyOnWriteList[] _connectionsForShard = new CopyOnWriteList[0]; + private volatile HostDistance _distance; private readonly HashedWheelTimer _timer; private readonly SemaphoreSlim _allConnectionClosedEventLock = new SemaphoreSlim(1, 1); private readonly Host _host; @@ -156,21 +159,9 @@ public IConnection BorrowExistingConnection(RoutingKey routingKey) return BorrowLeastBusyConnection(connections, routingKey); } - private IConnection ConnectionForShard(IConnection[] connections, int shardID) - { - for (int i = 0; i < connections.Length; i++) - { - if (connections[i] != null && connections[i].ShardID == shardID) - { - return connections[i]; - } - } - return null; - } - private IConnection BorrowLeastBusyConnection(IConnection[] connections, RoutingKey routingKey = null) { - int shardID = 0; + int shardID = -1; if (shardingInfo != null) { if (routingKey != null) @@ -184,7 +175,21 @@ private IConnection BorrowLeastBusyConnection(IConnection[] connections, Routing } } - IConnection c = ConnectionForShard(connections, shardID); + IConnection c = null; + if (shardID != -1) + { + var minInFlight = int.MaxValue; + var localInFlight = 0; + foreach (var connection in _connectionsForShard[shardID]) + { + localInFlight = connection.InFlight; + if (localInFlight < minInFlight) + { + minInFlight = localInFlight; + c = connection; + } + } + } var inFlight = 0; if (c != null) { @@ -309,6 +314,10 @@ public void Dispose() { c.Dispose(); } + for (int i = 0; i < _connectionsForShard.Length; i++) + { + _connectionsForShard[i].Clear(); + } _host.Up -= OnHostUp; _host.Down -= OnHostDown; _host.DistanceChanged -= OnDistanceChanged; @@ -443,7 +452,12 @@ internal void OnConnectionClosing(IConnection c = null) int currentLength; if (c != null) { + var shardID = c.ShardID; var removalInfo = _connections.RemoveAndCount(c); + if (shardID != -1 && _connectionsForShard.Length > shardID) + { + _connectionsForShard[shardID].Remove(c); + } currentLength = removalInfo.Item2; var removed = removalInfo.Item1; if (!removed) @@ -566,6 +580,10 @@ private void DrainConnectionsTimer(IConnection[] connections, Action afterDrainH GetHashCode(), _host.Address, connections.Length, drained ? "successful" : "unsuccessful"); foreach (var c in connections) { + if (c.ShardID != -1 && _connectionsForShard.Length > c.ShardID) + { + _connectionsForShard[c.ShardID].Remove(c); + } c.Dispose(); } afterDrainHandler?.Invoke(); @@ -782,7 +800,7 @@ private async Task CreateOpenConnection(bool satisfyWithAnOpenConne for (var i = 1; i <= shardCount; i++) { var _shardID = (lastAttemptedShard + i) % shardCount; - if (ConnectionForShard(connectionsSnapshot, _shardID) == null) + if (_connectionsForShard.Length > _shardID && _connectionsForShard[_shardID].Count < _connectionsPerShard) { lastAttemptedShard = _shardID; shardID = _shardID; @@ -811,6 +829,10 @@ private async Task CreateOpenConnection(bool satisfyWithAnOpenConne } var newLength = _connections.AddNew(c); + if (c.ShardID != -1 && _connectionsForShard.Length > c.ShardID) + { + _connectionsForShard[c.ShardID].Add(c); + } HostConnectionPool.Logger.Info("Connection to {0} opened successfully, pool #{1} length: {2}", _host.Address, GetHashCode(), newLength); @@ -821,6 +843,10 @@ private async Task CreateOpenConnection(bool satisfyWithAnOpenConne HostConnectionPool.Logger.Info("Connection to {0} opened successfully and added to the pool #{1} but the pool was being closed", _host.Address, GetHashCode()); _connections.Remove(c); + if (c.ShardID != -1 && _connectionsForShard.Length > c.ShardID) + { + _connectionsForShard[c.ShardID].Remove(c); + } c.Dispose(); return await FinishOpen(tcs, false, HostConnectionPool.GetNotConnectedException()).ConfigureAwait(false); } @@ -831,6 +857,10 @@ private async Task CreateOpenConnection(bool satisfyWithAnOpenConne HostConnectionPool.Logger.Info("Connection to {0} opened successfully and added to the pool #{1} but it got closed", _host.Address, GetHashCode()); _connections.Remove(c); + if (c.ShardID != -1 && _connectionsForShard.Length > c.ShardID) + { + _connectionsForShard[c.ShardID].Remove(c); + } c.Dispose(); return await FinishOpen(tcs, true, HostConnectionPool.GetNotConnectedException()).ConfigureAwait(false); } @@ -927,6 +957,7 @@ private IConnection[] GetExistingConnections() public void SetDistance(HostDistance distance) { + _distance = distance; _expectedConnectionLength = _poolingOptions.GetCoreConnectionsPerHost(distance); _maxInflightThresholdToConsiderResizing = _poolingOptions.GetMaxSimultaneousRequestsPerConnectionTreshold(distance); _maxConnectionLength = _poolingOptions.GetMaxConnectionPerHost(distance); @@ -1040,8 +1071,22 @@ public async Task Warmup() if (_shardingInfo != null) { shardingInfo = _shardingInfo; - _expectedConnectionLength = _shardingInfo.ScyllaNrShards; } + + var coreSize = _poolingOptions.GetCoreConnectionsPerHost(_distance); + var shardsCount = shardingInfo == null ? 1 : shardingInfo.ScyllaNrShards; + + _connectionsPerShard = coreSize / shardsCount + (coreSize % shardsCount > 0 ? 1 : 0); + _connectionsForShard = new CopyOnWriteList[shardsCount]; + for (int i = 0; i < shardsCount; i++) + { + _connectionsForShard[i] = new CopyOnWriteList(); + } + if (!_connectionsForShard[c.ShardID].Contains(c)) + { + _connectionsForShard[c.ShardID].Add(c); + } + _expectedConnectionLength = shardsCount * _connectionsPerShard; } } catch From fa2bffeae318922e86e5bf3277756f6db391143d Mon Sep 17 00:00:00 2001 From: Sylwia Szunejko Date: Thu, 24 Apr 2025 11:58:20 +0200 Subject: [PATCH 11/14] fix test --- src/Cassandra.IntegrationTests/Core/SessionTests.cs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/Cassandra.IntegrationTests/Core/SessionTests.cs b/src/Cassandra.IntegrationTests/Core/SessionTests.cs index cf97f8097..85e8b3afe 100644 --- a/src/Cassandra.IntegrationTests/Core/SessionTests.cs +++ b/src/Cassandra.IntegrationTests/Core/SessionTests.cs @@ -195,7 +195,7 @@ public void Should_Create_The_Right_Amount_Of_Connections(bool useShardAwareness Thread.Sleep(2000); var pool11 = localSession1.GetOrCreateConnectionPool(hosts1[0], HostDistance.Local); var pool12 = localSession1.GetOrCreateConnectionPool(hosts1[1], HostDistance.Local); - var expectedConnections1 = useShardAwareness ? 2 : 3; + var expectedConnections1 = useShardAwareness ? 4 : 3; Assert.That(pool11.OpenConnections, Is.EqualTo(expectedConnections1)); Assert.That(pool12.OpenConnections, Is.EqualTo(expectedConnections1)); From c49f1710919090457629e3af9da1efc045628269 Mon Sep 17 00:00:00 2001 From: Sylwia Szunejko Date: Fri, 25 Apr 2025 12:13:10 +0200 Subject: [PATCH 12/14] Add option to update sharding info every time connection is opened --- .../Connections/HostConnectionPool.cs | 48 ++++++++++--------- 1 file changed, 25 insertions(+), 23 deletions(-) diff --git a/src/Cassandra/Connections/HostConnectionPool.cs b/src/Cassandra/Connections/HostConnectionPool.cs index 7da0dc749..71fa2afae 100644 --- a/src/Cassandra/Connections/HostConnectionPool.cs +++ b/src/Cassandra/Connections/HostConnectionPool.cs @@ -693,6 +693,7 @@ private async Task CreateOrScheduleReconnectAsync(IReconnectionSchedule schedule try { var t = await CreateOpenConnection(false, schedule != null).ConfigureAwait(false); + UpdateShardingInfo(t); StartCreatingConnection(null); _host.BringUpIfDown(); } @@ -923,6 +924,7 @@ public async Task EnsureCreate() // It's the first time accessing or it has been recently set as UP // CreateOpenConnection() supports concurrent calls c = await CreateOpenConnection(true, false).ConfigureAwait(false); + UpdateShardingInfo(c); } catch (Exception) { @@ -1055,6 +1057,28 @@ private async Task GetConnectionFromHostAsync( return c; } + private void UpdateShardingInfo(IConnection c) + { + if (!_poolingOptions.GetDisableShardAwareness() && shardingInfo == null && c.ShardingInfo() != null) + { + shardingInfo = c.ShardingInfo(); + var coreSize = _poolingOptions.GetCoreConnectionsPerHost(_distance); + var shardsCount = shardingInfo == null ? 1 : shardingInfo.ScyllaNrShards; + + _connectionsPerShard = coreSize / shardsCount + (coreSize % shardsCount > 0 ? 1 : 0); + _connectionsForShard = new CopyOnWriteList[shardsCount]; + for (int i = 0; i < shardsCount; i++) + { + _connectionsForShard[i] = new CopyOnWriteList(); + } + if (!_connectionsForShard[c.ShardID].Contains(c)) + { + _connectionsForShard[c.ShardID].Add(c); + } + _expectedConnectionLength = shardsCount * _connectionsPerShard; + } + } + /// /// Creates the required connections to the hosts and awaits for all connections to be open. /// The task is completed when at least 1 of the connections is opened successfully. @@ -1065,29 +1089,7 @@ public async Task Warmup() try { var c = await CreateOpenConnection(false, false).ConfigureAwait(false); - if (!_poolingOptions.GetDisableShardAwareness()) - { - var _shardingInfo = c.ShardingInfo(); - if (_shardingInfo != null) - { - shardingInfo = _shardingInfo; - } - - var coreSize = _poolingOptions.GetCoreConnectionsPerHost(_distance); - var shardsCount = shardingInfo == null ? 1 : shardingInfo.ScyllaNrShards; - - _connectionsPerShard = coreSize / shardsCount + (coreSize % shardsCount > 0 ? 1 : 0); - _connectionsForShard = new CopyOnWriteList[shardsCount]; - for (int i = 0; i < shardsCount; i++) - { - _connectionsForShard[i] = new CopyOnWriteList(); - } - if (!_connectionsForShard[c.ShardID].Contains(c)) - { - _connectionsForShard[c.ShardID].Add(c); - } - _expectedConnectionLength = shardsCount * _connectionsPerShard; - } + UpdateShardingInfo(c); } catch { From 8f862ad45268bb2b88a30a5fa3d7a9de4a197067 Mon Sep 17 00:00:00 2001 From: Dmitry Kropachev Date: Mon, 28 Apr 2025 07:24:15 -0400 Subject: [PATCH 13/14] Introduce new CopyOnWriteListSharded to keep track of connections --- src/Cassandra.Tests/CollectionTests.cs | 271 ++++++++++++++++-- .../HostConnectionPoolTests.cs | 13 +- src/Cassandra.sln.DotSettings | 1 + .../Collections/CopyOnWriteListSharded.cs | 209 ++++++++++++++ .../Connections/HostConnectionPool.cs | 57 +--- src/Cassandra/Connections/IConnection.cs | 5 +- 6 files changed, 481 insertions(+), 75 deletions(-) create mode 100644 src/Cassandra/Collections/CopyOnWriteListSharded.cs diff --git a/src/Cassandra.Tests/CollectionTests.cs b/src/Cassandra.Tests/CollectionTests.cs index 5a63c18e1..c8dbe0a3b 100644 --- a/src/Cassandra.Tests/CollectionTests.cs +++ b/src/Cassandra.Tests/CollectionTests.cs @@ -73,17 +73,16 @@ public void CopyOnWriteList_Should_Allow_Parallel_Calls_To_Add() for (var i = 0; i < 100; i++) { var item = i; - actions.Add(() => - { - list.Add(item); - }); + actions.Add(() => { list.Add(item); }); } + TestHelper.ParallelInvoke(actions); Assert.AreEqual(100, list.Count); for (var i = 0; i < 100; i++) { Assert.True(list.Contains(i)); } + var counter = 0; CollectionAssert.AreEquivalent(Enumerable.Repeat(0, 100).Select(_ => counter++), list); } @@ -97,15 +96,14 @@ public void CopyOnWriteList_Should_Allow_Parallel_Calls_To_Remove() { list.Add(i); } + Assert.AreEqual(100, list.Count); for (var i = 0; i < 100; i++) { var item = i; - actions.Add(() => - { - list.Remove(item); - }); + actions.Add(() => { list.Remove(item); }); } + TestHelper.ParallelInvoke(actions); Assert.AreEqual(0, list.Count); } @@ -126,10 +124,10 @@ public void CopyOnWriteDictionary_Should_Add_And_Remove() { var map = new CopyOnWriteDictionary { - {"one", 1}, - {"two", 2}, - {"three", 3}, - {"four", 4} + { "one", 1 }, + { "two", 2 }, + { "three", 3 }, + { "four", 4 } }; Assert.AreEqual(4, map.Count); CollectionAssert.AreEquivalent(new[] { "one", "two", "three", "four" }, map.Keys); @@ -154,17 +152,16 @@ public void CopyOnWriteDictionary_Should_Allow_Parallel_Calls_To_Add() for (var i = 0; i < 100; i++) { var item = i; - actions.Add(() => - { - map.Add(item, item * 1000); - }); + actions.Add(() => { map.Add(item, item * 1000); }); } + TestHelper.ParallelInvoke(actions); Assert.AreEqual(100, map.Count); for (var i = 0; i < 100; i++) { Assert.AreEqual(i * 1000, map[i]); } + var counter = 0; CollectionAssert.AreEquivalent(Enumerable.Repeat(0, 100).Select(_ => counter++), map.Keys); } @@ -178,16 +175,15 @@ public void CopyOnWriteDictionary_Should_Allow_Parallel_Calls_To_Remove() { map.Add(i, i * 2000); } + Assert.AreEqual(100, map.Count); //remove everything except 0 and 1 for (var i = 2; i < 100; i++) { var item = i; - actions.Add(() => - { - map.Remove(item); - }); + actions.Add(() => { map.Remove(item); }); } + TestHelper.ParallelInvoke(actions); Assert.AreEqual(2, map.Count); Assert.AreEqual(0, map[0]); @@ -211,5 +207,238 @@ public void CopyOnWriteDictionary_GetOrAdd_Should_Return_The_Current_Value() Assert.AreEqual(2, map.Count); Assert.AreEqual(10, map["key2"]); } + + internal class ShardableItem : IShardable + { + public int ShardID { get; } + + public string Value { get; } + + public ShardableItem(int shardId, string value) + { + ShardID = shardId; + Value = value; + } + + public override bool Equals(object obj) + { + if (obj is ShardableItem other) + { + return ShardID == other.ShardID && Value == other.Value; + } + + return false; + } + + public override int GetHashCode() + { + return (ShardID, Value).GetHashCode(); + } + } + + [TestFixture] + public class ShardedListTests + { + [Test] + public void EmptyConstructor_ShouldInitializeEmpty() + { + var list = new ShardedList(); + + Assert.AreEqual(0, list.Count); + Assert.AreEqual(0, list.Length); + Assert.IsTrue(list.IsReadOnly); + Assert.IsEmpty(list.GetAllItems()); + Assert.IsEmpty(list.GetPerShardSnapshot()); + } + + [Test] + public void Constructor_WithArray_ShouldCopyItems() + { + var items = new[] + { + new ShardableItem(1, "A"), + new ShardableItem(2, "B"), + new ShardableItem(1, "C") + }; + + var list = new ShardedList(items); + + Assert.AreEqual(3, list.Count); + Assert.That(list.GetAllItems(), Is.EquivalentTo(items)); + } + + [Test] + public void Indexer_ShouldReturnCorrectItem() + { + var items = new[] + { + new ShardableItem(0, "X"), + new ShardableItem(1, "Y") + }; + + var list = new ShardedList(items); + + Assert.AreEqual("X", list[0].Value); + Assert.AreEqual("Y", list[1].Value); + } + + [Test] + public void GetItemsForShard_ShouldReturnCorrectShardItems() + { + var items = new[] + { + new ShardableItem(0, "First"), + new ShardableItem(1, "Second"), + new ShardableItem(0, "Third") + }; + + var list = new ShardedList(items); + + var shard0 = list.GetItemsForShard(0); + var shard1 = list.GetItemsForShard(1); + var shard2 = list.GetItemsForShard(2); // Should be empty + + Assert.That(shard0.Select(x => x.Value), Is.EquivalentTo(new[] { "First", "Third" })); + Assert.That(shard1.Select(x => x.Value), Is.EquivalentTo(new[] { "Second" })); + Assert.IsEmpty(shard2); + } + + [Test] + public void GetEnumerator_ShouldEnumerateItems() + { + var items = new[] + { + new ShardableItem(1, "One"), + new ShardableItem(2, "Two") + }; + + var list = new ShardedList(items); + + CollectionAssert.AreEqual(items, list.ToList()); + } + } + + [TestFixture] + public class CopyOnWriteShardedListTests + { + [Test] + public void Add_ShouldAddItem() + { + var list = new CopyOnWriteShardedList(); + + list.Add(new ShardableItem(0, "Alpha")); + + Assert.AreEqual(1, list.Count); + Assert.AreEqual("Alpha", list.GetSnapshot()[0].Value); + } + + [Test] + public void AddRange_ShouldAddMultipleItems() + { + var list = new CopyOnWriteShardedList(); + + list.AddRange(new[] + { + new ShardableItem(1, "Beta"), + new ShardableItem(2, "Gamma") + }); + + Assert.AreEqual(2, list.Count); + Assert.That(list.GetSnapshot().Select(x => x.Value), Is.EquivalentTo(new[] { "Beta", "Gamma" })); + } + + [Test] + public void Remove_ShouldRemoveItem() + { + var list = new CopyOnWriteShardedList(); + + var item = new ShardableItem(0, "Delta"); + list.Add(item); + var removed = list.Remove(item); + + Assert.IsTrue(removed); + Assert.AreEqual(0, list.Count); + } + + [Test] + public void Remove_ShouldReturnFalse_WhenItemNotFound() + { + var list = new CopyOnWriteShardedList(); + + var removed = list.Remove(new ShardableItem(1, "Zeta")); + + Assert.IsFalse(removed); + } + + [Test] + public void Clear_ShouldEmptyTheList() + { + var list = new CopyOnWriteShardedList(); + + list.Add(new ShardableItem(0, "Eta")); + list.Add(new ShardableItem(1, "Theta")); + + list.Clear(); + + Assert.AreEqual(0, list.Count); + } + + [Test] + public void Contains_ShouldFindItem() + { + var list = new CopyOnWriteShardedList(); + var item = new ShardableItem(0, "Iota"); + + list.Add(item); + + Assert.IsTrue(list.Contains(item)); + } + + [Test] + public void Contains_ShouldReturnFalse_WhenNotPresent() + { + var list = new CopyOnWriteShardedList(); + + Assert.IsFalse(list.Contains(new ShardableItem(1, "Kappa"))); + } + + [Test] + public void GetItemsForShard_ShouldReturnShardItems() + { + var list = new CopyOnWriteShardedList(); + + var item1 = new ShardableItem(2, "Lambda"); + var item2 = new ShardableItem(2, "Mu"); + var item3 = new ShardableItem(3, "Nu"); + + list.Add(item1); + list.Add(item2); + list.Add(item3); + + var shardItems = list.GetItemsForShard(2); + + Assert.That(shardItems.Select(x => x.Value), Is.EquivalentTo(new[] { "Lambda", "Mu" })); + + var shard3 = list.GetItemsForShard(3); + Assert.That(shard3.Select(x => x.Value), Is.EquivalentTo(new[] { "Nu" })); + + var nonExistentShard = list.GetItemsForShard(5); + Assert.IsEmpty(nonExistentShard); + } + + [Test] + public void CopyTo_ShouldCopyElements() + { + var list = new CopyOnWriteShardedList(); + + list.Add(new ShardableItem(0, "Xi")); + list.Add(new ShardableItem(0, "Omicron")); + + var array = new ShardableItem[2]; + list.CopyTo(array, 0); + + Assert.That(array.Select(x => x.Value), Is.EquivalentTo(new[] { "Xi", "Omicron" })); + } + } } -} +} \ No newline at end of file diff --git a/src/Cassandra.Tests/HostConnectionPoolTests.cs b/src/Cassandra.Tests/HostConnectionPoolTests.cs index 60fa94aa2..e50fad11c 100644 --- a/src/Cassandra.Tests/HostConnectionPoolTests.cs +++ b/src/Cassandra.Tests/HostConnectionPoolTests.cs @@ -21,6 +21,7 @@ using System.Net.Sockets; using System.Threading; using System.Threading.Tasks; +using Cassandra.Collections; using Cassandra.Connections; using Cassandra.Metrics; using Cassandra.Metrics.Internal; @@ -172,7 +173,7 @@ public void EnsureCreate_Serial_Calls_Should_Yield_First() return TaskHelper.ToTask(c); }); var pool = _mock.Object; - var creationTasks = new Task[4]; + var creationTasks = new Task>[4]; creationTasks[0] = pool.EnsureCreate(); creationTasks[1] = pool.EnsureCreate(); creationTasks[2] = pool.EnsureCreate(); @@ -193,7 +194,7 @@ public void EnsureCreate_Parallel_Calls_Should_Yield_First() var lastByte = 0; _mock.Setup(p => p.DoCreateAndOpen(It.IsAny(), -1, 0, 0)).Returns(() => TestHelper.DelayedTask(CreateConnection((byte)++lastByte), 100 + (lastByte > 1 ? 10000 : 0))); var pool = _mock.Object; - var creationTasks = new Task[10]; + var creationTasks = new Task>[10]; var counter = -1; var initialCreate = pool.EnsureCreate(); TestHelper.ParallelInvoke(() => @@ -365,14 +366,14 @@ public async Task EnsureCreate_Can_Handle_Multiple_Concurrent_Calls() [Test] public void MinInFlight_Returns_The_Min_Inflight_From_Two_Connections() { - var connections = new[] + var connections = new ShardedList(new[] { GetConnectionMock(0), GetConnectionMock(1), GetConnectionMock(1), GetConnectionMock(10), GetConnectionMock(1) - }; + }); var index = 0; var c = HostConnectionPool.MinInFlight(connections, ref index, 100, out int inFlight); Assert.AreEqual(index, 1); @@ -405,14 +406,14 @@ public void MinInFlight_Returns_The_Min_Inflight_From_Two_Connections() [Test] public void MinInFlight_Goes_Through_All_The_Connections_When_Over_Threshold() { - var connections = new[] + var connections = new ShardedList(new[] { GetConnectionMock(10), GetConnectionMock(1), GetConnectionMock(201), GetConnectionMock(200), GetConnectionMock(210) - }; + }); var index = 0; var c = HostConnectionPool.MinInFlight(connections, ref index, 100, out int inFlight); diff --git a/src/Cassandra.sln.DotSettings b/src/Cassandra.sln.DotSettings index c3b7f0f0d..3085868b8 100644 --- a/src/Cassandra.sln.DotSettings +++ b/src/Cassandra.sln.DotSettings @@ -403,6 +403,7 @@ II.2.12 <HandlesEvent /> True True True + True True True True diff --git a/src/Cassandra/Collections/CopyOnWriteListSharded.cs b/src/Cassandra/Collections/CopyOnWriteListSharded.cs new file mode 100644 index 000000000..6f3aac59c --- /dev/null +++ b/src/Cassandra/Collections/CopyOnWriteListSharded.cs @@ -0,0 +1,209 @@ +using System; +using System.Collections; +using System.Collections.Generic; +using System.Linq; + +namespace Cassandra.Collections +{ + internal interface IShardable + { + int ShardID { get; } + } + + internal class ShardedList : IEnumerable where T : IShardable + { + private static readonly T[] EmptyArray = new T[0]; + + private readonly T[] _array; + private readonly T[][] _arrayPerShard; + + internal ShardedList() + { + _array = EmptyArray; + _arrayPerShard = new T[0][]; + } + + public ShardedList(T[] array) + { + if (array == null || array.Length == 0) + { + return; + } + + var maxShardId = array.Select(item => item.ShardID).Concat(new[] { -1 }).Max(); + if (maxShardId < 0) + { + _array = array.Clone() as T[]; + _arrayPerShard = new T[0][]; + return; + } + _arrayPerShard = new T[maxShardId + 1][]; + _array = array.Clone() as T[]; + for (var i = 0; i <= maxShardId; i++) + { + _arrayPerShard[i] = EmptyArray; + } + + foreach (var item in array) + { + var shardId = item.ShardID; + if (shardId < 0) + { + continue; + } + var shardArray = _arrayPerShard[shardId]; + var newShardArray = new T[shardArray.Length + 1]; + shardArray.CopyTo(newShardArray, 0); + newShardArray[shardArray.Length] = item; + _arrayPerShard[shardId] = newShardArray; + } + } + + public T this[int index] => _array[index]; + public IEnumerator GetEnumerator() => ((IEnumerable)_array).GetEnumerator(); + IEnumerator IEnumerable.GetEnumerator() => GetEnumerator(); + public T[] GetAllItems() => _array; + public T[] GetItemsForShard(int shardId) + { + if (shardId < 0 || shardId >= _arrayPerShard.Length) + { + return EmptyArray; + } + return _arrayPerShard[shardId]; + } + + public T[][] GetPerShardSnapshot() => _arrayPerShard; + + public int Length => _array.Length; + public int Count => _array.Length; + public bool IsReadOnly => true; + } + + internal class CopyOnWriteShardedList : ICollection where T : IShardable + { + private static readonly T[] EmptyArray = new T[0]; + private volatile ShardedList _shardedList; + private readonly object _writeLock = new object(); + + internal CopyOnWriteShardedList() + { + _shardedList = new ShardedList(); + } + + public int Count => _shardedList.Count; + public int Length => _shardedList.Length; + public bool IsReadOnly => false; + + public IEnumerator GetEnumerator() => _shardedList.GetEnumerator(); + IEnumerator IEnumerable.GetEnumerator() => GetEnumerator(); + + public void Add(T item) + { + AddNew(item); + } + + public int AddNew(T item) + { + lock (_writeLock) + { + var current = _shardedList; + var currentArray = current.GetAllItems(); + + var newArray = new T[currentArray.Length + 1]; + currentArray.CopyTo(newArray, 0); + newArray[currentArray.Length] = item; + + _shardedList = new ShardedList(newArray); + return newArray.Length; + } + } + + public void AddRange(T[] items) + { + if (items == null || items.Length == 0) + return; + + lock (_writeLock) + { + var current = _shardedList; + var currentArray = current.GetAllItems(); + + var newArray = new T[currentArray.Length + items.Length]; + currentArray.CopyTo(newArray, 0); + Array.Copy(items, 0, newArray, currentArray.Length, items.Length); + + _shardedList = new ShardedList(newArray); + } + } + + public void Clear() + { + ClearAndGet(); + } + + public ShardedList ClearAndGet() + { + lock (_writeLock) + { + var items = _shardedList; + _shardedList = new ShardedList(); + return items; + } + } + + public bool Contains(T item) + { + var currentArray = _shardedList.GetAllItems(); + return Array.IndexOf(currentArray, item) >= 0; + } + + public void CopyTo(T[] array, int arrayIndex) + { + _shardedList.GetAllItems().CopyTo(array, arrayIndex); + } + + public bool Remove(T item) + { + return RemoveAndCount(item).Item1; + } + + public Tuple RemoveAndCount(T item) + { + lock (_writeLock) + { + var current = _shardedList; + var currentArray = current.GetAllItems(); + var idx = Array.IndexOf(currentArray, item); + if (idx < 0) + { + return Tuple.Create(false, currentArray.Length); + } + + if (currentArray.Length == 1 && idx == 0) + { + _shardedList = new ShardedList(); + return Tuple.Create(true, 0); + } + + var newArray = new T[currentArray.Length - 1]; + if (idx > 0) + Array.Copy(currentArray, 0, newArray, 0, idx); + if (idx < currentArray.Length - 1) + Array.Copy(currentArray, idx + 1, newArray, idx, currentArray.Length - idx - 1); + + _shardedList = new ShardedList(newArray); + return Tuple.Create(true, newArray.Length); + } + } + + public ShardedList GetSnapshot() + { + return _shardedList; + } + + public T[] GetItemsForShard(int shardId) + { + return _shardedList.GetItemsForShard(shardId); + } + } +} diff --git a/src/Cassandra/Connections/HostConnectionPool.cs b/src/Cassandra/Connections/HostConnectionPool.cs index 71fa2afae..c8e053c11 100644 --- a/src/Cassandra/Connections/HostConnectionPool.cs +++ b/src/Cassandra/Connections/HostConnectionPool.cs @@ -72,9 +72,8 @@ private static class PoolState private readonly Configuration _config; private readonly ISerializerManager _serializerManager; private readonly IObserverFactory _observerFactory; - private readonly CopyOnWriteList _connections = new CopyOnWriteList(); + private readonly CopyOnWriteShardedList _connections = new CopyOnWriteShardedList(); private volatile int _connectionsPerShard; - private volatile CopyOnWriteList[] _connectionsForShard = new CopyOnWriteList[0]; private volatile HostDistance _distance; private readonly HashedWheelTimer _timer; private readonly SemaphoreSlim _allConnectionClosedEventLock = new SemaphoreSlim(1, 1); @@ -110,7 +109,7 @@ private static class PoolState private bool IsClosing => Volatile.Read(ref _state) != PoolState.Init; /// - public IConnection[] ConnectionsSnapshot => _connections.GetSnapshot(); + public IConnection[] ConnectionsSnapshot => _connections.GetSnapshot().GetAllItems(); public ShardingInfo shardingInfo { get; private set; } @@ -159,7 +158,7 @@ public IConnection BorrowExistingConnection(RoutingKey routingKey) return BorrowLeastBusyConnection(connections, routingKey); } - private IConnection BorrowLeastBusyConnection(IConnection[] connections, RoutingKey routingKey = null) + private IConnection BorrowLeastBusyConnection(ShardedList connections, RoutingKey routingKey = null) { int shardID = -1; if (shardingInfo != null) @@ -180,7 +179,7 @@ private IConnection BorrowLeastBusyConnection(IConnection[] connections, Routing { var minInFlight = int.MaxValue; var localInFlight = 0; - foreach (var connection in _connectionsForShard[shardID]) + foreach (var connection in _connections.GetItemsForShard(shardID)) { localInFlight = connection.InFlight; if (localInFlight < minInFlight) @@ -314,10 +313,6 @@ public void Dispose() { c.Dispose(); } - for (int i = 0; i < _connectionsForShard.Length; i++) - { - _connectionsForShard[i].Clear(); - } _host.Up -= OnHostUp; _host.Down -= OnHostDown; _host.DistanceChanged -= OnDistanceChanged; @@ -401,7 +396,7 @@ public void OnHostRemoved() /// /// Out parameter containing the amount of in-flight requests of the selected connection. /// - public static IConnection MinInFlight(IConnection[] connections, ref int connectionIndex, int inFlightThreshold, + public static IConnection MinInFlight(ShardedList connections, ref int connectionIndex, int inFlightThreshold, out int inFlight) { if (connections.Length == 1) @@ -452,12 +447,7 @@ internal void OnConnectionClosing(IConnection c = null) int currentLength; if (c != null) { - var shardID = c.ShardID; var removalInfo = _connections.RemoveAndCount(c); - if (shardID != -1 && _connectionsForShard.Length > shardID) - { - _connectionsForShard[shardID].Remove(c); - } currentLength = removalInfo.Item2; var removed = removalInfo.Item1; if (!removed) @@ -562,7 +552,7 @@ private void DrainConnections(Action afterDrainHandler) DrainConnectionsTimer(connections, afterDrainHandler, delay / 1000); } - private void DrainConnectionsTimer(IConnection[] connections, Action afterDrainHandler, int steps) + private void DrainConnectionsTimer(ShardedList connections, Action afterDrainHandler, int steps) { _timer.NewTimeout(_ => { @@ -580,10 +570,6 @@ private void DrainConnectionsTimer(IConnection[] connections, Action afterDrainH GetHashCode(), _host.Address, connections.Length, drained ? "successful" : "unsuccessful"); foreach (var c in connections) { - if (c.ShardID != -1 && _connectionsForShard.Length > c.ShardID) - { - _connectionsForShard[c.ShardID].Remove(c); - } c.Dispose(); } afterDrainHandler?.Invoke(); @@ -801,7 +787,7 @@ private async Task CreateOpenConnection(bool satisfyWithAnOpenConne for (var i = 1; i <= shardCount; i++) { var _shardID = (lastAttemptedShard + i) % shardCount; - if (_connectionsForShard.Length > _shardID && _connectionsForShard[_shardID].Count < _connectionsPerShard) + if (_connections.Length > _shardID && _connections.GetItemsForShard(_shardID).Length < _connectionsPerShard) { lastAttemptedShard = _shardID; shardID = _shardID; @@ -830,10 +816,6 @@ private async Task CreateOpenConnection(bool satisfyWithAnOpenConne } var newLength = _connections.AddNew(c); - if (c.ShardID != -1 && _connectionsForShard.Length > c.ShardID) - { - _connectionsForShard[c.ShardID].Add(c); - } HostConnectionPool.Logger.Info("Connection to {0} opened successfully, pool #{1} length: {2}", _host.Address, GetHashCode(), newLength); @@ -844,10 +826,6 @@ private async Task CreateOpenConnection(bool satisfyWithAnOpenConne HostConnectionPool.Logger.Info("Connection to {0} opened successfully and added to the pool #{1} but the pool was being closed", _host.Address, GetHashCode()); _connections.Remove(c); - if (c.ShardID != -1 && _connectionsForShard.Length > c.ShardID) - { - _connectionsForShard[c.ShardID].Remove(c); - } c.Dispose(); return await FinishOpen(tcs, false, HostConnectionPool.GetNotConnectedException()).ConfigureAwait(false); } @@ -858,10 +836,6 @@ private async Task CreateOpenConnection(bool satisfyWithAnOpenConne HostConnectionPool.Logger.Info("Connection to {0} opened successfully and added to the pool #{1} but it got closed", _host.Address, GetHashCode()); _connections.Remove(c); - if (c.ShardID != -1 && _connectionsForShard.Length > c.ShardID) - { - _connectionsForShard[c.ShardID].Remove(c); - } c.Dispose(); return await FinishOpen(tcs, true, HostConnectionPool.GetNotConnectedException()).ConfigureAwait(false); } @@ -897,7 +871,7 @@ private static SocketException GetNotConnectedException() /// /// /// - public async Task EnsureCreate() + public async Task> EnsureCreate() { var connections = GetExistingConnections(); if (connections.Length > 0) @@ -932,7 +906,7 @@ public async Task EnsureCreate() throw; } StartCreatingConnection(null); - return new[] { c }; + return new ShardedList(new[] { c }); } /// @@ -940,10 +914,10 @@ public async Task EnsureCreate() /// If it's empty then it validates whether the pool is shutting down or the is down (in which case an exception is thrown). /// /// Not connected. - private IConnection[] GetExistingConnections() + private ShardedList GetExistingConnections() { var connections = _connections.GetSnapshot(); - if (connections.Length > 0) + if (connections.Count > 0) { return connections; } @@ -1066,15 +1040,6 @@ private void UpdateShardingInfo(IConnection c) var shardsCount = shardingInfo == null ? 1 : shardingInfo.ScyllaNrShards; _connectionsPerShard = coreSize / shardsCount + (coreSize % shardsCount > 0 ? 1 : 0); - _connectionsForShard = new CopyOnWriteList[shardsCount]; - for (int i = 0; i < shardsCount; i++) - { - _connectionsForShard[i] = new CopyOnWriteList(); - } - if (!_connectionsForShard[c.ShardID].Contains(c)) - { - _connectionsForShard[c.ShardID].Add(c); - } _expectedConnectionLength = shardsCount * _connectionsPerShard; } } diff --git a/src/Cassandra/Connections/IConnection.cs b/src/Cassandra/Connections/IConnection.cs index 698ed5996..97f679988 100644 --- a/src/Cassandra/Connections/IConnection.cs +++ b/src/Cassandra/Connections/IConnection.cs @@ -18,6 +18,7 @@ using System.Net; using System.Net.Sockets; using System.Threading.Tasks; +using Cassandra.Collections; using Cassandra.Requests; using Cassandra.Responses; using Cassandra.Serialization; @@ -27,7 +28,7 @@ namespace Cassandra.Connections /// /// Represents a TCP connection to a Cassandra Node /// - internal interface IConnection : IDisposable + internal interface IConnection : IDisposable, IShardable { /// /// The event that represents a event RESPONSE from a Cassandra node @@ -153,6 +154,6 @@ internal interface IConnection : IDisposable /// ShardingInfo ShardingInfo(); - int ShardID { get; set; } + new int ShardID { get; set; } } } From a71181cc556a85671165230535112a46023502d8 Mon Sep 17 00:00:00 2001 From: Sylwia Szunejko Date: Tue, 29 Apr 2025 14:56:17 +0200 Subject: [PATCH 14/14] Extend the tests --- .../ShardAwareOptionsTests.cs | 14 +++++++++++--- src/Cassandra/Connections/HostConnectionPool.cs | 2 +- 2 files changed, 12 insertions(+), 4 deletions(-) diff --git a/src/Cassandra.IntegrationTests/ShardAwareOptionsTests.cs b/src/Cassandra.IntegrationTests/ShardAwareOptionsTests.cs index 626ee66eb..281898e16 100644 --- a/src/Cassandra.IntegrationTests/ShardAwareOptionsTests.cs +++ b/src/Cassandra.IntegrationTests/ShardAwareOptionsTests.cs @@ -3,6 +3,7 @@ using Cassandra.IntegrationTests.TestClusterManagement; using Cassandra.SessionManagement; using NUnit.Framework; +using System.Linq; namespace Cassandra.IntegrationTests { @@ -32,12 +33,15 @@ public void Should_Connect_To_Shard_Aware_Cluster() Assert.IsTrue(controlConnection.IsShardAware()); } - [Test] - public void Should_Have_NrShards_Connections() + [TestCase(1)] + [TestCase(4)] + public void Should_Have_NrShards_Connections(int connectionsPerHost) { _realCluster = TestClusterManager.CreateNew(); var cluster = ClusterBuilder() .WithSocketOptions(new SocketOptions().SetReadTimeoutMillis(22000).SetConnectTimeoutMillis(60000)) + .WithPoolingOptions(new PoolingOptions() + .SetCoreConnectionsPerHost(HostDistance.Local, connectionsPerHost)) .AddContactPoint(_realCluster.InitialContactPoint) .Build(); var session = cluster.Connect(); @@ -45,7 +49,11 @@ public void Should_Have_NrShards_Connections() var pools = internalSession.GetPools(); foreach (var kvp in pools) { - Assert.AreEqual(2, kvp.Value.OpenConnections); + var shardCount = 2; + var connectionsPerShard = connectionsPerHost / shardCount + (connectionsPerHost % shardCount > 0 ? 1 : 0); + Assert.AreEqual(shardCount * connectionsPerShard, kvp.Value.OpenConnections); + var shardGroups = kvp.Value.ConnectionsSnapshot.GroupBy(c => c.ShardID); + Assert.IsTrue(shardGroups.All(g => g.Count() == connectionsPerShard)); } } } diff --git a/src/Cassandra/Connections/HostConnectionPool.cs b/src/Cassandra/Connections/HostConnectionPool.cs index c8e053c11..6a8e53546 100644 --- a/src/Cassandra/Connections/HostConnectionPool.cs +++ b/src/Cassandra/Connections/HostConnectionPool.cs @@ -111,7 +111,7 @@ private static class PoolState /// public IConnection[] ConnectionsSnapshot => _connections.GetSnapshot().GetAllItems(); - public ShardingInfo shardingInfo { get; private set; } + private ShardingInfo shardingInfo { get; set; } private int lastAttemptedShard = 0;