Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

io: wake pending writers on DuplexStream close #3756

Merged
merged 1 commit into from May 6, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
24 changes: 21 additions & 3 deletions tokio/src/io/util/mem.rs
Expand Up @@ -16,6 +16,14 @@ use std::{
/// that can be used as in-memory IO types. Writing to one of the pairs will
/// allow that data to be read from the other, and vice versa.
///
/// # Closing a `DuplexStream`
///
/// If one end of the `DuplexStream` channel is dropped, any pending reads on
/// the other side will continue to read data until the buffer is drained, then
/// they will signal EOF by returning 0 bytes. Any writes to the other side,
/// including pending ones (that are waiting for free space in the buffer) will
/// return `Err(BrokenPipe)` immediately.
///
/// # Example
///
/// ```
Expand Down Expand Up @@ -134,7 +142,8 @@ impl AsyncWrite for DuplexStream {
impl Drop for DuplexStream {
fn drop(&mut self) {
// notify the other side of the closure
self.write.lock().close();
self.write.lock().close_write();
self.read.lock().close_read();
}
}

Expand All @@ -151,12 +160,21 @@ impl Pipe {
}
}

fn close(&mut self) {
fn close_write(&mut self) {
self.is_closed = true;
// needs to notify any readers that no more data will come
if let Some(waker) = self.read_waker.take() {
waker.wake();
}
}

fn close_read(&mut self) {
self.is_closed = true;
// needs to notify any writers that they have to abort
if let Some(waker) = self.write_waker.take() {
waker.wake();
}
}
}

impl AsyncRead for Pipe {
Expand Down Expand Up @@ -217,7 +235,7 @@ impl AsyncWrite for Pipe {
mut self: Pin<&mut Self>,
_: &mut task::Context<'_>,
) -> Poll<std::io::Result<()>> {
self.close();
self.close_write();
Poll::Ready(Ok(()))
}
}
29 changes: 24 additions & 5 deletions tokio/tests/io_mem_stream.rs
Expand Up @@ -62,6 +62,25 @@ async fn disconnect() {
t2.await.unwrap();
}

#[tokio::test]
async fn disconnect_reader() {
let (a, mut b) = duplex(2);

let t1 = tokio::spawn(async move {
// this will block, as not all data fits into duplex
b.write_all(b"ping").await.unwrap_err();
});

let t2 = tokio::spawn(async move {
// here we drop the reader side, and we expect the writer in the other
// task to exit with an error
drop(a);
});

t2.await.unwrap();
t1.await.unwrap();
}

#[tokio::test]
async fn max_write_size() {
let (mut a, mut b) = duplex(32);
Expand All @@ -73,11 +92,11 @@ async fn max_write_size() {
assert_eq!(n, 4);
});

let t2 = tokio::spawn(async move {
let mut buf = [0u8; 4];
b.read_exact(&mut buf).await.unwrap();
});
let mut buf = [0u8; 4];
b.read_exact(&mut buf).await.unwrap();

t1.await.unwrap();
t2.await.unwrap();

// drop b only after task t1 finishes writing
drop(b);
}