Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add async support to SftpClient and SftpFileStream #819

Merged
merged 11 commits into from
Dec 14, 2021
25 changes: 22 additions & 3 deletions src/Renci.SshNet/Abstractions/DnsAbstraction.cs
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,13 @@
using System.Net;
using System.Net.Sockets;

#if FEATURE_TAP
using System.Threading.Tasks;
#endif

#if FEATURE_DNS_SYNC
#elif FEATURE_DNS_APM
using Renci.SshNet.Common;
#elif FEATURE_DNS_TAP
#elif FEATURE_DEVICEINFORMATION_APM
using System.Collections.Generic;
using System.Linq;
Expand Down Expand Up @@ -42,8 +45,6 @@ public static IPAddress[] GetHostAddresses(string hostNameOrAddress)
if (!asyncResult.AsyncWaitHandle.WaitOne(Session.InfiniteTimeSpan))
throw new SshOperationTimeoutException("Timeout resolving host name.");
return Dns.EndGetHostAddresses(asyncResult);
#elif FEATURE_DNS_TAP
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This should be reverted, if not we'll use the IPAddress.TryParse(hostNameOrAddress, out address) implementation for .NET Standard 1.3 and UAP 10. It's not like we're always using the DnsAbstrations.GetHostAddressesAsync() method for those target frameworks that support TAP.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ups, I was too quick to remove and missed that the fallback was just IPAddress.TryParse. I have reverted the change, but I propose to remove FEATURE_DNS_TAP in the future from this method as we're using Sync/Async, which could cause problems.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I agree. These conditional compilation symbols are a necessary evil. The less we have to maintain, the better.

return Dns.GetHostAddressesAsync(hostNameOrAddress).GetAwaiter().GetResult();
#else
IPAddress address;
if (IPAddress.TryParse(hostNameOrAddress, out address))
Expand Down Expand Up @@ -87,5 +88,23 @@ public static IPAddress[] GetHostAddresses(string hostNameOrAddress)
#endif // FEATURE_DEVICEINFORMATION_APM
#endif
}

#if FEATURE_TAP
/// <summary>
/// Returns the Internet Protocol (IP) addresses for the specified host.
/// </summary>
/// <param name="hostNameOrAddress">The host name or IP address to resolve</param>
/// <returns>
/// A task with result of an array of type <see cref="IPAddress"/> that holds the IP addresses for the host that
/// is specified by the <paramref name="hostNameOrAddress"/> parameter.
/// </returns>
/// <exception cref="ArgumentNullException"><paramref name="hostNameOrAddress"/> is <c>null</c>.</exception>
/// <exception cref="SocketException">An error is encountered when resolving <paramref name="hostNameOrAddress"/>.</exception>
public static Task<IPAddress[]> GetHostAddressesAsync(string hostNameOrAddress)
IgorMilavec marked this conversation as resolved.
Show resolved Hide resolved
{
return Dns.GetHostAddressesAsync(hostNameOrAddress);
}
#endif

}
}
43 changes: 30 additions & 13 deletions src/Renci.SshNet/Abstractions/SocketAbstraction.cs
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,9 @@
using System.Net;
using System.Net.Sockets;
using System.Threading;
#if FEATURE_TAP
using System.Threading.Tasks;
#endif
using Renci.SshNet.Common;
using Renci.SshNet.Messages.Transport;

Expand Down Expand Up @@ -59,15 +62,22 @@ public static void Connect(Socket socket, IPEndPoint remoteEndpoint, TimeSpan co
ConnectCore(socket, remoteEndpoint, connectTimeout, false);
}

#if FEATURE_TAP
public static Task ConnectAsync(Socket socket, IPEndPoint remoteEndpoint, CancellationToken cancellationToken)
{
return socket.ConnectAsync(remoteEndpoint, cancellationToken);
}
#endif

private static void ConnectCore(Socket socket, IPEndPoint remoteEndpoint, TimeSpan connectTimeout, bool ownsSocket)
{
#if FEATURE_SOCKET_EAP
var connectCompleted = new ManualResetEvent(false);
var args = new SocketAsyncEventArgs
{
UserToken = connectCompleted,
RemoteEndPoint = remoteEndpoint
};
{
IgorMilavec marked this conversation as resolved.
Show resolved Hide resolved
UserToken = connectCompleted,
RemoteEndPoint = remoteEndpoint
};
args.Completed += ConnectCompleted;

if (socket.ConnectAsync(args))
Expand Down Expand Up @@ -97,7 +107,7 @@ private static void ConnectCore(Socket socket, IPEndPoint remoteEndpoint, TimeSp

if (args.SocketError != SocketError.Success)
{
var socketError = (int) args.SocketError;
var socketError = (int)args.SocketError;
IgorMilavec marked this conversation as resolved.
Show resolved Hide resolved

if (ownsSocket)
{
Expand All @@ -124,7 +134,7 @@ private static void ConnectCore(Socket socket, IPEndPoint remoteEndpoint, TimeSp
throw new SshOperationTimeoutException(string.Format(CultureInfo.InvariantCulture,
"Connection failed to establish within {0:F0} milliseconds.", connectTimeout.TotalMilliseconds));
#else
#error Connecting to a remote endpoint is not implemented.
#error Connecting to a remote endpoint is not implemented.
IgorMilavec marked this conversation as resolved.
Show resolved Hide resolved
#endif
}

Expand All @@ -144,7 +154,7 @@ public static void ClearReadBuffer(Socket socket)
public static int ReadPartial(Socket socket, byte[] buffer, int offset, int size, TimeSpan timeout)
{
#if FEATURE_SOCKET_SYNC
socket.ReceiveTimeout = (int) timeout.TotalMilliseconds;
socket.ReceiveTimeout = (int)timeout.TotalMilliseconds;
IgorMilavec marked this conversation as resolved.
Show resolved Hide resolved

try
{
Expand Down Expand Up @@ -197,7 +207,7 @@ public static int ReadPartial(Socket socket, byte[] buffer, int offset, int size
receiveCompleted.Dispose();
}
#else
#error Receiving data from a Socket is not implemented.
#error Receiving data from a Socket is not implemented.
IgorMilavec marked this conversation as resolved.
Show resolved Hide resolved
#endif
}

Expand Down Expand Up @@ -259,7 +269,7 @@ public static void ReadContinuous(Socket socket, byte[] buffer, int offset, int
if (readToken.Exception != null)
throw readToken.Exception;
#else
#error Receiving data from a Socket is not implemented.
#error Receiving data from a Socket is not implemented.
IgorMilavec marked this conversation as resolved.
Show resolved Hide resolved
#endif
}

Expand Down Expand Up @@ -290,7 +300,7 @@ public static int ReadByte(Socket socket, TimeSpan timeout)
/// <exception cref="SocketException">The write failed.</exception>
public static void SendByte(Socket socket, byte value)
{
var buffer = new[] {value};
var buffer = new[] { value };
IgorMilavec marked this conversation as resolved.
Show resolved Hide resolved
Send(socket, buffer, 0, 1);
}

Expand All @@ -317,6 +327,13 @@ public static byte[] Read(Socket socket, int size, TimeSpan timeout)
return buffer;
}

#if FEATURE_TAP
public static Task<int> ReadAsync(Socket socket, byte[] buffer, int offset, int length, CancellationToken cancellationToken)
{
return socket.ReceiveAsync(buffer, offset, length, cancellationToken);
}
#endif

/// <summary>
/// Receives data from a bound <see cref="Socket"/> into a receive buffer.
/// </summary>
Expand Down Expand Up @@ -369,7 +386,7 @@ public static int Read(Socket socket, byte[] buffer, int offset, int size, TimeS
throw new SshOperationTimeoutException(string.Format(CultureInfo.InvariantCulture,
"Socket read operation has timed out after {0:F0} milliseconds.", readTimeout.TotalMilliseconds));

throw;
throw;
}
}
while (totalBytesRead < totalBytesToRead);
Expand Down Expand Up @@ -488,7 +505,7 @@ public static void Send(Socket socket, byte[] data, int offset, int size)
sendCompleted.Dispose();
}
#else
#error Sending data to a Socket is not implemented.
#error Sending data to a Socket is not implemented.
IgorMilavec marked this conversation as resolved.
Show resolved Hide resolved
#endif
}

Expand All @@ -508,7 +525,7 @@ public static bool IsErrorResumable(SocketError socketError)
#if FEATURE_SOCKET_EAP
private static void ConnectCompleted(object sender, SocketAsyncEventArgs e)
{
var eventWaitHandle = (ManualResetEvent) e.UserToken;
var eventWaitHandle = (ManualResetEvent)e.UserToken;
IgorMilavec marked this conversation as resolved.
Show resolved Hide resolved
if (eventWaitHandle != null)
eventWaitHandle.Set();
}
Expand Down
119 changes: 119 additions & 0 deletions src/Renci.SshNet/Abstractions/SocketExtensions.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,119 @@
#if FEATURE_TAP
using System;
using System.Net;
using System.Net.Sockets;
using System.Runtime.CompilerServices;
using System.Threading;
using System.Threading.Tasks;

namespace Renci.SshNet.Abstractions
{
// Async helpers based on https://devblogs.microsoft.com/pfxteam/awaiting-socket-operations/

internal static class SocketExtensions
{
sealed class SocketAsyncEventArgsAwaitable : SocketAsyncEventArgs, INotifyCompletion
{
private readonly static Action SENTINEL = () => { };

private bool isCancelled;
private Action continuationAction;

public SocketAsyncEventArgsAwaitable()
{
Completed += delegate { SetCompleted(); };
}

public SocketAsyncEventArgsAwaitable ExecuteAsync(Func<SocketAsyncEventArgs, bool> func)
{
if (!func(this))
{
SetCompleted();
}
return this;
}

public void SetCompleted()
{
IsCompleted = true;
var continuation = continuationAction ?? Interlocked.CompareExchange(ref continuationAction, SENTINEL, null);
if (continuation != null)
{
continuation();
}
}

public void SetCancelled()
{
isCancelled = true;
SetCompleted();
}

public SocketAsyncEventArgsAwaitable GetAwaiter() { return this; }

public bool IsCompleted { get; private set; }

void INotifyCompletion.OnCompleted(Action continuation)
{
if (continuationAction == SENTINEL || Interlocked.CompareExchange(ref continuationAction, continuation, null) == SENTINEL)
{
// We have already completed; run continuation asynchronously
Task.Run(continuation);
}
}

public void GetResult()
{
if (isCancelled)
{
throw new TaskCanceledException();
}
else if (IsCompleted)
{
if (SocketError != SocketError.Success)
{
throw new SocketException((int)SocketError);
}
}
else
{
// We don't support sync/async
throw new InvalidOperationException("The asynchronous operation has not yet completed.");
}
}
}

public static async Task ConnectAsync(this Socket socket, IPEndPoint remoteEndpoint, CancellationToken cancellationToken)
{
cancellationToken.ThrowIfCancellationRequested();

using (var args = new SocketAsyncEventArgsAwaitable())
{
args.RemoteEndPoint = remoteEndpoint;

using (cancellationToken.Register(o => ((SocketAsyncEventArgsAwaitable)o).SetCancelled(), args, false))
{
await args.ExecuteAsync(socket.ConnectAsync);
}
}
}

public static async Task<int> ReceiveAsync(this Socket socket, byte[] buffer, int offset, int length, CancellationToken cancellationToken)
{
cancellationToken.ThrowIfCancellationRequested();

using (var args = new SocketAsyncEventArgsAwaitable())
{
args.SetBuffer(buffer, offset, length);

using (cancellationToken.Register(o => ((SocketAsyncEventArgsAwaitable)o).SetCancelled(), args, false))
{
await args.ExecuteAsync(socket.ReceiveAsync);
}

return args.BytesTransferred;
}
}
}
}
#endif
Loading