Skip to content

Commit

Permalink
io: wake pending writers on DuplexStream close
Browse files Browse the repository at this point in the history
Performing a write on a closed DuplexStream (i.e. other end is dropped)
results in an Err(BrokenPipe). However, if there is a writer waiting to be
awoken from a buffer-full condition, it would previously be ignored, and
thus stuck in suspended state, as no further reads could ever be made.

Split the Pipe::close routine into close_read and close_write, and perform
both in case one side of the DuplexStream is dropped. close_read will
notify any writers to wake up, which will then cause them to see the
updated is_closed flag and return an Err(BrokenPipe) immediately.

Test case 'disconnect_reader' is added to test the fixed behaviour, it
would previously get stuck indefinitely.

The 'max_write_size' test needs to be updated with a notify barrier, as
otherwise the read side is dropped immediately after performing the
read_exact, which now resulted in the second write being awoken and failing
with a BrokenPipe error.
  • Loading branch information
PiMaker committed May 6, 2021
1 parent 177522c commit 48c98f0
Show file tree
Hide file tree
Showing 2 changed files with 50 additions and 3 deletions.
24 changes: 21 additions & 3 deletions tokio/src/io/util/mem.rs
Expand Up @@ -16,6 +16,14 @@ use std::{
/// that can be used as in-memory IO types. Writing to one of the pairs will
/// allow that data to be read from the other, and vice versa.
///
/// # Closing a `DuplexStream`
///
/// If one end of the `DuplexStream` channel is dropped, any pending reads on
/// the other side will continue to read data until the buffer is drained, then
/// they will signal EOF by returning 0 bytes. Any writes to the other side,
/// including pending ones (that are waiting for free space in the buffer) will
/// return `Err(BrokenPipe)` immediately.
///
/// # Example
///
/// ```
Expand Down Expand Up @@ -134,7 +142,8 @@ impl AsyncWrite for DuplexStream {
impl Drop for DuplexStream {
fn drop(&mut self) {
// notify the other side of the closure
self.write.lock().close();
self.write.lock().close_write();
self.read.lock().close_read();
}
}

Expand All @@ -151,12 +160,21 @@ impl Pipe {
}
}

fn close(&mut self) {
fn close_write(&mut self) {
self.is_closed = true;
// needs to notify any readers that no more data will come
if let Some(waker) = self.read_waker.take() {
waker.wake();
}
}

fn close_read(&mut self) {
self.is_closed = true;
// needs to notify any writers that they have to abort
if let Some(waker) = self.write_waker.take() {
waker.wake();
}
}
}

impl AsyncRead for Pipe {
Expand Down Expand Up @@ -217,7 +235,7 @@ impl AsyncWrite for Pipe {
mut self: Pin<&mut Self>,
_: &mut task::Context<'_>,
) -> Poll<std::io::Result<()>> {
self.close();
self.close_write();
Poll::Ready(Ok(()))
}
}
29 changes: 29 additions & 0 deletions tokio/tests/io_mem_stream.rs
@@ -1,7 +1,9 @@
#![warn(rust_2018_idioms)]
#![cfg(feature = "full")]

use std::sync::Arc;
use tokio::io::{duplex, AsyncReadExt, AsyncWriteExt};
use tokio::sync::Notify;

#[tokio::test]
async fn ping_pong() {
Expand Down Expand Up @@ -62,20 +64,47 @@ async fn disconnect() {
t2.await.unwrap();
}

#[tokio::test]
async fn disconnect_reader() {
let (a, mut b) = duplex(2);

let t1 = tokio::spawn(async move {
// this will block, as not all data fits into duplex
b.write_all(b"ping").await.unwrap_err();
});

let t2 = tokio::spawn(async move {
// here we drop the reader side, and we expect the writer in the other
// task to exit with an error
let _moved = a;
});

// wait for drop first
t2.await.unwrap();
// then try and join writer task
t1.await.unwrap();
}

#[tokio::test]
async fn max_write_size() {
let (mut a, mut b) = duplex(32);

// needs a barrier to avoid droping b before we can perform the second write
let r_barrier = Arc::new(Notify::new());
let w_barrier = r_barrier.clone();

let t1 = tokio::spawn(async move {
let n = a.write(&[0u8; 64]).await.unwrap();
assert_eq!(n, 32);
let n = a.write(&[0u8; 64]).await.unwrap();
assert_eq!(n, 4);
w_barrier.notify_waiters();
});

let t2 = tokio::spawn(async move {
let mut buf = [0u8; 4];
b.read_exact(&mut buf).await.unwrap();
r_barrier.notified().await;
});

t1.await.unwrap();
Expand Down

0 comments on commit 48c98f0

Please sign in to comment.