From 57a8a01e4ad3a5df9c7449a37491d5c947f577b8 Mon Sep 17 00:00:00 2001 From: Miguel Guarniz Date: Thu, 16 Mar 2023 16:31:47 -0400 Subject: [PATCH] Add Experimental HTTP/3 Support (#1599) This adds experimental HTTP/3 support, using `h3` and `h3-quinn` (though the internals are not to be depended on). --- .gitignore | 1 + Cargo.toml | 10 + examples/h3_simple.rs | 51 +++++ src/async_impl/client.rs | 296 ++++++++++++++++++++++++---- src/async_impl/h3_client/connect.rs | 87 ++++++++ src/async_impl/h3_client/dns.rs | 43 ++++ src/async_impl/h3_client/mod.rs | 88 +++++++++ src/async_impl/h3_client/pool.rs | 198 +++++++++++++++++++ src/async_impl/mod.rs | 1 + src/tls.rs | 9 +- 10 files changed, 742 insertions(+), 42 deletions(-) create mode 100644 examples/h3_simple.rs create mode 100644 src/async_impl/h3_client/connect.rs create mode 100644 src/async_impl/h3_client/dns.rs create mode 100644 src/async_impl/h3_client/mod.rs create mode 100644 src/async_impl/h3_client/pool.rs diff --git a/.gitignore b/.gitignore index d4f917d3d..a57891807 100644 --- a/.gitignore +++ b/.gitignore @@ -1,3 +1,4 @@ target Cargo.lock *.swp +.idea \ No newline at end of file diff --git a/Cargo.toml b/Cargo.toml index 6dd3f1639..a82c7c403 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -62,6 +62,9 @@ stream = ["tokio/fs", "tokio-util", "wasm-streams"] socks = ["tokio-socks"] +# Experimental HTTP/3 client. +http3 = ["rustls-tls", "h3", "h3-quinn", "quinn", "futures-channel"] + # Internal (PRIVATE!) features used to aid testing. # Don't rely on these whatsoever. They may disappear at anytime. @@ -135,6 +138,13 @@ tokio-socks = { version = "0.5.1", optional = true } ## trust-dns trust-dns-resolver = { version = "0.22", optional = true } +# HTTP/3 experimental support +h3 = { version="0.0.1", optional = true } +h3-quinn = { version="0.0.1", optional = true } +quinn = { version = "0.8", default-features = false, features = ["tls-rustls", "ring"], optional = true } +futures-channel = { version="0.3", optional = true} + + [target.'cfg(not(target_arch = "wasm32"))'.dev-dependencies] env_logger = "0.8" hyper = { version = "0.14", default-features = false, features = ["tcp", "stream", "http1", "http2", "client", "server", "runtime"] } diff --git a/examples/h3_simple.rs b/examples/h3_simple.rs new file mode 100644 index 000000000..ba9667f32 --- /dev/null +++ b/examples/h3_simple.rs @@ -0,0 +1,51 @@ +#![deny(warnings)] + +// This is using the `tokio` runtime. You'll need the following dependency: +// +// `tokio = { version = "1", features = ["full"] }` +#[cfg(feature = "http3")] +#[cfg(not(target_arch = "wasm32"))] +#[tokio::main] +async fn main() -> Result<(), reqwest::Error> { + use http::Version; + use reqwest::{Client, IntoUrl, Response}; + + async fn get(url: T) -> reqwest::Result { + Client::builder() + .http3_prior_knowledge() + .build()? + .get(url) + .version(Version::HTTP_3) + .send() + .await + } + + // Some simple CLI args requirements... + let url = match std::env::args().nth(1) { + Some(url) => url, + None => { + println!("No CLI URL provided, using default."); + "https://hyper.rs".into() + } + }; + + eprintln!("Fetching {:?}...", url); + + let res = get(url).await?; + + eprintln!("Response: {:?} {}", res.version(), res.status()); + eprintln!("Headers: {:#?}\n", res.headers()); + + let body = res.text().await?; + + println!("{}", body); + + Ok(()) +} + +// The [cfg(not(target_arch = "wasm32"))] above prevent building the tokio::main function +// for wasm32 target, because tokio isn't compatible with wasm32. +// If you aren't building for wasm32, you don't need that line. +// The two lines below avoid the "'main' function not found" error when building for wasm32 target. +#[cfg(any(target_arch = "wasm32", not(feature = "http3")))] +fn main() {} diff --git a/src/async_impl/client.rs b/src/async_impl/client.rs index b56c7fee7..559e5f365 100644 --- a/src/async_impl/client.rs +++ b/src/async_impl/client.rs @@ -13,7 +13,7 @@ use http::header::{ }; use http::uri::Scheme; use http::Uri; -use hyper::client::{HttpConnector, ResponseFuture}; +use hyper::client::{HttpConnector, ResponseFuture as HyperResponseFuture}; #[cfg(feature = "native-tls-crate")] use native_tls_crate::TlsConnector; use pin_project_lite::pin_project; @@ -22,12 +22,14 @@ use std::pin::Pin; use std::task::{Context, Poll}; use tokio::time::Sleep; -use log::{debug, trace}; - use super::decoder::Accepts; use super::request::{Request, RequestBuilder}; use super::response::Response; use super::Body; +#[cfg(feature = "http3")] +use crate::async_impl::h3_client::connect::H3Connector; +#[cfg(feature = "http3")] +use crate::async_impl::h3_client::{H3Client, H3ResponseFuture}; use crate::connect::Connector; #[cfg(feature = "cookies")] use crate::cookie; @@ -44,6 +46,11 @@ use crate::Certificate; #[cfg(any(feature = "native-tls", feature = "__rustls"))] use crate::Identity; use crate::{IntoUrl, Method, Proxy, StatusCode, Url}; +use log::{debug, trace}; +#[cfg(feature = "http3")] +use quinn::TransportConfig; +#[cfg(feature = "http3")] +use quinn::VarInt; /// An asynchronous `Client` to make Requests with. /// @@ -72,6 +79,8 @@ pub struct ClientBuilder { enum HttpVersionPref { Http1, Http2, + #[cfg(feature = "http3")] + Http3, All, } @@ -125,6 +134,16 @@ struct Config { trust_dns: bool, error: Option, https_only: bool, + #[cfg(feature = "http3")] + tls_enable_early_data: bool, + #[cfg(feature = "http3")] + quic_max_idle_timeout: Option, + #[cfg(feature = "http3")] + quic_stream_receive_window: Option, + #[cfg(feature = "http3")] + quic_receive_window: Option, + #[cfg(feature = "http3")] + quic_send_window: Option, dns_overrides: HashMap>, dns_resolver: Option>, } @@ -196,6 +215,16 @@ impl ClientBuilder { cookie_store: None, https_only: false, dns_overrides: HashMap::new(), + #[cfg(feature = "http3")] + tls_enable_early_data: false, + #[cfg(feature = "http3")] + quic_max_idle_timeout: None, + #[cfg(feature = "http3")] + quic_stream_receive_window: None, + #[cfg(feature = "http3")] + quic_receive_window: None, + #[cfg(feature = "http3")] + quic_send_window: None, dns_resolver: None, }, } @@ -220,6 +249,10 @@ impl ClientBuilder { } let proxies = Arc::new(proxies); + #[allow(unused)] + #[cfg(feature = "http3")] + let mut h3_connector = None; + let mut connector = { #[cfg(feature = "__tls")] fn user_agent(headers: &HeaderMap) -> Option { @@ -242,7 +275,7 @@ impl ClientBuilder { config.dns_overrides, )); } - let http = HttpConnector::new_with_resolver(DynResolver::new(resolver)); + let http = HttpConnector::new_with_resolver(DynResolver::new(resolver.clone())); #[cfg(feature = "__tls")] match config.tls { @@ -250,7 +283,7 @@ impl ClientBuilder { TlsBackend::Default => { let mut tls = TlsConnector::builder(); - #[cfg(feature = "native-tls-alpn")] + #[cfg(all(feature = "native-tls-alpn", not(feature = "http3")))] { match config.http_version_pref { HttpVersionPref::Http1 => { @@ -445,11 +478,47 @@ impl ClientBuilder { HttpVersionPref::Http2 => { tls.alpn_protocols = vec!["h2".into()]; } + #[cfg(feature = "http3")] + HttpVersionPref::Http3 => { + tls.alpn_protocols = vec!["h3".into()]; + } HttpVersionPref::All => { tls.alpn_protocols = vec!["h2".into(), "http/1.1".into()]; } } + #[cfg(feature = "http3")] + { + tls.enable_early_data = config.tls_enable_early_data; + + let mut transport_config = TransportConfig::default(); + + if let Some(max_idle_timeout) = config.quic_max_idle_timeout { + transport_config.max_idle_timeout(Some( + max_idle_timeout.try_into().map_err(error::builder)?, + )); + } + + if let Some(stream_receive_window) = config.quic_stream_receive_window { + transport_config.stream_receive_window(stream_receive_window); + } + + if let Some(receive_window) = config.quic_receive_window { + transport_config.receive_window(receive_window); + } + + if let Some(send_window) = config.quic_send_window { + transport_config.send_window(send_window); + } + + h3_connector = Some(H3Connector::new( + DynResolver::new(resolver), + tls.clone(), + config.local_address, + transport_config, + )); + } + Connector::new_rustls_tls( http, tls, @@ -519,8 +588,6 @@ impl ClientBuilder { builder.http1_allow_obsolete_multiline_headers_in_responses(true); } - let hyper_client = builder.build(connector); - let proxies_maybe_http_auth = proxies.iter().any(|p| p.maybe_has_http_auth()); Ok(Client { @@ -528,7 +595,12 @@ impl ClientBuilder { accepts: config.accepts, #[cfg(feature = "cookies")] cookie_store: config.cookie_store, - hyper: hyper_client, + #[cfg(feature = "http3")] + h3_client: H3Client::new( + h3_connector.expect("missing HTTP/3 connector"), + config.pool_idle_timeout, + ), + hyper: builder.build(connector), headers: config.headers, redirect_policy: config.redirect_policy, referer: config.referer, @@ -929,6 +1001,13 @@ impl ClientBuilder { self } + /// Only use HTTP/3. + #[cfg(feature = "http3")] + pub fn http3_prior_knowledge(mut self) -> ClientBuilder { + self.config.http_version_pref = HttpVersionPref::Http3; + self + } + /// Sets the `SETTINGS_INITIAL_WINDOW_SIZE` option for HTTP2 stream-level flow control. /// /// Default is currently 65,535 but may change internally to optimize for common uses. @@ -1387,6 +1466,62 @@ impl ClientBuilder { self.config.dns_resolver = Some(resolver as _); self } + + /// Whether to send data on the first flight ("early data") in TLS 1.3 handshakes + /// for HTTP/3 connections. + /// + /// The default is false. + #[cfg(feature = "http3")] + pub fn set_tls_enable_early_data(mut self, enabled: bool) -> ClientBuilder { + self.config.tls_enable_early_data = enabled; + self + } + + /// Maximum duration of inactivity to accept before timing out the QUIC connection. + /// + /// Please see docs in [`TransportConfig`] in [`quinn`]. + /// + /// [`TransportConfig`]: https://docs.rs/quinn/latest/quinn/struct.TransportConfig.html + #[cfg(feature = "http3")] + pub fn set_quic_max_idle_timeout(mut self, value: Duration) -> ClientBuilder { + self.config.quic_max_idle_timeout = Some(value); + self + } + + /// Maximum number of bytes the peer may transmit without acknowledgement on any one stream + /// before becoming blocked. + /// + /// Please see docs in [`TransportConfig`] in [`quinn`]. + /// + /// [`TransportConfig`]: https://docs.rs/quinn/latest/quinn/struct.TransportConfig.html + #[cfg(feature = "http3")] + pub fn set_quic_stream_receive_window(mut self, value: VarInt) -> ClientBuilder { + self.config.quic_stream_receive_window = Some(value); + self + } + + /// Maximum number of bytes the peer may transmit across all streams of a connection before + /// becoming blocked. + /// + /// Please see docs in [`TransportConfig`] in [`quinn`]. + /// + /// [`TransportConfig`]: https://docs.rs/quinn/latest/quinn/struct.TransportConfig.html + #[cfg(feature = "http3")] + pub fn set_quic_receive_window(mut self, value: VarInt) -> ClientBuilder { + self.config.quic_receive_window = Some(value); + self + } + + /// Maximum number of bytes to transmit to a peer without acknowledgment + /// + /// Please see docs in [`TransportConfig`] in [`quinn`]. + /// + /// [`TransportConfig`]: https://docs.rs/quinn/latest/quinn/struct.TransportConfig.html + #[cfg(feature = "http3")] + pub fn set_quic_send_window(mut self, value: u64) -> ClientBuilder { + self.config.quic_send_window = Some(value); + self + } } type HyperClient = hyper::Client; @@ -1553,22 +1688,32 @@ impl Client { self.proxy_auth(&uri, &mut headers); - let mut req = hyper::Request::builder() + let builder = hyper::Request::builder() .method(method.clone()) .uri(uri) - .version(version) - .body(body.into_stream()) - .expect("valid request parts"); + .version(version); + + let in_flight = match version { + #[cfg(feature = "http3")] + http::Version::HTTP_3 => { + let mut req = builder.body(body).expect("valid request parts"); + *req.headers_mut() = headers.clone(); + ResponseFuture::H3(self.inner.h3_client.request(req)) + } + _ => { + let mut req = builder + .body(body.into_stream()) + .expect("valid request parts"); + *req.headers_mut() = headers.clone(); + ResponseFuture::Default(self.inner.hyper.request(req)) + } + }; let timeout = timeout .or(self.inner.request_timeout) .map(tokio::time::sleep) .map(Box::pin); - *req.headers_mut() = headers.clone(); - - let in_flight = self.inner.hyper.request(req); - Pending { inner: PendingInner::Request(PendingRequest { method, @@ -1752,6 +1897,13 @@ impl Config { if !self.dns_overrides.is_empty() { f.field("dns_overrides", &self.dns_overrides); } + + #[cfg(feature = "http3")] + { + if self.tls_enable_early_data { + f.field("tls_enable_early_data", &true); + } + } } } @@ -1761,6 +1913,8 @@ struct ClientRef { cookie_store: Option>, headers: HeaderMap, hyper: HyperClient, + #[cfg(feature = "http3")] + h3_client: H3Client, redirect_policy: redirect::Policy, referer: bool, request_timeout: Option, @@ -1835,6 +1989,12 @@ pin_project! { } } +enum ResponseFuture { + Default(HyperResponseFuture), + #[cfg(feature = "http3")] + H3(H3ResponseFuture), +} + impl PendingRequest { fn in_flight(self: Pin<&mut Self>) -> Pin<&mut ResponseFuture> { self.project().in_flight @@ -1875,21 +2035,43 @@ impl PendingRequest { self.retry_count += 1; let uri = expect_uri(&self.url); - let mut req = hyper::Request::builder() - .method(self.method.clone()) - .uri(uri) - .body(body.into_stream()) - .expect("valid request parts"); - *req.headers_mut() = self.headers.clone(); - - *self.as_mut().in_flight().get_mut() = self.client.hyper.request(req); + *self.as_mut().in_flight().get_mut() = match *self.as_mut().in_flight().as_ref() { + #[cfg(feature = "http3")] + ResponseFuture::H3(_) => { + let mut req = hyper::Request::builder() + .method(self.method.clone()) + .uri(uri) + .body(body) + .expect("valid request parts"); + *req.headers_mut() = self.headers.clone(); + ResponseFuture::H3(self.client.h3_client.request(req)) + } + _ => { + let mut req = hyper::Request::builder() + .method(self.method.clone()) + .uri(uri) + .body(body.into_stream()) + .expect("valid request parts"); + *req.headers_mut() = self.headers.clone(); + ResponseFuture::Default(self.client.hyper.request(req)) + } + }; true } } fn is_retryable_error(err: &(dyn std::error::Error + 'static)) -> bool { + #[cfg(feature = "http3")] + if let Some(cause) = err.source() { + if let Some(err) = cause.downcast_ref::() { + debug!("determining if HTTP/3 error {} can be retried", err); + // TODO: Does h3 provide an API for checking the error? + return err.to_string().as_str() == "timeout"; + } + } + if let Some(cause) = err.source() { if let Some(err) = cause.downcast_ref::() { // They sent us a graceful shutdown, try with a new connection! @@ -1940,15 +2122,32 @@ impl Future for PendingRequest { } loop { - let res = match self.as_mut().in_flight().as_mut().poll(cx) { - Poll::Ready(Err(e)) => { - if self.as_mut().retry_error(&e) { - continue; + let res = match self.as_mut().in_flight().get_mut() { + ResponseFuture::Default(r) => match Pin::new(r).poll(cx) { + Poll::Ready(Err(e)) => { + if self.as_mut().retry_error(&e) { + continue; + } + return Poll::Ready(Err( + crate::error::request(e).with_url(self.url.clone()) + )); } - return Poll::Ready(Err(crate::error::request(e).with_url(self.url.clone()))); - } - Poll::Ready(Ok(res)) => res, - Poll::Pending => return Poll::Pending, + Poll::Ready(Ok(res)) => res, + Poll::Pending => return Poll::Pending, + }, + #[cfg(feature = "http3")] + ResponseFuture::H3(r) => match Pin::new(r).poll(cx) { + Poll::Ready(Err(e)) => { + if self.as_mut().retry_error(&e) { + continue; + } + return Poll::Ready(Err( + crate::error::request(e).with_url(self.url.clone()) + )); + } + Poll::Ready(Ok(res)) => res, + Poll::Pending => return Poll::Pending, + }, }; #[cfg(feature = "cookies")] @@ -2048,11 +2247,6 @@ impl Future for PendingRequest { Some(Some(ref body)) => Body::reusable(body.clone()), _ => Body::empty(), }; - let mut req = hyper::Request::builder() - .method(self.method.clone()) - .uri(uri.clone()) - .body(body.into_stream()) - .expect("valid request parts"); // Add cookies from the cookie store. #[cfg(feature = "cookies")] @@ -2062,9 +2256,31 @@ impl Future for PendingRequest { } } - *req.headers_mut() = headers.clone(); - std::mem::swap(self.as_mut().headers(), &mut headers); - *self.as_mut().in_flight().get_mut() = self.client.hyper.request(req); + *self.as_mut().in_flight().get_mut() = + match *self.as_mut().in_flight().as_ref() { + #[cfg(feature = "http3")] + ResponseFuture::H3(_) => { + let mut req = hyper::Request::builder() + .method(self.method.clone()) + .uri(uri.clone()) + .body(body) + .expect("valid request parts"); + *req.headers_mut() = headers.clone(); + std::mem::swap(self.as_mut().headers(), &mut headers); + ResponseFuture::H3(self.client.h3_client.request(req)) + } + _ => { + let mut req = hyper::Request::builder() + .method(self.method.clone()) + .uri(uri.clone()) + .body(body.into_stream()) + .expect("valid request parts"); + *req.headers_mut() = headers.clone(); + std::mem::swap(self.as_mut().headers(), &mut headers); + ResponseFuture::Default(self.client.hyper.request(req)) + } + }; + continue; } redirect::ActionKind::Stop => { diff --git a/src/async_impl/h3_client/connect.rs b/src/async_impl/h3_client/connect.rs new file mode 100644 index 000000000..755864590 --- /dev/null +++ b/src/async_impl/h3_client/connect.rs @@ -0,0 +1,87 @@ +use crate::async_impl::h3_client::dns::resolve; +use crate::dns::DynResolver; +use crate::error::BoxError; +use bytes::Bytes; +use h3::client::SendRequest; +use h3_quinn::{Connection, OpenStreams}; +use http::Uri; +use hyper::client::connect::dns::Name; +use quinn::{ClientConfig, Endpoint, TransportConfig}; +use std::net::{IpAddr, SocketAddr}; +use std::str::FromStr; +use std::sync::Arc; + +type H3Connection = ( + h3::client::Connection, + SendRequest, +); + +#[derive(Clone)] +pub(crate) struct H3Connector { + resolver: DynResolver, + endpoint: Endpoint, +} + +impl H3Connector { + pub fn new( + resolver: DynResolver, + tls: rustls::ClientConfig, + local_addr: Option, + transport_config: TransportConfig, + ) -> H3Connector { + let mut config = ClientConfig::new(Arc::new(tls)); + // FIXME: Replace this when there is a setter. + config.transport = Arc::new(transport_config); + + let socket_addr = match local_addr { + Some(ip) => SocketAddr::new(ip, 0), + None => "[::]:0".parse::().unwrap(), + }; + + let mut endpoint = Endpoint::client(socket_addr).expect("unable to create QUIC endpoint"); + endpoint.set_default_client_config(config); + + Self { resolver, endpoint } + } + + pub async fn connect(&mut self, dest: Uri) -> Result { + let host = dest.host().ok_or("destination must have a host")?; + let port = dest.port_u16().unwrap_or(443); + + let addrs = if let Some(addr) = IpAddr::from_str(host).ok() { + // If the host is already an IP address, skip resolving. + vec![SocketAddr::new(addr, port)] + } else { + let addrs = resolve(&mut self.resolver, Name::from_str(host)?).await?; + let addrs = addrs.map(|mut addr| { + addr.set_port(port); + addr + }); + addrs.collect() + }; + + self.remote_connect(addrs, host).await + } + + async fn remote_connect( + &mut self, + addrs: Vec, + server_name: &str, + ) -> Result { + let mut err = None; + for addr in addrs { + match self.endpoint.connect(addr, server_name)?.await { + Ok(new_conn) => { + let quinn_conn = Connection::new(new_conn); + return Ok(h3::client::new(quinn_conn).await?); + } + Err(e) => err = Some(e), + } + } + + match err { + Some(e) => Err(Box::new(e) as BoxError), + None => Err("failed to establish connection for HTTP/3 request".into()), + } + } +} diff --git a/src/async_impl/h3_client/dns.rs b/src/async_impl/h3_client/dns.rs new file mode 100644 index 000000000..9cb50d1e3 --- /dev/null +++ b/src/async_impl/h3_client/dns.rs @@ -0,0 +1,43 @@ +use core::task; +use hyper::client::connect::dns::Name; +use std::future::Future; +use std::net::SocketAddr; +use std::task::Poll; +use tower_service::Service; + +// Trait from hyper to implement DNS resolution for HTTP/3 client. +pub trait Resolve { + type Addrs: Iterator; + type Error: Into>; + type Future: Future>; + + fn poll_ready(&mut self, cx: &mut task::Context<'_>) -> Poll>; + fn resolve(&mut self, name: Name) -> Self::Future; +} + +impl Resolve for S +where + S: Service, + S::Response: Iterator, + S::Error: Into>, +{ + type Addrs = S::Response; + type Error = S::Error; + type Future = S::Future; + + fn poll_ready(&mut self, cx: &mut task::Context<'_>) -> Poll> { + Service::poll_ready(self, cx) + } + + fn resolve(&mut self, name: Name) -> Self::Future { + Service::call(self, name) + } +} + +pub(super) async fn resolve(resolver: &mut R, name: Name) -> Result +where + R: Resolve, +{ + futures_util::future::poll_fn(|cx| resolver.poll_ready(cx)).await?; + resolver.resolve(name).await +} diff --git a/src/async_impl/h3_client/mod.rs b/src/async_impl/h3_client/mod.rs new file mode 100644 index 000000000..919e13c0a --- /dev/null +++ b/src/async_impl/h3_client/mod.rs @@ -0,0 +1,88 @@ +#![cfg(feature = "http3")] + +pub(crate) mod connect; +pub(crate) mod dns; +mod pool; + +use crate::async_impl::h3_client::pool::{Key, Pool, PoolClient}; +use crate::error::{BoxError, Error, Kind}; +use crate::{error, Body}; +use connect::H3Connector; +use futures_util::future; +use http::{Request, Response}; +use hyper::Body as HyperBody; +use log::trace; +use std::future::Future; +use std::pin::Pin; +use std::task::{Context, Poll}; +use std::time::Duration; + +#[derive(Clone)] +pub(crate) struct H3Client { + pool: Pool, + connector: H3Connector, +} + +impl H3Client { + pub fn new(connector: H3Connector, pool_timeout: Option) -> Self { + H3Client { + pool: Pool::new(pool_timeout), + connector, + } + } + + async fn get_pooled_client(&mut self, key: Key) -> Result { + if let Some(client) = self.pool.try_pool(&key) { + trace!("getting client from pool with key {:?}", key); + return Ok(client); + } + + trace!("did not find connection {:?} in pool so connecting...", key); + + let dest = pool::domain_as_uri(key.clone()); + self.pool.connecting(key.clone())?; + let (driver, tx) = self.connector.connect(dest).await?; + Ok(self.pool.new_connection(key, driver, tx)) + } + + async fn send_request( + mut self, + key: Key, + req: Request, + ) -> Result, Error> { + let mut pooled = match self.get_pooled_client(key).await { + Ok(client) => client, + Err(e) => return Err(error::request(e)), + }; + pooled + .send_request(req) + .await + .map_err(|e| Error::new(Kind::Request, Some(e))) + } + + pub fn request(&self, mut req: Request) -> H3ResponseFuture { + let pool_key = match pool::extract_domain(req.uri_mut()) { + Ok(s) => s, + Err(e) => { + return H3ResponseFuture { + inner: Box::pin(future::err(e)), + } + } + }; + H3ResponseFuture { + inner: Box::pin(self.clone().send_request(pool_key, req)), + } + } +} + +pub(crate) struct H3ResponseFuture { + inner: Pin, Error>> + Send>>, +} + +impl Future for H3ResponseFuture { + type Output = Result, Error>; + + fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { + self.inner.as_mut().poll(cx) + } +} diff --git a/src/async_impl/h3_client/pool.rs b/src/async_impl/h3_client/pool.rs new file mode 100644 index 000000000..6fcb8e719 --- /dev/null +++ b/src/async_impl/h3_client/pool.rs @@ -0,0 +1,198 @@ +use bytes::Bytes; +use std::collections::{HashMap, HashSet}; +use std::sync::mpsc::{Receiver, TryRecvError}; +use std::sync::{Arc, Mutex}; +use std::time::Duration; +use tokio::time::Instant; + +use crate::error::{BoxError, Error, Kind}; +use crate::Body; +use bytes::Buf; +use futures_util::future; +use h3::client::SendRequest; +use h3_quinn::{Connection, OpenStreams}; +use http::uri::{Authority, Scheme}; +use http::{Request, Response, Uri}; +use hyper::Body as HyperBody; +use log::trace; + +pub(super) type Key = (Scheme, Authority); + +#[derive(Clone)] +pub struct Pool { + inner: Arc>, +} + +impl Pool { + pub fn new(timeout: Option) -> Self { + Self { + inner: Arc::new(Mutex::new(PoolInner { + connecting: HashSet::new(), + idle_conns: HashMap::new(), + timeout, + })), + } + } + + pub fn connecting(&self, key: Key) -> Result<(), BoxError> { + let mut inner = self.inner.lock().unwrap(); + if !inner.connecting.insert(key.clone()) { + return Err(format!("HTTP/3 connecting already in progress for {:?}", key).into()); + } + return Ok(()); + } + + pub fn try_pool(&self, key: &Key) -> Option { + let mut inner = self.inner.lock().unwrap(); + let timeout = inner.timeout; + if let Some(conn) = inner.idle_conns.get(&key) { + // We check first if the connection still valid + // and if not, we remove it from the pool. + if conn.is_invalid() { + trace!("pooled HTTP/3 connection is invalid so removing it..."); + inner.idle_conns.remove(&key); + return None; + } + + if let Some(duration) = timeout { + if Instant::now().saturating_duration_since(conn.idle_timeout) > duration { + trace!("pooled connection expired"); + return None; + } + } + } + + inner + .idle_conns + .get_mut(&key) + .and_then(|conn| Some(conn.pool())) + } + + pub fn new_connection( + &mut self, + key: Key, + mut driver: h3::client::Connection, + tx: SendRequest, + ) -> PoolClient { + let (close_tx, close_rx) = std::sync::mpsc::channel(); + tokio::spawn(async move { + if let Err(e) = future::poll_fn(|cx| driver.poll_close(cx)).await { + trace!("poll_close returned error {:?}", e); + close_tx.send(e).ok(); + } + }); + + let mut inner = self.inner.lock().unwrap(); + + let client = PoolClient::new(tx); + let conn = PoolConnection::new(client.clone(), close_rx); + inner.insert(key.clone(), conn); + + // We clean up "connecting" here so we don't have to acquire the lock again. + let existed = inner.connecting.remove(&key); + debug_assert!(existed, "key not in connecting set"); + + client + } +} + +struct PoolInner { + connecting: HashSet, + idle_conns: HashMap, + timeout: Option, +} + +impl PoolInner { + fn insert(&mut self, key: Key, conn: PoolConnection) { + if self.idle_conns.contains_key(&key) { + trace!("connection already exists for key {:?}", key); + } + + self.idle_conns.insert(key, conn); + } +} + +#[derive(Clone)] +pub struct PoolClient { + inner: SendRequest, +} + +impl PoolClient { + pub fn new(tx: SendRequest) -> Self { + Self { inner: tx } + } + + pub async fn send_request( + &mut self, + req: Request, + ) -> Result, BoxError> { + let (head, req_body) = req.into_parts(); + let req = Request::from_parts(head, ()); + let mut stream = self.inner.send_request(req).await?; + + match req_body.as_bytes() { + Some(b) if !b.is_empty() => { + stream.send_data(Bytes::copy_from_slice(b)).await?; + } + _ => {} + } + + stream.finish().await?; + + let resp = stream.recv_response().await?; + + let mut resp_body = Vec::new(); + while let Some(chunk) = stream.recv_data().await? { + resp_body.extend(chunk.chunk()) + } + + Ok(resp.map(|_| HyperBody::from(resp_body))) + } +} + +pub struct PoolConnection { + // This receives errors from polling h3 driver. + close_rx: Receiver, + client: PoolClient, + idle_timeout: Instant, +} + +impl PoolConnection { + pub fn new(client: PoolClient, close_rx: Receiver) -> Self { + Self { + close_rx, + client, + idle_timeout: Instant::now(), + } + } + + pub fn pool(&mut self) -> PoolClient { + self.idle_timeout = Instant::now(); + self.client.clone() + } + + pub fn is_invalid(&self) -> bool { + match self.close_rx.try_recv() { + Err(TryRecvError::Empty) => false, + Err(TryRecvError::Disconnected) => true, + Ok(_) => true, + } + } +} + +pub(crate) fn extract_domain(uri: &mut Uri) -> Result { + let uri_clone = uri.clone(); + match (uri_clone.scheme(), uri_clone.authority()) { + (Some(scheme), Some(auth)) => Ok((scheme.clone(), auth.clone())), + _ => Err(Error::new(Kind::Request, None::)), + } +} + +pub(crate) fn domain_as_uri((scheme, auth): Key) -> Uri { + http::uri::Builder::new() + .scheme(scheme) + .authority(auth) + .path_and_query("/") + .build() + .expect("domain is valid Uri") +} diff --git a/src/async_impl/mod.rs b/src/async_impl/mod.rs index b69230c91..5d99ef027 100644 --- a/src/async_impl/mod.rs +++ b/src/async_impl/mod.rs @@ -10,6 +10,7 @@ pub(crate) use self::decoder::Decoder; pub mod body; pub mod client; pub mod decoder; +pub mod h3_client; #[cfg(feature = "multipart")] pub mod multipart; pub(crate) mod request; diff --git a/src/tls.rs b/src/tls.rs index db898e84b..b54ffa19e 100644 --- a/src/tls.rs +++ b/src/tls.rs @@ -379,6 +379,8 @@ impl Version { } pub(crate) enum TlsBackend { + // This is the default and HTTP/3 feature does not use it so suppress it. + #[allow(dead_code)] #[cfg(feature = "default-tls")] Default, #[cfg(feature = "native-tls")] @@ -410,12 +412,15 @@ impl fmt::Debug for TlsBackend { impl Default for TlsBackend { fn default() -> TlsBackend { - #[cfg(feature = "default-tls")] + #[cfg(all(feature = "default-tls", not(feature = "http3")))] { TlsBackend::Default } - #[cfg(all(feature = "__rustls", not(feature = "default-tls")))] + #[cfg(any( + all(feature = "__rustls", not(feature = "default-tls")), + feature = "http3" + ))] { TlsBackend::Rustls }