Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat!: Add a new error type for handshake timeouts #37

Merged
merged 3 commits into from
Oct 17, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ hyper-h2 = ["hyper", "hyper/http2"]
[dependencies]
futures-util = "0.3.8"
hyper = { version = "0.14.1", features = ["server", "tcp"], optional = true }
pin-project-lite = "0.2.8"
pin-project-lite = "0.2.13"
thiserror = "1.0.30"
tokio = { version = "1.0", features = ["time"] }
tokio-native-tls = { version = "0.3.0", optional = true }
Expand Down
6 changes: 3 additions & 3 deletions examples/echo-threads.rs
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ mod tls_config;
use tls_config::tls_acceptor;

#[inline]
async fn handle_stream(stream: TlsStream<TcpStream>) {
async fn handle_stream(stream: TlsStream<TcpStream>, _remote_addr: SocketAddr) {
let (mut reader, mut writer) = split(stream);
match copy(&mut reader, &mut writer).await {
Ok(cnt) => eprintln!("Processed {} bytes", cnt),
Expand All @@ -32,8 +32,8 @@ async fn main() -> Result<(), Box<dyn std::error::Error>> {
TlsListener::new(SpawningHandshakes(tls_acceptor()), listener)
.for_each_concurrent(None, |s| async {
match s {
Ok(stream) => {
handle_stream(stream).await;
Ok((stream, remote_addr)) => {
handle_stream(stream, remote_addr).await;
}
Err(e) => {
eprintln!("Error: {:?}", e);
Expand Down
10 changes: 7 additions & 3 deletions examples/echo.rs
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ mod tls_config;
use tls_config::tls_acceptor;

#[inline]
async fn handle_stream(stream: TlsStream<TcpStream>) {
async fn handle_stream(stream: TlsStream<TcpStream>, _remote_addr: SocketAddr) {
let (mut reader, mut writer) = split(stream);
match copy(&mut reader, &mut writer).await {
Ok(cnt) => eprintln!("Processed {} bytes", cnt),
Expand All @@ -41,10 +41,14 @@ async fn main() -> Result<(), Box<dyn std::error::Error>> {
TlsListener::new(tls_acceptor(), listener)
.for_each_concurrent(None, |s| async {
match s {
Ok(stream) => {
handle_stream(stream).await;
Ok((stream, remote_addr)) => {
handle_stream(stream, remote_addr).await;
}
Err(e) => {
if let Some(remote_addr) = e.peer_addr() {
eprint!("[client {remote_addr}] ");
}

eprintln!("Error accepting connection: {:?}", e);
}
}
Expand Down
8 changes: 6 additions & 2 deletions examples/http-change-certificate.rs
Original file line number Diff line number Diff line change
Expand Up @@ -30,18 +30,22 @@ async fn main() {
tokio::select! {
conn = listener.accept() => {
match conn.expect("Tls listener stream should be infinite") {
Ok(conn) => {
Ok((conn, remote_addr)) => {
let http = http.clone();
let tx = tx.clone();
let counter = counter.clone();
tokio::spawn(async move {
let svc = service_fn(move |request| handle_request(tx.clone(), counter.clone(), request));
if let Err(err) = http.serve_connection(conn, svc).await {
eprintln!("Application error: {}", err);
eprintln!("Application error (client address: {remote_addr}): {err}");
}
});
},
Err(e) => {
if let Some(remote_addr) = e.peer_addr() {
eprint!("[client {remote_addr}] ");
}

eprintln!("Bad connection: {}", e);
}
}
Expand Down
8 changes: 6 additions & 2 deletions examples/http-low-level.rs
Original file line number Diff line number Diff line change
Expand Up @@ -27,15 +27,19 @@ async fn main() {
listener
.for_each(|r| async {
match r {
Ok(conn) => {
Ok((conn, remote_addr)) => {
let http = http.clone();
tokio::spawn(async move {
if let Err(err) = http.serve_connection(conn, svc).await {
eprintln!("Application error: {}", err);
eprintln!("[client {remote_addr}] Application error: {}", err);
}
});
}
Err(err) => {
if let Some(remote_addr) = err.peer_addr() {
eprint!("[client {remote_addr}] ");
}

eprintln!("Error accepting connection: {}", err);
}
}
Expand Down
18 changes: 10 additions & 8 deletions examples/http-stream.rs
Original file line number Diff line number Diff line change
Expand Up @@ -22,14 +22,16 @@ async fn main() -> Result<(), Box<dyn std::error::Error>> {
});

// This uses a filter to handle errors with connecting
let incoming = TlsListener::new(tls_acceptor(), AddrIncoming::bind(&addr)?).filter(|conn| {
if let Err(err) = conn {
eprintln!("Error: {:?}", err);
ready(false)
} else {
ready(true)
}
});
let incoming = TlsListener::new(tls_acceptor(), AddrIncoming::bind(&addr)?)
.connections()
.filter(|conn| {
if let Err(err) = conn {
eprintln!("Error: {:?}", err);
ready(false)
} else {
ready(true)
}
});

let server = Server::builder(accept::from_stream(incoming)).serve(new_svc);
server.await?;
Expand Down
3 changes: 3 additions & 0 deletions examples/tls_config/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,9 @@ mod config {

const CERT: &[u8] = include_bytes!("local.cert");
const PKEY: &[u8] = include_bytes!("local.key");
#[allow(dead_code)]
const CERT2: &[u8] = include_bytes!("local2.cert");
#[allow(dead_code)]
const PKEY2: &[u8] = include_bytes!("local2.key");

pub type Acceptor = tokio_rustls::TlsAcceptor;
Expand All @@ -27,6 +29,7 @@ mod config {
tls_acceptor_impl(PKEY, CERT)
}

#[allow(dead_code)]
pub fn tls_acceptor2() -> Acceptor {
tls_acceptor_impl(PKEY2, CERT2)
}
Expand Down
27 changes: 21 additions & 6 deletions src/hyper.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,12 +7,16 @@ use std::ops::{Deref, DerefMut};
impl AsyncAccept for AddrIncoming {
type Connection = AddrStream;
type Error = std::io::Error;
type Address = std::net::SocketAddr;

fn poll_accept(
self: Pin<&mut Self>,
cx: &mut Context<'_>,
) -> Poll<Option<Result<Self::Connection, Self::Error>>> {
<AddrIncoming as HyperAccept>::poll_accept(self, cx)
) -> Poll<Option<Result<(Self::Connection, Self::Address), Self::Error>>> {
<AddrIncoming as HyperAccept>::poll_accept(self, cx).map_ok(|conn| {
let peer_addr = conn.remote_addr();
(conn, peer_addr)
})
}
}

Expand All @@ -22,6 +26,11 @@ pin_project! {
/// Unfortunately, it isn't possible to use a blanket impl, due to coherence rules.
/// At least until [RFC 1210](https://rust-lang.github.io/rfcs/1210-impl-specialization.html)
/// (specialization) is stabilized.
///
/// Note that, because `hyper::server::accept::Accept` does not expose the
/// remote address, the implementation of `AsyncAccept` for `WrappedAccept`
/// doesn't expose it either. That is, [`AsyncAccept::Address`] is `()` in
/// this case.
//#[cfg_attr(docsrs, doc(cfg(any(feature = "hyper-h1", feature = "hyper-h2"))))]
pub struct WrappedAccept<A> {
// sadly, pin-project-lite doesn't suport tuple structs :(
Expand All @@ -43,15 +52,20 @@ pub fn wrap<A: HyperAccept>(acceptor: A) -> WrappedAccept<A> {
impl<A: HyperAccept> AsyncAccept for WrappedAccept<A>
where
A::Conn: AsyncRead + AsyncWrite,
A::Error: std::error::Error,
{
type Connection = A::Conn;
type Error = A::Error;
type Address = ();

fn poll_accept(
self: Pin<&mut Self>,
cx: &mut Context<'_>,
) -> Poll<Option<Result<Self::Connection, Self::Error>>> {
self.project().inner.poll_accept(cx)
) -> Poll<Option<Result<(Self::Connection, ()), Self::Error>>> {
self.project()
.inner
.poll_accept(cx)
.map_ok(|conn| (conn, ()))
}
}

Expand All @@ -78,6 +92,7 @@ impl<A: HyperAccept> WrappedAccept<A> {
impl<A: HyperAccept, T> TlsListener<WrappedAccept<A>, T>
where
A::Conn: AsyncWrite + AsyncRead,
A::Error: std::error::Error,
T: AsyncTls<A::Conn>,
{
/// Create a `TlsListener` from a hyper [`Accept`](::hyper::server::accept::Accept) and tls
Expand All @@ -95,12 +110,12 @@ where
T: AsyncTls<A::Connection>,
{
type Conn = T::Stream;
type Error = Error<A::Error, T::Error>;
type Error = Error<A::Error, T::Error, A::Address>;

fn poll_accept(
self: Pin<&mut Self>,
cx: &mut Context<'_>,
) -> Poll<Option<Result<Self::Conn, Self::Error>>> {
self.poll_next(cx)
self.poll_next(cx).map_ok(|(conn, _)| conn)
}
}
Loading