diff --git a/Cargo.toml b/Cargo.toml index cd283bf..fa4a672 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -11,7 +11,8 @@ repository = "https://github.com/ctz/hyper-rustls" [dependencies] log = "0.4.4" -hyper = { version = "0.14", default-features = false, features = ["client", "http1"] } +http = "0.2" +hyper = { version = "0.14", default-features = false, features = ["client"] } rustls = "0.20" rustls-native-certs = { version = "0.6", optional = true } tokio = "1.0" diff --git a/examples/client.rs b/examples/client.rs index fee5eef..aa37eac 100644 --- a/examples/client.rs +++ b/examples/client.rs @@ -51,23 +51,23 @@ async fn run_client() -> io::Result<()> { .map_err(|_| error("failed to load custom CA store".into()))?; let mut roots = RootCertStore::empty(); roots.add_parsable_certificates(&certs); - // Build a TLS client, using the custom CA store for lookups. + // TLS client config using the custom CA store for lookups rustls::ClientConfig::builder() .with_safe_defaults() .with_root_certificates(roots) .with_no_client_auth() } + // Default TLS client config with native roots None => rustls::ClientConfig::builder() .with_safe_defaults() .with_native_roots(), }; - - // Build an HTTP connector which supports HTTPS too. - let mut http = client::HttpConnector::new(); - http.enforce_http(false); - - // Join the above parts into an HTTPS connector. - let https = hyper_rustls::HttpsConnector::from((http, tls)); + // Prepare the HTTPS connector + let https = hyper_rustls::HttpsConnectorBuilder::new() + .with_tls_config(tls) + .https_or_http() + .enable_http1() + .build(); // Build the hyper client from the HTTPS connector. let client: client::Client<_, hyper::Body> = client::Client::builder().build(https); diff --git a/src/connector.rs b/src/connector.rs index 11eb3c3..e78d55d 100644 --- a/src/connector.rs +++ b/src/connector.rs @@ -5,15 +5,14 @@ use std::sync::Arc; use std::task::{Context, Poll}; use std::{fmt, io}; -#[cfg(feature = "tokio-runtime")] -use hyper::client::connect::HttpConnector; use hyper::{client::connect::Connection, service::Service, Uri}; -use rustls::ClientConfig; use tokio::io::{AsyncRead, AsyncWrite}; use tokio_rustls::TlsConnector; use crate::stream::MaybeHttpsStream; +pub mod builder; + type BoxError = Box; /// A Connector for the `https` scheme. @@ -21,20 +20,7 @@ type BoxError = Box; pub struct HttpsConnector { force_https: bool, http: T, - tls_config: Arc, -} - -#[cfg(all( - any(feature = "rustls-native-certs", feature = "webpki-roots"), - feature = "tokio-runtime" -))] -impl HttpsConnector { - /// Force the use of HTTPS when connecting. - /// - /// If a URL is not `https` when connecting, an error is returned. Disabled by default. - pub fn https_only(&mut self, enable: bool) { - self.force_https = enable; - } + tls_config: Arc, } impl fmt::Debug for HttpsConnector { @@ -47,7 +33,7 @@ impl fmt::Debug for HttpsConnector { impl From<(H, C)> for HttpsConnector where - C: Into>, + C: Into>, { fn from((http, cfg): (H, C)) -> Self { HttpsConnector { @@ -81,38 +67,43 @@ where } fn call(&mut self, dst: Uri) -> Self::Future { - let is_https = dst.scheme_str() == Some("https"); + // dst.scheme() would need to derive Eq to be matchable; + // use an if cascade instead + if let Some(sch) = dst.scheme() { + if sch == &http::uri::Scheme::HTTP && !self.force_https { + let connecting_future = self.http.call(dst); - if !is_https && self.force_https { - // Early abort if HTTPS is forced but can't be used - let err = io::Error::new(io::ErrorKind::Other, "https required but URI was not https"); - Box::pin(async move { Err(err.into()) }) - } else if !is_https { - let connecting_future = self.http.call(dst); + let f = async move { + let tcp = connecting_future.await.map_err(Into::into)?; - let f = async move { - let tcp = connecting_future.await.map_err(Into::into)?; + Ok(MaybeHttpsStream::Http(tcp)) + }; + Box::pin(f) + } else if sch == &http::uri::Scheme::HTTPS { + let cfg = self.tls_config.clone(); + let hostname = dst.host().unwrap_or_default().to_string(); + let connecting_future = self.http.call(dst); - Ok(MaybeHttpsStream::Http(tcp)) - }; - Box::pin(f) + let f = async move { + let tcp = connecting_future.await.map_err(Into::into)?; + let connector = TlsConnector::from(cfg); + let dnsname = rustls::ServerName::try_from(hostname.as_str()) + .map_err(|_| io::Error::new(io::ErrorKind::Other, "invalid dnsname"))?; + let tls = connector + .connect(dnsname, tcp) + .await + .map_err(|e| io::Error::new(io::ErrorKind::Other, e))?; + Ok(MaybeHttpsStream::Https(tls)) + }; + Box::pin(f) + } else { + let err = + io::Error::new(io::ErrorKind::Other, format!("Unsupported scheme {}", sch)); + Box::pin(async move { Err(err.into()) }) + } } else { - let cfg = self.tls_config.clone(); - let hostname = dst.host().unwrap_or_default().to_string(); - let connecting_future = self.http.call(dst); - - let f = async move { - let tcp = connecting_future.await.map_err(Into::into)?; - let connector = TlsConnector::from(cfg); - let dnsname = rustls::ServerName::try_from(hostname.as_str()) - .map_err(|_| io::Error::new(io::ErrorKind::Other, "invalid dnsname"))?; - let tls = connector - .connect(dnsname, tcp) - .await - .map_err(|e| io::Error::new(io::ErrorKind::Other, e))?; - Ok(MaybeHttpsStream::Https(tls)) - }; - Box::pin(f) + let err = io::Error::new(io::ErrorKind::Other, "Missing scheme"); + Box::pin(async move { Err(err.into()) }) } } } diff --git a/src/connector/builder.rs b/src/connector/builder.rs new file mode 100644 index 0000000..35eaba0 --- /dev/null +++ b/src/connector/builder.rs @@ -0,0 +1,234 @@ +use rustls::ClientConfig; + +use super::HttpsConnector; +#[cfg(any(feature = "rustls-native-certs", feature = "webpki-roots"))] +use crate::config::ConfigBuilderExt; + +#[cfg(feature = "tokio-runtime")] +use hyper::client::HttpConnector; + +/// A builder for an [HttpsConnector] +/// +/// This makes configuration flexible and explicit and ensures connector +/// features match crate features +/// +/// # Examples +/// +/// ``` +/// use hyper_rustls::HttpsConnectorBuilder; +/// +/// # #[cfg(all(feature = "webpki-roots", feature = "tokio-runtime", feature = "http1"))] +/// let https = HttpsConnectorBuilder::new() +/// .with_webpki_roots() +/// .https_only() +/// .enable_http1() +/// .build(); +/// ``` +pub struct ConnectorBuilder(State); + +/// State of a builder that needs a TLS client config next +pub struct WantsTlsConfig(()); + +/// State of a builder that needs schemes (https:// and http://) to be +/// configured next +pub struct WantsSchemes { + tls_config: ClientConfig, +} + +/// State of a builder that needs to have some protocols (HTTP1 or later) +/// enabled next +/// +/// No protocol has been enabled at this point. +pub struct WantsProtocols1 { + tls_config: ClientConfig, + https_only: bool, +} + +/// State of a builder with HTTP1 enabled, that may have some other +/// protocols (HTTP2 or later) enabled next +/// +/// At this point a connector can be built, see +/// [build](ConnectorBuilder::build) and +/// [wrap_connector](ConnectorBuilder::wrap_connector). +pub struct WantsProtocols2 { + inner: WantsProtocols1, +} + +/// State of a builder with HTTP2 (and possibly HTTP1) enabled +/// +/// At this point a connector can be built, see +/// [build](ConnectorBuilder::build) and +/// [wrap_connector](ConnectorBuilder::wrap_connector). +#[cfg(feature = "http2")] +pub struct WantsProtocols3 { + inner: WantsProtocols1, + // ALPN is built piecemeal without the need to read back this field + #[allow(dead_code)] + enable_http1: bool, +} + +impl ConnectorBuilder { + /// Creates a new [ConnectorBuilder] + pub fn new() -> Self { + Self(WantsTlsConfig(())) + } + + /// Passes a rustls [ClientConfig] to configure the TLS connection + /// + /// The [alpn_protocols](ClientConfig::alpn_protocols) field will be rewritten to + /// match the enabled schemes (see + /// [enable_http1](ConnectorBuilder::enable_http1), + /// [enable_http2](ConnectorBuilder::enable_http2)) before the + /// connector is built. + pub fn with_tls_config(self, config: ClientConfig) -> ConnectorBuilder { + ConnectorBuilder(WantsSchemes { tls_config: config }) + } + + /// Shorthand for using rustls' [safe defaults][with_safe_defaults] + /// and native roots + /// + /// See [ConfigBuilderExt::with_native_roots] + /// + /// [with_safe_defaults]: rustls::ConfigBuilder::with_safe_defaults + #[cfg(feature = "rustls-native-certs")] + #[cfg_attr(docsrs, doc(cfg(feature = "rustls-native-certs")))] + pub fn with_native_roots(self) -> ConnectorBuilder { + self.with_tls_config( + ClientConfig::builder() + .with_safe_defaults() + .with_native_roots(), + ) + } + + /// Shorthand for using rustls' [safe defaults][with_safe_defaults] + /// and Mozilla roots + /// + /// See [ConfigBuilderExt::with_webpki_roots] + /// + /// [with_safe_defaults]: rustls::ConfigBuilder::with_safe_defaults + #[cfg(feature = "webpki-roots")] + #[cfg_attr(docsrs, doc(cfg(feature = "webpki-roots")))] + pub fn with_webpki_roots(self) -> ConnectorBuilder { + self.with_tls_config( + ClientConfig::builder() + .with_safe_defaults() + .with_webpki_roots(), + ) + } +} + +impl Default for ConnectorBuilder { + fn default() -> Self { + Self::new() + } +} + +impl ConnectorBuilder { + /// Enforce the use of HTTPS when connecting + /// + /// Only URLs using the HTTPS scheme will be connectable. + pub fn https_only(self) -> ConnectorBuilder { + ConnectorBuilder(WantsProtocols1 { + tls_config: self.0.tls_config, + https_only: true, + }) + } + + /// Allow both HTTPS and HTTP when connecting + /// + /// HTTPS URLs will be handled through rustls, + /// HTTP URLs will be handled by the lower-level connector. + pub fn https_or_http(self) -> ConnectorBuilder { + ConnectorBuilder(WantsProtocols1 { + tls_config: self.0.tls_config, + https_only: false, + }) + } +} + +impl WantsProtocols1 { + fn wrap_connector(mut self, conn: H) -> HttpsConnector { + self.tls_config.alpn_protocols.clear(); + HttpsConnector { + force_https: self.https_only, + http: conn, + tls_config: std::sync::Arc::new(self.tls_config), + } + } + + #[cfg(feature = "tokio-runtime")] + fn build(self) -> HttpsConnector { + let mut http = HttpConnector::new(); + // HttpConnector won't enforce scheme, but HttpsConnector will + http.enforce_http(false); + self.wrap_connector(http) + } +} + +impl ConnectorBuilder { + /// Enable HTTP1 + /// + /// This needs to be called explicitly, no protocol is enabled by default + #[cfg(feature = "http1")] + pub fn enable_http1(self) -> ConnectorBuilder { + ConnectorBuilder(WantsProtocols2 { inner: self.0 }) + } + + /// Enable HTTP2 + /// + /// This needs to be called explicitly, no protocol is enabled by default + #[cfg(feature = "http2")] + pub fn enable_http2(mut self) -> ConnectorBuilder { + self.0.tls_config.alpn_protocols = vec![b"h2".to_vec()]; + ConnectorBuilder(WantsProtocols3 { + inner: self.0, + enable_http1: false, + }) + } +} + +impl ConnectorBuilder { + /// Enable HTTP2 + /// + /// This needs to be called explicitly, no protocol is enabled by default + #[cfg(feature = "http2")] + pub fn enable_http2(mut self) -> ConnectorBuilder { + self.0.inner.tls_config.alpn_protocols = vec![b"h2".to_vec(), b"http/1.1".to_vec()]; + ConnectorBuilder(WantsProtocols3 { + inner: self.0.inner, + enable_http1: true, + }) + } + + /// This builds an [HttpsConnector] built on hyper's default [HttpConnector] + #[cfg(feature = "tokio-runtime")] + pub fn build(self) -> HttpsConnector { + self.0.inner.build() + } + + /// This wraps an arbitrary low-level connector into an [HttpsConnector] + pub fn wrap_connector(self, conn: H) -> HttpsConnector { + // HTTP1-only, alpn_protocols stays empty + // HttpConnector doesn't have a way to say http1-only; + // its connection pool may still support HTTP2 + // though it won't be used + self.0.inner.wrap_connector(conn) + } +} + +#[cfg(feature = "http2")] +impl ConnectorBuilder { + /// This builds an [HttpsConnector] built on hyper's default [HttpConnector] + #[cfg(feature = "tokio-runtime")] + pub fn build(self) -> HttpsConnector { + self.0.inner.build() + } + + /// This wraps an arbitrary low-level connector into an [HttpsConnector] + pub fn wrap_connector(self, conn: H) -> HttpsConnector { + // If HTTP1 is disabled, we can set http2_only + // on the Client (a higher-level object that uses the connector) + // client.http2_only(!self.0.enable_http1); + self.0.inner.wrap_connector(conn) + } +} diff --git a/src/lib.rs b/src/lib.rs index 2223d22..ad56a04 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -1,27 +1,33 @@ //! # hyper-rustls //! -//! A pure-Rust HTTPS connector for [hyper](https://hyper.rs), based on [Rustls](https://github.com/ctz/rustls). +//! A pure-Rust HTTPS connector for [hyper](https://hyper.rs), based on +//! [Rustls](https://github.com/ctz/rustls). //! //! ## Example //! //! ```no_run -//! # #[cfg(all(feature = "rustls-native-certs", feature = "tokio-runtime"))] +//! # #[cfg(all(feature = "rustls-native-certs", feature = "tokio-runtime", feature = "http1"))] //! # fn main() { //! use hyper::{Body, Client, StatusCode, Uri}; //! //! let mut rt = tokio::runtime::Runtime::new().unwrap(); //! let url = ("https://hyper.rs").parse().unwrap(); -//! let https = hyper_rustls::HttpsConnector::with_native_roots(); +//! let https = hyper_rustls::HttpsConnectorBuilder::new() +//! .with_native_roots() +//! .https_only() +//! .enable_http1() +//! .build(); //! //! let client: Client<_, hyper::Body> = Client::builder().build(https); //! //! let res = rt.block_on(client.get(url)).unwrap(); //! assert_eq!(res.status(), StatusCode::OK); //! # } -//! # #[cfg(not(all(feature = "rustls-native-certs", feature = "tokio-runtime")))] +//! # #[cfg(not(all(feature = "rustls-native-certs", feature = "tokio-runtime", feature = "http1")))] //! # fn main() {} //! ``` +#![warn(missing_docs)] #![cfg_attr(docsrs, feature(doc_cfg))] mod config; @@ -29,5 +35,15 @@ mod connector; mod stream; pub use crate::config::ConfigBuilderExt; +pub use crate::connector::builder::ConnectorBuilder as HttpsConnectorBuilder; pub use crate::connector::HttpsConnector; pub use crate::stream::MaybeHttpsStream; + +/// The various states of the [HttpsConnectorBuilder] +pub mod builderstates { + #[cfg(feature = "http2")] + pub use crate::connector::builder::WantsProtocols3; + pub use crate::connector::builder::{ + WantsProtocols1, WantsProtocols2, WantsSchemes, WantsTlsConfig, + }; +}