Skip to content

Commit

Permalink
Make TcpClientAdapter public
Browse files Browse the repository at this point in the history
Improve test for SocketFactory

Follow-up to:
* #1414
* #1415
* #1416

https://groups.google.com/g/rabbitmq-users/c/9_ohuUbX9NY
  • Loading branch information
lukebakken committed Nov 16, 2023
1 parent b2586ed commit 56bd2c9
Show file tree
Hide file tree
Showing 7 changed files with 51 additions and 43 deletions.
3 changes: 2 additions & 1 deletion projects/RabbitMQ.Client/client/api/ITcpClient.cs
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,8 @@
namespace RabbitMQ.Client
{
/// <summary>
/// Wrapper interface for standard TCP-client. Provides socket for socket frame handler class.
/// Wrapper interface for <see cref="Socket"/>.
/// Provides the socket for socket frame handler class.
/// </summary>
/// <remarks>Contains all methods that are currently in use in rabbitmq client.</remarks>
public interface ITcpClient : IDisposable
Expand Down
Original file line number Diff line number Diff line change
@@ -1,14 +1,16 @@
using System;
using System.Collections.Generic;
using System.Linq;
using System.Net;
using System.Net.Sockets;
using System.Threading.Tasks;

namespace RabbitMQ.Client.Impl
namespace RabbitMQ.Client
{
/// <summary>
/// Simple wrapper around TcpClient.
/// Simple wrapper around <see cref="Socket"/>.
/// </summary>
internal class TcpClientAdapter : ITcpClient
public class TcpClientAdapter : ITcpClient
{
private Socket _sock;

Expand All @@ -21,7 +23,7 @@ public virtual async Task ConnectAsync(string host, int port)
{
AssertSocket();
IPAddress[] adds = await Dns.GetHostAddressesAsync(host).ConfigureAwait(false);
IPAddress ep = TcpClientAdapterHelper.GetMatchingHost(adds, _sock.AddressFamily);
IPAddress ep = GetMatchingHost(adds, _sock.AddressFamily);
if (ep == default(IPAddress))
{
throw new ArgumentException($"No ip address could be resolved for {host}");
Expand All @@ -38,12 +40,11 @@ public virtual Task ConnectAsync(IPAddress ep, int port)

public virtual void Close()
{
_sock?.Dispose();
_sock.Dispose();
_sock = null;
}

[Obsolete("Override Dispose(bool) instead.")]
public virtual void Dispose()
public void Dispose()
{
Dispose(true);
}
Expand All @@ -52,11 +53,8 @@ protected virtual void Dispose(bool disposing)
{
if (disposing)
{
// dispose managed resources
Close();
}

// dispose unmanaged resources
}

public virtual NetworkStream GetStream()
Expand Down Expand Up @@ -106,5 +104,15 @@ private void AssertSocket()
throw new InvalidOperationException("Cannot perform operation as socket is null");
}
}

public static IPAddress GetMatchingHost(IReadOnlyCollection<IPAddress> addresses, AddressFamily addressFamily)
{
IPAddress ep = addresses.FirstOrDefault(a => a.AddressFamily == addressFamily);
if (ep is null && addresses.Count == 1 && addressFamily == AddressFamily.Unspecified)
{
return addresses.Single();
}
return ep;
}
}
}
4 changes: 2 additions & 2 deletions projects/RabbitMQ.Client/client/impl/SocketFrameHandler.cs
Original file line number Diff line number Diff line change
Expand Up @@ -117,7 +117,7 @@ internal sealed class SocketFrameHandler : IFrameHandler

// Resolve the hostname to know if it's even possible to even try IPv6
IPAddress[] adds = Dns.GetHostAddresses(endpoint.HostName);
IPAddress ipv6 = TcpClientAdapterHelper.GetMatchingHost(adds, AddressFamily.InterNetworkV6);
IPAddress ipv6 = TcpClientAdapter.GetMatchingHost(adds, AddressFamily.InterNetworkV6);

if (ipv6 == default(IPAddress))
{
Expand All @@ -141,7 +141,7 @@ internal sealed class SocketFrameHandler : IFrameHandler

if (_socket is null)
{
IPAddress ipv4 = TcpClientAdapterHelper.GetMatchingHost(adds, AddressFamily.InterNetwork);
IPAddress ipv4 = TcpClientAdapter.GetMatchingHost(adds, AddressFamily.InterNetwork);
if (ipv4 == default(IPAddress))
{
throw new ConnectFailureException("Connection failed", new ArgumentException($"No ip address could be resolved for {endpoint.HostName}"));
Expand Down
20 changes: 0 additions & 20 deletions projects/RabbitMQ.Client/client/impl/TcpClientAdapterHelper.cs

This file was deleted.

14 changes: 14 additions & 0 deletions projects/Unit/APIApproval.Approve.verified.txt
Original file line number Diff line number Diff line change
Expand Up @@ -765,6 +765,20 @@ namespace RabbitMQ.Client
public string ServerName { get; set; }
public System.Security.Authentication.SslProtocols Version { get; set; }
}
public class TcpClientAdapter : RabbitMQ.Client.ITcpClient, System.IDisposable
{
public TcpClientAdapter(System.Net.Sockets.Socket socket) { }
public virtual System.Net.Sockets.Socket Client { get; }
public virtual bool Connected { get; }
public virtual System.TimeSpan ReceiveTimeout { get; set; }
public virtual void Close() { }
public virtual System.Threading.Tasks.Task ConnectAsync(System.Net.IPAddress ep, int port) { }
public virtual System.Threading.Tasks.Task ConnectAsync(string host, int port) { }
public void Dispose() { }
protected virtual void Dispose(bool disposing) { }
public virtual System.Net.Sockets.NetworkStream GetStream() { }
public static System.Net.IPAddress GetMatchingHost(System.Collections.Generic.IReadOnlyCollection<System.Net.IPAddress> addresses, System.Net.Sockets.AddressFamily addressFamily) { }
}
public class TimerBasedCredentialRefresher : RabbitMQ.Client.ICredentialsRefresher
{
public TimerBasedCredentialRefresher() { }
Expand Down
18 changes: 12 additions & 6 deletions projects/Unit/TestConnectionFactory.cs
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,6 @@
using System.Collections.Generic;
using System.Net.Sockets;
using RabbitMQ.Client.Exceptions;
using RabbitMQ.Client.Impl;
using Xunit;

namespace RabbitMQ.Client.Unit
Expand Down Expand Up @@ -74,16 +73,23 @@ public void TestProperties()
[Fact]
public void TestConnectionFactoryWithCustomSocketFactory()
{
const int bufsz = 1024;
const int testBufsz = 1024;
int defaultReceiveBufsz = 0;
int defaultSendBufsz = 0;
using (var defaultSocket = new Socket(AddressFamily.InterNetwork, SocketType.Stream, ProtocolType.IP))
{
defaultReceiveBufsz = defaultSocket.ReceiveBufferSize;
defaultSendBufsz = defaultSocket.SendBufferSize;
}

ConnectionFactory cf = new()
{
SocketFactory = (AddressFamily af) =>
{
var socket = new Socket(af, SocketType.Stream, ProtocolType.Tcp)
{
SendBufferSize = bufsz,
ReceiveBufferSize = bufsz,
SendBufferSize = testBufsz,
ReceiveBufferSize = testBufsz,
NoDelay = false
};
return new TcpClientAdapter(socket);
Expand All @@ -94,8 +100,8 @@ public void TestConnectionFactoryWithCustomSocketFactory()
Assert.IsType<TcpClientAdapter>(c);
TcpClientAdapter tcpClientAdapter = (TcpClientAdapter)c;
Socket s = tcpClientAdapter.Client;
Assert.Equal(bufsz, s.ReceiveBufferSize);
Assert.Equal(bufsz, s.SendBufferSize);
Assert.NotEqual(defaultReceiveBufsz, s.ReceiveBufferSize);
Assert.NotEqual(defaultSendBufsz, s.SendBufferSize);
Assert.False(s.NoDelay);
}

Expand Down
7 changes: 3 additions & 4 deletions projects/Unit/TestTcpClientAdapter.cs
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,6 @@

using System.Net;
using System.Net.Sockets;
using RabbitMQ.Client.Impl;
using Xunit;

namespace RabbitMQ.Client.Unit
Expand All @@ -42,15 +41,15 @@ public class TestTcpClientAdapter
public void TcpClientAdapterHelperGetMatchingHostReturnNoAddressIfFamilyDoesNotMatch()
{
var address = IPAddress.Parse("127.0.0.1");
IPAddress matchingAddress = TcpClientAdapterHelper.GetMatchingHost(new[] { address }, AddressFamily.InterNetworkV6);
IPAddress matchingAddress = TcpClientAdapter.GetMatchingHost(new[] { address }, AddressFamily.InterNetworkV6);
Assert.Null(matchingAddress);
}

[Fact]
public void TcpClientAdapterHelperGetMatchingHostReturnsSingleAddressIfFamilyIsUnspecified()
{
var address = IPAddress.Parse("1.1.1.1");
IPAddress matchingAddress = TcpClientAdapterHelper.GetMatchingHost(new[] { address }, AddressFamily.Unspecified);
IPAddress matchingAddress = TcpClientAdapter.GetMatchingHost(new[] { address }, AddressFamily.Unspecified);
Assert.Equal(address, matchingAddress);
}

Expand All @@ -59,7 +58,7 @@ public void TcpClientAdapterHelperGetMatchingHostReturnNoAddressIfFamilyIsUnspec
{
var address = IPAddress.Parse("1.1.1.1");
var address2 = IPAddress.Parse("2.2.2.2");
IPAddress matchingAddress = TcpClientAdapterHelper.GetMatchingHost(new[] { address, address2 }, AddressFamily.Unspecified);
IPAddress matchingAddress = TcpClientAdapter.GetMatchingHost(new[] { address, address2 }, AddressFamily.Unspecified);
Assert.Null(matchingAddress);
}
}
Expand Down

0 comments on commit 56bd2c9

Please sign in to comment.