Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Unify connection and endpoint drivers #1219

Closed
wants to merge 5 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions quinn/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@ rustc-hash = "1.1"
pin-project-lite = "0.2"
proto = { package = "quinn-proto", path = "../quinn-proto", version = "0.9", default-features = false }
rustls = { version = "0.20.3", default-features = false, features = ["quic"], optional = true }
slab = "0.4"
thiserror = "1.0.21"
tracing = "0.1.10"
tokio = { version = "1.13.0", features = ["sync"] }
Expand All @@ -53,6 +54,7 @@ anyhow = "1.0.22"
crc = "3"
bencher = "0.1.5"
directories-next = "2"
proptest = "=1.0.0" # Pinned for MSRV
rand = "0.8"
rcgen = "0.10.0"
rustls-pemfile = "1.0.0"
Expand Down
243 changes: 42 additions & 201 deletions quinn/src/connection.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
use std::{
any::Any,
collections::VecDeque,
fmt,
future::Future,
net::{IpAddr, SocketAddr},
Expand All @@ -9,21 +10,20 @@ use std::{
time::{Duration, Instant},
};

use crate::runtime::{AsyncTimer, Runtime};
use bytes::Bytes;
use pin_project_lite::pin_project;
use proto::{ConnectionError, ConnectionHandle, ConnectionStats, Dir, StreamEvent, StreamId};
use rustc_hash::FxHashMap;
use thiserror::Error;
use tokio::sync::{futures::Notified, mpsc, oneshot, Notify};
use tracing::debug_span;
use udp::UdpState;

use crate::{
delay_queue::Timer,
mutex::Mutex,
recv_stream::RecvStream,
send_stream::{SendStream, WriteError},
ConnectionEvent, EndpointEvent, VarInt,
VarInt,
};
use proto::congestion::Controller;

Expand All @@ -38,33 +38,28 @@ pub struct Connecting {

impl Connecting {
pub(crate) fn new(
dirty: mpsc::UnboundedSender<ConnectionHandle>,
handle: ConnectionHandle,
conn: proto::Connection,
endpoint_events: mpsc::UnboundedSender<(ConnectionHandle, EndpointEvent)>,
conn_events: mpsc::UnboundedReceiver<ConnectionEvent>,
udp_state: Arc<UdpState>,
runtime: Arc<dyn Runtime>,
) -> Connecting {
) -> (Connecting, ConnectionRef) {
let (on_handshake_data_send, on_handshake_data_recv) = oneshot::channel();
let (on_connected_send, on_connected_recv) = oneshot::channel();
let conn = ConnectionRef::new(
handle,
conn,
endpoint_events,
conn_events,
dirty,
on_handshake_data_send,
on_connected_send,
udp_state,
runtime.clone(),
);

runtime.spawn(Box::pin(ConnectionDriver(conn.clone())));

Connecting {
conn: Some(conn),
connected: on_connected_recv,
handshake_data_ready: Some(on_handshake_data_recv),
}
(
Connecting {
conn: Some(conn.clone()),
connected: on_connected_recv,
handshake_data_ready: Some(on_handshake_data_recv),
},
conn,
)
}

/// Convert into a 0-RTT or 0.5-RTT connection at the cost of weakened security
Expand Down Expand Up @@ -202,57 +197,6 @@ impl Future for ZeroRttAccepted {
}
}

/// A future that drives protocol logic for a connection
///
/// This future handles the protocol logic for a single connection, routing events from the
/// `Connection` API object to the `Endpoint` task and the related stream-related interfaces.
/// It also keeps track of outstanding timeouts for the `Connection`.
///
/// If the connection encounters an error condition, this future will yield an error. It will
/// terminate (yielding `Ok(())`) if the connection was closed without error. Unlike other
/// connection-related futures, this waits for the draining period to complete to ensure that
/// packets still in flight from the peer are handled gracefully.
#[must_use = "connection drivers must be spawned for their connections to function"]
#[derive(Debug)]
struct ConnectionDriver(ConnectionRef);

impl Future for ConnectionDriver {
type Output = ();

#[allow(unused_mut)] // MSRV
fn poll(mut self: Pin<&mut Self>, cx: &mut Context) -> Poll<Self::Output> {
let conn = &mut *self.0.state.lock("poll");

let span = debug_span!("drive", id = conn.handle.0);
let _guard = span.enter();

if let Err(e) = conn.process_conn_events(&self.0.shared, cx) {
conn.terminate(e, &self.0.shared);
return Poll::Ready(());
}
let mut keep_going = conn.drive_transmit();
// If a timer expires, there might be more to transmit. When we transmit something, we
// might need to reset a timer. Hence, we must loop until neither happens.
keep_going |= conn.drive_timer(cx);
conn.forward_endpoint_events();
conn.forward_app_events(&self.0.shared);

if !conn.inner.is_drained() {
if keep_going {
// If the connection hasn't processed all tasks, schedule it again
cx.waker().wake_by_ref();
} else {
conn.driver = Some(cx.waker().clone());
}
return Poll::Pending;
}
if conn.error.is_none() {
unreachable!("drained connections always have an error");
}
Poll::Ready(())
}
}

/// A QUIC connection.
///
/// If all references to a connection (including every clone of the `Connection` handle, streams of
Expand Down Expand Up @@ -745,33 +689,29 @@ impl ConnectionRef {
fn new(
handle: ConnectionHandle,
conn: proto::Connection,
endpoint_events: mpsc::UnboundedSender<(ConnectionHandle, EndpointEvent)>,
conn_events: mpsc::UnboundedReceiver<ConnectionEvent>,
dirty: mpsc::UnboundedSender<ConnectionHandle>,
on_handshake_data: oneshot::Sender<()>,
on_connected: oneshot::Sender<bool>,
udp_state: Arc<UdpState>,
runtime: Arc<dyn Runtime>,
) -> Self {
let _ = dirty.send(handle);
Self(Arc::new(ConnectionInner {
state: Mutex::new(State {
inner: conn,
driver: None,
handle,
span: debug_span!("connection", id = handle.0),
is_dirty: true,
dirty,
on_handshake_data: Some(on_handshake_data),
on_connected: Some(on_connected),
connected: false,
timer: None,
timer_handle: None,
timer_deadline: None,
conn_events,
endpoint_events,
blocked_writers: FxHashMap::default(),
blocked_readers: FxHashMap::default(),
finishing: FxHashMap::default(),
stopped: FxHashMap::default(),
error: None,
ref_count: 0,
udp_state,
runtime,
}),
shared: Shared::default(),
}))
Expand Down Expand Up @@ -831,15 +771,18 @@ pub(crate) struct Shared {

pub(crate) struct State {
pub(crate) inner: proto::Connection,
driver: Option<Waker>,
handle: ConnectionHandle,
pub(crate) span: tracing::Span,
/// Whether `handle` has been sent to `dirty` since the last time this connection was driven by
/// the endpoint. Ensures `dirty`'s size stays bounded regardless of activity.
pub(crate) is_dirty: bool,
/// `handle` is sent here when `wake` is called, prompting the endpoint to drive the connection
dirty: mpsc::UnboundedSender<ConnectionHandle>,
on_handshake_data: Option<oneshot::Sender<()>>,
on_connected: Option<oneshot::Sender<bool>>,
connected: bool,
timer: Option<Pin<Box<dyn AsyncTimer>>>,
timer_deadline: Option<Instant>,
conn_events: mpsc::UnboundedReceiver<ConnectionEvent>,
endpoint_events: mpsc::UnboundedSender<(ConnectionHandle, EndpointEvent)>,
pub(crate) timer_handle: Option<Timer>,
pub(crate) timer_deadline: Option<Instant>,
pub(crate) blocked_writers: FxHashMap<StreamId, Waker>,
pub(crate) blocked_readers: FxHashMap<StreamId, Waker>,
pub(crate) finishing: FxHashMap<StreamId, oneshot::Sender<Option<WriteError>>>,
Expand All @@ -848,26 +791,23 @@ pub(crate) struct State {
pub(crate) error: Option<ConnectionError>,
/// Number of live handles that can be used to initiate or handle I/O; excludes the driver
ref_count: usize,
udp_state: Arc<UdpState>,
runtime: Arc<dyn Runtime>,
}

impl State {
fn drive_transmit(&mut self) -> bool {
pub(crate) fn drive_transmit(
&mut self,
out: &mut VecDeque<proto::Transmit>,
max_datagrams: usize,
) -> bool {
let now = Instant::now();
let mut transmits = 0;

let max_datagrams = self.udp_state.max_gso_segments();

while let Some(t) = self.inner.poll_transmit(now, max_datagrams) {
transmits += match t.segment_size {
None => 1,
Some(s) => (t.contents.len() + s - 1) / s, // round up
};
// If the endpoint driver is gone, noop.
let _ = self
.endpoint_events
.send((self.handle, EndpointEvent::Transmit(t)));
out.push_back(t);

if transmits >= MAX_TRANSMIT_DATAGRAMS {
// TODO: What isn't ideal here yet is that if we don't poll all
Expand All @@ -881,47 +821,7 @@ impl State {
false
}

fn forward_endpoint_events(&mut self) {
while let Some(event) = self.inner.poll_endpoint_events() {
// If the endpoint driver is gone, noop.
let _ = self
.endpoint_events
.send((self.handle, EndpointEvent::Proto(event)));
}
}

/// If this returns `Err`, the endpoint is dead, so the driver should exit immediately.
fn process_conn_events(
&mut self,
shared: &Shared,
cx: &mut Context,
) -> Result<(), ConnectionError> {
loop {
match self.conn_events.poll_recv(cx) {
Poll::Ready(Some(ConnectionEvent::Ping)) => {
self.inner.ping();
}
Poll::Ready(Some(ConnectionEvent::Proto(event))) => {
self.inner.handle_event(event);
}
Poll::Ready(Some(ConnectionEvent::Close { reason, error_code })) => {
self.close(error_code, reason, shared);
}
Poll::Ready(None) => {
return Err(ConnectionError::TransportError(proto::TransportError {
code: proto::TransportErrorCode::INTERNAL_ERROR,
frame: None,
reason: "endpoint driver future was dropped".to_string(),
}));
}
Poll::Pending => {
return Ok(());
}
}
}
}

fn forward_app_events(&mut self, shared: &Shared) {
pub(crate) fn forward_app_events(&mut self, shared: &Shared) {
while let Some(event) = self.inner.poll() {
use proto::Event::*;
match event {
Expand Down Expand Up @@ -987,61 +887,14 @@ impl State {
}
}

fn drive_timer(&mut self, cx: &mut Context) -> bool {
// Check whether we need to (re)set the timer. If so, we must poll again to ensure the
// timer is registered with the runtime (and check whether it's already
// expired).
match self.inner.poll_timeout() {
Some(deadline) => {
if let Some(delay) = &mut self.timer {
// There is no need to reset the tokio timer if the deadline
// did not change
if self
.timer_deadline
.map(|current_deadline| current_deadline != deadline)
.unwrap_or(true)
{
delay.as_mut().reset(deadline);
}
} else {
self.timer = Some(self.runtime.new_timer(deadline));
}
// Store the actual expiration time of the timer
self.timer_deadline = Some(deadline);
}
None => {
self.timer_deadline = None;
return false;
}
}

if self.timer_deadline.is_none() {
return false;
}

let delay = self
.timer
.as_mut()
.expect("timer must exist in this state")
.as_mut();
if delay.poll(cx).is_pending() {
// Since there wasn't a timeout event, there is nothing new
// for the connection to do
return false;
}

// A timer expired, so the caller needs to check for
// new transmits, which might cause new timers to be set.
self.inner.handle_timeout(Instant::now());
self.timer_deadline = None;
true
}

/// Wake up a blocked `Driver` task to process I/O
/// Wake up endpoint to process I/O by marking it as "dirty" for the endpoint
pub(crate) fn wake(&mut self) {
if let Some(x) = self.driver.take() {
x.wake();
if self.is_dirty {
return;
}
self.is_dirty = true;
// Take no action if the endpoint is gone
let _ = self.dirty.send(self.handle);
}

/// Used to wake up all blocked futures when the connection becomes closed for any reason
Expand Down Expand Up @@ -1073,7 +926,7 @@ impl State {
shared.closed.notify_waiters();
}

fn close(&mut self, error_code: VarInt, reason: Bytes, shared: &Shared) {
pub fn close(&mut self, error_code: VarInt, reason: Bytes, shared: &Shared) {
self.inner.close(Instant::now(), error_code, reason);
self.terminate(ConnectionError::LocallyClosed, shared);
self.wake();
Expand All @@ -1096,18 +949,6 @@ impl State {
}
}

impl Drop for State {
fn drop(&mut self) {
if !self.inner.is_drained() {
// Ensure the endpoint can tidy up
let _ = self.endpoint_events.send((
self.handle,
EndpointEvent::Proto(proto::EndpointEvent::drained()),
));
}
}
}

impl fmt::Debug for State {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
f.debug_struct("State").field("inner", &self.inner).finish()
Expand Down
Loading