Skip to content

Commit

Permalink
Merge pull request from GHSA-2qph-qpvm-2qf7
Browse files Browse the repository at this point in the history
Continue accepting incoming connections if no TLS connection is ready.
  • Loading branch information
tmccombs committed Mar 15, 2024
2 parents 6c57dea + d1769ec commit d5a7655
Show file tree
Hide file tree
Showing 3 changed files with 52 additions and 32 deletions.
2 changes: 1 addition & 1 deletion Cargo.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
[package]
name = "tls-listener"
description = "wrap incoming Stream of connections in TLS"
version = "0.9.1"
version = "0.10.0"
authors = ["Thayne McCombs <astrothayne@gmail.com>"]
repository = "https://github.com/tmccombs/tls-listener"
edition = "2018"
Expand Down
3 changes: 2 additions & 1 deletion examples/http-change-certificate.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ use hyper::service::service_fn;
use hyper::{body::Body, Request, Response};
use hyper_util::rt::tokio::TokioIo;
use std::convert::Infallible;
use std::num::NonZeroUsize;
use std::sync::atomic::{AtomicU64, Ordering};
use std::sync::Arc;
use tokio::net::TcpListener;
Expand All @@ -22,7 +23,7 @@ async fn main() {
let counter = Arc::new(AtomicU64::new(0));

let mut listener = tls_listener::builder(tls_acceptor())
.max_handshakes(10)
.accept_batch_size(NonZeroUsize::new(10).unwrap())
.listen(TcpListener::bind(addr).await.expect("Failed to bind port"));

let (tx, mut rx) = mpsc::channel::<Acceptor>(1);
Expand Down
79 changes: 49 additions & 30 deletions src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ use pin_project_lite::pin_project;
pub use spawning_handshake::SpawningHandshakes;
use std::fmt::Debug;
use std::future::{poll_fn, Future};
use std::num::NonZeroUsize;
use std::pin::Pin;
use std::task::{ready, Context, Poll};
use std::time::Duration;
Expand All @@ -38,8 +39,8 @@ mod spawning_handshake;
#[cfg(feature = "tokio-net")]
mod net;

/// Default number of concurrent handshakes
pub const DEFAULT_MAX_HANDSHAKES: usize = 64;
/// Default number of connections to accept in a batch before trying to
pub const DEFAULT_ACCEPT_BATCH_SIZE: NonZeroUsize = unsafe { NonZeroUsize::new_unchecked(64) };
/// Default timeout for the TLS handshake.
pub const DEFAULT_HANDSHAKE_TIMEOUT: Duration = Duration::from_secs(10);

Expand Down Expand Up @@ -112,7 +113,7 @@ pin_project! {
listener: A,
tls: T,
waiting: FuturesUnordered<Waiting<A, T>>,
max_handshakes: usize,
accept_batch_size: NonZeroUsize,
timeout: Duration,
}
}
Expand All @@ -121,7 +122,7 @@ pin_project! {
#[derive(Clone)]
pub struct Builder<T> {
tls: T,
max_handshakes: usize,
accept_batch_size: NonZeroUsize,
handshake_timeout: Duration,
}

Expand Down Expand Up @@ -182,26 +183,36 @@ where
pub fn poll_accept(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<<Self as Stream>::Item> {
let mut this = self.project();

while this.waiting.len() < *this.max_handshakes {
match this.listener.as_mut().poll_accept(cx) {
Poll::Pending => break,
Poll::Ready(Ok((conn, addr))) => {
this.waiting.push(Waiting {
inner: timeout(*this.timeout, this.tls.accept(conn)),
peer_addr: Some(addr),
});
}
Poll::Ready(Err(e)) => {
return Poll::Ready(Err(Error::ListenerError(e)));
loop {
let mut empty_listener = false;
for _ in 0..this.accept_batch_size.get() {
match this.listener.as_mut().poll_accept(cx) {
Poll::Pending => {
empty_listener = true;
break;
}
Poll::Ready(Ok((conn, addr))) => {
this.waiting.push(Waiting {
inner: timeout(*this.timeout, this.tls.accept(conn)),
peer_addr: Some(addr),
});
}
Poll::Ready(Err(e)) => {
return Poll::Ready(Err(Error::ListenerError(e)));
}
}
}
}

match this.waiting.poll_next_unpin(cx) {
Poll::Ready(Some(result)) => Poll::Ready(result),
// If we don't have anything waiting yet,
// then we are still pending,
Poll::Ready(None) | Poll::Pending => Poll::Pending,
match this.waiting.poll_next_unpin(cx) {
Poll::Ready(Some(result)) => return Poll::Ready(result),
// If we don't have anything waiting yet,
// then we are still pending,
Poll::Ready(None) | Poll::Pending => {
if empty_listener {
return Poll::Pending;
}
}
}
}
}

Expand Down Expand Up @@ -318,15 +329,19 @@ where
}

impl<T> Builder<T> {
/// Set the maximum number of concurrent handshakes.
/// Set the size of batches of incoming connections to accept at once
///
/// When polling for a new connection, the `TlsListener` will first check
/// for incomming connections on the listener that need to start a TLS handshake.
/// This specifies the maximum number of connections it will accept before seeing if any
/// TLS connections are ready.
///
/// At most `max` handshakes will be concurrently processed. If that limit is
/// reached, the `TlsListener` will stop polling the underlying listener until a
/// handshake completes and the encrypted stream has been returned.
/// Having a limit for this ensures that ready TLS conections aren't starved if there are a
/// large number of incoming connections.
///
/// Defaults to `DEFAULT_MAX_HANDSHAKES`.
pub fn max_handshakes(&mut self, max: usize) -> &mut Self {
self.max_handshakes = max;
/// Defaults to `DEFAULT_ACCEPT_BATCH_SIZE`.
pub fn accept_batch_size(&mut self, size: NonZeroUsize) -> &mut Self {
self.accept_batch_size = size;
self
}

Expand All @@ -335,6 +350,10 @@ impl<T> Builder<T> {
/// If a timeout takes longer than `timeout`, then the handshake will be
/// aborted and the underlying connection will be dropped.
///
/// The default is fairly conservative, to avoid dropping connections. It is
/// recommended that you adjust this to meet the specific needs of your use case
/// in production deployments.
///
/// Defaults to `DEFAULT_HANDSHAKE_TIMEOUT`.
pub fn handshake_timeout(&mut self, timeout: Duration) -> &mut Self {
self.handshake_timeout = timeout;
Expand All @@ -354,7 +373,7 @@ impl<T> Builder<T> {
listener,
tls: self.tls.clone(),
waiting: FuturesUnordered::new(),
max_handshakes: self.max_handshakes,
accept_batch_size: self.accept_batch_size,
timeout: self.handshake_timeout,
}
}
Expand Down Expand Up @@ -382,7 +401,7 @@ impl<LE: std::error::Error, TE: std::error::Error, A> Error<LE, TE, A> {
pub fn builder<T>(tls: T) -> Builder<T> {
Builder {
tls,
max_handshakes: DEFAULT_MAX_HANDSHAKES,
accept_batch_size: DEFAULT_ACCEPT_BATCH_SIZE,
handshake_timeout: DEFAULT_HANDSHAKE_TIMEOUT,
}
}
Expand Down

0 comments on commit d5a7655

Please sign in to comment.