diff --git a/src/Renci.SshNet/ShellStream.cs b/src/Renci.SshNet/ShellStream.cs index e3ef7ff07..3d41d00fd 100644 --- a/src/Renci.SshNet/ShellStream.cs +++ b/src/Renci.SshNet/ShellStream.cs @@ -891,6 +891,57 @@ public override void WriteByte(byte value) Write([value]); } + /// + public override Task WriteAsync(byte[] buffer, int offset, int count, CancellationToken cancellationToken) + { +#if !NET + ThrowHelper. +#endif + ValidateBufferArguments(buffer, offset, count); + + return WriteAsync(buffer.AsMemory(offset, count), cancellationToken).AsTask(); + } + +#if NET + /// + public override async ValueTask WriteAsync(ReadOnlyMemory buffer, CancellationToken cancellationToken = default) +#else + private async ValueTask WriteAsync(ReadOnlyMemory buffer, CancellationToken cancellationToken) +#endif + { + ThrowHelper.ThrowObjectDisposedIf(_disposed, this); + + while (!buffer.IsEmpty) + { + if (_writeBuffer.AvailableLength == 0) + { + await FlushAsync(cancellationToken).ConfigureAwait(false); + } + + var bytesToCopy = Math.Min(buffer.Length, _writeBuffer.AvailableLength); + + Debug.Assert(bytesToCopy > 0); + + buffer.Slice(0, bytesToCopy).CopyTo(_writeBuffer.AvailableMemory); + + _writeBuffer.Commit(bytesToCopy); + + buffer = buffer.Slice(bytesToCopy); + } + } + + /// + public override IAsyncResult BeginWrite(byte[] buffer, int offset, int count, AsyncCallback? callback, object? state) + { + return TaskToAsyncResult.Begin(WriteAsync(buffer, offset, count), callback, state); + } + + /// + public override void EndWrite(IAsyncResult asyncResult) + { + TaskToAsyncResult.End(asyncResult); + } + /// /// Writes the line to the shell. /// diff --git a/test/Renci.SshNet.Tests/Classes/ShellStreamTest_ReadExpect.cs b/test/Renci.SshNet.Tests/Classes/ShellStreamTest_ReadExpect.cs index b4b071707..c36169241 100644 --- a/test/Renci.SshNet.Tests/Classes/ShellStreamTest_ReadExpect.cs +++ b/test/Renci.SshNet.Tests/Classes/ShellStreamTest_ReadExpect.cs @@ -236,6 +236,21 @@ public void Read_AfterDispose_StillWorks() Assert.IsNull(_shellStream.ReadLine()); } + [TestMethod] + public async Task ReadAsyncDoesNotBlockWriteAsync() + { + byte[] buffer = new byte[16]; + Task readTask = _shellStream.ReadAsync(buffer, 0, buffer.Length); + + await _shellStream.WriteAsync("ls\n"u8.ToArray(), 0, 3); + + Assert.IsFalse(readTask.IsCompleted); + + _channelSessionStub.Receive("Directory.Build.props"u8.ToArray()); + + await readTask; + } + [TestMethod] public void Read_MultiByte() {