diff --git a/bench/src/lib.rs b/bench/src/lib.rs index d30ac235e..081ec239b 100644 --- a/bench/src/lib.rs +++ b/bench/src/lib.rs @@ -138,7 +138,9 @@ pub async fn send_data_on_stream(stream: &mut quinn::SendStream, stream_size: u6 .context("failed sending data")?; } - stream.finish().await.context("failed finishing stream")?; + stream.finish().unwrap(); + // Wait for stream to close + _ = stream.stopped().await; Ok(()) } diff --git a/perf/src/bin/perf_client.rs b/perf/src/bin/perf_client.rs index 474d815ec..d081ae179 100644 --- a/perf/src/bin/perf_client.rs +++ b/perf/src/bin/perf_client.rs @@ -302,7 +302,7 @@ async fn request( let upload_start = Instant::now(); send.write_all(&download.to_be_bytes()).await?; if upload == 0 { - send.finish().await?; + send.finish().unwrap(); return Ok(()); } @@ -317,7 +317,9 @@ async fn request( send_stream_stats.on_bytes(chunk_len as usize); upload -= chunk_len; } - send.finish().await?; + send.finish().unwrap(); + // Wait for stream to close + _ = send.stopped().await; send_stream_stats.finish(upload_start.elapsed()); debug!("upload finished on {}", send.id()); diff --git a/quinn/benches/bench.rs b/quinn/benches/bench.rs index cabf33524..2a37ddc1e 100644 --- a/quinn/benches/bench.rs +++ b/quinn/benches/bench.rs @@ -54,7 +54,9 @@ fn send_data(bench: &mut Bencher, data: &'static [u8], concurrent_streams: usize handles.push(runtime.spawn(async move { let mut stream = client.open_uni().await.unwrap(); stream.write_all(data).await.unwrap(); - stream.finish().await.unwrap(); + stream.finish().unwrap(); + // Wait for stream to close + _ = stream.stopped().await; })); } diff --git a/quinn/examples/client.rs b/quinn/examples/client.rs index 729b0c9df..c037833a7 100644 --- a/quinn/examples/client.rs +++ b/quinn/examples/client.rs @@ -131,9 +131,7 @@ async fn run(options: Opt) -> Result<()> { send.write_all(request.as_bytes()) .await .map_err(|e| anyhow!("failed to send request: {}", e))?; - send.finish() - .await - .map_err(|e| anyhow!("failed to shutdown stream: {}", e))?; + send.finish().unwrap(); let response_start = Instant::now(); eprintln!("request sent at {:?}", response_start - start); let resp = recv diff --git a/quinn/examples/server.rs b/quinn/examples/server.rs index 82aaea9de..5f9e6f35d 100644 --- a/quinn/examples/server.rs +++ b/quinn/examples/server.rs @@ -232,9 +232,7 @@ async fn handle_request( .await .map_err(|e| anyhow!("failed to send response: {}", e))?; // Gracefully terminate the stream - send.finish() - .await - .map_err(|e| anyhow!("failed to shutdown stream: {}", e))?; + send.finish().unwrap(); info!("complete"); Ok(()) } diff --git a/quinn/src/connection.rs b/quinn/src/connection.rs index 3df0159b2..89c1733c5 100644 --- a/quinn/src/connection.rs +++ b/quinn/src/connection.rs @@ -21,7 +21,7 @@ use crate::{ mutex::Mutex, recv_stream::RecvStream, runtime::{AsyncTimer, AsyncUdpSocket, Runtime, UdpPoller}, - send_stream::{SendStream, WriteError}, + send_stream::SendStream, udp_transmit, ConnectionEvent, VarInt, }; use proto::{ @@ -856,7 +856,6 @@ impl ConnectionRef { endpoint_events, blocked_writers: FxHashMap::default(), blocked_readers: FxHashMap::default(), - finishing: FxHashMap::default(), stopped: FxHashMap::default(), error: None, ref_count: 0, @@ -936,7 +935,6 @@ pub(crate) struct State { endpoint_events: mpsc::UnboundedSender<(ConnectionHandle, EndpointEvent)>, pub(crate) blocked_writers: FxHashMap, pub(crate) blocked_readers: FxHashMap, - pub(crate) finishing: FxHashMap>>, pub(crate) stopped: FxHashMap, /// Always set to Some before the connection becomes drained pub(crate) error: Option, @@ -1102,21 +1100,14 @@ impl State { shared.stream_budget_available[dir as usize].notify_waiters(); } Stream(StreamEvent::Finished { id }) => { - if let Some(finishing) = self.finishing.remove(&id) { - // If the finishing stream was already dropped, there's nothing more to do. - let _ = finishing.send(None); - } if let Some(stopped) = self.stopped.remove(&id) { stopped.wake(); } } - Stream(StreamEvent::Stopped { id, error_code }) => { + Stream(StreamEvent::Stopped { id, .. }) => { if let Some(stopped) = self.stopped.remove(&id) { stopped.wake(); } - if let Some(finishing) = self.finishing.remove(&id) { - let _ = finishing.send(Some(WriteError::Stopped(error_code))); - } if let Some(writer) = self.blocked_writers.remove(&id) { writer.wake(); } @@ -1200,9 +1191,6 @@ impl State { shared.stream_incoming[Dir::Bi as usize].notify_waiters(); shared.datagram_received.notify_waiters(); shared.datagrams_unblocked.notify_waiters(); - for (_, x) in self.finishing.drain() { - let _ = x.send(Some(WriteError::ConnectionLost(reason.clone()))); - } if let Some(x) = self.on_connected.take() { let _ = x.send(false); } @@ -1285,8 +1273,20 @@ pub struct UnknownStream { _private: (), } +impl UnknownStream { + pub(crate) fn new() -> Self { + Self { _private: () } + } +} + impl From for UnknownStream { fn from(_: proto::UnknownStream) -> Self { Self { _private: () } } } + +impl From for io::Error { + fn from(x: UnknownStream) -> Self { + Self::new(io::ErrorKind::NotConnected, x) + } +} diff --git a/quinn/src/recv_stream.rs b/quinn/src/recv_stream.rs index bd9ff84b7..ace12035d 100644 --- a/quinn/src/recv_stream.rs +++ b/quinn/src/recv_stream.rs @@ -45,50 +45,6 @@ use crate::{ /// bidirectional stream 1, the first stream yielded by [`Connection::accept_bi`] on the receiver /// will be bidirectional stream 0. /// -/// ## Unexpected [`WriteError::Stopped`] in sender -/// -/// When a stream is expected to be closed gracefully the sender should call -/// [`SendStream::finish`]. However there is no guarantee the connected [`RecvStream`] will -/// receive the "finished" notification in the same QUIC frame as the last frame which -/// carried data. -/// -/// Even if the application layer logic already knows it read all the data because it does -/// its own framing, it should still read until it reaches the end of the [`RecvStream`]. -/// Otherwise it risks inadvertently calling [`RecvStream::stop`] if it drops the stream. -/// And calling [`RecvStream::stop`] could result in the connected [`SendStream::finish`] -/// call failing with a [`WriteError::Stopped`] error. -/// -/// For example if exactly 10 bytes are to be read, you still need to explicitly read the -/// end of the stream: -/// -/// ```no_run -/// # use quinn::{SendStream, RecvStream}; -/// # async fn func( -/// # mut send_stream: SendStream, -/// # mut recv_stream: RecvStream, -/// # ) -> anyhow::Result<()> -/// # { -/// // In the sending task -/// send_stream.write(&b"0123456789"[..]).await?; -/// send_stream.finish().await?; -/// -/// // In the receiving task -/// let mut buf = [0u8; 10]; -/// let data = recv_stream.read_exact(&mut buf).await?; -/// if recv_stream.read_to_end(0).await.is_err() { -/// // Discard unexpected data and notify the peer to stop sending it -/// let _ = recv_stream.stop(0u8.into()); -/// } -/// # Ok(()) -/// # } -/// ``` -/// -/// An alternative approach, used in HTTP/3, is to specify a particular error code used with `stop` -/// that indicates graceful receiver-initiated stream shutdown, rather than a true error condition. -/// -/// [`RecvStream::read_chunk`] could be used instead which does not take ownership and -/// allows using an explicit call to [`RecvStream::stop`] with a custom error code. -/// /// [`ReadError`]: crate::ReadError /// [`stop()`]: RecvStream::stop /// [`SendStream::finish`]: crate::SendStream::finish diff --git a/quinn/src/send_stream.rs b/quinn/src/send_stream.rs index b953678c0..96a200fbe 100644 --- a/quinn/src/send_stream.rs +++ b/quinn/src/send_stream.rs @@ -8,7 +8,6 @@ use std::{ use bytes::Bytes; use proto::{ConnectionError, FinishError, StreamId, Written}; use thiserror::Error; -use tokio::sync::oneshot; use crate::{ connection::{ConnectionRef, UnknownStream}, @@ -17,8 +16,9 @@ use crate::{ /// A stream that can only be used to send data /// -/// If dropped, streams that haven't been explicitly [`reset()`] will continue to (re)transmit -/// previously written data until it has been fully acknowledged or the connection is closed. +/// If dropped, streams that haven't been explicitly [`reset()`] will be implicitly [`finish()`]ed, +/// continuing to (re)transmit previously written data until it has been fully acknowledged or the +/// connection is closed. /// /// # Cancellation /// @@ -29,12 +29,12 @@ use crate::{ /// cancel-safe. /// /// [`reset()`]: SendStream::reset +/// [`finish()`]: SendStream::finish #[derive(Debug)] pub struct SendStream { conn: ConnectionRef, stream: StreamId, is_0rtt: bool, - finishing: Option>>, } impl SendStream { @@ -43,7 +43,6 @@ impl SendStream { conn, stream, is_0rtt, - finishing: None, } } @@ -130,56 +129,28 @@ impl SendStream { Poll::Ready(Ok(result)) } - /// Shut down the send stream gracefully. + /// Notify the peer that no more data will ever be written to this stream /// - /// No new data may be written after calling this method. Completes when the peer has - /// acknowledged all sent data, retransmitting data as needed. - pub async fn finish(&mut self) -> Result<(), WriteError> { - Finish { stream: self }.await - } - - /// Attempt to shut down the send stream gracefully. + /// It is an error to write to a [`SendStream`] after `finish()`ing it. [`reset()`](Self::reset) + /// may still be called after `finish` to abandon transmission of any stream data that might + /// still be buffered. /// - /// No new data may be written after calling this method. Completes when the peer has - /// acknowledged all sent data, retransmitting data as needed. - pub fn poll_finish(&mut self, cx: &mut Context) -> Poll> { - let mut conn = self.conn.state.lock("poll_finish"); - if self.is_0rtt { - conn.check_0rtt() - .map_err(|()| WriteError::ZeroRttRejected)?; - } - if self.finishing.is_none() { - conn.inner - .send_stream(self.stream) - .finish() - .map_err(|e| match e { - FinishError::UnknownStream => WriteError::UnknownStream, - FinishError::Stopped(error_code) => WriteError::Stopped(error_code), - })?; - let (send, recv) = oneshot::channel(); - self.finishing = Some(recv); - conn.finishing.insert(self.stream, send); - conn.wake(); - } - match Pin::new(self.finishing.as_mut().unwrap()) - .poll(cx) - .map(|x| x.unwrap()) - { - Poll::Ready(x) => { - self.finishing = None; - Poll::Ready(x.map_or(Ok(()), Err)) - } - Poll::Pending => { - // To ensure that finished streams can be detected even after the connection is - // closed, we must only check for connection errors after determining that the - // stream has not yet been finished. Note that this relies on holding the connection - // lock so that it is impossible for the stream to become finished between the above - // poll call and this check. - if let Some(ref x) = conn.error { - return Poll::Ready(Err(WriteError::ConnectionLost(x.clone()))); - } - Poll::Pending + /// To wait for the peer to receive all buffered stream data, see [`stopped()`](Self::stopped). + /// + /// May fail if [`finish()`](Self::finish) or [`reset()`](Self::reset) was previously + /// called. This error is harmless and serves only to indicate that the caller may have + /// incorrect assumptions about the stream's state. + pub fn finish(&mut self) -> Result<(), UnknownStream> { + let mut conn = self.conn.state.lock("finish"); + match conn.inner.send_stream(self.stream).finish() { + Ok(()) => { + conn.wake(); + Ok(()) } + Err(FinishError::UnknownStream) => Err(UnknownStream::new()), + // Harmless. If the application needs to know about stopped streams at this point, it + // should call `stopped`. + Err(FinishError::Stopped(_)) => Ok(()), } } @@ -188,6 +159,10 @@ impl SendStream { /// No new data can be written after calling this method. Locally buffered data is dropped, and /// previously transmitted data will no longer be retransmitted if lost. If an attempt has /// already been made to finish the stream, the peer may still receive all written data. + /// + /// May fail if [`finish()`](Self::finish) or [`reset()`](Self::reset) was previously + /// called. This error is harmless and serves only to indicate that the caller may have + /// incorrect assumptions about the stream's state. pub fn reset(&mut self, error_code: VarInt) -> Result<(), UnknownStream> { let mut conn = self.conn.state.lock("SendStream::reset"); if self.is_0rtt && conn.check_0rtt().is_err() { @@ -272,8 +247,8 @@ impl futures_io::AsyncWrite for SendStream { Poll::Ready(Ok(())) } - fn poll_close(self: Pin<&mut Self>, cx: &mut Context) -> Poll> { - self.get_mut().poll_finish(cx).map_err(Into::into) + fn poll_close(self: Pin<&mut Self>, _cx: &mut Context) -> Poll> { + Poll::Ready(self.get_mut().finish().map_err(Into::into)) } } @@ -291,8 +266,8 @@ impl tokio::io::AsyncWrite for SendStream { Poll::Ready(Ok(())) } - fn poll_shutdown(self: Pin<&mut Self>, cx: &mut Context) -> Poll> { - self.get_mut().poll_finish(cx).map_err(Into::into) + fn poll_shutdown(self: Pin<&mut Self>, _cx: &mut Context) -> Poll> { + Poll::Ready(self.get_mut().finish().map_err(Into::into)) } } @@ -301,42 +276,25 @@ impl Drop for SendStream { let mut conn = self.conn.state.lock("SendStream::drop"); // clean up any previously registered wakers - conn.finishing.remove(&self.stream); conn.stopped.remove(&self.stream); conn.blocked_writers.remove(&self.stream); if conn.error.is_some() || (self.is_0rtt && conn.check_0rtt().is_err()) { return; } - if self.finishing.is_none() { - match conn.inner.send_stream(self.stream).finish() { - Ok(()) => conn.wake(), - Err(FinishError::Stopped(reason)) => { - if conn.inner.send_stream(self.stream).reset(reason).is_ok() { - conn.wake(); - } + match conn.inner.send_stream(self.stream).finish() { + Ok(()) => conn.wake(), + Err(FinishError::Stopped(reason)) => { + if conn.inner.send_stream(self.stream).reset(reason).is_ok() { + conn.wake(); } - // Already finished or reset, which is fine. - Err(FinishError::UnknownStream) => {} } + // Already finished or reset, which is fine. + Err(FinishError::UnknownStream) => {} } } } -/// Future produced by `SendStream::finish` -#[must_use = "futures/streams/sinks do nothing unless you `.await` or poll them"] -struct Finish<'a> { - stream: &'a mut SendStream, -} - -impl Future for Finish<'_> { - type Output = Result<(), WriteError>; - - fn poll(self: Pin<&mut Self>, cx: &mut Context) -> Poll { - self.get_mut().stream.poll_finish(cx) - } -} - /// Future produced by `SendStream::stopped` #[must_use = "futures/streams/sinks do nothing unless you `.await` or poll them"] struct Stopped<'a> { @@ -483,6 +441,22 @@ pub enum WriteError { ZeroRttRejected, } +impl From for WriteError { + #[inline] + fn from(_: UnknownStream) -> Self { + Self::UnknownStream + } +} + +impl From for WriteError { + fn from(x: StoppedError) -> Self { + match x { + StoppedError::ConnectionLost(e) => Self::ConnectionLost(e), + StoppedError::ZeroRttRejected => Self::ZeroRttRejected, + } + } +} + /// Errors that arise while monitoring for a send stream stop from the peer #[derive(Debug, Error, Clone, PartialEq, Eq)] pub enum StoppedError { diff --git a/quinn/src/tests.rs b/quinn/src/tests.rs index a268d5599..ee242e7f8 100755 --- a/quinn/src/tests.rs +++ b/quinn/src/tests.rs @@ -145,7 +145,9 @@ fn read_after_close() { .expect("connection"); let mut s = new_conn.open_uni().await.unwrap(); s.write_all(MSG).await.unwrap(); - s.finish().await.unwrap(); + s.finish().unwrap(); + // Wait for the stream to be closed, one way or another. + _ = s.stopped().await; }); runtime.block_on(async move { let new_conn = endpoint @@ -317,13 +319,13 @@ async fn zero_rtt() { info!("sending 0.5-RTT"); let mut s = connection.open_uni().await.expect("open_uni"); s.write_all(MSG0).await.expect("write"); - s.finish().await.expect("finish"); + s.finish().unwrap(); established.await; info!("sending 1-RTT"); let mut s = connection.open_uni().await.expect("open_uni"); s.write_all(MSG1).await.expect("write"); // The peer might close the connection before ACKing - let _ = s.finish().await; + let _ = s.finish(); } }); @@ -367,7 +369,7 @@ async fn zero_rtt() { let mut s = c.open_uni().await.expect("0-RTT open uni"); info!("sending 0-RTT"); s.write_all(MSG0).await.expect("0-RTT write"); - s.finish().await.expect("0-RTT finish"); + s.finish().unwrap(); }); let mut stream = connection.accept_uni().await.expect("incoming streams"); @@ -567,7 +569,7 @@ fn run_echo(args: EchoArgs) { let send_task = async { send.write_all(&msg).await.expect("write"); - send.finish().await.expect("finish"); + send.finish().unwrap(); }; let recv_task = async { recv.read_to_end(usize::max_value()).await.expect("read") }; @@ -620,7 +622,7 @@ async fn echo((mut send, mut recv): (SendStream, RecvStream)) { } } - let _ = send.finish().await; + let _ = send.finish(); } fn gen_data(size: usize, seed: u64) -> Vec { @@ -706,7 +708,9 @@ async fn rebind_recv() { write_recv.notified().await; let mut stream = connection.open_uni().await.unwrap(); stream.write_all(MSG).await.unwrap(); - stream.finish().await.unwrap(); + stream.finish().unwrap(); + // Wait for the stream to be closed, one way or another. + _ = stream.stopped().await; }); let connection = { diff --git a/quinn/tests/many_connections.rs b/quinn/tests/many_connections.rs index ff7da42d2..7afae35ff 100644 --- a/quinn/tests/many_connections.rs +++ b/quinn/tests/many_connections.rs @@ -6,7 +6,7 @@ use std::{ }; use crc::Crc; -use quinn::{ConnectionError, ReadError, TransportConfig, WriteError}; +use quinn::{ConnectionError, ReadError, StoppedError, TransportConfig, WriteError}; use rand::{self, RngCore}; use rustls::pki_types::{CertificateDer, PrivatePkcs8KeyDer}; use tokio::runtime::Builder; @@ -117,11 +117,12 @@ async fn read_from_peer(mut stream: quinn::RecvStream) -> Result<(), quinn::Conn async fn write_to_peer(conn: quinn::Connection, data: Vec) -> Result<(), WriteError> { let mut s = conn.open_uni().await.map_err(WriteError::ConnectionLost)?; s.write_all(&data).await?; - // Suppress finish errors, since the peer may close before ACKing - match s.finish().await { - Ok(()) => Ok(()), - Err(WriteError::ConnectionLost(ConnectionError::ApplicationClosed { .. })) => Ok(()), - Err(e) => Err(e), + s.finish().unwrap(); + // Wait for the stream to be fully received + match s.stopped().await { + Ok(_) => Ok(()), + Err(StoppedError::ConnectionLost(ConnectionError::ApplicationClosed { .. })) => Ok(()), + Err(e) => Err(e.into()), } }