Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

#37, but also report the remote address #38

Closed
wants to merge 2 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions examples/echo-threads.rs
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ mod tls_config;
use tls_config::tls_acceptor;

#[inline]
async fn handle_stream(stream: TlsStream<TcpStream>) {
async fn handle_stream(stream: TlsStream<TcpStream>, _remote_addr: SocketAddr) {
let (mut reader, mut writer) = split(stream);
match copy(&mut reader, &mut writer).await {
Ok(cnt) => eprintln!("Processed {} bytes", cnt),
Expand All @@ -32,8 +32,8 @@ async fn main() -> Result<(), Box<dyn std::error::Error>> {
TlsListener::new(SpawningHandshakes(tls_acceptor()), listener)
.for_each_concurrent(None, |s| async {
match s {
Ok(stream) => {
handle_stream(stream).await;
Ok((stream, remote_addr)) => {
handle_stream(stream, remote_addr).await;
}
Err(e) => {
eprintln!("Error: {:?}", e);
Expand Down
10 changes: 7 additions & 3 deletions examples/echo.rs
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ mod tls_config;
use tls_config::tls_acceptor;

#[inline]
async fn handle_stream(stream: TlsStream<TcpStream>) {
async fn handle_stream(stream: TlsStream<TcpStream>, _remote_addr: SocketAddr) {
let (mut reader, mut writer) = split(stream);
match copy(&mut reader, &mut writer).await {
Ok(cnt) => eprintln!("Processed {} bytes", cnt),
Expand All @@ -41,10 +41,14 @@ async fn main() -> Result<(), Box<dyn std::error::Error>> {
TlsListener::new(tls_acceptor(), listener)
.for_each_concurrent(None, |s| async {
match s {
Ok(stream) => {
handle_stream(stream).await;
Ok((stream, remote_addr)) => {
handle_stream(stream, remote_addr).await;
}
Err(e) => {
if let Some(remote_addr) = e.remote_addr() {
eprint!("[client {remote_addr}] ");
}

eprintln!("Error accepting connection: {:?}", e);
}
}
Expand Down
8 changes: 6 additions & 2 deletions examples/http-change-certificate.rs
Original file line number Diff line number Diff line change
Expand Up @@ -30,18 +30,22 @@ async fn main() {
tokio::select! {
conn = listener.accept() => {
match conn.expect("Tls listener stream should be infinite") {
Ok(conn) => {
Ok((conn, remote_addr)) => {
let http = http.clone();
let tx = tx.clone();
let counter = counter.clone();
tokio::spawn(async move {
let svc = service_fn(move |request| handle_request(tx.clone(), counter.clone(), request));
if let Err(err) = http.serve_connection(conn, svc).await {
eprintln!("Application error: {}", err);
eprintln!("Application error (client address: {remote_addr}): {}", err);
}
});
},
Err(e) => {
if let Some(remote_addr) = e.remote_addr() {
eprint!("[client {remote_addr}] ");
}

eprintln!("Bad connection: {}", e);
}
}
Expand Down
8 changes: 6 additions & 2 deletions examples/http-low-level.rs
Original file line number Diff line number Diff line change
Expand Up @@ -27,15 +27,19 @@ async fn main() {
listener
.for_each(|r| async {
match r {
Ok(conn) => {
Ok((conn, remote_addr)) => {
let http = http.clone();
tokio::spawn(async move {
if let Err(err) = http.serve_connection(conn, svc).await {
eprintln!("Application error: {}", err);
eprintln!("[client {remote_addr}] Application error: {}", err);
}
});
}
Err(err) => {
if let Some(remote_addr) = err.remote_addr() {
eprint!("[client {remote_addr}] ");
}

eprintln!("Error accepting connection: {}", err);
}
}
Expand Down
20 changes: 11 additions & 9 deletions examples/http-stream.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
use futures_util::stream::StreamExt;
use futures_util::stream::{StreamExt, TryStreamExt};
use hyper::server::accept;
use hyper::server::conn::AddrIncoming;
use hyper::service::{make_service_fn, service_fn};
Expand All @@ -22,14 +22,16 @@ async fn main() -> Result<(), Box<dyn std::error::Error>> {
});

// This uses a filter to handle errors with connecting
let incoming = TlsListener::new(tls_acceptor(), AddrIncoming::bind(&addr)?).filter(|conn| {
if let Err(err) = conn {
eprintln!("Error: {:?}", err);
ready(false)
} else {
ready(true)
}
});
let incoming = TlsListener::new(tls_acceptor(), AddrIncoming::bind(&addr)?)
.filter(|conn| {
if let Err(err) = conn {
eprintln!("Error: {:?}", err);
ready(false)
} else {
ready(true)
}
})
.map_ok(|(conn, _remote_addr)| conn);

let server = Server::builder(accept::from_stream(incoming)).serve(new_svc);
server.await?;
Expand Down
25 changes: 19 additions & 6 deletions src/hyper.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,12 +7,16 @@ use std::ops::{Deref, DerefMut};
impl AsyncAccept for AddrIncoming {
type Connection = AddrStream;
type Error = std::io::Error;
type Address = std::net::SocketAddr;

fn poll_accept(
self: Pin<&mut Self>,
cx: &mut Context<'_>,
) -> Poll<Option<Result<Self::Connection, Self::Error>>> {
<AddrIncoming as HyperAccept>::poll_accept(self, cx)
) -> Poll<Option<Result<(Self::Connection, Self::Address), Self::Error>>> {
<AddrIncoming as HyperAccept>::poll_accept(self, cx).map_ok(|conn| {
let remote_addr = conn.remote_addr();
(conn, remote_addr)
})
}
}

Expand All @@ -22,6 +26,11 @@ pin_project! {
/// Unfortunately, it isn't possible to use a blanket impl, due to coherence rules.
/// At least until [RFC 1210](https://rust-lang.github.io/rfcs/1210-impl-specialization.html)
/// (specialization) is stabilized.
///
/// Note that, because `hyper::server::accept::Accept` does not expose the
/// remote address, the implementation of `AsyncAccept` for `WrappedAccept`
/// doesn't expose it either. That is, [`AsyncAccept::Address`] is `()` in
/// this case.
//#[cfg_attr(docsrs, doc(cfg(any(feature = "hyper-h1", feature = "hyper-h2"))))]
pub struct WrappedAccept<A> {
// sadly, pin-project-lite doesn't suport tuple structs :(
Expand All @@ -46,12 +55,16 @@ where
{
type Connection = A::Conn;
type Error = A::Error;
type Address = ();

fn poll_accept(
self: Pin<&mut Self>,
cx: &mut Context<'_>,
) -> Poll<Option<Result<Self::Connection, Self::Error>>> {
self.project().inner.poll_accept(cx)
) -> Poll<Option<Result<(Self::Connection, Self::Address), Self::Error>>> {
self.project()
.inner
.poll_accept(cx)
.map_ok(|conn| (conn, ()))
}
}

Expand Down Expand Up @@ -95,12 +108,12 @@ where
T: AsyncTls<A::Connection>,
{
type Conn = T::Stream;
type Error = Error<A::Error, T::Error>;
type Error = Error<A::Error, T::Error, A::Address>;

fn poll_accept(
self: Pin<&mut Self>,
cx: &mut Context<'_>,
) -> Poll<Option<Result<Self::Conn, Self::Error>>> {
self.poll_next(cx)
self.poll_next(cx).map_ok(|(conn, _)| conn)
}
}
123 changes: 103 additions & 20 deletions src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ use futures_util::stream::{FuturesUnordered, Stream, StreamExt};
use pin_project_lite::pin_project;
#[cfg(feature = "rt")]
pub use spawning_handshake::SpawningHandshakes;
use std::fmt::Debug;
use std::future::Future;
use std::pin::Pin;
use std::task::{Context, Poll};
Expand Down Expand Up @@ -67,14 +68,19 @@ pub trait AsyncTls<C: AsyncRead + AsyncWrite>: Clone {
pub trait AsyncAccept {
/// The type of the connection that is accepted.
type Connection: AsyncRead + AsyncWrite;
/// The type of the remote address, such as [`std::net::SocketAddr`].
///
/// If no remote address can be determined (such as for mock connections),
/// `()` or a similar dummy type can be used.
type Address: Debug;
/// The type of error that may be returned.
type Error;

/// Poll to accept the next connection.
fn poll_accept(
self: Pin<&mut Self>,
cx: &mut Context<'_>,
) -> Poll<Option<Result<Self::Connection, Self::Error>>>;
) -> Poll<Option<Result<(Self::Connection, Self::Address), Self::Error>>>;

/// Return a new `AsyncAccept` that stops accepting connections after
/// `ender` completes.
Expand Down Expand Up @@ -126,7 +132,7 @@ pin_project! {
#[pin]
listener: A,
tls: T,
waiting: FuturesUnordered<Timeout<T::AcceptFuture>>,
waiting: FuturesUnordered<FutureWithExtraData<Timeout<T::AcceptFuture>, A::Address>>,
max_handshakes: usize,
timeout: Duration,
}
Expand All @@ -142,13 +148,36 @@ pub struct Builder<T> {

/// Wraps errors from either the listener or the TLS Acceptor
#[derive(Debug, Error)]
pub enum Error<LE: std::error::Error, TE: std::error::Error> {
#[non_exhaustive]
// TODO: It would probably be more simple and more future-proof to use the
// `AsyncAccept` and `AsyncTls` implementations as the type parameters here, so
// that their associated types can be used in the fields
// (i.e. `error: A::Error, remote_addr: A::Address`), but that would require us
// to either hand-write `impl Debug` or use a proc-macro crate like
// `impl-tools` to derive `Debug` with custom bounds,
// due to https://github.com/rust-lang/rust/issues/26925
pub enum Error<LE: std::error::Error, TE: std::error::Error, A> {
/// An error that arose from the listener ([AsyncAccept::Error])
#[error("{0}")]
ListenerError(#[source] LE),
/// An error that occurred during the TLS accept handshake
#[error("{0}")]
TlsAcceptError(#[source] TE),
#[error("{error}")]
#[non_exhaustive]
TlsAcceptError {
/// The error that occurred.
#[source]
error: TE,

/// The client's address and port.
remote_addr: A,
},
/// The TLS handshake timed out
#[error("Timeout during TLS handshake")]
#[non_exhaustive]
HandshakeTimeout {
/// The client's address and port.
remote_addr: A,
},
}

impl<A: AsyncAccept, T> TlsListener<A, T>
Expand Down Expand Up @@ -200,17 +229,19 @@ where
A::Error: std::error::Error,
T: AsyncTls<A::Connection>,
{
type Item = Result<T::Stream, Error<A::Error, T::Error>>;
type Item = Result<(T::Stream, A::Address), Error<A::Error, T::Error, A::Address>>;

fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
let mut this = self.project();

while this.waiting.len() < *this.max_handshakes {
match this.listener.as_mut().poll_accept(cx) {
Poll::Pending => break,
Poll::Ready(Some(Ok(conn))) => {
this.waiting
.push(timeout(*this.timeout, this.tls.accept(conn)));
Poll::Ready(Some(Ok((conn, address)))) => {
this.waiting.push(FutureWithExtraData::new(
timeout(*this.timeout, this.tls.accept(conn)),
address,
));
}
Poll::Ready(Some(Err(e))) => {
return Poll::Ready(Some(Err(Error::ListenerError(e))));
Expand All @@ -219,16 +250,17 @@ where
}
}

loop {
return match this.waiting.poll_next_unpin(cx) {
Poll::Ready(Some(Ok(conn))) => {
Poll::Ready(Some(conn.map_err(Error::TlsAcceptError)))
}
// The handshake timed out, try getting another connection from the
// queue
Poll::Ready(Some(Err(_))) => continue,
_ => Poll::Pending,
};
match this.waiting.poll_next_unpin(cx) {
Poll::Ready(Some((Ok(result), remote_addr))) => Poll::Ready(Some(match result {
Ok(conn) => Ok((conn, remote_addr)),
Err(error) => Err(Error::TlsAcceptError { error, remote_addr }),
})),
// The handshake timed out, try getting another connection from the
// queue
Poll::Ready(Some((Err(_), remote_addr))) => {
Poll::Ready(Some(Err(Error::HandshakeTimeout { remote_addr })))
}
_ => Poll::Pending,
}
}
}
Expand Down Expand Up @@ -334,6 +366,19 @@ impl<T> Builder<T> {
}
}

impl<LE: std::error::Error, TE: std::error::Error, A> Error<LE, TE, A> {
/// Returns the client's address and port, if known.
pub fn remote_addr(&self) -> Option<&A> {
match self {
Self::ListenerError(_) => None,

Self::TlsAcceptError { remote_addr, .. } | Self::HandshakeTimeout { remote_addr } => {
Some(remote_addr)
}
}
}
}

/// Create a new Builder for a TlsListener
///
/// `server_config` will be used to configure the TLS sessions.
Expand All @@ -358,11 +403,12 @@ pin_project! {
impl<A: AsyncAccept, E: Future> AsyncAccept for Until<A, E> {
type Connection = A::Connection;
type Error = A::Error;
type Address = A::Address;

fn poll_accept(
self: Pin<&mut Self>,
cx: &mut Context<'_>,
) -> Poll<Option<Result<Self::Connection, Self::Error>>> {
) -> Poll<Option<Result<(Self::Connection, Self::Address), Self::Error>>> {
let this = self.project();

match this.ender.poll(cx) {
Expand All @@ -371,3 +417,40 @@ impl<A: AsyncAccept, E: Future> AsyncAccept for Until<A, E> {
}
}
}

pin_project! {
struct FutureWithExtraData<Fut, X> {
#[pin]
future: Fut,
extra: Option<X>,
}
}

impl<Fut, X> FutureWithExtraData<Fut, X> {
fn new(future: Fut, extra: X) -> Self {
Self {
future,
extra: Some(extra),
}
}
}

impl<Fut, X> Future for FutureWithExtraData<Fut, X>
where
Fut: Future,
{
type Output = (Fut::Output, X);

fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
let this = self.project();
let extra = this.extra;

this.future.poll(cx).map(|output| {
let extra = extra
.take()
.expect("this future has already been polled to completion");

(output, extra)
})
}
}
Loading