From 9e779d5440a431a619d4156a1f39e243fe1099ff Mon Sep 17 00:00:00 2001 From: Robert Hague Date: Fri, 3 Oct 2025 13:24:25 +0100 Subject: [PATCH] Override WriteAsync in ShellStream ShellStream does not currently override the Read/Write async variants. They fall back to the base class implementations which run the sync variants on a thread pool thread, only allowing one call of either at a time in order to protect implementations that would break if Read/Write were called simultaneously. In ShellStream, reads and writes are independent so mutually excluding their use is unnecessary and can lead to effective deadlocks. We therefore override WriteAsync to get around this restriction. We do not override ReadAsync because the sync implementation does not lend itself well to async given the use of Monitor.Wait/Pulse. Note that while reading and writing simultaneously is allowed, it is not intended that ShellStream is used with multiple simultaneous reads or multiple simultaneous writes, so it is fine to keep the base one-at-a-time implementation on ReadAsync. Another note is that the new WriteAsync will be simple (synchronous) buffer copying in most cases, with a call to FlushAsync in others. We also do not override FlushAsync, so that will go onto a thread pool thread and potentially acquire some locks. But given that the current base implementation of WriteAsync does that unconditionally, it makes the new WriteAsync slightly better and certainly no worse than the current version. --- src/Renci.SshNet/ShellStream.cs | 51 +++++++++++++++++++ .../Classes/ShellStreamTest_ReadExpect.cs | 15 ++++++ 2 files changed, 66 insertions(+) diff --git a/src/Renci.SshNet/ShellStream.cs b/src/Renci.SshNet/ShellStream.cs index e3ef7ff07..3d41d00fd 100644 --- a/src/Renci.SshNet/ShellStream.cs +++ b/src/Renci.SshNet/ShellStream.cs @@ -891,6 +891,57 @@ public override void WriteByte(byte value) Write([value]); } + /// + public override Task WriteAsync(byte[] buffer, int offset, int count, CancellationToken cancellationToken) + { +#if !NET + ThrowHelper. +#endif + ValidateBufferArguments(buffer, offset, count); + + return WriteAsync(buffer.AsMemory(offset, count), cancellationToken).AsTask(); + } + +#if NET + /// + public override async ValueTask WriteAsync(ReadOnlyMemory buffer, CancellationToken cancellationToken = default) +#else + private async ValueTask WriteAsync(ReadOnlyMemory buffer, CancellationToken cancellationToken) +#endif + { + ThrowHelper.ThrowObjectDisposedIf(_disposed, this); + + while (!buffer.IsEmpty) + { + if (_writeBuffer.AvailableLength == 0) + { + await FlushAsync(cancellationToken).ConfigureAwait(false); + } + + var bytesToCopy = Math.Min(buffer.Length, _writeBuffer.AvailableLength); + + Debug.Assert(bytesToCopy > 0); + + buffer.Slice(0, bytesToCopy).CopyTo(_writeBuffer.AvailableMemory); + + _writeBuffer.Commit(bytesToCopy); + + buffer = buffer.Slice(bytesToCopy); + } + } + + /// + public override IAsyncResult BeginWrite(byte[] buffer, int offset, int count, AsyncCallback? callback, object? state) + { + return TaskToAsyncResult.Begin(WriteAsync(buffer, offset, count), callback, state); + } + + /// + public override void EndWrite(IAsyncResult asyncResult) + { + TaskToAsyncResult.End(asyncResult); + } + /// /// Writes the line to the shell. /// diff --git a/test/Renci.SshNet.Tests/Classes/ShellStreamTest_ReadExpect.cs b/test/Renci.SshNet.Tests/Classes/ShellStreamTest_ReadExpect.cs index b4b071707..c36169241 100644 --- a/test/Renci.SshNet.Tests/Classes/ShellStreamTest_ReadExpect.cs +++ b/test/Renci.SshNet.Tests/Classes/ShellStreamTest_ReadExpect.cs @@ -236,6 +236,21 @@ public void Read_AfterDispose_StillWorks() Assert.IsNull(_shellStream.ReadLine()); } + [TestMethod] + public async Task ReadAsyncDoesNotBlockWriteAsync() + { + byte[] buffer = new byte[16]; + Task readTask = _shellStream.ReadAsync(buffer, 0, buffer.Length); + + await _shellStream.WriteAsync("ls\n"u8.ToArray(), 0, 3); + + Assert.IsFalse(readTask.IsCompleted); + + _channelSessionStub.Receive("Directory.Build.props"u8.ToArray()); + + await readTask; + } + [TestMethod] public void Read_MultiByte() {