diff --git a/src/Renci.SshNet/Common/Extensions.cs b/src/Renci.SshNet/Common/Extensions.cs index 75d606fb7..3174f12a9 100644 --- a/src/Renci.SshNet/Common/Extensions.cs +++ b/src/Renci.SshNet/Common/Extensions.cs @@ -417,6 +417,14 @@ async Task WaitCore() return await completedTask.ConfigureAwait(false); } } + + extension(Task t) + { + internal bool IsCompletedSuccessfully + { + get { return t.Status == TaskStatus.RanToCompletion; } + } + } #endif } } diff --git a/src/Renci.SshNet/Common/ReadOnlyMemoryOwner.cs b/src/Renci.SshNet/Common/ReadOnlyMemoryOwner.cs new file mode 100644 index 000000000..b98f7288b --- /dev/null +++ b/src/Renci.SshNet/Common/ReadOnlyMemoryOwner.cs @@ -0,0 +1,87 @@ +#nullable enable +using System; +using System.Buffers; +using System.Diagnostics; +using System.Net; + +namespace Renci.SshNet.Common +{ + /// + /// A type representing ownership of a rented, read-only buffer. + /// + internal sealed class ReadOnlyMemoryOwner : IMemoryOwner + { + private ArrayBuffer _buffer; + + public ReadOnlyMemoryOwner(ArrayBuffer buffer) + { + _buffer = buffer; + + AssertValid(); + } + + [Conditional("DEBUG")] + private void AssertValid() + { + Debug.Assert( + _buffer.ActiveLength > 0 || _buffer.AvailableLength == 0, + "If the buffer is empty, then it should have been returned to the pool."); + } + + public int Length + { + get + { + AssertValid(); + return _buffer.ActiveLength; + } + } + + public bool IsEmpty + { + get + { + AssertValid(); + return _buffer.ActiveLength == 0; + } + } + + public ReadOnlySpan Span + { + get + { + AssertValid(); + return _buffer.ActiveReadOnlySpan; + } + } + + Memory IMemoryOwner.Memory + { + get + { + AssertValid(); + return _buffer.ActiveMemory; + } + } + + public void Slice(int start) + { + AssertValid(); + + _buffer.Discard(start); + + if (_buffer.ActiveLength == 0) + { + // Return the rented buffer as soon as it's no longer in use. + _buffer.ClearAndReturnBuffer(); + } + } + + public void Dispose() + { + AssertValid(); + + _buffer.ClearAndReturnBuffer(); + } + } +} diff --git a/src/Renci.SshNet/Sftp/ISftpSession.cs b/src/Renci.SshNet/Sftp/ISftpSession.cs index 53fcd86fb..4a804a35e 100644 --- a/src/Renci.SshNet/Sftp/ISftpSession.cs +++ b/src/Renci.SshNet/Sftp/ISftpSession.cs @@ -3,6 +3,7 @@ using System.Threading; using System.Threading.Tasks; +using Renci.SshNet.Common; using Renci.SshNet.Sftp.Responses; namespace Renci.SshNet.Sftp @@ -198,7 +199,7 @@ internal interface ISftpSession : ISubsystemSession /// its contains the data read from the file, or an empty /// array when the end of the file is reached. /// - Task RequestReadAsync(byte[] handle, ulong offset, uint length, CancellationToken cancellationToken); + Task RequestReadAsync(byte[] handle, ulong offset, uint length, CancellationToken cancellationToken); /// /// Performs a SSH_FXP_READDIR request. diff --git a/src/Renci.SshNet/Sftp/Responses/SftpDataResponse.cs b/src/Renci.SshNet/Sftp/Responses/SftpDataResponse.cs index 04a7d4d82..a824f7de1 100644 --- a/src/Renci.SshNet/Sftp/Responses/SftpDataResponse.cs +++ b/src/Renci.SshNet/Sftp/Responses/SftpDataResponse.cs @@ -1,4 +1,6 @@ -namespace Renci.SshNet.Sftp.Responses +using System; + +namespace Renci.SshNet.Sftp.Responses { internal sealed class SftpDataResponse : SftpResponse { @@ -7,7 +9,7 @@ public override SftpMessageTypes SftpMessageType get { return SftpMessageTypes.Data; } } - public byte[] Data { get; set; } + public ArraySegment Data { get; set; } public SftpDataResponse(uint protocolVersion) : base(protocolVersion) @@ -18,14 +20,14 @@ protected override void LoadData() { base.LoadData(); - Data = ReadBinary(); + Data = ReadBinarySegment(); } protected override void SaveData() { base.SaveData(); - WriteBinary(Data, 0, Data.Length); + WriteBinary(Data.Array, Data.Offset, Data.Count); } } } diff --git a/src/Renci.SshNet/Sftp/SftpFileReader.cs b/src/Renci.SshNet/Sftp/SftpFileReader.cs index 76e55849b..bdb309ecb 100644 --- a/src/Renci.SshNet/Sftp/SftpFileReader.cs +++ b/src/Renci.SshNet/Sftp/SftpFileReader.cs @@ -6,9 +6,7 @@ using System.Threading; using System.Threading.Tasks; -#if !NET using Renci.SshNet.Common; -#endif namespace Renci.SshNet.Sftp { @@ -58,7 +56,7 @@ public SftpFileReader(byte[] handle, ISftpSession sftpSession, int chunkSize, lo _cts = new CancellationTokenSource(); } - public async Task ReadAsync(CancellationToken cancellationToken) + public async Task ReadAsync(CancellationToken cancellationToken) { _exception?.Throw(); @@ -172,14 +170,21 @@ public void Dispose() if (_requests.Count > 0) { - // Cancel outstanding requests and observe the exception on them - // as an effort to prevent unhandled exceptions. - _cts.Cancel(); foreach (var request in _requests.Values) { - _ = request.Task.Exception; + // Return rented buffers to the pool, or observe exception on + // the task as an effort to prevent unhandled exceptions. + + if (request.Task.IsCompletedSuccessfully) + { + request.Task.GetAwaiter().GetResult().Dispose(); + } + else + { + _ = request.Task.Exception; + } } _requests.Clear(); @@ -190,7 +195,7 @@ public void Dispose() private sealed class Request { - public Request(ulong offset, uint count, Task task) + public Request(ulong offset, uint count, Task task) { Offset = offset; Count = count; @@ -199,7 +204,7 @@ public Request(ulong offset, uint count, Task task) public ulong Offset { get; } public uint Count { get; } - public Task Task { get; } + public Task Task { get; } } } } diff --git a/src/Renci.SshNet/Sftp/SftpFileStream.cs b/src/Renci.SshNet/Sftp/SftpFileStream.cs index d9552ebd8..d8c0d126f 100644 --- a/src/Renci.SshNet/Sftp/SftpFileStream.cs +++ b/src/Renci.SshNet/Sftp/SftpFileStream.cs @@ -26,7 +26,7 @@ public sealed partial class SftpFileStream : Stream private readonly int _readBufferSize; private SftpFileReader? _sftpFileReader; - private ReadOnlyMemory _readBuffer; + private ReadOnlyMemoryOwner _readBuffer; private System.Net.ArrayBuffer _writeBuffer; private long _position; @@ -153,6 +153,7 @@ private SftpFileStream( _readBufferSize = readBufferSize; _position = position; _writeBuffer = new System.Net.ArrayBuffer(writeBufferSize); + _readBuffer = new ReadOnlyMemoryOwner(new System.Net.ArrayBuffer(0, usePool: true)); _sftpFileReader = initialReader; } @@ -390,7 +391,7 @@ await _session.RequestWriteAsync( private void InvalidateReads() { - _readBuffer = ReadOnlyMemory.Empty; + _readBuffer.Dispose(); _sftpFileReader?.Dispose(); _sftpFileReader = null; } @@ -441,7 +442,7 @@ private int Read(Span buffer) var bytesRead = Math.Min(buffer.Length, _readBuffer.Length); _readBuffer.Span.Slice(0, bytesRead).CopyTo(buffer); - _readBuffer = _readBuffer.Slice(bytesRead); + _readBuffer.Slice(bytesRead); _position += bytesRead; @@ -494,8 +495,8 @@ private async ValueTask ReadAsync(Memory buffer, CancellationToken ca var bytesRead = Math.Min(buffer.Length, _readBuffer.Length); - _readBuffer.Slice(0, bytesRead).CopyTo(buffer); - _readBuffer = _readBuffer.Slice(bytesRead); + _readBuffer.Span.Slice(0, bytesRead).CopyTo(buffer.Span); + _readBuffer.Slice(bytesRead); _position += bytesRead; @@ -649,7 +650,7 @@ public override long Seek(long offset, SeekOrigin origin) if (readBufferStart <= newPosition && newPosition <= readBufferEnd) { - _readBuffer = _readBuffer.Slice((int)(newPosition - readBufferStart)); + _readBuffer.Slice((int)(newPosition - readBufferStart)); } else { diff --git a/src/Renci.SshNet/Sftp/SftpSession.cs b/src/Renci.SshNet/Sftp/SftpSession.cs index 17b0baff0..d5648bda4 100644 --- a/src/Renci.SshNet/Sftp/SftpSession.cs +++ b/src/Renci.SshNet/Sftp/SftpSession.cs @@ -3,6 +3,7 @@ using System.Collections.Generic; using System.Diagnostics; using System.Globalization; +using System.Net; using System.Text; using System.Threading; using System.Threading.Tasks; @@ -24,7 +25,7 @@ internal sealed class SftpSession : SubsystemSession, ISftpSession private readonly Dictionary _requests = new Dictionary(); private readonly ISftpResponseFactory _sftpResponseFactory; private readonly Encoding _encoding; - private System.Net.ArrayBuffer _buffer = new(32 * 1024); + private ArrayBuffer _buffer = new(32 * 1024); private EventWaitHandle _sftpVersionConfirmed = new AutoResetEvent(initialState: false); private IDictionary _supportedExtensions; @@ -495,7 +496,7 @@ public byte[] RequestRead(byte[] handle, ulong offset, uint length) length, response => { - data = response.Data; + data = response.Data.ToArray(); wait.SetIgnoringObjectDisposed(); }, response => @@ -526,28 +527,42 @@ public byte[] RequestRead(byte[] handle, ulong offset, uint length) } /// - public Task RequestReadAsync(byte[] handle, ulong offset, uint length, CancellationToken cancellationToken) + public Task RequestReadAsync(byte[] handle, ulong offset, uint length, CancellationToken cancellationToken) { Debug.Assert(length > 0, "This implementation cannot distinguish between EOF and zero-length reads"); if (cancellationToken.IsCancellationRequested) { - return Task.FromCanceled(cancellationToken); + return Task.FromCanceled(cancellationToken); } - var tcs = new TaskCompletionSource(TaskCreationOptions.RunContinuationsAsynchronously); + var tcs = new TaskCompletionSource(TaskCreationOptions.RunContinuationsAsynchronously); SendRequest(new SftpReadRequest(ProtocolVersion, NextRequestId, handle, offset, length, - response => tcs.TrySetResult(response.Data), + response => + { + ArrayBuffer buffer = new(response.Data.Count, usePool: true); + + response.Data.AsSpan().CopyTo(buffer.AvailableSpan); + + buffer.Commit(response.Data.Count); + + ReadOnlyMemoryOwner owner = new(buffer); + + if (!tcs.TrySetResult(owner)) + { + owner.Dispose(); + } + }, response => { if (response.StatusCode == StatusCode.Eof) { - _ = tcs.TrySetResult(Array.Empty()); + _ = tcs.TrySetResult(new(new(0, usePool: true))); } else { diff --git a/test/Renci.SshNet.Tests/Classes/Sftp/Responses/SftpDataResponseTest.cs b/test/Renci.SshNet.Tests/Classes/Sftp/Responses/SftpDataResponseTest.cs index 99431d647..5dc49603c 100644 --- a/test/Renci.SshNet.Tests/Classes/Sftp/Responses/SftpDataResponseTest.cs +++ b/test/Renci.SshNet.Tests/Classes/Sftp/Responses/SftpDataResponseTest.cs @@ -32,7 +32,7 @@ public void Constructor() { var target = new SftpDataResponse(_protocolVersion); - Assert.IsNull(target.Data); + Assert.AreEqual(default, target.Data); Assert.AreEqual(_protocolVersion, target.ProtocolVersion); Assert.AreEqual((uint)0, target.ResponseId); Assert.AreEqual(SftpMessageTypes.Data, target.SftpMessageType); @@ -52,7 +52,6 @@ public void Load() target.Load(sshData); - Assert.IsNotNull(target.Data); Assert.IsTrue(target.Data.SequenceEqual(_data)); Assert.AreEqual(_protocolVersion, target.ProtocolVersion); Assert.AreEqual(_responseId, target.ResponseId); diff --git a/test/Renci.SshNet.Tests/Classes/Sftp/SftpDataResponseBuilder.cs b/test/Renci.SshNet.Tests/Classes/Sftp/SftpDataResponseBuilder.cs index d3aac4299..5cc852c34 100644 --- a/test/Renci.SshNet.Tests/Classes/Sftp/SftpDataResponseBuilder.cs +++ b/test/Renci.SshNet.Tests/Classes/Sftp/SftpDataResponseBuilder.cs @@ -1,4 +1,6 @@ -using Renci.SshNet.Sftp.Responses; +using System; + +using Renci.SshNet.Sftp.Responses; namespace Renci.SshNet.Tests.Classes.Sftp { @@ -31,7 +33,7 @@ public SftpDataResponse Build() return new SftpDataResponse(_protocolVersion) { ResponseId = _responseId, - Data = _data + Data = new ArraySegment(_data) }; } } diff --git a/test/Renci.SshNet.Tests/Classes/Sftp/SftpFileStreamTest.cs b/test/Renci.SshNet.Tests/Classes/Sftp/SftpFileStreamTest.cs index 603aabc7d..2d497f245 100644 --- a/test/Renci.SshNet.Tests/Classes/Sftp/SftpFileStreamTest.cs +++ b/test/Renci.SshNet.Tests/Classes/Sftp/SftpFileStreamTest.cs @@ -217,6 +217,10 @@ private void TestSendsBufferedWrites(Action flushAction) sessionMock.Setup(s => s.CalculateOptimalReadLength(It.IsAny())).Returns(x => x); sessionMock.Setup(s => s.CalculateOptimalWriteLength(It.IsAny(), It.IsAny())).Returns((x, _) => x); sessionMock.Setup(s => s.IsOpen).Returns(true); + sessionMock + .Setup(s => s.RequestReadAsync(It.IsAny(), It.IsAny(), It.IsAny(), It.IsAny())) + .Returns((_, _, _, _) => Task.FromResult(new ReadOnlyMemoryOwner(new(0, usePool: true)))); + SetupRemoteSize(sessionMock, 0); var s = SftpFileStream.Open(sessionMock.Object, "file.txt", FileMode.OpenOrCreate, FileAccess.ReadWrite, bufferSize: 1024); @@ -301,6 +305,9 @@ private void TestFstatFailure(Action s.CalculateOptimalWriteLength(It.IsAny(), It.IsAny())).Returns((x, _) => x); sessionMock.Setup(p => p.SessionLoggerFactory).Returns(NullLoggerFactory.Instance); sessionMock.Setup(s => s.IsOpen).Returns(true); + sessionMock + .Setup(s => s.RequestReadAsync(It.IsAny(), It.IsAny(), It.IsAny(), It.IsAny())) + .Returns((_, _, _, _) => Task.FromResult(new ReadOnlyMemoryOwner(new(0, usePool: true)))); fstatSetup(sessionMock.Setup(s => s.RequestFStat(It.IsAny()))); diff --git a/test/Renci.SshNet.Tests/Classes/Sftp/SftpSessionTest_DataReceived_MultipleSftpMessagesInSingleSshDataMessage.cs b/test/Renci.SshNet.Tests/Classes/Sftp/SftpSessionTest_DataReceived_MultipleSftpMessagesInSingleSshDataMessage.cs index f8fa7512d..feada4d09 100644 --- a/test/Renci.SshNet.Tests/Classes/Sftp/SftpSessionTest_DataReceived_MultipleSftpMessagesInSingleSshDataMessage.cs +++ b/test/Renci.SshNet.Tests/Classes/Sftp/SftpSessionTest_DataReceived_MultipleSftpMessagesInSingleSshDataMessage.cs @@ -187,12 +187,14 @@ protected void Arrange() protected void Act() { Task openTask = _sftpSession.RequestOpenAsync(_path, Flags.Read, CancellationToken.None); - Task readTask = _sftpSession.RequestReadAsync(_handle, _offset, _length, CancellationToken.None); + Task readTask = _sftpSession.RequestReadAsync(_handle, _offset, _length, CancellationToken.None); Task.WaitAll(openTask, readTask); _actualHandle = openTask.Result; - _actualData = readTask.Result; + + using ReadOnlyMemoryOwner actualData = readTask.Result; + _actualData = actualData.Span.ToArray(); } [TestMethod]