Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add async support to SftpClient and SftpFileStream #819

Merged
merged 11 commits into from
Dec 14, 2021
22 changes: 22 additions & 0 deletions src/Renci.SshNet/Abstractions/DnsAbstraction.cs
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,10 @@
using System.Net;
using System.Net.Sockets;

#if FEATURE_TAP
using System.Threading.Tasks;
#endif

#if FEATURE_DNS_SYNC
#elif FEATURE_DNS_APM
using Renci.SshNet.Common;
Expand Down Expand Up @@ -87,5 +91,23 @@ public static IPAddress[] GetHostAddresses(string hostNameOrAddress)
#endif // FEATURE_DEVICEINFORMATION_APM
#endif
}

#if FEATURE_TAP
/// <summary>
/// Returns the Internet Protocol (IP) addresses for the specified host.
/// </summary>
/// <param name="hostNameOrAddress">The host name or IP address to resolve</param>
/// <returns>
/// A task with result of an array of type <see cref="IPAddress"/> that holds the IP addresses for the host that
/// is specified by the <paramref name="hostNameOrAddress"/> parameter.
/// </returns>
/// <exception cref="ArgumentNullException"><paramref name="hostNameOrAddress"/> is <c>null</c>.</exception>
/// <exception cref="SocketException">An error is encountered when resolving <paramref name="hostNameOrAddress"/>.</exception>
public static Task<IPAddress[]> GetHostAddressesAsync(string hostNameOrAddress)
IgorMilavec marked this conversation as resolved.
Show resolved Hide resolved
{
return Dns.GetHostAddressesAsync(hostNameOrAddress);
}
#endif

}
}
17 changes: 17 additions & 0 deletions src/Renci.SshNet/Abstractions/SocketAbstraction.cs
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,9 @@
using System.Net;
using System.Net.Sockets;
using System.Threading;
#if FEATURE_TAP
using System.Threading.Tasks;
#endif
using Renci.SshNet.Common;
using Renci.SshNet.Messages.Transport;

Expand Down Expand Up @@ -59,6 +62,13 @@ public static void Connect(Socket socket, IPEndPoint remoteEndpoint, TimeSpan co
ConnectCore(socket, remoteEndpoint, connectTimeout, false);
}

#if FEATURE_TAP
public static Task ConnectAsync(Socket socket, IPEndPoint remoteEndpoint, CancellationToken cancellationToken)
{
return socket.ConnectAsync(remoteEndpoint, cancellationToken);
}
#endif

private static void ConnectCore(Socket socket, IPEndPoint remoteEndpoint, TimeSpan connectTimeout, bool ownsSocket)
{
#if FEATURE_SOCKET_EAP
Expand Down Expand Up @@ -317,6 +327,13 @@ public static byte[] Read(Socket socket, int size, TimeSpan timeout)
return buffer;
}

#if FEATURE_TAP
public static Task<int> ReadAsync(Socket socket, byte[] buffer, int offset, int length, CancellationToken cancellationToken)
{
return socket.ReceiveAsync(buffer, offset, length, cancellationToken);
}
#endif

/// <summary>
/// Receives data from a bound <see cref="Socket"/> into a receive buffer.
/// </summary>
Expand Down
119 changes: 119 additions & 0 deletions src/Renci.SshNet/Abstractions/SocketExtensions.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,119 @@
#if FEATURE_TAP
using System;
using System.Net;
using System.Net.Sockets;
using System.Runtime.CompilerServices;
using System.Threading;
using System.Threading.Tasks;

namespace Renci.SshNet.Abstractions
{
// Async helpers based on https://devblogs.microsoft.com/pfxteam/awaiting-socket-operations/

internal static class SocketExtensions
{
sealed class SocketAsyncEventArgsAwaitable : SocketAsyncEventArgs, INotifyCompletion
{
private readonly static Action SENTINEL = () => { };

private bool isCancelled;
private Action continuationAction;

public SocketAsyncEventArgsAwaitable()
{
Completed += delegate { SetCompleted(); };
}

public SocketAsyncEventArgsAwaitable ExecuteAsync(Func<SocketAsyncEventArgs, bool> func)
{
if (!func(this))
{
SetCompleted();
}
return this;
}

public void SetCompleted()
{
IsCompleted = true;
var continuation = continuationAction ?? Interlocked.CompareExchange(ref continuationAction, SENTINEL, null);
if (continuation != null)
{
continuation();
}
}

public void SetCancelled()
{
isCancelled = true;
SetCompleted();
}

public SocketAsyncEventArgsAwaitable GetAwaiter() { return this; }

public bool IsCompleted { get; private set; }

void INotifyCompletion.OnCompleted(Action continuation)
{
if (continuationAction == SENTINEL || Interlocked.CompareExchange(ref continuationAction, continuation, null) == SENTINEL)
{
// We have already completed; run continuation asynchronously
Task.Run(continuation);
}
}

public void GetResult()
{
if (isCancelled)
{
throw new TaskCanceledException();
}
else if (IsCompleted)
{
if (SocketError != SocketError.Success)
{
throw new SocketException((int)SocketError);
}
}
else
{
// We don't support sync/async
throw new InvalidOperationException("The asynchronous operation has not yet completed.");
}
}
}

public static async Task ConnectAsync(this Socket socket, IPEndPoint remoteEndpoint, CancellationToken cancellationToken)
{
cancellationToken.ThrowIfCancellationRequested();

using (var args = new SocketAsyncEventArgsAwaitable())
{
args.RemoteEndPoint = remoteEndpoint;

using (cancellationToken.Register(o => ((SocketAsyncEventArgsAwaitable)o).SetCancelled(), args, false))
{
await args.ExecuteAsync(socket.ConnectAsync);
}
}
}

public static async Task<int> ReceiveAsync(this Socket socket, byte[] buffer, int offset, int length, CancellationToken cancellationToken)
{
cancellationToken.ThrowIfCancellationRequested();

using (var args = new SocketAsyncEventArgsAwaitable())
{
args.SetBuffer(buffer, offset, length);

using (cancellationToken.Register(o => ((SocketAsyncEventArgsAwaitable)o).SetCancelled(), args, false))
{
await args.ExecuteAsync(socket.ReceiveAsync);
}

return args.BytesTransferred;
}
}
}
}
#endif
80 changes: 80 additions & 0 deletions src/Renci.SshNet/BaseClient.cs
Original file line number Diff line number Diff line change
@@ -1,6 +1,9 @@
using System;
using System.Net.Sockets;
using System.Threading;
#if FEATURE_TAP
using System.Threading.Tasks;
#endif
using Renci.SshNet.Abstractions;
using Renci.SshNet.Common;
using Renci.SshNet.Messages.Transport;
Expand Down Expand Up @@ -239,6 +242,63 @@ public void Connect()
StartKeepAliveTimer();
}

#if FEATURE_TAP
/// <summary>
/// Asynchronously connects client to the server.
/// </summary>
/// <param name="cancellationToken">The <see cref="CancellationToken"/> to observe.</param>
/// <returns>A <see cref="Task"/> that represents the asynchronous connect operation.
/// </returns>
/// <exception cref="InvalidOperationException">The client is already connected.</exception>
/// <exception cref="ObjectDisposedException">The method was called after the client was disposed.</exception>
/// <exception cref="SocketException">Socket connection to the SSH server or proxy server could not be established, or an error occurred while resolving the hostname.</exception>
/// <exception cref="SshConnectionException">SSH session could not be established.</exception>
/// <exception cref="SshAuthenticationException">Authentication of SSH session failed.</exception>
/// <exception cref="ProxyException">Failed to establish proxy connection.</exception>
public async Task ConnectAsync(CancellationToken cancellationToken)
{
CheckDisposed();
cancellationToken.ThrowIfCancellationRequested();

// TODO (see issue #1758):
// we're not stopping the keep-alive timer and disposing the session here
//
// we could do this but there would still be side effects as concrete
// implementations may still hang on to the original session
//
// therefore it would be better to actually invoke the Disconnect method
// (and then the Dispose on the session) but even that would have side effects
// eg. it would remove all forwarded ports from SshClient
//
// I think we should modify our concrete clients to better deal with a
// disconnect. In case of SshClient this would mean not removing the
// forwarded ports on disconnect (but only on dispose ?) and link a
// forwarded port with a client instead of with a session
//
// To be discussed with Oleg (or whoever is interested)
if (IsSessionConnected())
throw new InvalidOperationException("The client is already connected.");

OnConnecting();

Session = await CreateAndConnectSessionAsync(cancellationToken).ConfigureAwait(false);
try
{
// Even though the method we invoke makes you believe otherwise, at this point only
// the SSH session itself is connected.
OnConnected();
}
catch
{
// Only dispose the session as Disconnect() would have side-effects (such as remove forwarded
// ports in SshClient).
DisposeSession();
throw;
}
StartKeepAliveTimer();
}
#endif

/// <summary>
/// Disconnects client from the server.
/// </summary>
Expand Down Expand Up @@ -473,6 +533,26 @@ private ISession CreateAndConnectSession()
}
}

#if FEATURE_TAP
private async Task<ISession> CreateAndConnectSessionAsync(CancellationToken cancellationToken)
{
var session = _serviceFactory.CreateSession(ConnectionInfo, _serviceFactory.CreateSocketFactory());
session.HostKeyReceived += Session_HostKeyReceived;
session.ErrorOccured += Session_ErrorOccured;

try
{
await session.ConnectAsync(cancellationToken).ConfigureAwait(false);
return session;
}
catch
{
DisposeSession(session);
throw;
}
}
#endif

private void DisposeSession(ISession session)
{
session.ErrorOccured -= Session_ErrorOccured;
Expand Down
45 changes: 45 additions & 0 deletions src/Renci.SshNet/Connection/ConnectorBase.cs
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,11 @@
using System;
using System.Net;
using System.Net.Sockets;
using System.Threading;

#if FEATURE_TAP
using System.Threading.Tasks;
#endif

namespace Renci.SshNet.Connection
{
Expand All @@ -21,6 +26,10 @@ protected ConnectorBase(ISocketFactory socketFactory)

public abstract Socket Connect(IConnectionInfo connectionInfo);

#if FEATURE_TAP
public abstract Task<Socket> ConnectAsync(IConnectionInfo connectionInfo, CancellationToken cancellationToken);
#endif

/// <summary>
/// Establishes a socket connection to the specified host and port.
/// </summary>
Expand Down Expand Up @@ -54,6 +63,42 @@ protected Socket SocketConnect(string host, int port, TimeSpan timeout)
}
}

#if FEATURE_TAP
/// <summary>
/// Establishes a socket connection to the specified host and port.
/// </summary>
/// <param name="host">The host name of the server to connect to.</param>
/// <param name="port">The port to connect to.</param>
/// <param name="cancellationToken">The cancellation token to observe.</param>
/// <exception cref="SshOperationTimeoutException">The connection failed to establish within the configured <see cref="ConnectionInfo.Timeout"/>.</exception>
/// <exception cref="SocketException">An error occurred trying to establish the connection.</exception>
protected async Task<Socket> SocketConnectAsync(string host, int port, CancellationToken cancellationToken)
{
cancellationToken.ThrowIfCancellationRequested();

var ipAddress = (await DnsAbstraction.GetHostAddressesAsync(host).ConfigureAwait(false))[0];
var ep = new IPEndPoint(ipAddress, port);

DiagnosticAbstraction.Log(string.Format("Initiating connection to '{0}:{1}'.", host, port));

var socket = SocketFactory.Create(ep.AddressFamily, SocketType.Stream, ProtocolType.Tcp);
try
{
await SocketAbstraction.ConnectAsync(socket, ep, cancellationToken).ConfigureAwait(false);

const int socketBufferSize = 2 * Session.MaximumSshPacketSize;
socket.SendBufferSize = socketBufferSize;
socket.ReceiveBufferSize = socketBufferSize;
return socket;
}
catch (Exception)
{
socket.Dispose();
throw;
}
}
#endif

protected static byte SocketReadByte(Socket socket)
{
var buffer = new byte[1];
Expand Down
10 changes: 9 additions & 1 deletion src/Renci.SshNet/Connection/DirectConnector.cs
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
using System.Net.Sockets;
using System.Threading;

namespace Renci.SshNet.Connection
{
internal class DirectConnector : ConnectorBase
internal sealed class DirectConnector : ConnectorBase
{
public DirectConnector(ISocketFactory socketFactory) : base(socketFactory)
{
Expand All @@ -12,5 +13,12 @@ public override Socket Connect(IConnectionInfo connectionInfo)
{
return SocketConnect(connectionInfo.Host, connectionInfo.Port, connectionInfo.Timeout);
}

#if FEATURE_TAP
public override System.Threading.Tasks.Task<Socket> ConnectAsync(IConnectionInfo connectionInfo, CancellationToken cancellationToken)
{
return SocketConnectAsync(connectionInfo.Host, connectionInfo.Port, cancellationToken);
}
#endif
}
}
Loading