From fd294049c784cb52680a423616fadc29d57fa25b Mon Sep 17 00:00:00 2001 From: Sergio Benitez Date: Tue, 19 Dec 2023 14:32:11 -0800 Subject: [PATCH] Update to hyper 1. Enable custom + unix listeners. This commit completely rewrites Rocket's HTTP serving. In addition to significant internal cleanup, this commit introduces the following major features: * Support for custom, external listeners in the `listener` module. The new `listener` module contains new `Bindable`, `Listener`, and `Connection` traits which enable composable, external implementations of connection listeners. Rocket can launch on any `Listener`, or anything that can be used to create a listener (`Bindable`), via a new `launch_on()` method. * Support for Unix domain socket listeners out of the box. The default listener backwards compatibly supports listening on Unix domain sockets. To do so, configure an `address` of `unix:path/to/socket` and optional set `reuse` to `true` (the default) or `false` which controls whether Rocket will handle creating and deleting the unix domain socket. In addition to these new features, this commit makes the following major improvements: * Rocket now depends on hyper 1. * Rocket no longer depends on hyper to handle connections. This allows us to handle more connection failure conditions which results in an overall more robust server with fewer dependencies. * Logic to work around hyper's inability to reference incoming request data in the response results in a 15% performance improvement. * `Client`s can be marked secure with `Client::{un}tracked_secure()`, allowing Rocket to treat local connections as running under TLS. * The `macros` feature of `tokio` is no longer used by Rocket itself. Dependencies can take advantage of this reduction in compile-time cost by disabling the new default feature `tokio-macros`. * A new `TlsConfig::validate()` method allows checking a TLS config. * New `TlsConfig::{certs,key}_reader()`, `MtlsConfig::ca_certs_reader()` methods return `BufReader`s, which allow reading the configured certs and key directly. * A new `NamedFile::open_with()` constructor allows specifying `OpenOptions`. These improvements resulted in the following breaking changes: * The MSRV is now 1.74. * `hyper` is no longer exported from `rocket::http`. * `IoHandler::io` takes `Box` instead of `Pin>`. - Use `Box::into_pin(self)` to recover the previous type. * `Response::upgrade()` now returns an `&mut dyn IoHandler`, not `Pin<& mut _>`. * `Config::{address,port,tls,mtls}` methods have been removed. - Use methods on `Rocket::endpoint()` instead. * `TlsConfig` was moved to `tls::TlsConfig`. * `MutualTls` was renamed and moved to `mtls::MtlsConfig`. * `ErrorKind::TlsBind` was removed. * The second field of `ErrorKind::Shutdown` was removed. * `{Local}Request::{set_}remote()` methods take/return an `Endpoint`. * `Client::new()` was removed; it was previously deprecated. Internally, the following major changes were made: * A new `async_bound` attribute macro was introduced to allow setting bounds on futures returned by `async fn`s in traits while maintaining good docs. * All utility functionality was moved to a new `util` module. Resolves #2671. Resolves #1070. --- contrib/ws/src/duplex.rs | 1 - contrib/ws/src/websocket.rs | 19 +- core/codegen/src/attribute/async_bound/mod.rs | 61 ++ core/codegen/src/attribute/mod.rs | 1 + core/codegen/src/attribute/route/mod.rs | 2 +- core/codegen/src/http_codegen.rs | 22 +- core/codegen/src/lib.rs | 7 + core/http/Cargo.toml | 23 +- core/http/src/header/header.rs | 3 +- core/http/src/hyper.rs | 35 - core/http/src/lib.rs | 13 +- core/http/src/listener.rs | 257 ------- core/http/src/method.rs | 20 +- core/http/src/tls/listener.rs | 235 ------ core/http/src/tls/mod.rs | 11 - core/lib/Cargo.toml | 66 +- core/lib/src/config/config.rs | 95 +-- core/lib/src/config/mod.rs | 273 +------ core/lib/src/config/secret_key.rs | 2 +- core/lib/src/config/shutdown.rs | 10 +- core/lib/src/data/data_stream.rs | 42 +- core/lib/src/data/io_stream.rs | 19 +- core/lib/src/data/transform.rs | 24 +- core/lib/src/erased.rs | 193 +++++ core/lib/src/error.rs | 98 ++- core/lib/src/ext.rs | 404 ----------- core/lib/src/form/mod.rs | 5 +- core/lib/src/fs/named_file.rs | 9 +- core/lib/src/{ => http}/cookies.rs | 17 +- core/lib/src/http/mod.rs | 12 + core/lib/src/lib.rs | 114 +-- core/lib/src/lifecycle.rs | 272 +++++++ core/lib/src/listener/bindable.rs | 40 + core/lib/src/listener/bounced.rs | 58 ++ core/lib/src/listener/cancellable.rs | 273 +++++++ core/lib/src/listener/connection.rs | 93 +++ core/lib/src/listener/default.rs | 61 ++ core/lib/src/listener/endpoint.rs | 281 +++++++ core/lib/src/listener/listener.rs | 65 ++ core/lib/src/listener/mod.rs | 24 + core/lib/src/listener/tcp.rs | 43 ++ core/lib/src/listener/tls.rs | 116 +++ core/lib/src/listener/unix.rs | 107 +++ core/lib/src/local/asynchronous/client.rs | 13 +- core/lib/src/local/asynchronous/request.rs | 6 +- core/lib/src/local/asynchronous/response.rs | 12 +- core/lib/src/local/blocking/client.rs | 6 +- core/lib/src/local/blocking/request.rs | 2 +- core/lib/src/local/client.rs | 22 +- core/lib/src/local/request.rs | 32 +- core/lib/src/mtls.rs | 25 - .../mtls.rs => lib/src/mtls/certificate.rs} | 307 +------- core/lib/src/mtls/config.rs | 212 ++++++ core/lib/src/mtls/error.rs | 74 ++ core/lib/src/mtls/mod.rs | 56 ++ core/lib/src/mtls/name.rs | 146 ++++ core/lib/src/phase.rs | 2 + core/lib/src/request/atomic_method.rs | 43 ++ core/lib/src/request/from_request.rs | 21 +- core/lib/src/request/mod.rs | 2 + core/lib/src/request/request.rs | 181 +++-- core/lib/src/request/tests.rs | 11 +- core/lib/src/response/response.rs | 43 +- core/lib/src/response/stream/sse.rs | 33 +- core/lib/src/rocket.rs | 86 ++- core/lib/src/route/handler.rs | 1 - core/lib/src/server.rs | 684 ++++-------------- core/lib/src/shield/shield.rs | 2 +- core/lib/src/shutdown.rs | 2 +- core/lib/src/{config/tls.rs => tls/config.rs} | 592 +++++++-------- core/{http => lib}/src/tls/error.rs | 3 + core/lib/src/tls/mod.rs | 7 + core/{http => lib}/src/tls/util.rs | 0 core/lib/src/util/chain.rs | 52 ++ core/lib/src/util/join.rs | 77 ++ core/lib/src/util/mod.rs | 12 + core/lib/src/util/reader_stream.rs | 124 ++++ .../src/{trip_wire.rs => util/tripwire.rs} | 0 core/lib/src/util/unix.rs | 25 + core/lib/tests/can-launch-tls.rs | 8 +- .../on_launch_fairing_can_inspect_port.rs | 8 +- core/lib/tests/sentinel.rs | 2 +- core/lib/tests/tls-config-from-source-1503.rs | 13 +- examples/config/src/tests.rs | 4 - examples/hello/src/main.rs | 11 - examples/tls/src/redirector.rs | 67 +- examples/tls/src/tests.rs | 71 +- examples/upgrade/static/index.html | 2 +- scripts/mk-docs.sh | 6 +- scripts/test.sh | 3 +- 90 files changed, 3630 insertions(+), 3007 deletions(-) create mode 100644 core/codegen/src/attribute/async_bound/mod.rs delete mode 100644 core/http/src/hyper.rs delete mode 100644 core/http/src/listener.rs delete mode 100644 core/http/src/tls/listener.rs delete mode 100644 core/http/src/tls/mod.rs create mode 100644 core/lib/src/erased.rs delete mode 100644 core/lib/src/ext.rs rename core/lib/src/{ => http}/cookies.rs (97%) create mode 100644 core/lib/src/http/mod.rs create mode 100644 core/lib/src/lifecycle.rs create mode 100644 core/lib/src/listener/bindable.rs create mode 100644 core/lib/src/listener/bounced.rs create mode 100644 core/lib/src/listener/cancellable.rs create mode 100644 core/lib/src/listener/connection.rs create mode 100644 core/lib/src/listener/default.rs create mode 100644 core/lib/src/listener/endpoint.rs create mode 100644 core/lib/src/listener/listener.rs create mode 100644 core/lib/src/listener/mod.rs create mode 100644 core/lib/src/listener/tcp.rs create mode 100644 core/lib/src/listener/tls.rs create mode 100644 core/lib/src/listener/unix.rs delete mode 100644 core/lib/src/mtls.rs rename core/{http/src/tls/mtls.rs => lib/src/mtls/certificate.rs} (50%) create mode 100644 core/lib/src/mtls/config.rs create mode 100644 core/lib/src/mtls/error.rs create mode 100644 core/lib/src/mtls/mod.rs create mode 100644 core/lib/src/mtls/name.rs create mode 100644 core/lib/src/request/atomic_method.rs rename core/lib/src/{config/tls.rs => tls/config.rs} (56%) rename core/{http => lib}/src/tls/error.rs (94%) create mode 100644 core/lib/src/tls/mod.rs rename core/{http => lib}/src/tls/util.rs (100%) create mode 100644 core/lib/src/util/chain.rs create mode 100644 core/lib/src/util/join.rs create mode 100644 core/lib/src/util/mod.rs create mode 100644 core/lib/src/util/reader_stream.rs rename core/lib/src/{trip_wire.rs => util/tripwire.rs} (100%) create mode 100644 core/lib/src/util/unix.rs diff --git a/contrib/ws/src/duplex.rs b/contrib/ws/src/duplex.rs index 76b5eac289..04da6160ec 100644 --- a/contrib/ws/src/duplex.rs +++ b/contrib/ws/src/duplex.rs @@ -33,7 +33,6 @@ use crate::result::{Result, Error}; /// /// [`StreamExt`]: rocket::futures::StreamExt /// [`SinkExt`]: rocket::futures::SinkExt - pub struct DuplexStream(tokio_tungstenite::WebSocketStream); impl DuplexStream { diff --git a/contrib/ws/src/websocket.rs b/contrib/ws/src/websocket.rs index 63414a111c..662cbe6574 100644 --- a/contrib/ws/src/websocket.rs +++ b/contrib/ws/src/websocket.rs @@ -1,5 +1,4 @@ use std::io; -use std::pin::Pin; use rocket::data::{IoHandler, IoStream}; use rocket::futures::{self, StreamExt, SinkExt, future::BoxFuture, stream::SplitStream}; @@ -37,10 +36,6 @@ pub struct WebSocket { } impl WebSocket { - fn new(key: String) -> WebSocket { - WebSocket { config: Config::default(), key } - } - /// Change the default connection configuration to `config`. /// /// # Example @@ -202,7 +197,9 @@ impl<'r> FromRequest<'r> for WebSocket { let is_13 = headers.get_one("Sec-WebSocket-Version").map_or(false, |v| v == "13"); let key = headers.get_one("Sec-WebSocket-Key").map(|k| derive_accept_key(k.as_bytes())); match key { - Some(key) if is_upgrade && is_ws && is_13 => Outcome::Success(WebSocket::new(key)), + Some(key) if is_upgrade && is_ws && is_13 => { + Outcome::Success(WebSocket { key, config: Config::default() }) + }, Some(_) | None => Outcome::Forward(Status::BadRequest) } } @@ -232,9 +229,9 @@ impl<'r, 'o: 'r, S> Responder<'r, 'o> for MessageStream<'o, S> #[rocket::async_trait] impl IoHandler for Channel<'_> { - async fn io(self: Pin>, io: IoStream) -> io::Result<()> { - let channel = Pin::into_inner(self); - let result = (channel.handler)(DuplexStream::new(io, channel.ws.config).await).await; + async fn io(self: Box, io: IoStream) -> io::Result<()> { + let stream = DuplexStream::new(io, self.ws.config).await; + let result = (self.handler)(stream).await; handle_result(result).map(|_| ()) } } @@ -243,9 +240,9 @@ impl IoHandler for Channel<'_> { impl<'r, S> IoHandler for MessageStream<'r, S> where S: futures::Stream> + Send + 'r { - async fn io(self: Pin>, io: IoStream) -> io::Result<()> { + async fn io(self: Box, io: IoStream) -> io::Result<()> { let (mut sink, source) = DuplexStream::new(io, self.ws.config).await.split(); - let stream = (Pin::into_inner(self).handler)(source); + let stream = (self.handler)(source); rocket::tokio::pin!(stream); while let Some(msg) = stream.next().await { let result = match msg { diff --git a/core/codegen/src/attribute/async_bound/mod.rs b/core/codegen/src/attribute/async_bound/mod.rs new file mode 100644 index 0000000000..d58f0afe9f --- /dev/null +++ b/core/codegen/src/attribute/async_bound/mod.rs @@ -0,0 +1,61 @@ +use proc_macro2::{TokenStream, Span}; +use devise::{Spanned, Result, ext::SpanDiagnosticExt}; +use syn::{Token, parse_quote, parse_quote_spanned}; +use syn::{TraitItemFn, TypeParamBound, ReturnType, Attribute}; +use syn::punctuated::Punctuated; +use syn::parse::Parser; + +fn _async_bound( + args: proc_macro::TokenStream, + input: proc_macro::TokenStream +) -> Result { + let bounds = >::parse_terminated.parse(args)?; + if bounds.is_empty() { + return Ok(input.into()); + } + + let mut func: TraitItemFn = syn::parse(input)?; + let original: TraitItemFn = func.clone(); + if !func.sig.asyncness.is_some() { + let diag = Span::call_site() + .error("attribute can only be applied to async fns") + .span_help(func.sig.span(), "this fn declaration must be `async`"); + + return Err(diag); + } + + let doc: Attribute = parse_quote! { + #[doc = concat!( + "# Future Bounds", + "\n", + "**The `Future` generated by this `async fn` must be `", stringify!(#bounds), "`**." + )] + }; + + func.sig.asyncness = None; + func.sig.output = match func.sig.output { + ReturnType::Type(arrow, ty) => parse_quote_spanned!(ty.span() => + #arrow impl ::core::future::Future + #bounds + ), + default@ReturnType::Default => default + }; + + Ok(quote! { + #[cfg(all(not(doc), rust_analyzer))] + #original + + #[cfg(all(doc, not(rust_analyzer)))] + #doc + #original + + #[cfg(not(any(doc, rust_analyzer)))] + #func + }) +} + +pub fn async_bound( + args: proc_macro::TokenStream, + input: proc_macro::TokenStream +) -> TokenStream { + _async_bound(args, input).unwrap_or_else(|d| d.emit_as_item_tokens()) +} diff --git a/core/codegen/src/attribute/mod.rs b/core/codegen/src/attribute/mod.rs index 4d06591df9..c851bebcd2 100644 --- a/core/codegen/src/attribute/mod.rs +++ b/core/codegen/src/attribute/mod.rs @@ -2,3 +2,4 @@ pub mod entry; pub mod catch; pub mod route; pub mod param; +pub mod async_bound; diff --git a/core/codegen/src/attribute/route/mod.rs b/core/codegen/src/attribute/route/mod.rs index 6e29a401c0..dbf28a5180 100644 --- a/core/codegen/src/attribute/route/mod.rs +++ b/core/codegen/src/attribute/route/mod.rs @@ -331,7 +331,7 @@ fn codegen_route(route: Route) -> Result { let internal_uri_macro = internal_uri_macro_decl(&route); let responder_outcome = responder_outcome_expr(&route); - let method = route.attr.method; + let method = &route.attr.method; let uri = route.attr.uri.to_string(); let rank = Optional(route.attr.rank); let format = Optional(route.attr.format.as_ref()); diff --git a/core/codegen/src/http_codegen.rs b/core/codegen/src/http_codegen.rs index 8d021a38d6..35a5cb9a8d 100644 --- a/core/codegen/src/http_codegen.rs +++ b/core/codegen/src/http_codegen.rs @@ -13,7 +13,7 @@ pub struct Status(pub http::Status); #[derive(Debug)] pub struct MediaType(pub http::MediaType); -#[derive(Debug, Copy, Clone)] +#[derive(Debug, Clone)] pub struct Method(pub http::Method); #[derive(Clone, Debug)] @@ -108,7 +108,7 @@ const VALID_METHODS: &[http::Method] = &[ impl FromMeta for Method { fn from_meta(meta: &MetaItem) -> Result { let span = meta.value_span(); - let help_text = format!("method must be one of: {}", VALID_METHODS_STR); + let help_text = format!("method must be one of: {VALID_METHODS_STR}"); if let MetaItem::Path(path) = meta { if let Some(ident) = path.last_ident() { @@ -131,19 +131,13 @@ impl FromMeta for Method { impl ToTokens for Method { fn to_tokens(&self, tokens: &mut TokenStream) { - let method_tokens = match self.0 { - http::Method::Get => quote!(::rocket::http::Method::Get), - http::Method::Put => quote!(::rocket::http::Method::Put), - http::Method::Post => quote!(::rocket::http::Method::Post), - http::Method::Delete => quote!(::rocket::http::Method::Delete), - http::Method::Options => quote!(::rocket::http::Method::Options), - http::Method::Head => quote!(::rocket::http::Method::Head), - http::Method::Trace => quote!(::rocket::http::Method::Trace), - http::Method::Connect => quote!(::rocket::http::Method::Connect), - http::Method::Patch => quote!(::rocket::http::Method::Patch), - }; + let mut chars = self.0.as_str().chars(); + let variant_str = chars.next() + .map(|c| c.to_ascii_uppercase().to_string() + &chars.as_str().to_lowercase()) + .unwrap_or_default(); - tokens.extend(method_tokens); + let variant = syn::Ident::new(&variant_str, Span::call_site()); + tokens.extend(quote!(::rocket::http::Method::#variant)); } } diff --git a/core/codegen/src/lib.rs b/core/codegen/src/lib.rs index c375351f1c..aab19f3049 100644 --- a/core/codegen/src/lib.rs +++ b/core/codegen/src/lib.rs @@ -1497,3 +1497,10 @@ pub fn internal_guide_tests(input: TokenStream) -> TokenStream { pub fn export(input: TokenStream) -> TokenStream { emit!(bang::export_internal(input)) } + +/// Private Rocket attribute: `async_bound(Bounds + On + Returned + Future)`. +#[doc(hidden)] +#[proc_macro_attribute] +pub fn async_bound(args: TokenStream, input: TokenStream) -> TokenStream { + emit!(attribute::async_bound::async_bound(args, input)) +} diff --git a/core/http/Cargo.toml b/core/http/Cargo.toml index 6c29fa1762..c5f6a309d5 100644 --- a/core/http/Cargo.toml +++ b/core/http/Cargo.toml @@ -17,43 +17,22 @@ rust-version = "1.64" [features] default = [] -tls = ["rustls", "tokio-rustls", "rustls-pemfile"] -mtls = ["tls", "x509-parser"] -http2 = ["hyper/http2"] -private-cookies = ["cookie/private", "cookie/key-expansion"] serde = ["uncased/with-serde-alloc", "serde_"] uuid = ["uuid_"] [dependencies] smallvec = { version = "1.11", features = ["const_generics", "const_new"] } percent-encoding = "2" -http = "0.2" time = { version = "0.3", features = ["formatting", "macros"] } indexmap = "2" -rustls = { version = "0.22", optional = true } -tokio-rustls = { version = "0.25", optional = true } -rustls-pemfile = { version = "2.0.0", optional = true } -tokio = { version = "1.6.1", features = ["net", "sync", "time"] } -log = "0.4" ref-cast = "1.0" -uncased = "0.9.6" +uncased = "0.9.10" either = "1" pear = "0.2.8" -pin-project-lite = "0.2" memchr = "2" stable-pattern = "0.1" cookie = { version = "0.18", features = ["percent-encode"] } state = "0.6" -futures = { version = "0.3", default-features = false } - -[dependencies.x509-parser] -version = "0.13" -optional = true - -[dependencies.hyper] -version = "0.14.9" -default-features = false -features = ["http1", "runtime", "server", "stream"] [dependencies.serde_] package = "serde" diff --git a/core/http/src/header/header.rs b/core/http/src/header/header.rs index e51be2097d..8a76b6a3f3 100644 --- a/core/http/src/header/header.rs +++ b/core/http/src/header/header.rs @@ -745,8 +745,7 @@ impl<'h> HeaderMap<'h> { /// WARNING: This is unstable! Do not use this method outside of Rocket! #[doc(hidden)] #[inline] - pub fn into_iter_raw(self) - -> impl Iterator, Vec>)> { + pub fn into_iter_raw(self) -> impl Iterator, Vec>)> { self.headers.into_iter() } } diff --git a/core/http/src/hyper.rs b/core/http/src/hyper.rs deleted file mode 100644 index 2e98e1f01f..0000000000 --- a/core/http/src/hyper.rs +++ /dev/null @@ -1,35 +0,0 @@ -//! Re-exported hyper HTTP library types. -//! -//! All types that are re-exported from Hyper reside inside of this module. -//! These types will, with certainty, be removed with time, but they reside here -//! while necessary. - -pub use hyper::{Method, Error, Body, Uri, Version, Request, Response}; -pub use hyper::{body, server, service, upgrade}; -pub use http::{HeaderValue, request, uri}; - -/// Reexported Hyper HTTP header types. -pub mod header { - macro_rules! import_http_headers { - ($($name:ident),*) => ($( - pub use hyper::header::$name as $name; - )*) - } - - import_http_headers! { - ACCEPT, ACCEPT_CHARSET, ACCEPT_ENCODING, ACCEPT_LANGUAGE, ACCEPT_RANGES, - ACCESS_CONTROL_ALLOW_CREDENTIALS, ACCESS_CONTROL_ALLOW_HEADERS, - ACCESS_CONTROL_ALLOW_METHODS, ACCESS_CONTROL_ALLOW_ORIGIN, - ACCESS_CONTROL_EXPOSE_HEADERS, ACCESS_CONTROL_MAX_AGE, - ACCESS_CONTROL_REQUEST_HEADERS, ACCESS_CONTROL_REQUEST_METHOD, ALLOW, - AUTHORIZATION, CACHE_CONTROL, CONNECTION, CONTENT_DISPOSITION, - CONTENT_ENCODING, CONTENT_LANGUAGE, CONTENT_LENGTH, CONTENT_LOCATION, - CONTENT_RANGE, CONTENT_SECURITY_POLICY, - CONTENT_SECURITY_POLICY_REPORT_ONLY, CONTENT_TYPE, DATE, ETAG, EXPECT, - EXPIRES, FORWARDED, FROM, HOST, IF_MATCH, IF_MODIFIED_SINCE, - IF_NONE_MATCH, IF_RANGE, IF_UNMODIFIED_SINCE, LAST_MODIFIED, LINK, - LOCATION, ORIGIN, PRAGMA, RANGE, REFERER, REFERRER_POLICY, REFRESH, - STRICT_TRANSPORT_SECURITY, TE, TRANSFER_ENCODING, UPGRADE, USER_AGENT, - VARY - } -} diff --git a/core/http/src/lib.rs b/core/http/src/lib.rs index 7ab89758d6..86935cce74 100644 --- a/core/http/src/lib.rs +++ b/core/http/src/lib.rs @@ -4,15 +4,11 @@ //! Types that map to concepts in HTTP. //! //! This module exports types that map to HTTP concepts or to the underlying -//! HTTP library when needed. Because the underlying HTTP library is likely to -//! change (see [#17]), types in [`hyper`] should be considered unstable. -//! -//! [#17]: https://github.com/rwf2/Rocket/issues/17 +//! HTTP library when needed. #[macro_use] extern crate pear; -pub mod hyper; pub mod uri; pub mod ext; @@ -22,7 +18,6 @@ mod method; mod status; mod raw_str; mod parse; -mod listener; /// Case-preserving, ASCII case-insensitive string types. /// @@ -39,14 +34,8 @@ pub mod uncased { pub mod private { pub use crate::parse::Indexed; pub use smallvec::{SmallVec, Array}; - pub use crate::listener::{TcpListener, Incoming, Listener, Connection, Certificates}; - pub use cookie; } -#[doc(hidden)] -#[cfg(feature = "tls")] -pub mod tls; - pub use crate::method::Method; pub use crate::status::{Status, StatusClass}; pub use crate::raw_str::{RawStr, RawStrBuf}; diff --git a/core/http/src/listener.rs b/core/http/src/listener.rs deleted file mode 100644 index 956c8ec4a2..0000000000 --- a/core/http/src/listener.rs +++ /dev/null @@ -1,257 +0,0 @@ -use std::fmt; -use std::future::Future; -use std::io; -use std::net::SocketAddr; -use std::pin::Pin; -use std::task::{Context, Poll}; -use std::time::Duration; -use std::sync::Arc; - -use log::warn; -use tokio::time::Sleep; -use tokio::io::{AsyncRead, AsyncWrite}; -use tokio::net::TcpStream; -use hyper::server::accept::Accept; -use state::InitCell; - -pub use tokio::net::TcpListener; - -/// A thin wrapper over raw, DER-encoded X.509 client certificate data. -#[cfg(not(feature = "tls"))] -#[derive(Debug, Clone, Eq, PartialEq)] -pub struct CertificateDer(pub(crate) Vec); - -/// A thin wrapper over raw, DER-encoded X.509 client certificate data. -#[cfg(feature = "tls")] -#[derive(Debug, Clone, Eq, PartialEq)] -#[repr(transparent)] -pub struct CertificateDer(pub(crate) rustls::pki_types::CertificateDer<'static>); - -/// A collection of raw certificate data. -#[derive(Clone, Default)] -pub struct Certificates(Arc>>); - -impl From> for Certificates { - fn from(value: Vec) -> Self { - Certificates(Arc::new(value.into())) - } -} - -#[cfg(feature = "tls")] -impl From>> for Certificates { - fn from(value: Vec>) -> Self { - let value: Vec<_> = value.into_iter().map(CertificateDer).collect(); - Certificates(Arc::new(value.into())) - } -} - -#[doc(hidden)] -impl Certificates { - /// Set the the raw certificate chain data. Only the first call actually - /// sets the data; the remaining do nothing. - #[cfg(feature = "tls")] - pub(crate) fn set(&self, data: Vec) { - self.0.set(data); - } - - /// Returns the raw certificate chain data, if any is available. - pub fn chain_data(&self) -> Option<&[CertificateDer]> { - self.0.try_get().map(|v| v.as_slice()) - } -} - -// TODO.async: 'Listener' and 'Connection' provide common enough functionality -// that they could be introduced in upstream libraries. -/// A 'Listener' yields incoming connections -pub trait Listener { - /// The connection type returned by this listener. - type Connection: Connection; - - /// Return the actual address this listener bound to. - fn local_addr(&self) -> Option; - - /// Try to accept an incoming Connection if ready. This should only return - /// an `Err` when a fatal problem occurs as Hyper kills the server on `Err`. - fn poll_accept( - self: Pin<&mut Self>, - cx: &mut Context<'_> - ) -> Poll>; -} - -/// A 'Connection' represents an open connection to a client -pub trait Connection: AsyncRead + AsyncWrite { - /// The remote address, i.e. the client's socket address, if it is known. - fn peer_address(&self) -> Option; - - /// Requests that the connection not delay reading or writing data as much - /// as possible. For connections backed by TCP, this corresponds to setting - /// `TCP_NODELAY`. - fn enable_nodelay(&self) -> io::Result<()>; - - /// DER-encoded X.509 certificate chain presented by the client, if any. - /// - /// The certificate order must be as it appears in the TLS protocol: the - /// first certificate relates to the peer, the second certifies the first, - /// the third certifies the second, and so on. - /// - /// Defaults to an empty vector to indicate that no certificates were - /// presented. - fn peer_certificates(&self) -> Option { None } -} - -pin_project_lite::pin_project! { - /// This is a generic version of hyper's AddrIncoming that is intended to be - /// usable with listeners other than a plain TCP stream, e.g. TLS and/or Unix - /// sockets. It does so by bridging the `Listener` trait to what hyper wants (an - /// Accept). This type is internal to Rocket. - #[must_use = "streams do nothing unless polled"] - pub struct Incoming { - sleep_on_errors: Option, - nodelay: bool, - #[pin] - pending_error_delay: Option, - #[pin] - listener: L, - } -} - -impl Incoming { - /// Construct an `Incoming` from an existing `Listener`. - pub fn new(listener: L) -> Self { - Self { - listener, - sleep_on_errors: Some(Duration::from_millis(250)), - pending_error_delay: None, - nodelay: false, - } - } - - /// Set whether and how long to sleep on accept errors. - /// - /// A possible scenario is that the process has hit the max open files - /// allowed, and so trying to accept a new connection will fail with - /// `EMFILE`. In some cases, it's preferable to just wait for some time, if - /// the application will likely close some files (or connections), and try - /// to accept the connection again. If this option is `true`, the error - /// will be logged at the `error` level, since it is still a big deal, - /// and then the listener will sleep for 1 second. - /// - /// In other cases, hitting the max open files should be treat similarly - /// to being out-of-memory, and simply error (and shutdown). Setting - /// this option to `None` will allow that. - /// - /// Default is 1 second. - pub fn sleep_on_errors(mut self, val: Option) -> Self { - self.sleep_on_errors = val; - self - } - - /// Set whether to request no delay on all incoming connections. The default - /// is `false`. See [`Connection::enable_nodelay()`] for details. - pub fn nodelay(mut self, nodelay: bool) -> Self { - self.nodelay = nodelay; - self - } - - fn poll_accept_next( - self: Pin<&mut Self>, - cx: &mut Context<'_> - ) -> Poll> { - /// This function defines per-connection errors: errors that affect only - /// a single connection's accept() and don't imply anything about the - /// success probability of the next accept(). Thus, we can attempt to - /// `accept()` another connection immediately. All other errors will - /// incur a delay before the next `accept()` is performed. The delay is - /// useful to handle resource exhaustion errors like ENFILE and EMFILE. - /// Otherwise, could enter into tight loop. - fn is_connection_error(e: &io::Error) -> bool { - matches!(e.kind(), - | io::ErrorKind::ConnectionRefused - | io::ErrorKind::ConnectionAborted - | io::ErrorKind::ConnectionReset) - } - - let mut this = self.project(); - loop { - // Check if a previous sleep timer is active, set on I/O errors. - if let Some(delay) = this.pending_error_delay.as_mut().as_pin_mut() { - futures::ready!(delay.poll(cx)); - } - - this.pending_error_delay.set(None); - - match futures::ready!(this.listener.as_mut().poll_accept(cx)) { - Ok(stream) => { - if *this.nodelay { - if let Err(e) = stream.enable_nodelay() { - warn!("failed to enable NODELAY: {}", e); - } - } - - return Poll::Ready(Ok(stream)); - }, - Err(e) => { - if is_connection_error(&e) { - warn!("single connection accept error {}; accepting next now", e); - } else if let Some(duration) = this.sleep_on_errors { - // We might be able to recover. Try again in a bit. - warn!("accept error {}; recovery attempt in {}ms", e, duration.as_millis()); - this.pending_error_delay.set(Some(tokio::time::sleep(*duration))); - } else { - return Poll::Ready(Err(e)); - } - }, - } - } - } -} - -impl Accept for Incoming { - type Conn = L::Connection; - type Error = io::Error; - - #[inline] - fn poll_accept( - self: Pin<&mut Self>, - cx: &mut Context<'_> - ) -> Poll>> { - self.poll_accept_next(cx).map(Some) - } -} - -impl fmt::Debug for Incoming { - fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - f.debug_struct("Incoming") - .field("listener", &self.listener) - .finish() - } -} - -impl Listener for TcpListener { - type Connection = TcpStream; - - #[inline] - fn local_addr(&self) -> Option { - self.local_addr().ok() - } - - #[inline] - fn poll_accept( - self: Pin<&mut Self>, - cx: &mut Context<'_> - ) -> Poll> { - (*self).poll_accept(cx).map_ok(|(stream, _addr)| stream) - } -} - -impl Connection for TcpStream { - #[inline] - fn peer_address(&self) -> Option { - self.peer_addr().ok() - } - - #[inline] - fn enable_nodelay(&self) -> io::Result<()> { - self.set_nodelay(true) - } -} diff --git a/core/http/src/method.rs b/core/http/src/method.rs index 734d7681f7..a959fbc203 100644 --- a/core/http/src/method.rs +++ b/core/http/src/method.rs @@ -3,8 +3,6 @@ use std::str::FromStr; use self::Method::*; -use crate::hyper; - // TODO: Support non-standard methods, here and in codegen? /// Representation of HTTP methods. @@ -29,6 +27,7 @@ use crate::hyper; /// } /// # } /// ``` +#[repr(u8)] #[derive(Clone, Copy, PartialEq, Eq, Hash, Debug)] pub enum Method { /// The `GET` variant. @@ -52,23 +51,6 @@ pub enum Method { } impl Method { - /// WARNING: This is unstable! Do not use this method outside of Rocket! - #[doc(hidden)] - pub fn from_hyp(method: &hyper::Method) -> Option { - match *method { - hyper::Method::GET => Some(Get), - hyper::Method::PUT => Some(Put), - hyper::Method::POST => Some(Post), - hyper::Method::DELETE => Some(Delete), - hyper::Method::OPTIONS => Some(Options), - hyper::Method::HEAD => Some(Head), - hyper::Method::TRACE => Some(Trace), - hyper::Method::CONNECT => Some(Connect), - hyper::Method::PATCH => Some(Patch), - _ => None, - } - } - /// Returns `true` if an HTTP request with the method represented by `self` /// always supports a payload. /// diff --git a/core/http/src/tls/listener.rs b/core/http/src/tls/listener.rs deleted file mode 100644 index 7ef76ebd8d..0000000000 --- a/core/http/src/tls/listener.rs +++ /dev/null @@ -1,235 +0,0 @@ -use std::io; -use std::pin::Pin; -use std::sync::Arc; -use std::task::{Context, Poll}; -use std::future::Future; -use std::net::SocketAddr; - -use tokio::net::{TcpListener, TcpStream}; -use tokio::io::{AsyncRead, AsyncWrite}; -use tokio_rustls::{Accept, TlsAcceptor, server::TlsStream as BareTlsStream}; -use rustls::server::{ServerSessionMemoryCache, ServerConfig, WebPkiClientVerifier}; - -use crate::tls::util::{load_cert_chain, load_key, load_ca_certs}; -use crate::listener::{Connection, Listener, Certificates, CertificateDer}; - -/// A TLS listener over TCP. -pub struct TlsListener { - listener: TcpListener, - acceptor: TlsAcceptor, -} - -/// This implementation exists so that ROCKET_WORKERS=1 can make progress while -/// a TLS handshake is being completed. It does this by returning `Ready` from -/// `poll_accept()` as soon as we have a TCP connection and performing the -/// handshake in the `AsyncRead` and `AsyncWrite` implementations. -/// -/// A straight-forward implementation of this strategy results in none of the -/// TLS information being available at the time the connection is "established", -/// that is, when `poll_accept()` returns, since the handshake has yet to occur. -/// Importantly, certificate information isn't available at the time that we -/// request it. -/// -/// The underlying problem is hyper's "Accept" trait. Were we to manage -/// connections ourselves, we'd likely want to: -/// -/// 1. Stop blocking the worker as soon as we have a TCP connection. -/// 2. Perform the handshake in the background. -/// 3. Give the connection to Rocket when/if the handshake is done. -/// -/// See hyperium/hyper/issues/2321 for more details. -/// -/// To work around this, we "lie" when `peer_certificates()` are requested and -/// always return `Some(Certificates)`. Internally, `Certificates` is an -/// `Arc>>`, effectively a shared, thread-safe, -/// `OnceCell`. The cell is initially empty and is filled as soon as the -/// handshake is complete. If the certificate data were to be requested prior to -/// this point, it would be empty. However, in Rocket, we only request -/// certificate data when we have a `Request` object, which implies we're -/// receiving payload data, which implies the TLS handshake has finished, so the -/// certificate data as seen by a Rocket application will always be "fresh". -pub struct TlsStream { - remote: SocketAddr, - state: TlsState, - certs: Certificates, -} - -/// State of `TlsStream`. -pub enum TlsState { - /// The TLS handshake is taking place. We don't have a full connection yet. - Handshaking(Accept), - /// TLS handshake completed successfully; we're getting payload data. - Streaming(BareTlsStream), -} - -/// TLS as ~configured by `TlsConfig` in `rocket` core. -pub struct Config { - pub cert_chain: R, - pub private_key: R, - pub ciphersuites: Vec, - pub prefer_server_order: bool, - pub ca_certs: Option, - pub mandatory_mtls: bool, -} - -impl TlsListener { - pub async fn bind(addr: SocketAddr, mut c: Config) -> crate::tls::Result - where R: io::BufRead - { - let provider = rustls::crypto::CryptoProvider { - cipher_suites: c.ciphersuites, - ..rustls::crypto::ring::default_provider() - }; - - let verifier = match c.ca_certs { - Some(ref mut ca_certs) => { - let ca_roots = Arc::new(load_ca_certs(ca_certs)?); - let verifier = WebPkiClientVerifier::builder(ca_roots); - match c.mandatory_mtls { - true => verifier.build()?, - false => verifier.allow_unauthenticated().build()?, - } - }, - None => WebPkiClientVerifier::no_client_auth(), - }; - - let key = load_key(&mut c.private_key)?; - let cert_chain = load_cert_chain(&mut c.cert_chain)?; - let mut config = ServerConfig::builder_with_provider(Arc::new(provider)) - .with_safe_default_protocol_versions()? - .with_client_cert_verifier(verifier) - .with_single_cert(cert_chain, key)?; - - config.ignore_client_order = c.prefer_server_order; - config.session_storage = ServerSessionMemoryCache::new(1024); - config.ticketer = rustls::crypto::ring::Ticketer::new()?; - config.alpn_protocols = vec![b"http/1.1".to_vec()]; - if cfg!(feature = "http2") { - config.alpn_protocols.insert(0, b"h2".to_vec()); - } - - let listener = TcpListener::bind(addr).await?; - let acceptor = TlsAcceptor::from(Arc::new(config)); - Ok(TlsListener { listener, acceptor }) - } -} - -impl Listener for TlsListener { - type Connection = TlsStream; - - fn local_addr(&self) -> Option { - self.listener.local_addr().ok() - } - - fn poll_accept( - self: Pin<&mut Self>, - cx: &mut Context<'_> - ) -> Poll> { - match futures::ready!(self.listener.poll_accept(cx)) { - Ok((io, addr)) => Poll::Ready(Ok(TlsStream { - remote: addr, - state: TlsState::Handshaking(self.acceptor.accept(io)), - // These are empty and filled in after handshake is complete. - certs: Certificates::default(), - })), - Err(e) => Poll::Ready(Err(e)), - } - } -} - -impl Connection for TlsStream { - fn peer_address(&self) -> Option { - Some(self.remote) - } - - fn enable_nodelay(&self) -> io::Result<()> { - // If `Handshaking` is `None`, it either failed, so we returned an `Err` - // from `poll_accept()` and there's no connection to enable `NODELAY` - // on, or it succeeded, so we're in the `Streaming` stage and we have - // infallible access to the connection. - match &self.state { - TlsState::Handshaking(accept) => match accept.get_ref() { - None => Ok(()), - Some(s) => s.enable_nodelay(), - }, - TlsState::Streaming(stream) => stream.get_ref().0.enable_nodelay() - } - } - - fn peer_certificates(&self) -> Option { - Some(self.certs.clone()) - } -} - -impl TlsStream { - fn poll_accept_then( - mut self: Pin<&mut Self>, - cx: &mut Context<'_>, - mut f: F - ) -> Poll> - where F: FnMut(&mut BareTlsStream, &mut Context<'_>) -> Poll> - { - loop { - match self.state { - TlsState::Handshaking(ref mut accept) => { - match futures::ready!(Pin::new(accept).poll(cx)) { - Ok(stream) => { - if let Some(peer_certs) = stream.get_ref().1.peer_certificates() { - self.certs.set(peer_certs.into_iter() - .map(|v| CertificateDer(v.clone().into_owned())) - .collect()); - } - - self.state = TlsState::Streaming(stream); - } - Err(e) => { - log::warn!("tls handshake with {} failed: {}", self.remote, e); - return Poll::Ready(Err(e)); - } - } - }, - TlsState::Streaming(ref mut stream) => return f(stream, cx), - } - } - } -} - -impl AsyncRead for TlsStream { - fn poll_read( - self: Pin<&mut Self>, - cx: &mut Context<'_>, - buf: &mut tokio::io::ReadBuf<'_>, - ) -> Poll> { - self.poll_accept_then(cx, |stream, cx| Pin::new(stream).poll_read(cx, buf)) - } -} - -impl AsyncWrite for TlsStream { - fn poll_write( - self: Pin<&mut Self>, - cx: &mut Context<'_>, - buf: &[u8], - ) -> Poll> { - self.poll_accept_then(cx, |stream, cx| Pin::new(stream).poll_write(cx, buf)) - } - - fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { - match &mut self.state { - TlsState::Handshaking(accept) => match accept.get_mut() { - Some(io) => Pin::new(io).poll_flush(cx), - None => Poll::Ready(Ok(())), - } - TlsState::Streaming(stream) => Pin::new(stream).poll_flush(cx), - } - } - - fn poll_shutdown(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { - match &mut self.state { - TlsState::Handshaking(accept) => match accept.get_mut() { - Some(io) => Pin::new(io).poll_shutdown(cx), - None => Poll::Ready(Ok(())), - } - TlsState::Streaming(stream) => Pin::new(stream).poll_shutdown(cx), - } - } -} diff --git a/core/http/src/tls/mod.rs b/core/http/src/tls/mod.rs deleted file mode 100644 index 8d3bcb3d67..0000000000 --- a/core/http/src/tls/mod.rs +++ /dev/null @@ -1,11 +0,0 @@ -mod listener; - -#[cfg(feature = "mtls")] -pub mod mtls; - -pub use rustls; -pub use listener::{TlsListener, Config}; -pub mod util; -pub mod error; - -pub use error::Result; diff --git a/core/lib/Cargo.toml b/core/lib/Cargo.toml index 01b2b370b1..6724b21d32 100644 --- a/core/lib/Cargo.toml +++ b/core/lib/Cargo.toml @@ -20,23 +20,36 @@ rust-version = "1.64" all-features = true [features] -default = ["http2"] -tls = ["rocket_http/tls"] -mtls = ["rocket_http/mtls", "tls"] -http2 = ["rocket_http/http2"] -secrets = ["rocket_http/private-cookies"] -json = ["serde_json", "tokio/io-util"] -msgpack = ["rmp-serde", "tokio/io-util"] +default = ["http2", "tokio-macros"] +http2 = ["hyper/http2", "hyper-util/http2"] +secrets = ["cookie/private", "cookie/key-expansion"] +json = ["serde_json"] +msgpack = ["rmp-serde"] uuid = ["uuid_", "rocket_http/uuid"] +tls = ["rustls", "tokio-rustls", "rustls-pemfile"] +mtls = ["tls", "x509-parser"] +tokio-macros = ["tokio/macros"] [dependencies] -# Serialization dependencies. +# Optional serialization dependencies. serde_json = { version = "1.0.26", optional = true } rmp-serde = { version = "1", optional = true } uuid_ = { package = "uuid", version = "1", optional = true, features = ["serde"] } +# Optional TLS dependencies +rustls = { version = "0.22", optional = true } +tokio-rustls = { version = "0.25", optional = true } +rustls-pemfile = { version = "2.0.0", optional = true } + +# Optional MTLS dependencies +x509-parser = { version = "0.13", optional = true } + +# Hyper dependencies +http = "1" +bytes = "1.4" +hyper = { version = "1.1", default-features = false, features = ["http1", "server"] } + # Non-optional, core dependencies from here on out. -futures = { version = "0.3.0", default-features = false, features = ["std"] } yansi = { version = "1.0.0-rc", features = ["detect-tty"] } log = { version = "0.4", features = ["std"] } num_cpus = "1.0" @@ -44,11 +57,11 @@ time = { version = "0.3", features = ["macros", "parsing"] } memchr = "2" # TODO: Use pear instead. binascii = "0.1" ref-cast = "1.0" -atomic = "0.5" +ref-swap = "0.1.2" parking_lot = "0.12" ubyte = {version = "0.10.2", features = ["serde"] } serde = { version = "1.0", features = ["derive"] } -figment = { version = "0.10.6", features = ["toml", "env"] } +figment = { version = "0.10.13", features = ["toml", "env"] } rand = "0.8" either = "1" pin-project-lite = "0.2" @@ -58,8 +71,25 @@ async-trait = "0.1.43" async-stream = "0.3.2" multer = { version = "3.0.0", features = ["tokio-io"] } tokio-stream = { version = "0.1.6", features = ["signal", "time"] } +cookie = { version = "0.18", features = ["percent-encode"] } +futures = { version = "0.3.30", default-features = false, features = ["std"] } state = "0.6" +[dependencies.hyper-util] +git = "https://github.com/SergioBenitez/hyper-util.git" +branch = "fix-readversion" +default-features = false +features = ["http1", "server", "tokio"] + +[dependencies.tokio] +version = "1.35.1" +features = ["rt-multi-thread", "net", "io-util", "fs", "time", "sync", "signal", "parking_lot"] + +[dependencies.tokio-util] +version = "0.7" +default-features = false +features = ["io"] + [dependencies.rocket_codegen] version = "0.6.0-dev" path = "../codegen" @@ -69,21 +99,13 @@ version = "0.6.0-dev" path = "../http" features = ["serde"] -[dependencies.tokio] -version = "1.6.1" -features = ["fs", "io-std", "io-util", "rt-multi-thread", "sync", "signal", "macros"] - -[dependencies.tokio-util] -version = "0.7" -default-features = false -features = ["io"] - -[dependencies.bytes] -version = "1.0" +[target.'cfg(unix)'.dependencies] +libc = "0.2.149" [build-dependencies] version_check = "0.9.1" [dev-dependencies] +tokio = { version = "1", features = ["macros", "io-std"] } figment = { version = "0.10", features = ["test"] } pretty_assertions = "1" diff --git a/core/lib/src/config/config.rs b/core/lib/src/config/config.rs index 197b6a2f6c..e208944c76 100644 --- a/core/lib/src/config/config.rs +++ b/core/lib/src/config/config.rs @@ -1,5 +1,3 @@ -use std::net::{IpAddr, Ipv4Addr}; - use figment::{Figment, Profile, Provider, Metadata, error::Result}; use figment::providers::{Serialized, Env, Toml, Format}; use figment::value::{Map, Dict, magic::RelativePathBuf}; @@ -12,9 +10,6 @@ use crate::request::{self, Request, FromRequest}; use crate::http::uncased::Uncased; use crate::data::Limits; -#[cfg(feature = "tls")] -use crate::config::TlsConfig; - #[cfg(feature = "secrets")] use crate::config::SecretKey; @@ -66,10 +61,6 @@ pub struct Config { /// set to the extracting Figment's selected `Profile`._ #[serde(skip)] pub profile: Profile, - /// IP address to serve on. **(default: `127.0.0.1`)** - pub address: IpAddr, - /// Port to serve on. **(default: `8000`)** - pub port: u16, /// Number of threads to use for executing futures. **(default: `num_cores`)** /// /// _**Note:** Rocket only reads this value from sources in the [default @@ -121,10 +112,6 @@ pub struct Config { pub temp_dir: RelativePathBuf, /// Keep-alive timeout in seconds; disabled when `0`. **(default: `5`)** pub keep_alive: u32, - /// The TLS configuration, if any. **(default: `None`)** - #[cfg(feature = "tls")] - #[cfg_attr(nightly, doc(cfg(feature = "tls")))] - pub tls: Option, /// The secret key for signing and encrypting. **(default: `0`)** /// /// _**Note:** This field _always_ serializes as a 256-bit array of `0`s to @@ -148,7 +135,6 @@ pub struct Config { /// use rocket::Config; /// /// let config = Config { - /// port: 1024, /// keep_alive: 10, /// ..Default::default() /// }; @@ -204,8 +190,6 @@ impl Config { pub fn debug_default() -> Config { Config { profile: Self::DEBUG_PROFILE, - address: Ipv4Addr::new(127, 0, 0, 1).into(), - port: 8000, workers: num_cpus::get(), max_blocking: 512, ident: Ident::default(), @@ -214,8 +198,6 @@ impl Config { limits: Limits::default(), temp_dir: std::env::temp_dir().into(), keep_alive: 5, - #[cfg(feature = "tls")] - tls: None, #[cfg(feature = "secrets")] secret_key: SecretKey::zero(), shutdown: Shutdown::default(), @@ -331,59 +313,6 @@ impl Config { Self::try_from(provider).unwrap_or_else(bail_with_config_error) } - /// Returns `true` if TLS is enabled. - /// - /// TLS is enabled when the `tls` feature is enabled and TLS has been - /// configured with at least one ciphersuite. Note that without changing - /// defaults, all supported ciphersuites are enabled in the recommended - /// configuration. - /// - /// # Example - /// - /// ```rust - /// let config = rocket::Config::default(); - /// if config.tls_enabled() { - /// println!("TLS is enabled!"); - /// } else { - /// println!("TLS is disabled."); - /// } - /// ``` - pub fn tls_enabled(&self) -> bool { - #[cfg(feature = "tls")] { - self.tls.as_ref().map_or(false, |tls| !tls.ciphers.is_empty()) - } - - #[cfg(not(feature = "tls"))] { false } - } - - /// Returns `true` if mTLS is enabled. - /// - /// mTLS is enabled when TLS is enabled ([`Config::tls_enabled()`]) _and_ - /// the `mtls` feature is enabled _and_ mTLS has been configured with a CA - /// certificate chain. - /// - /// # Example - /// - /// ```rust - /// let config = rocket::Config::default(); - /// if config.mtls_enabled() { - /// println!("mTLS is enabled!"); - /// } else { - /// println!("mTLS is disabled."); - /// } - /// ``` - pub fn mtls_enabled(&self) -> bool { - if !self.tls_enabled() { - return false; - } - - #[cfg(feature = "mtls")] { - self.tls.as_ref().map_or(false, |tls| tls.mutual.is_some()) - } - - #[cfg(not(feature = "mtls"))] { false } - } - #[cfg(feature = "secrets")] pub(crate) fn known_secret_key_used(&self) -> bool { const KNOWN_SECRET_KEYS: &'static [&'static str] = &[ @@ -420,8 +349,6 @@ impl Config { self.trace_print(figment); launch_meta!("{}Configured for {}.", "🔧 ".emoji(), self.profile.underline()); - launch_meta_!("address: {}", self.address.paint(VAL)); - launch_meta_!("port: {}", self.port.paint(VAL)); launch_meta_!("workers: {}", self.workers.paint(VAL)); launch_meta_!("max blocking threads: {}", self.max_blocking.paint(VAL)); launch_meta_!("ident: {}", self.ident.paint(VAL)); @@ -445,12 +372,6 @@ impl Config { ka => launch_meta_!("keep-alive: {}{}", ka.paint(VAL), "s".paint(VAL)), } - match (self.tls_enabled(), self.mtls_enabled()) { - (true, true) => launch_meta_!("tls: {}", "enabled w/mtls".paint(VAL)), - (true, false) => launch_meta_!("tls: {} w/o mtls", "enabled".paint(VAL)), - (false, _) => launch_meta_!("tls: {}", "disabled".paint(VAL)), - } - launch_meta_!("shutdown: {}", self.shutdown.paint(VAL)); launch_meta_!("log level: {}", self.log_level.paint(VAL)); launch_meta_!("cli colors: {}", self.cli_colors.paint(VAL)); @@ -519,12 +440,6 @@ impl Config { /// This isn't `pub` because setting it directly does nothing. const PROFILE: &'static str = "profile"; - /// The stringy parameter name for setting/extracting [`Config::address`]. - pub const ADDRESS: &'static str = "address"; - - /// The stringy parameter name for setting/extracting [`Config::port`]. - pub const PORT: &'static str = "port"; - /// The stringy parameter name for setting/extracting [`Config::workers`]. pub const WORKERS: &'static str = "workers"; @@ -546,9 +461,6 @@ impl Config { /// The stringy parameter name for setting/extracting [`Config::limits`]. pub const LIMITS: &'static str = "limits"; - /// The stringy parameter name for setting/extracting [`Config::tls`]. - pub const TLS: &'static str = "tls"; - /// The stringy parameter name for setting/extracting [`Config::secret_key`]. pub const SECRET_KEY: &'static str = "secret_key"; @@ -566,9 +478,10 @@ impl Config { /// An array of all of the stringy parameter names. pub const PARAMETERS: &'static [&'static str] = &[ - Self::ADDRESS, Self::PORT, Self::WORKERS, Self::MAX_BLOCKING, Self::KEEP_ALIVE, - Self::IDENT, Self::IP_HEADER, Self::PROXY_PROTO_HEADER, Self::LIMITS, Self::TLS, - Self::SECRET_KEY, Self::TEMP_DIR, Self::LOG_LEVEL, Self::SHUTDOWN, Self::CLI_COLORS, + Self::WORKERS, Self::MAX_BLOCKING, Self::KEEP_ALIVE, Self::IDENT, + Self::IP_HEADER, Self::PROXY_PROTO_HEADER, Self::LIMITS, + Self::SECRET_KEY, Self::TEMP_DIR, Self::LOG_LEVEL, Self::SHUTDOWN, + Self::CLI_COLORS, ]; } diff --git a/core/lib/src/config/mod.rs b/core/lib/src/config/mod.rs index 2bf83cd264..86481af1fe 100644 --- a/core/lib/src/config/mod.rs +++ b/core/lib/src/config/mod.rs @@ -117,9 +117,6 @@ mod shutdown; mod cli_colors; mod http_header; -#[cfg(feature = "tls")] -mod tls; - #[cfg(feature = "secrets")] mod secret_key; @@ -132,12 +129,6 @@ pub use shutdown::Shutdown; pub use ident::Ident; pub use cli_colors::CliColors; -#[cfg(feature = "tls")] -pub use tls::{TlsConfig, CipherSuite}; - -#[cfg(feature = "mtls")] -pub use tls::MutualTls; - #[cfg(feature = "secrets")] pub use secret_key::SecretKey; @@ -146,7 +137,6 @@ pub use shutdown::Sig; #[cfg(test)] mod tests { - use std::net::Ipv4Addr; use figment::{Figment, Profile}; use pretty_assertions::assert_eq; @@ -202,9 +192,7 @@ mod tests { figment::Jail::expect_with(|jail| { jail.create_file("Rocket.toml", r#" [default] - address = "1.2.3.4" ident = "Something Cool" - port = 1234 workers = 20 keep_alive = 10 log_level = "off" @@ -213,8 +201,6 @@ mod tests { let config = Config::from(Config::figment()); assert_eq!(config, Config { - address: Ipv4Addr::new(1, 2, 3, 4).into(), - port: 1234, workers: 20, ident: ident!("Something Cool"), keep_alive: 10, @@ -225,9 +211,7 @@ mod tests { jail.create_file("Rocket.toml", r#" [global] - address = "1.2.3.4" ident = "Something Else Cool" - port = 1234 workers = 20 keep_alive = 10 log_level = "off" @@ -236,8 +220,6 @@ mod tests { let config = Config::from(Config::figment()); assert_eq!(config, Config { - address: Ipv4Addr::new(1, 2, 3, 4).into(), - port: 1234, workers: 20, ident: ident!("Something Else Cool"), keep_alive: 10, @@ -249,8 +231,6 @@ mod tests { jail.set_env("ROCKET_CONFIG", "Other.toml"); jail.create_file("Other.toml", r#" [default] - address = "1.2.3.4" - port = 1234 workers = 20 keep_alive = 10 log_level = "off" @@ -259,8 +239,6 @@ mod tests { let config = Config::from(Config::figment()); assert_eq!(config, Config { - address: Ipv4Addr::new(1, 2, 3, 4).into(), - port: 1234, workers: 20, keep_alive: 10, log_level: LogLevel::Off, @@ -367,228 +345,6 @@ mod tests { }) } - #[test] - #[cfg(feature = "tls")] - fn test_tls_config_from_file() { - use crate::config::{TlsConfig, CipherSuite, Ident, Shutdown}; - - figment::Jail::expect_with(|jail| { - jail.create_file("Rocket.toml", r#" - [global] - shutdown.ctrlc = 0 - ident = false - - [global.tls] - certs = "/ssl/cert.pem" - key = "/ssl/key.pem" - - [global.limits] - forms = "1mib" - json = "10mib" - stream = "50kib" - "#)?; - - let config = Config::from(Config::figment()); - assert_eq!(config, Config { - shutdown: Shutdown { ctrlc: false, ..Default::default() }, - ident: Ident::none(), - tls: Some(TlsConfig::from_paths("/ssl/cert.pem", "/ssl/key.pem")), - limits: Limits::default() - .limit("forms", 1.mebibytes()) - .limit("json", 10.mebibytes()) - .limit("stream", 50.kibibytes()), - ..Config::default() - }); - - jail.create_file("Rocket.toml", r#" - [global.tls] - certs = "cert.pem" - key = "key.pem" - "#)?; - - let config = Config::from(Config::figment()); - assert_eq!(config, Config { - tls: Some(TlsConfig::from_paths( - jail.directory().join("cert.pem"), - jail.directory().join("key.pem") - )), - ..Config::default() - }); - - jail.create_file("Rocket.toml", r#" - [global.tls] - certs = "cert.pem" - key = "key.pem" - prefer_server_cipher_order = true - ciphers = [ - "TLS_CHACHA20_POLY1305_SHA256", - "TLS_AES_256_GCM_SHA384", - "TLS_AES_128_GCM_SHA256", - "TLS_ECDHE_ECDSA_WITH_CHACHA20_POLY1305_SHA256", - "TLS_ECDHE_ECDSA_WITH_AES_256_GCM_SHA384", - "TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256", - ] - "#)?; - - let config = Config::from(Config::figment()); - let cert_path = jail.directory().join("cert.pem"); - let key_path = jail.directory().join("key.pem"); - assert_eq!(config, Config { - tls: Some(TlsConfig::from_paths(cert_path, key_path) - .with_preferred_server_cipher_order(true) - .with_ciphers([ - CipherSuite::TLS_CHACHA20_POLY1305_SHA256, - CipherSuite::TLS_AES_256_GCM_SHA384, - CipherSuite::TLS_AES_128_GCM_SHA256, - CipherSuite::TLS_ECDHE_ECDSA_WITH_CHACHA20_POLY1305_SHA256, - CipherSuite::TLS_ECDHE_ECDSA_WITH_AES_256_GCM_SHA384, - CipherSuite::TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256, - ])), - ..Config::default() - }); - - jail.create_file("Rocket.toml", r#" - [global] - shutdown.ctrlc = 0 - ident = false - - [global.tls] - certs = "/ssl/cert.pem" - key = "/ssl/key.pem" - - [global.limits] - forms = "1mib" - json = "10mib" - stream = "50kib" - "#)?; - - let config = Config::from(Config::figment()); - assert_eq!(config, Config { - shutdown: Shutdown { ctrlc: false, ..Default::default() }, - ident: Ident::none(), - tls: Some(TlsConfig::from_paths("/ssl/cert.pem", "/ssl/key.pem")), - limits: Limits::default() - .limit("forms", 1.mebibytes()) - .limit("json", 10.mebibytes()) - .limit("stream", 50.kibibytes()), - ..Config::default() - }); - - jail.create_file("Rocket.toml", r#" - [global.tls] - certs = "cert.pem" - key = "key.pem" - "#)?; - - let config = Config::from(Config::figment()); - assert_eq!(config, Config { - tls: Some(TlsConfig::from_paths( - jail.directory().join("cert.pem"), - jail.directory().join("key.pem") - )), - ..Config::default() - }); - - jail.create_file("Rocket.toml", r#" - [global.tls] - certs = "cert.pem" - key = "key.pem" - prefer_server_cipher_order = true - ciphers = [ - "TLS_CHACHA20_POLY1305_SHA256", - "TLS_AES_256_GCM_SHA384", - "TLS_AES_128_GCM_SHA256", - "TLS_ECDHE_ECDSA_WITH_CHACHA20_POLY1305_SHA256", - "TLS_ECDHE_ECDSA_WITH_AES_256_GCM_SHA384", - "TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256", - ] - "#)?; - - let config = Config::from(Config::figment()); - let cert_path = jail.directory().join("cert.pem"); - let key_path = jail.directory().join("key.pem"); - assert_eq!(config, Config { - tls: Some(TlsConfig::from_paths(cert_path, key_path) - .with_preferred_server_cipher_order(true) - .with_ciphers([ - CipherSuite::TLS_CHACHA20_POLY1305_SHA256, - CipherSuite::TLS_AES_256_GCM_SHA384, - CipherSuite::TLS_AES_128_GCM_SHA256, - CipherSuite::TLS_ECDHE_ECDSA_WITH_CHACHA20_POLY1305_SHA256, - CipherSuite::TLS_ECDHE_ECDSA_WITH_AES_256_GCM_SHA384, - CipherSuite::TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256, - ])), - ..Config::default() - }); - - Ok(()) - }); - } - - #[test] - #[cfg(feature = "mtls")] - fn test_mtls_config() { - use std::path::Path; - - figment::Jail::expect_with(|jail| { - jail.create_file("Rocket.toml", r#" - [default.tls] - certs = "/ssl/cert.pem" - key = "/ssl/key.pem" - "#)?; - - let config = Config::from(Config::figment()); - assert!(config.tls.is_some()); - assert!(config.tls.as_ref().unwrap().mutual.is_none()); - assert!(config.tls_enabled()); - assert!(!config.mtls_enabled()); - - jail.create_file("Rocket.toml", r#" - [default.tls] - certs = "/ssl/cert.pem" - key = "/ssl/key.pem" - mutual = { ca_certs = "/ssl/ca.pem" } - "#)?; - - let config = Config::from(Config::figment()); - assert!(config.tls_enabled()); - assert!(config.mtls_enabled()); - - let mtls = config.tls.as_ref().unwrap().mutual.as_ref().unwrap(); - assert_eq!(mtls.ca_certs().unwrap_left(), Path::new("/ssl/ca.pem")); - assert!(!mtls.mandatory); - - jail.create_file("Rocket.toml", r#" - [default.tls] - certs = "/ssl/cert.pem" - key = "/ssl/key.pem" - - [default.tls.mutual] - ca_certs = "/ssl/ca.pem" - mandatory = true - "#)?; - - let config = Config::from(Config::figment()); - let mtls = config.tls.as_ref().unwrap().mutual.as_ref().unwrap(); - assert_eq!(mtls.ca_certs().unwrap_left(), Path::new("/ssl/ca.pem")); - assert!(mtls.mandatory); - - jail.create_file("Rocket.toml", r#" - [default.tls] - certs = "/ssl/cert.pem" - key = "/ssl/key.pem" - mutual = { ca_certs = "relative/ca.pem" } - "#)?; - - let config = Config::from(Config::figment()); - let mtls = config.tls.as_ref().unwrap().mutual().unwrap(); - assert_eq!(mtls.ca_certs().unwrap_left(), - jail.directory().join("relative/ca.pem")); - - Ok(()) - }); - } - #[test] fn test_profiles_merge() { figment::Jail::expect_with(|jail| { @@ -629,42 +385,41 @@ mod tests { } #[test] - #[cfg(feature = "tls")] fn test_env_vars_merge() { - use crate::config::{TlsConfig, Ident}; + use crate::config::{Ident, Shutdown}; figment::Jail::expect_with(|jail| { - jail.set_env("ROCKET_PORT", 9999); + jail.set_env("ROCKET_KEEP_ALIVE", 9999); let config = Config::from(Config::figment()); assert_eq!(config, Config { - port: 9999, + keep_alive: 9999, ..Config::default() }); - jail.set_env("ROCKET_TLS", r#"{certs="certs.pem"}"#); + jail.set_env("ROCKET_SHUTDOWN", r#"{grace=7}"#); let first_figment = Config::figment(); - jail.set_env("ROCKET_TLS", r#"{key="key.pem"}"#); + jail.set_env("ROCKET_SHUTDOWN", r#"{mercy=10}"#); let prev_figment = Config::figment().join(&first_figment); let config = Config::from(&prev_figment); assert_eq!(config, Config { - port: 9999, - tls: Some(TlsConfig::from_paths("certs.pem", "key.pem")), + keep_alive: 9999, + shutdown: Shutdown { grace: 7, mercy: 10, ..Default::default() }, ..Config::default() }); - jail.set_env("ROCKET_TLS", r#"{certs="new.pem"}"#); + jail.set_env("ROCKET_SHUTDOWN", r#"{mercy=20}"#); let config = Config::from(Config::figment().join(&prev_figment)); assert_eq!(config, Config { - port: 9999, - tls: Some(TlsConfig::from_paths("new.pem", "key.pem")), + keep_alive: 9999, + shutdown: Shutdown { grace: 7, mercy: 20, ..Default::default() }, ..Config::default() }); jail.set_env("ROCKET_LIMITS", r#"{stream=100kiB}"#); let config = Config::from(Config::figment().join(&prev_figment)); assert_eq!(config, Config { - port: 9999, - tls: Some(TlsConfig::from_paths("new.pem", "key.pem")), + keep_alive: 9999, + shutdown: Shutdown { grace: 7, mercy: 20, ..Default::default() }, limits: Limits::default().limit("stream", 100.kibibytes()), ..Config::default() }); @@ -672,8 +427,8 @@ mod tests { jail.set_env("ROCKET_IDENT", false); let config = Config::from(Config::figment().join(&prev_figment)); assert_eq!(config, Config { - port: 9999, - tls: Some(TlsConfig::from_paths("new.pem", "key.pem")), + keep_alive: 9999, + shutdown: Shutdown { grace: 7, mercy: 20, ..Default::default() }, limits: Limits::default().limit("stream", 100.kibibytes()), ident: Ident::none(), ..Config::default() diff --git a/core/lib/src/config/secret_key.rs b/core/lib/src/config/secret_key.rs index 07d804a323..46818c4f51 100644 --- a/core/lib/src/config/secret_key.rs +++ b/core/lib/src/config/secret_key.rs @@ -1,8 +1,8 @@ use std::fmt; +use cookie::Key; use serde::{de, ser, Deserialize, Serialize}; -use crate::http::private::cookie::Key; use crate::request::{Outcome, Request, FromRequest}; #[derive(Debug, Clone, PartialEq, Serialize, Deserialize)] diff --git a/core/lib/src/config/shutdown.rs b/core/lib/src/config/shutdown.rs index e445ca2404..2353a4fbae 100644 --- a/core/lib/src/config/shutdown.rs +++ b/core/lib/src/config/shutdown.rs @@ -1,4 +1,4 @@ -use std::fmt; +use std::{fmt, time::Duration}; #[cfg(unix)] use std::collections::HashSet; @@ -291,6 +291,14 @@ impl Default for Shutdown { } impl Shutdown { + pub(crate) fn grace(&self) -> Duration { + Duration::from_secs(self.grace as u64) + } + + pub(crate) fn mercy(&self) -> Duration { + Duration::from_secs(self.mercy as u64) + } + #[cfg(unix)] pub(crate) fn signal_stream(&self) -> Option> { use tokio_stream::{StreamExt, StreamMap, wrappers::SignalStream}; diff --git a/core/lib/src/data/data_stream.rs b/core/lib/src/data/data_stream.rs index f30f046e3b..77d033284a 100644 --- a/core/lib/src/data/data_stream.rs +++ b/core/lib/src/data/data_stream.rs @@ -3,16 +3,16 @@ use std::task::{Context, Poll}; use std::path::Path; use std::io::{self, Cursor}; +use futures::ready; +use futures::stream::Stream; use tokio::fs::File; use tokio::io::{AsyncRead, AsyncWrite, AsyncReadExt, ReadBuf, Take}; use tokio_util::io::StreamReader; -use futures::{ready, stream::Stream}; +use hyper::body::{Body, Bytes, Incoming as HyperBody}; -use crate::http::hyper; -use crate::ext::{PollExt, Chain}; use crate::data::{Capped, N}; -use crate::http::hyper::body::Bytes; use crate::data::transform::Transform; +use crate::util::Chain; use super::peekable::Peekable; use super::transform::TransformBuf; @@ -68,7 +68,7 @@ pub type RawReader<'r> = StreamReader, Bytes>; /// Raw underlying data stream. pub enum RawStream<'r> { Empty, - Body(&'r mut hyper::Body), + Body(&'r mut HyperBody), Multipart(multer::Field<'r>), } @@ -154,8 +154,14 @@ impl<'r> DataStream<'r> { /// ``` pub fn hint(&self) -> usize { let base = self.base(); - let buf_len = base.get_ref().get_ref().0.get_ref().len(); - std::cmp::min(buf_len, base.limit() as usize) + if let (Some(cursor), _) = base.get_ref().get_ref() { + let len = cursor.get_ref().len() as u64; + let position = cursor.position().min(len); + let remaining = len - position; + remaining.min(base.limit()) as usize + } else { + 0 + } } /// A helper method to write the body of the request to any `AsyncWrite` @@ -331,17 +337,25 @@ impl Stream for RawStream<'_> { fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { match self.get_mut() { - RawStream::Body(body) => Pin::new(body).poll_next(cx) - .map_err_ext(|e| io::Error::new(io::ErrorKind::Other, e)), - RawStream::Multipart(mp) => Pin::new(mp).poll_next(cx) - .map_err_ext(|e| io::Error::new(io::ErrorKind::Other, e)), + // TODO: Expose trailer headers, somehow. + RawStream::Body(body) => { + Pin::new(body) + .poll_frame(cx) + .map_ok(|frame| frame.into_data().unwrap_or_else(|_| Bytes::new())) + .map_err(io::Error::other) + } + RawStream::Multipart(s) => Pin::new(s).poll_next(cx).map_err(io::Error::other), RawStream::Empty => Poll::Ready(None), } } fn size_hint(&self) -> (usize, Option) { match self { - RawStream::Body(body) => body.size_hint(), + RawStream::Body(body) => { + let hint = body.size_hint(); + let (lower, upper) = (hint.lower(), hint.upper()); + (lower as usize, upper.map(|x| x as usize)) + }, RawStream::Multipart(mp) => mp.size_hint(), RawStream::Empty => (0, Some(0)), } @@ -358,8 +372,8 @@ impl std::fmt::Display for RawStream<'_> { } } -impl<'r> From<&'r mut hyper::Body> for RawStream<'r> { - fn from(value: &'r mut hyper::Body) -> Self { +impl<'r> From<&'r mut HyperBody> for RawStream<'r> { + fn from(value: &'r mut HyperBody) -> Self { Self::Body(value) } } diff --git a/core/lib/src/data/io_stream.rs b/core/lib/src/data/io_stream.rs index 0945c5c0f1..595138431a 100644 --- a/core/lib/src/data/io_stream.rs +++ b/core/lib/src/data/io_stream.rs @@ -3,8 +3,8 @@ use std::task::{Context, Poll}; use std::pin::Pin; use tokio::io::{AsyncRead, AsyncWrite, ReadBuf}; - -use crate::http::hyper::upgrade::Upgraded; +use hyper::upgrade::Upgraded; +use hyper_util::rt::TokioIo; /// A bidirectional, raw stream to the client. /// @@ -28,7 +28,7 @@ pub struct IoStream { /// Just in case we want to add stream kinds in the future. enum IoStreamKind { - Upgraded(Upgraded) + Upgraded(TokioIo) } /// An upgraded connection I/O handler. @@ -51,7 +51,7 @@ enum IoStreamKind { /// /// #[rocket::async_trait] /// impl IoHandler for EchoHandler { -/// async fn io(self: Pin>, io: IoStream) -> io::Result<()> { +/// async fn io(self: Box, io: IoStream) -> io::Result<()> { /// let (mut reader, mut writer) = io::split(io); /// io::copy(&mut reader, &mut writer).await?; /// Ok(()) @@ -68,13 +68,20 @@ enum IoStreamKind { #[crate::async_trait] pub trait IoHandler: Send { /// Performs the raw I/O. - async fn io(self: Pin>, io: IoStream) -> io::Result<()>; + async fn io(self: Box, io: IoStream) -> io::Result<()>; +} + +#[crate::async_trait] +impl IoHandler for () { + async fn io(self: Box, _: IoStream) -> io::Result<()> { + Ok(()) + } } #[doc(hidden)] impl From for IoStream { fn from(io: Upgraded) -> Self { - IoStream { kind: IoStreamKind::Upgraded(io) } + IoStream { kind: IoStreamKind::Upgraded(TokioIo::new(io)) } } } diff --git a/core/lib/src/data/transform.rs b/core/lib/src/data/transform.rs index e3be992c76..f52478e6c1 100644 --- a/core/lib/src/data/transform.rs +++ b/core/lib/src/data/transform.rs @@ -178,7 +178,7 @@ impl<'a, 'b> DerefMut for TransformBuf<'a, 'b> { #[allow(deprecated)] mod tests { use std::hash::SipHasher; - use std::sync::{Arc, atomic::{AtomicU64, AtomicU8}}; + use std::sync::{Arc, atomic::{AtomicU8, AtomicU64, Ordering}}; use parking_lot::Mutex; use ubyte::ToByteUnit; @@ -264,41 +264,41 @@ mod tests { assert_eq!(bytes.len(), 8); let bytes: [u8; 8] = bytes.try_into().expect("[u8; 8]"); let value = u64::from_be_bytes(bytes); - hash1.store(value, atomic::Ordering::Release); + hash1.store(value, Ordering::Release); }) .chain_inspect(move |bytes| { assert_eq!(bytes.len(), 8); let bytes: [u8; 8] = bytes.try_into().expect("[u8; 8]"); let value = u64::from_be_bytes(bytes); - let prev = hash2.load(atomic::Ordering::Acquire); + let prev = hash2.load(Ordering::Acquire); assert_eq!(prev, value); - inspect2.fetch_add(1, atomic::Ordering::Release); + inspect2.fetch_add(1, Ordering::Release); }); }))); // Make sure nothing has happened yet. assert!(raw_data.lock().is_empty()); - assert_eq!(hash.load(atomic::Ordering::Acquire), 0); - assert_eq!(inspect2.load(atomic::Ordering::Acquire), 0); + assert_eq!(hash.load(Ordering::Acquire), 0); + assert_eq!(inspect2.load(Ordering::Acquire), 0); // Check that nothing happens if the data isn't read. let client = Client::debug(rocket).unwrap(); client.get("/").body("Hello, world!").dispatch(); assert!(raw_data.lock().is_empty()); - assert_eq!(hash.load(atomic::Ordering::Acquire), 0); - assert_eq!(inspect2.load(atomic::Ordering::Acquire), 0); + assert_eq!(hash.load(Ordering::Acquire), 0); + assert_eq!(inspect2.load(Ordering::Acquire), 0); // Check inspect + hash + inspect + inspect. client.post("/").body("Hello, world!").dispatch(); assert_eq!(raw_data.lock().as_slice(), "Hello, world!".as_bytes()); - assert_eq!(hash.load(atomic::Ordering::Acquire), 0xae5020d7cf49d14f); - assert_eq!(inspect2.load(atomic::Ordering::Acquire), 1); + assert_eq!(hash.load(Ordering::Acquire), 0xae5020d7cf49d14f); + assert_eq!(inspect2.load(Ordering::Acquire), 1); // Check inspect + hash + inspect + inspect, round 2. let string = "Rocket, Rocket, where art thee? Oh, tis in the sky, I see!"; client.post("/").body(string).dispatch(); assert_eq!(raw_data.lock().as_slice(), string.as_bytes()); - assert_eq!(hash.load(atomic::Ordering::Acquire), 0x323f9aa98f907faf); - assert_eq!(inspect2.load(atomic::Ordering::Acquire), 2); + assert_eq!(hash.load(Ordering::Acquire), 0x323f9aa98f907faf); + assert_eq!(inspect2.load(Ordering::Acquire), 2); } } diff --git a/core/lib/src/erased.rs b/core/lib/src/erased.rs new file mode 100644 index 0000000000..7b62522c55 --- /dev/null +++ b/core/lib/src/erased.rs @@ -0,0 +1,193 @@ +use std::io; +use std::mem::transmute; +use std::pin::Pin; +use std::sync::Arc; +use std::task::{Poll, Context}; + +use futures::future::BoxFuture; +use http::request::Parts; +use hyper::body::Incoming; +use tokio::io::{AsyncRead, ReadBuf}; + +use crate::data::{Data, IoHandler}; +use crate::{Request, Response, Rocket, Orbit}; + +// TODO: Magic with trait async fn to get rid of the box pin. +// TODO: Write safety proofs. + +macro_rules! static_assert_covariance { + ($T:tt) => ( + const _: () = { + fn _assert_covariance<'x: 'y, 'y>(x: &'y $T<'x>) -> &'y $T<'y> { x } + }; + ) +} + +#[derive(Debug)] +pub struct ErasedRequest { + // XXX: SAFETY: This (dependent) field must come first due to drop order! + request: Request<'static>, + _rocket: Arc>, + _parts: Box, +} + +impl Drop for ErasedRequest { + fn drop(&mut self) { } +} + +#[derive(Debug)] +pub struct ErasedResponse { + // XXX: SAFETY: This (dependent) field must come first due to drop order! + response: Response<'static>, + _request: Arc, + _incoming: Box, +} + +impl Drop for ErasedResponse { + fn drop(&mut self) { } +} + +pub struct ErasedIoHandler { + // XXX: SAFETY: This (dependent) field must come first due to drop order! + io: Box, + _request: Arc, +} + +impl Drop for ErasedIoHandler { + fn drop(&mut self) { } +} + +impl ErasedRequest { + pub fn new( + rocket: Arc>, + parts: Parts, + constructor: impl for<'r> FnOnce( + &'r Rocket, + &'r Parts + ) -> Request<'r>, + ) -> ErasedRequest { + let rocket: Arc> = rocket; + let parts: Box = Box::new(parts); + let request: Request<'_> = { + let rocket: &Rocket = &*rocket; + let rocket: &'static Rocket = unsafe { transmute(rocket) }; + let parts: &Parts = &*parts; + let parts: &'static Parts = unsafe { transmute(parts) }; + constructor(&rocket, &parts) + }; + + ErasedRequest { _rocket: rocket, _parts: parts, request, } + } + + pub async fn into_response( + self, + incoming: Incoming, + data_builder: impl for<'r> FnOnce(&'r mut Incoming) -> Data<'r>, + preprocess: impl for<'r, 'x> FnOnce( + &'r Rocket, + &'r mut Request<'x>, + &'r mut Data<'x> + ) -> BoxFuture<'r, T>, + dispatch: impl for<'r> FnOnce( + T, + &'r Rocket, + &'r Request<'r>, + Data<'r> + ) -> BoxFuture<'r, Response<'r>>, + ) -> ErasedResponse { + let mut incoming = Box::new(incoming); + let mut data: Data<'_> = { + let incoming: &mut Incoming = &mut *incoming; + let incoming: &'static mut Incoming = unsafe { transmute(incoming) }; + data_builder(incoming) + }; + + let mut parent = Arc::new(self); + let token: T = { + let parent: &mut ErasedRequest = Arc::get_mut(&mut parent).unwrap(); + let rocket: &Rocket = &*parent._rocket; + let request: &mut Request<'_> = &mut parent.request; + let data: &mut Data<'_> = &mut data; + preprocess(rocket, request, data).await + }; + + let parent = parent; + let response: Response<'_> = { + let parent: &ErasedRequest = &*parent; + let parent: &'static ErasedRequest = unsafe { transmute(parent) }; + let rocket: &Rocket = &*parent._rocket; + let request: &Request<'_> = &parent.request; + dispatch(token, rocket, request, data).await + }; + + ErasedResponse { + _request: parent, + _incoming: incoming, + response: response, + } + } +} + +impl ErasedResponse { + pub fn inner<'a>(&'a self) -> &'a Response<'a> { + static_assert_covariance!(Response); + &self.response + } + + pub fn with_inner_mut<'a, T>( + &'a mut self, + f: impl for<'r> FnOnce(&'a mut Response<'r>) -> T + ) -> T { + static_assert_covariance!(Response); + f(&mut self.response) + } + + pub fn to_io_handler<'a>( + &'a mut self, + constructor: impl for<'r> FnOnce( + &'r Request<'r>, + &'a mut Response<'r>, + ) -> Option> + ) -> Option { + let parent: Arc = self._request.clone(); + let io: Option> = { + let parent: &ErasedRequest = &*parent; + let parent: &'static ErasedRequest = unsafe { transmute(parent) }; + let request: &Request<'_> = &parent.request; + constructor(request, &mut self.response) + }; + + io.map(|io| ErasedIoHandler { _request: parent, io }) + } +} + +impl ErasedIoHandler { + pub fn with_inner_mut<'a, T: 'a>( + &'a mut self, + f: impl for<'r> FnOnce(&'a mut Box) -> T + ) -> T { + fn _assert_covariance<'x: 'y, 'y>( + x: &'y Box + ) -> &'y Box { x } + + f(&mut self.io) + } + + pub fn take<'a>(&'a mut self) -> Box { + fn _assert_covariance<'x: 'y, 'y>( + x: &'y Box + ) -> &'y Box { x } + + self.with_inner_mut(|handler| std::mem::replace(handler, Box::new(()))) + } +} + +impl AsyncRead for ErasedResponse { + fn poll_read( + self: Pin<&mut Self>, + cx: &mut Context<'_>, + buf: &mut ReadBuf<'_>, + ) -> Poll> { + self.get_mut().with_inner_mut(|r| Pin::new(r.body_mut()).poll_read(cx, buf)) + } +} diff --git a/core/lib/src/error.rs b/core/lib/src/error.rs index ff3ef79f61..21753b1f1b 100644 --- a/core/lib/src/error.rs +++ b/core/lib/src/error.rs @@ -74,11 +74,8 @@ pub struct Error { #[derive(Debug)] #[non_exhaustive] pub enum ErrorKind { - /// Binding to the provided address/port failed. - Bind(io::Error), - /// Binding via TLS to the provided address/port failed. - #[cfg(feature = "tls")] - TlsBind(crate::http::tls::error::Error), + /// Binding to the network interface failed. + Bind(Box), /// An I/O error occurred during launch. Io(io::Error), /// A valid [`Config`](crate::Config) could not be extracted from the @@ -90,15 +87,10 @@ pub enum ErrorKind { FailedFairings(Vec), /// Sentinels requested abort. SentinelAborts(Vec), - /// The configuration profile is not debug but not secret key is configured. + /// The configuration profile is not debug but no secret key is configured. InsecureSecretKey(Profile), - /// Shutdown failed. - Shutdown( - /// The instance of Rocket that failed to shutdown. - Arc>, - /// The error that occurred during shutdown, if any. - Option> - ), + /// Shutdown failed. Contains the Rocket instance that failed to shutdown. + Shutdown(Arc>), } /// An error that occurs when a value was unexpectedly empty. @@ -111,20 +103,24 @@ impl From for Error { } } +impl From for Error { + fn from(e: figment::Error) -> Self { + Error::new(ErrorKind::Config(e)) + } +} + +impl From for Error { + fn from(e: io::Error) -> Self { + Error::new(ErrorKind::Io(e)) + } +} + impl Error { #[inline(always)] pub(crate) fn new(kind: ErrorKind) -> Error { Error { handled: AtomicBool::new(false), kind } } - #[inline(always)] - pub(crate) fn shutdown(rocket: Arc>, error: E) -> Error - where E: Into> - { - let error = error.into().map(|e| Box::new(e) as Box); - Error::new(ErrorKind::Shutdown(rocket, error)) - } - #[inline(always)] fn was_handled(&self) -> bool { self.handled.load(Ordering::Acquire) @@ -176,9 +172,9 @@ impl Error { self.mark_handled(); match self.kind() { ErrorKind::Bind(ref e) => { - error!("Rocket failed to bind network socket to given address/port."); + error!("Binding to the network interface failed."); info_!("{}", e); - "aborting due to socket bind error" + "aborting due to bind error" } ErrorKind::Io(ref e) => { error!("Rocket failed to launch due to an I/O error."); @@ -229,20 +225,10 @@ impl Error { "aborting due to sentinel-triggered abort(s)" } - ErrorKind::Shutdown(_, error) => { + ErrorKind::Shutdown(_) => { error!("Rocket failed to shutdown gracefully."); - if let Some(e) = error { - info_!("{}", e); - } - "aborting due to failed shutdown" } - #[cfg(feature = "tls")] - ErrorKind::TlsBind(e) => { - error!("Rocket failed to bind via TLS to network socket."); - info_!("{}", e); - "aborting due to TLS bind error" - } } } } @@ -260,10 +246,7 @@ impl fmt::Display for ErrorKind { ErrorKind::InsecureSecretKey(_) => "insecure secret key config".fmt(f), ErrorKind::Config(_) => "failed to extract configuration".fmt(f), ErrorKind::SentinelAborts(_) => "sentinel(s) aborted".fmt(f), - ErrorKind::Shutdown(_, Some(e)) => write!(f, "shutdown failed: {e}"), - ErrorKind::Shutdown(_, None) => "shutdown failed".fmt(f), - #[cfg(feature = "tls")] - ErrorKind::TlsBind(e) => write!(f, "TLS bind failed: {e}"), + ErrorKind::Shutdown(_) => "shutdown failed".fmt(f), } } } @@ -308,3 +291,42 @@ impl fmt::Display for Empty { } impl StdError for Empty { } + +/// Log an error that occurs during request processing +pub(crate) fn log_server_error(error: &Box) { + struct ServerError<'a>(&'a (dyn StdError + 'static)); + + impl fmt::Display for ServerError<'_> { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + let error = &self.0; + if let Some(e) = error.downcast_ref::() { + write!(f, "request processing failed: {e}")?; + } else if let Some(e) = error.downcast_ref::() { + write!(f, "connection I/O error: ")?; + + match e.kind() { + io::ErrorKind::NotConnected => write!(f, "remote disconnected")?, + io::ErrorKind::UnexpectedEof => write!(f, "remote sent early eof")?, + io::ErrorKind::ConnectionReset + | io::ErrorKind::ConnectionAborted + | io::ErrorKind::BrokenPipe => write!(f, "terminated by remote")?, + _ => write!(f, "{e}")?, + } + } else { + write!(f, "http server error: {error}")?; + } + + if let Some(e) = error.source() { + write!(f, " ({})", ServerError(e))?; + } + + Ok(()) + } + } + + if error.downcast_ref::().is_some() { + warn!("{}", ServerError(&**error)) + } else { + error!("{}", ServerError(&**error)) + } +} diff --git a/core/lib/src/ext.rs b/core/lib/src/ext.rs deleted file mode 100644 index 03922184df..0000000000 --- a/core/lib/src/ext.rs +++ /dev/null @@ -1,404 +0,0 @@ -use std::{io, time::Duration}; -use std::task::{Poll, Context}; -use std::pin::Pin; - -use bytes::{Bytes, BytesMut}; -use pin_project_lite::pin_project; -use tokio::io::{AsyncRead, AsyncWrite, ReadBuf}; -use tokio::time::{sleep, Sleep}; - -use futures::stream::Stream; -use futures::future::{self, Future, FutureExt}; - -pin_project! { - pub struct ReaderStream { - #[pin] - reader: Option, - buf: BytesMut, - cap: usize, - } -} - -impl Stream for ReaderStream { - type Item = std::io::Result; - - fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { - use tokio_util::io::poll_read_buf; - - let mut this = self.as_mut().project(); - - let reader = match this.reader.as_pin_mut() { - Some(r) => r, - None => return Poll::Ready(None), - }; - - if this.buf.capacity() == 0 { - this.buf.reserve(*this.cap); - } - - match poll_read_buf(reader, cx, &mut this.buf) { - Poll::Pending => Poll::Pending, - Poll::Ready(Err(err)) => { - self.project().reader.set(None); - Poll::Ready(Some(Err(err))) - } - Poll::Ready(Ok(0)) => { - self.project().reader.set(None); - Poll::Ready(None) - } - Poll::Ready(Ok(_)) => { - let chunk = this.buf.split(); - Poll::Ready(Some(Ok(chunk.freeze()))) - } - } - } -} - -pub trait AsyncReadExt: AsyncRead + Sized { - fn into_bytes_stream(self, cap: usize) -> ReaderStream { - ReaderStream { reader: Some(self), cap, buf: BytesMut::with_capacity(cap) } - } -} - -impl AsyncReadExt for T { } - -pub trait PollExt { - fn map_err_ext(self, f: F) -> Poll>> - where F: FnOnce(E) -> U; -} - -impl PollExt for Poll>> { - /// Changes the error value of this `Poll` with the closure provided. - fn map_err_ext(self, f: F) -> Poll>> - where F: FnOnce(E) -> U - { - match self { - Poll::Ready(Some(Ok(t))) => Poll::Ready(Some(Ok(t))), - Poll::Ready(Some(Err(e))) => Poll::Ready(Some(Err(f(e)))), - Poll::Ready(None) => Poll::Ready(None), - Poll::Pending => Poll::Pending, - } - } -} - -pin_project! { - /// Stream for the [`chain`](super::AsyncReadExt::chain) method. - #[must_use = "streams do nothing unless polled"] - pub struct Chain { - #[pin] - first: T, - #[pin] - second: U, - done_first: bool, - } -} - -impl Chain { - pub(crate) fn new(first: T, second: U) -> Self { - Self { first, second, done_first: false } - } -} - -impl Chain { - /// Gets references to the underlying readers in this `Chain`. - pub fn get_ref(&self) -> (&T, &U) { - (&self.first, &self.second) - } -} - -impl AsyncRead for Chain { - fn poll_read( - self: Pin<&mut Self>, - cx: &mut Context<'_>, - buf: &mut ReadBuf<'_>, - ) -> Poll> { - let me = self.project(); - - if !*me.done_first { - let init_rem = buf.remaining(); - futures::ready!(me.first.poll_read(cx, buf))?; - if buf.remaining() == init_rem { - *me.done_first = true; - } else { - return Poll::Ready(Ok(())); - } - } - me.second.poll_read(cx, buf) - } -} - -enum State { - /// I/O has not been cancelled. Proceed as normal. - Active, - /// I/O has been cancelled. See if we can finish before the timer expires. - Grace(Pin>), - /// Grace period elapsed. Shutdown the connection, waiting for the timer - /// until we force close. - Mercy(Pin>), -} - -pin_project! { - /// I/O that can be cancelled when a future `F` resolves. - #[must_use = "futures do nothing unless polled"] - pub struct CancellableIo { - #[pin] - io: Option, - #[pin] - trigger: future::Fuse, - state: State, - grace: Duration, - mercy: Duration, - } -} - -impl CancellableIo { - pub fn new(trigger: F, io: I, grace: Duration, mercy: Duration) -> Self { - CancellableIo { - grace, mercy, - io: Some(io), - trigger: trigger.fuse(), - state: State::Active, - } - } - - pub fn io(&self) -> Option<&I> { - self.io.as_ref() - } - - /// Run `do_io` while connection processing should continue. - fn poll_trigger_then( - mut self: Pin<&mut Self>, - cx: &mut Context<'_>, - do_io: impl FnOnce(Pin<&mut I>, &mut Context<'_>) -> Poll>, - ) -> Poll> { - let mut me = self.as_mut().project(); - let io = match me.io.as_pin_mut() { - Some(io) => io, - None => return Poll::Ready(Err(gone())), - }; - - loop { - match me.state { - State::Active => { - if me.trigger.as_mut().poll(cx).is_ready() { - *me.state = State::Grace(Box::pin(sleep(*me.grace))); - } else { - return do_io(io, cx); - } - } - State::Grace(timer) => { - if timer.as_mut().poll(cx).is_ready() { - *me.state = State::Mercy(Box::pin(sleep(*me.mercy))); - } else { - return do_io(io, cx); - } - } - State::Mercy(timer) => { - if timer.as_mut().poll(cx).is_ready() { - self.project().io.set(None); - return Poll::Ready(Err(time_out())); - } else { - let result = futures::ready!(io.poll_shutdown(cx)); - self.project().io.set(None); - return match result { - Err(e) => Poll::Ready(Err(e)), - Ok(()) => Poll::Ready(Err(gone())) - }; - } - }, - } - } - } -} - -fn time_out() -> io::Error { - io::Error::new(io::ErrorKind::TimedOut, "Shutdown grace timed out") -} - -fn gone() -> io::Error { - io::Error::new(io::ErrorKind::BrokenPipe, "IO driver has terminated") -} - -impl AsyncRead for CancellableIo { - fn poll_read( - mut self: Pin<&mut Self>, - cx: &mut Context<'_>, - buf: &mut ReadBuf<'_>, - ) -> Poll> { - self.as_mut().poll_trigger_then(cx, |io, cx| io.poll_read(cx, buf)) - } -} - -impl AsyncWrite for CancellableIo { - fn poll_write( - mut self: Pin<&mut Self>, - cx: &mut Context<'_>, - buf: &[u8], - ) -> Poll> { - self.as_mut().poll_trigger_then(cx, |io, cx| io.poll_write(cx, buf)) - } - - fn poll_flush( - mut self: Pin<&mut Self>, - cx: &mut Context<'_> - ) -> Poll> { - self.as_mut().poll_trigger_then(cx, |io, cx| io.poll_flush(cx)) - } - - fn poll_shutdown( - mut self: Pin<&mut Self>, - cx: &mut Context<'_> - ) -> Poll> { - self.as_mut().poll_trigger_then(cx, |io, cx| io.poll_shutdown(cx)) - } - - fn poll_write_vectored( - mut self: Pin<&mut Self>, - cx: &mut Context<'_>, - bufs: &[io::IoSlice<'_>], - ) -> Poll> { - self.as_mut().poll_trigger_then(cx, |io, cx| io.poll_write_vectored(cx, bufs)) - } - - fn is_write_vectored(&self) -> bool { - self.io().map(|io| io.is_write_vectored()).unwrap_or(false) - } -} - -use crate::http::private::{Listener, Connection, Certificates}; - -impl Connection for CancellableIo { - fn peer_address(&self) -> Option { - self.io().and_then(|io| io.peer_address()) - } - - fn peer_certificates(&self) -> Option { - self.io().and_then(|io| io.peer_certificates()) - } - - fn enable_nodelay(&self) -> io::Result<()> { - match self.io() { - Some(io) => io.enable_nodelay(), - None => Ok(()) - } - } -} - -pin_project! { - pub struct CancellableListener { - pub trigger: F, - #[pin] - pub listener: L, - pub grace: Duration, - pub mercy: Duration, - } -} - -impl CancellableListener { - pub fn new(trigger: F, listener: L, grace: u64, mercy: u64) -> Self { - let (grace, mercy) = (Duration::from_secs(grace), Duration::from_secs(mercy)); - CancellableListener { trigger, listener, grace, mercy } - } -} - -impl Listener for CancellableListener { - type Connection = CancellableIo; - - fn local_addr(&self) -> Option { - self.listener.local_addr() - } - - fn poll_accept( - mut self: Pin<&mut Self>, - cx: &mut Context<'_> - ) -> Poll> { - self.as_mut().project().listener - .poll_accept(cx) - .map(|res| res.map(|conn| { - CancellableIo::new(self.trigger.clone(), conn, self.grace, self.mercy) - })) - } -} - -pub trait StreamExt: Sized + Stream { - fn join(self, other: U) -> Join - where U: Stream; -} - -impl StreamExt for S { - fn join(self, other: U) -> Join - where U: Stream - { - Join::new(self, other) - } -} - -pin_project! { - /// Stream returned by the [`join`](super::StreamExt::join) method. - pub struct Join { - #[pin] - a: T, - #[pin] - b: U, - // When `true`, poll `a` first, otherwise, `poll` b`. - toggle: bool, - // Set when either `a` or `b` return `None`. - done: bool, - } -} - -impl Join { - pub(super) fn new(a: T, b: U) -> Join - where T: Stream, U: Stream, - { - Join { a, b, toggle: false, done: false, } - } - - fn poll_next>( - first: Pin<&mut A>, - second: Pin<&mut B>, - done: &mut bool, - cx: &mut Context<'_>, - ) -> Poll> { - match first.poll_next(cx) { - Poll::Ready(opt) => { *done = opt.is_none(); Poll::Ready(opt) } - Poll::Pending => match second.poll_next(cx) { - Poll::Ready(opt) => { *done = opt.is_none(); Poll::Ready(opt) } - Poll::Pending => Poll::Pending - } - } - } -} - -impl Stream for Join - where T: Stream, - U: Stream, -{ - type Item = T::Item; - - fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { - if self.done { - return Poll::Ready(None); - } - - let me = self.project(); - *me.toggle = !*me.toggle; - match *me.toggle { - true => Self::poll_next(me.a, me.b, me.done, cx), - false => Self::poll_next(me.b, me.a, me.done, cx), - } - } - - fn size_hint(&self) -> (usize, Option) { - let (left_low, left_high) = self.a.size_hint(); - let (right_low, right_high) = self.b.size_hint(); - - let low = left_low.saturating_add(right_low); - let high = match (left_high, right_high) { - (Some(h1), Some(h2)) => h1.checked_add(h2), - _ => None, - }; - - (low, high) - } -} diff --git a/core/lib/src/form/mod.rs b/core/lib/src/form/mod.rs index aa772a39a4..915a64f1b1 100644 --- a/core/lib/src/form/mod.rs +++ b/core/lib/src/form/mod.rs @@ -341,6 +341,7 @@ // `key_contexts: Vec`, a vector of `value_contexts: // Vec`, a `mapping` from a string index to an integer index // into the `contexts`, and a vector of `errors`. +// // 2. **Push.** An index is required; an error is emitted and `push` returns // if they field's first key does not contain an index. If the first key // contains _one_ index, a new `K::Context` and `V::Context` are created. @@ -356,9 +357,9 @@ // to `second` in `mapping`. If the first index is `k`, the field, // stripped of the first key, is pushed to the key's context; the same is // done for the value's context is the first index is `v`. +// // 3. **Finalization.** Every context is finalized; errors and `Ok` values -// are collected. TODO: FINISH. Split this into two: one for single-index, -// another for two-indices. +// are collected. mod field; mod options; diff --git a/core/lib/src/fs/named_file.rs b/core/lib/src/fs/named_file.rs index 8fd165f209..d4eed82a92 100644 --- a/core/lib/src/fs/named_file.rs +++ b/core/lib/src/fs/named_file.rs @@ -2,7 +2,7 @@ use std::io; use std::path::{Path, PathBuf}; use std::ops::{Deref, DerefMut}; -use tokio::fs::File; +use tokio::fs::{File, OpenOptions}; use crate::request::Request; use crate::response::{self, Responder}; @@ -60,7 +60,7 @@ impl NamedFile { /// } /// ``` pub async fn open>(path: P) -> io::Result { - // FIXME: Grab the file size here and prohibit `seek`ing later (or else + // TODO: Grab the file size here and prohibit `seek`ing later (or else // the file's effective size may change), to save on the cost of doing // all of those `seek`s to determine the file size. But, what happens if // the file gets changed between now and then? @@ -68,6 +68,11 @@ impl NamedFile { Ok(NamedFile(path.as_ref().to_path_buf(), file)) } + pub async fn open_with>(path: P, opts: &OpenOptions) -> io::Result { + let file = opts.open(path.as_ref()).await?; + Ok(NamedFile(path.as_ref().to_path_buf(), file)) + } + /// Retrieve the underlying `File`. /// /// # Example diff --git a/core/lib/src/cookies.rs b/core/lib/src/http/cookies.rs similarity index 97% rename from core/lib/src/cookies.rs rename to core/lib/src/http/cookies.rs index b64441f03d..1a17949ab9 100644 --- a/core/lib/src/cookies.rs +++ b/core/lib/src/http/cookies.rs @@ -2,11 +2,10 @@ use std::fmt; use parking_lot::Mutex; -use crate::http::private::cookie; use crate::{Rocket, Orbit}; #[doc(inline)] -pub use self::cookie::{Cookie, SameSite, Iter}; +pub use cookie::{Cookie, SameSite, Iter}; /// Collection of one or more HTTP cookies. /// @@ -167,7 +166,7 @@ pub(crate) struct CookieState<'a> { #[derive(Clone)] enum Op { Add(Cookie<'static>, bool), - Remove(Cookie<'static>, bool), + Remove(Cookie<'static>), } impl<'a> CookieJar<'a> { @@ -177,7 +176,7 @@ impl<'a> CookieJar<'a> { ops: Mutex::new(Vec::new()), state: CookieState { // This is updated dynamically when headers are received. - secure: rocket.config().tls_enabled(), + secure: rocket.endpoint().is_tls(), config: rocket.config(), } } @@ -256,7 +255,7 @@ impl<'a> CookieJar<'a> { for op in ops.iter().rev().filter(|op| op.cookie().name() == name) { match op { Op::Add(c, _) => return Some(c.clone()), - Op::Remove(_, _) => return None, + Op::Remove(_) => return None, } } @@ -389,7 +388,7 @@ impl<'a> CookieJar<'a> { pub fn remove>>(&self, cookie: C) { let mut cookie = cookie.into(); Self::set_removal_defaults(&mut cookie); - self.ops.lock().push(Op::Remove(cookie, false)); + self.ops.lock().push(Op::Remove(cookie)); } /// Removes the private `cookie` from the collection. @@ -432,7 +431,7 @@ impl<'a> CookieJar<'a> { pub fn remove_private>>(&self, cookie: C) { let mut cookie = cookie.into(); Self::set_removal_defaults(&mut cookie); - self.ops.lock().push(Op::Remove(cookie, true)); + self.ops.lock().push(Op::Remove(cookie)); } /// Returns an iterator over all of the _original_ cookies present in this @@ -477,7 +476,7 @@ impl<'a> CookieJar<'a> { Op::Add(c, true) => { jar.private_mut(&self.state.config.secret_key.key).add(c); } - Op::Remove(mut c, _) => { + Op::Remove(mut c) => { if self.jar.get(c.name()).is_some() { c.make_removal(); jar.add(c); @@ -595,7 +594,7 @@ impl<'a> Clone for CookieJar<'a> { impl Op { fn cookie(&self) -> &Cookie<'static> { match self { - Op::Add(c, _) | Op::Remove(c, _) => c + Op::Add(c, _) | Op::Remove(c) => c } } } diff --git a/core/lib/src/http/mod.rs b/core/lib/src/http/mod.rs new file mode 100644 index 0000000000..ac38395c1b --- /dev/null +++ b/core/lib/src/http/mod.rs @@ -0,0 +1,12 @@ +//! Types that map to concepts in HTTP. +//! +//! This module exports types that map to HTTP concepts or to the underlying +//! HTTP library when needed. + +mod cookies; + +#[doc(inline)] +pub use rocket_http::*; + +#[doc(inline)] +pub use cookies::*; diff --git a/core/lib/src/lib.rs b/core/lib/src/lib.rs index 232a981482..5ffee01a8f 100644 --- a/core/lib/src/lib.rs +++ b/core/lib/src/lib.rs @@ -7,7 +7,9 @@ #![cfg_attr(nightly, feature(decl_macro))] #![warn(rust_2018_idioms)] -#![warn(missing_docs)] +// #![warn(missing_docs)] +#![allow(async_fn_in_trait)] +#![allow(refining_impl_trait)] //! # Rocket - Core API Documentation //! @@ -109,18 +111,24 @@ /// These are public dependencies! Update docs if these are changed, especially /// figment's version number in docs. -#[doc(hidden)] pub use yansi; -#[doc(hidden)] pub use async_stream; +#[doc(hidden)] +pub use yansi; +#[doc(hidden)] +pub use async_stream; pub use futures; pub use tokio; pub use figment; pub use time; #[doc(hidden)] -#[macro_use] pub mod log; -#[macro_use] pub mod outcome; -#[macro_use] pub mod data; -#[doc(hidden)] pub mod sentinel; +#[macro_use] +pub mod log; +#[macro_use] +pub mod outcome; +#[macro_use] +pub mod data; +#[doc(hidden)] +pub mod sentinel; pub mod local; pub mod request; pub mod response; @@ -133,74 +141,41 @@ pub mod route; pub mod serde; pub mod shield; pub mod fs; - -// Reexport of HTTP everything. -pub mod http { - //! Types that map to concepts in HTTP. - //! - //! This module exports types that map to HTTP concepts or to the underlying - //! HTTP library when needed. - - #[doc(inline)] - pub use rocket_http::*; - - /// Re-exported hyper HTTP library types. - /// - /// All types that are re-exported from Hyper reside inside of this module. - /// These types will, with certainty, be removed with time, but they reside here - /// while necessary. - pub mod hyper { - #[doc(hidden)] - pub use rocket_http::hyper::*; - - pub use rocket_http::hyper::header; - } - - #[doc(inline)] - pub use crate::cookies::*; -} - +pub mod http; +pub mod listener; +#[cfg(feature = "tls")] +#[cfg_attr(nightly, doc(cfg(feature = "tls")))] +pub mod tls; #[cfg(feature = "mtls")] #[cfg_attr(nightly, doc(cfg(feature = "mtls")))] pub mod mtls; -/// TODO: We need a futures mod or something. -mod trip_wire; +mod util; mod shutdown; mod server; -mod ext; +mod lifecycle; mod state; -mod cookies; mod rocket; mod router; mod phase; +mod erased; + +#[doc(hidden)] pub use either::Either; + +#[doc(inline)] pub use rocket_codegen::*; #[doc(inline)] pub use crate::response::Response; #[doc(inline)] pub use crate::data::Data; #[doc(inline)] pub use crate::config::Config; #[doc(inline)] pub use crate::catcher::Catcher; #[doc(inline)] pub use crate::route::Route; -#[doc(hidden)] pub use either::Either; -#[doc(inline)] pub use phase::{Phase, Build, Ignite, Orbit}; -#[doc(inline)] pub use error::Error; -#[doc(inline)] pub use sentinel::Sentinel; +#[doc(inline)] pub use crate::phase::{Phase, Build, Ignite, Orbit}; +#[doc(inline)] pub use crate::error::Error; +#[doc(inline)] pub use crate::sentinel::Sentinel; #[doc(inline)] pub use crate::request::Request; #[doc(inline)] pub use crate::rocket::Rocket; #[doc(inline)] pub use crate::shutdown::Shutdown; #[doc(inline)] pub use crate::state::State; -#[doc(inline)] pub use rocket_codegen::*; - -/// Creates a [`Rocket`] instance with the default config provider: aliases -/// [`Rocket::build()`]. -pub fn build() -> Rocket { - Rocket::build() -} - -/// Creates a [`Rocket`] instance with a custom config provider: aliases -/// [`Rocket::custom()`]. -pub fn custom(provider: T) -> Rocket { - Rocket::custom(provider) -} /// Retrofits support for `async fn` in trait impls and declarations. /// @@ -231,6 +206,20 @@ pub fn custom(provider: T) -> Rocket { #[doc(inline)] pub use async_trait::async_trait; +const WORKER_PREFIX: &'static str = "rocket-worker"; + +/// Creates a [`Rocket`] instance with the default config provider: aliases +/// [`Rocket::build()`]. +pub fn build() -> Rocket { + Rocket::build() +} + +/// Creates a [`Rocket`] instance with a custom config provider: aliases +/// [`Rocket::custom()`]. +pub fn custom(provider: T) -> Rocket { + Rocket::custom(provider) +} + /// WARNING: This is unstable! Do not use this method outside of Rocket! #[doc(hidden)] pub fn async_run(fut: F, workers: usize, sync: usize, force_end: bool, name: &str) -> R @@ -255,7 +244,7 @@ pub fn async_run(fut: F, workers: usize, sync: usize, force_end: bool, nam /// WARNING: This is unstable! Do not use this method outside of Rocket! #[doc(hidden)] pub fn async_test(fut: impl std::future::Future) -> R { - async_run(fut, 1, 32, true, "rocket-worker-test-thread") + async_run(fut, 1, 32, true, &format!("{WORKER_PREFIX}-test-thread")) } /// WARNING: This is unstable! Do not use this method outside of Rocket! @@ -276,7 +265,7 @@ pub fn async_main(fut: impl std::future::Future + Send) -> R { let workers = fig.extract_inner(Config::WORKERS).unwrap_or_else(bail); let max_blocking = fig.extract_inner(Config::MAX_BLOCKING).unwrap_or_else(bail); let force = fig.focus(Config::SHUTDOWN).extract_inner("force").unwrap_or_else(bail); - async_run(fut, workers, max_blocking, force, "rocket-worker-thread") + async_run(fut, workers, max_blocking, force, &format!("{WORKER_PREFIX}-thread")) } /// Executes a `future` to completion on a new tokio-based Rocket async runtime. @@ -359,3 +348,14 @@ pub fn execute(future: F) -> R { async_main(future) } + +/// Returns a future that evalutes to `true` exactly when there is a presently +/// running tokio async runtime that was likely started by Rocket. +fn running_within_rocket_async_rt() -> impl std::future::Future { + use futures::FutureExt; + + tokio::task::spawn_blocking(|| { + let this = std::thread::current(); + this.name().map_or(false, |s| s.starts_with(WORKER_PREFIX)) + }).map(|r| r.unwrap_or(false)) +} diff --git a/core/lib/src/lifecycle.rs b/core/lib/src/lifecycle.rs new file mode 100644 index 0000000000..1759af5c32 --- /dev/null +++ b/core/lib/src/lifecycle.rs @@ -0,0 +1,272 @@ +use yansi::Paint; +use futures::future::{FutureExt, Future}; + +use crate::{route, Rocket, Orbit, Request, Response, Data}; +use crate::data::IoHandler; +use crate::http::{Method, Status, Header}; +use crate::outcome::Outcome; +use crate::form::Form; + +// A token returned to force the execution of one method before another. +pub(crate) struct RequestToken; + +async fn catch_handle(name: Option<&str>, run: F) -> Option + where F: FnOnce() -> Fut, Fut: Future, +{ + macro_rules! panic_info { + ($name:expr, $e:expr) => {{ + match $name { + Some(name) => error_!("Handler {} panicked.", name.primary()), + None => error_!("A handler panicked.") + }; + + info_!("This is an application bug."); + info_!("A panic in Rust must be treated as an exceptional event."); + info_!("Panicking is not a suitable error handling mechanism."); + info_!("Unwinding, the result of a panic, is an expensive operation."); + info_!("Panics will degrade application performance."); + info_!("Instead of panicking, return `Option` and/or `Result`."); + info_!("Values of either type can be returned directly from handlers."); + warn_!("A panic is treated as an internal server error."); + $e + }} + } + + let run = std::panic::AssertUnwindSafe(run); + let fut = std::panic::catch_unwind(move || run()) + .map_err(|e| panic_info!(name, e)) + .ok()?; + + std::panic::AssertUnwindSafe(fut) + .catch_unwind() + .await + .map_err(|e| panic_info!(name, e)) + .ok() +} + +impl Rocket { + /// Preprocess the request for Rocket things. Currently, this means: + /// + /// * Rewriting the method in the request if _method form field exists. + /// * Run the request fairings. + /// + /// This is the only place during lifecycle processing that `Request` is + /// mutable. Keep this in-sync with the `FromForm` derive. + pub(crate) async fn preprocess( + &self, + req: &mut Request<'_>, + data: &mut Data<'_> + ) -> RequestToken { + // Check if this is a form and if the form contains the special _method + // field which we use to reinterpret the request's method. + let (min_len, max_len) = ("_method=get".len(), "_method=delete".len()); + let peek_buffer = data.peek(max_len).await; + let is_form = req.content_type().map_or(false, |ct| ct.is_form()); + + if is_form && req.method() == Method::Post && peek_buffer.len() >= min_len { + let method = std::str::from_utf8(peek_buffer).ok() + .and_then(|raw_form| Form::values(raw_form).next()) + .filter(|field| field.name == "_method") + .and_then(|field| field.value.parse().ok()); + + if let Some(method) = method { + req.set_method(method); + } + } + + // Run request fairings. + self.fairings.handle_request(req, data).await; + + RequestToken + } + + /// Dispatches the request to the router and processes the outcome to + /// produce a response. If the initial outcome is a *forward* and the + /// request was a HEAD request, the request is rewritten and rerouted as a + /// GET. This is automatic HEAD handling. + /// + /// After performing the above, if the outcome is a forward or error, the + /// appropriate error catcher is invoked to produce the response. Otherwise, + /// the successful response is used directly. + /// + /// Finally, new cookies in the cookie jar are added to the response, + /// Rocket-specific headers are written, and response fairings are run. Note + /// that error responses have special cookie handling. See `handle_error`. + pub(crate) async fn dispatch<'r, 's: 'r>( + &'s self, + _token: RequestToken, + request: &'r Request<'s>, + data: Data<'r>, + // io_stream: impl Future> + Send, + ) -> Response<'r> { + info!("{}:", request); + + // Remember if the request is `HEAD` for later body stripping. + let was_head_request = request.method() == Method::Head; + + // Route the request and run the user's handlers. + let mut response = match self.route(request, data).await { + Outcome::Success(response) => response, + Outcome::Forward((data, _)) if request.method() == Method::Head => { + info_!("Autohandling {} request.", "HEAD".primary().bold()); + + // Dispatch the request again with Method `GET`. + request._set_method(Method::Get); + match self.route(request, data).await { + Outcome::Success(response) => response, + Outcome::Error(status) => self.dispatch_error(status, request).await, + Outcome::Forward((_, status)) => self.dispatch_error(status, request).await, + } + } + Outcome::Forward((_, status)) => self.dispatch_error(status, request).await, + Outcome::Error(status) => self.dispatch_error(status, request).await, + }; + + // Set the cookies. Note that error responses will only include cookies + // set by the error handler. See `handle_error` for more. + let delta_jar = request.cookies().take_delta_jar(); + for cookie in delta_jar.delta() { + response.adjoin_header(cookie); + } + + // Add a default 'Server' header if it isn't already there. + // TODO: If removing Hyper, write out `Date` header too. + if let Some(ident) = request.rocket().config.ident.as_str() { + if !response.headers().contains("Server") { + response.set_header(Header::new("Server", ident)); + } + } + + // Run the response fairings. + self.fairings.handle_response(request, &mut response).await; + + // Strip the body if this is a `HEAD` request. + if was_head_request { + response.strip_body(); + } + + // TODO: Should upgrades be handled here? We miss them on local clients. + response + } + + pub(crate) fn extract_io_handler<'r>( + request: &Request<'_>, + response: &mut Response<'r>, + // io_stream: impl Future> + Send, + ) -> Option> { + let upgrades = request.headers().get("upgrade"); + let Ok(upgrade) = response.search_upgrades(upgrades) else { + warn_!("Request wants upgrade but no I/O handler matched."); + info_!("Request is not being upgraded."); + return None; + }; + + if let Some((proto, io_handler)) = upgrade { + info_!("Attemping upgrade with {proto} I/O handler."); + response.set_status(Status::SwitchingProtocols); + response.set_raw_header("Connection", "Upgrade"); + response.set_raw_header("Upgrade", proto.to_string()); + return Some(io_handler); + } + + None + } + + /// Calls the handler for each matching route until one of the handlers + /// returns success or error, or there are no additional routes to try, in + /// which case a `Forward` with the last forwarding state is returned. + #[inline] + async fn route<'s, 'r: 's>( + &'s self, + request: &'r Request<'s>, + mut data: Data<'r>, + ) -> route::Outcome<'r> { + // Go through all matching routes until we fail or succeed or run out of + // routes to try, in which case we forward with the last status. + let mut status = Status::NotFound; + for route in self.router.route(request) { + // Retrieve and set the requests parameters. + info_!("Matched: {}", route); + request.set_route(route); + + let name = route.name.as_deref(); + let outcome = catch_handle(name, || route.handler.handle(request, data)).await + .unwrap_or(Outcome::Error(Status::InternalServerError)); + + // Check if the request processing completed (Some) or if the + // request needs to be forwarded. If it does, continue the loop + // (None) to try again. + info_!("{}", outcome.log_display()); + match outcome { + o@Outcome::Success(_) | o@Outcome::Error(_) => return o, + Outcome::Forward(forwarded) => (data, status) = forwarded, + } + } + + error_!("No matching routes for {}.", request); + Outcome::Forward((data, status)) + } + + // Invokes the catcher for `status`. Returns the response on success. + // + // Resets the cookie jar delta state to prevent any modifications from + // earlier unsuccessful paths from being reflected in the error response. + // + // On catcher error, the 500 error catcher is attempted. If _that_ errors, + // the (infallible) default 500 error cather is used. + pub(crate) async fn dispatch_error<'r, 's: 'r>( + &'s self, + mut status: Status, + req: &'r Request<'s> + ) -> Response<'r> { + // We may wish to relax this in the future. + req.cookies().reset_delta(); + + // Dispatch to the `status` catcher. + if let Ok(r) = self.invoke_catcher(status, req).await { + return r; + } + + // If it fails and it's not a 500, try the 500 catcher. + if status != Status::InternalServerError { + error_!("Catcher failed. Attempting 500 error catcher."); + status = Status::InternalServerError; + if let Ok(r) = self.invoke_catcher(status, req).await { + return r; + } + } + + // If it failed again or if it was already a 500, use Rocket's default. + error_!("{} catcher failed. Using Rocket default 500.", status.code); + crate::catcher::default_handler(Status::InternalServerError, req) + } + + /// Invokes the handler with `req` for catcher with status `status`. + /// + /// In order of preference, invoked handler is: + /// * the user's registered handler for `status` + /// * the user's registered `default` handler + /// * Rocket's default handler for `status` + /// + /// Return `Ok(result)` if the handler succeeded. Returns `Ok(Some(Status))` + /// if the handler ran to completion but failed. Returns `Ok(None)` if the + /// handler panicked while executing. + async fn invoke_catcher<'s, 'r: 's>( + &'s self, + status: Status, + req: &'r Request<'s> + ) -> Result, Option> { + if let Some(catcher) = self.router.catch(status, req) { + warn_!("Responding with registered {} catcher.", catcher); + let name = catcher.name.as_deref(); + catch_handle(name, || catcher.handler.handle(status, req)).await + .map(|result| result.map_err(Some)) + .unwrap_or_else(|| Err(None)) + } else { + let code = status.code.blue().bold(); + warn_!("No {} catcher registered. Using Rocket default.", code); + Ok(crate::catcher::default_handler(status, req)) + } + } + +} diff --git a/core/lib/src/listener/bindable.rs b/core/lib/src/listener/bindable.rs new file mode 100644 index 0000000000..6702138239 --- /dev/null +++ b/core/lib/src/listener/bindable.rs @@ -0,0 +1,40 @@ +use futures::TryFutureExt; + +use crate::listener::Listener; + +pub trait Bindable: Sized { + type Listener: Listener + 'static; + + type Error: std::error::Error + Send + 'static; + + async fn bind(self) -> Result; +} + +impl Bindable for L { + type Listener = L; + + type Error = std::convert::Infallible; + + async fn bind(self) -> Result { + Ok(self) + } +} + +impl Bindable for either::Either { + type Listener = tokio_util::either::Either; + + type Error = either::Either; + + async fn bind(self) -> Result { + match self { + either::Either::Left(a) => a.bind() + .map_ok(tokio_util::either::Either::Left) + .map_err(either::Either::Left) + .await, + either::Either::Right(b) => b.bind() + .map_ok(tokio_util::either::Either::Right) + .map_err(either::Either::Right) + .await, + } + } +} diff --git a/core/lib/src/listener/bounced.rs b/core/lib/src/listener/bounced.rs new file mode 100644 index 0000000000..c8e4203bec --- /dev/null +++ b/core/lib/src/listener/bounced.rs @@ -0,0 +1,58 @@ +use std::{io, time::Duration}; + +use crate::listener::{Listener, Endpoint}; + +static DURATION: Duration = Duration::from_millis(250); + +pub struct Bounced { + listener: L, +} + +pub trait BouncedExt: Sized { + fn bounced(self) -> Bounced { + Bounced { listener: self } + } +} + +impl BouncedExt for L { } + +fn is_recoverable(e: &io::Error) -> bool { + matches!(e.kind(), + | io::ErrorKind::ConnectionRefused + | io::ErrorKind::ConnectionAborted + | io::ErrorKind::ConnectionReset) +} + +impl Bounced { + #[inline] + pub async fn accept_next(&self) -> ::Accept { + loop { + match self.listener.accept().await { + Ok(accept) => return accept, + Err(e) if is_recoverable(&e) => warn!("recoverable connection error: {e}"), + Err(e) => { + warn!("accept error: {e} [retrying in {}ms]", DURATION.as_millis()); + tokio::time::sleep(DURATION).await; + } + }; + } + } +} + +impl Listener for Bounced { + type Accept = L::Accept; + + type Connection = L::Connection; + + async fn accept(&self) -> io::Result { + Ok(self.accept_next().await) + } + + async fn connect(&self, accept: Self::Accept) -> io::Result { + self.listener.connect(accept).await + } + + fn socket_addr(&self) -> io::Result { + self.listener.socket_addr() + } +} diff --git a/core/lib/src/listener/cancellable.rs b/core/lib/src/listener/cancellable.rs new file mode 100644 index 0000000000..fbabfb2c6d --- /dev/null +++ b/core/lib/src/listener/cancellable.rs @@ -0,0 +1,273 @@ +use std::io; +use std::time::Duration; +use std::task::{Poll, Context}; +use std::pin::Pin; + +use tokio::time::{sleep, Sleep}; +use tokio::io::{AsyncRead, AsyncWrite, ReadBuf}; +use futures::{StreamExt, future::{select, Either, Fuse, Future, FutureExt}}; +use pin_project_lite::pin_project; + +use crate::{config, Shutdown}; +use crate::listener::{Listener, Connection, Certificates, Bounced, Endpoint}; + +// Rocket wraps all connections in a `CancellableIo` struct, an internal +// structure that gracefully closes I/O when it receives a signal. That signal +// is the `shutdown` future. When the future resolves, `CancellableIo` begins to +// terminate in grace, mercy, and finally force close phases. Since all +// connections are wrapped in `CancellableIo`, this eventually ends all I/O. +// +// At that point, unless a user spawned an infinite, stand-alone task that isn't +// monitoring `Shutdown`, all tasks should resolve. This means that all +// instances of the shared `Arc` are dropped and we can return the owned +// instance of `Rocket`. +// +// Unfortunately, the Hyper `server` future resolves as soon as it has finished +// processing requests without respect for ongoing responses. That is, `server` +// resolves even when there are running tasks that are generating a response. +// So, `server` resolving implies little to nothing about the state of +// connections. As a result, we depend on the timing of grace + mercy + some +// buffer to determine when all connections should be closed, thus all tasks +// should be complete, thus all references to `Arc` should be dropped +// and we can get a unique reference. +pin_project! { + pub struct CancellableListener { + pub trigger: F, + #[pin] + pub listener: L, + pub grace: Duration, + pub mercy: Duration, + } +} + +pin_project! { + /// I/O that can be cancelled when a future `F` resolves. + #[must_use = "futures do nothing unless polled"] + pub struct CancellableIo { + #[pin] + io: Option, + #[pin] + trigger: Fuse, + state: State, + grace: Duration, + mercy: Duration, + } +} + +enum State { + /// I/O has not been cancelled. Proceed as normal. + Active, + /// I/O has been cancelled. See if we can finish before the timer expires. + Grace(Pin>), + /// Grace period elapsed. Shutdown the connection, waiting for the timer + /// until we force close. + Mercy(Pin>), +} + +pub trait CancellableExt: Sized { + fn cancellable( + self, + trigger: Shutdown, + config: &config::Shutdown + ) -> CancellableListener { + if let Some(mut stream) = config.signal_stream() { + let trigger = trigger.clone(); + tokio::spawn(async move { + while let Some(sig) = stream.next().await { + if trigger.0.tripped() { + warn!("Received {}. Shutdown already in progress.", sig); + } else { + warn!("Received {}. Requesting shutdown.", sig); + } + + trigger.0.trip(); + } + }); + }; + + CancellableListener { + trigger, + listener: self, + grace: config.grace(), + mercy: config.mercy(), + } + } +} + +impl CancellableExt for L { } + +fn time_out() -> io::Error { + io::Error::new(io::ErrorKind::TimedOut, "Shutdown grace timed out") +} + +fn gone() -> io::Error { + io::Error::new(io::ErrorKind::BrokenPipe, "IO driver has terminated") +} + +impl CancellableListener> + where L: Listener + Sync, + F: Future + Unpin + Clone + Send + Sync + 'static +{ + pub async fn accept_next(&self) -> Option<::Accept> { + let next = std::pin::pin!(self.listener.accept_next()); + match select(next, self.trigger.clone()).await { + Either::Left((next, _)) => Some(next), + Either::Right(_) => None, + } + } +} + +impl CancellableListener + where L: Listener + Sync, + F: Future + Clone + Send + Sync + 'static +{ + fn io(&self, conn: C) -> CancellableIo { + CancellableIo { + io: Some(conn), + trigger: self.trigger.clone().fuse(), + state: State::Active, + grace: self.grace, + mercy: self.mercy, + } + } +} + +impl Listener for CancellableListener + where L: Listener + Sync, + F: Future + Clone + Send + Sync + Unpin + 'static +{ + type Accept = L::Accept; + + type Connection = CancellableIo; + + async fn accept(&self) -> io::Result { + let accept = std::pin::pin!(self.listener.accept()); + match select(accept, self.trigger.clone()).await { + Either::Left((result, _)) => result, + Either::Right(_) => Err(gone()), + } + } + + async fn connect(&self, accept: Self::Accept) -> io::Result { + let conn = std::pin::pin!(self.listener.connect(accept)); + match select(conn, self.trigger.clone()).await { + Either::Left((conn, _)) => Ok(self.io(conn?)), + Either::Right(_) => Err(gone()), + } + } + + fn socket_addr(&self) -> io::Result { + self.listener.socket_addr() + } +} + +impl CancellableIo { + fn inner(&self) -> Option<&I> { + self.io.as_ref() + } + + /// Run `do_io` while connection processing should continue. + fn poll_trigger_then( + mut self: Pin<&mut Self>, + cx: &mut Context<'_>, + do_io: impl FnOnce(Pin<&mut I>, &mut Context<'_>) -> Poll>, + ) -> Poll> { + let mut me = self.as_mut().project(); + let io = match me.io.as_pin_mut() { + Some(io) => io, + None => return Poll::Ready(Err(gone())), + }; + + loop { + match me.state { + State::Active => { + if me.trigger.as_mut().poll(cx).is_ready() { + *me.state = State::Grace(Box::pin(sleep(*me.grace))); + } else { + return do_io(io, cx); + } + } + State::Grace(timer) => { + if timer.as_mut().poll(cx).is_ready() { + *me.state = State::Mercy(Box::pin(sleep(*me.mercy))); + } else { + return do_io(io, cx); + } + } + State::Mercy(timer) => { + if timer.as_mut().poll(cx).is_ready() { + self.project().io.set(None); + return Poll::Ready(Err(time_out())); + } else { + let result = futures::ready!(io.poll_shutdown(cx)); + self.project().io.set(None); + return match result { + Err(e) => Poll::Ready(Err(e)), + Ok(()) => Poll::Ready(Err(gone())) + }; + } + }, + } + } + } +} + +impl AsyncRead for CancellableIo { + fn poll_read( + mut self: Pin<&mut Self>, + cx: &mut Context<'_>, + buf: &mut ReadBuf<'_>, + ) -> Poll> { + self.as_mut().poll_trigger_then(cx, |io, cx| io.poll_read(cx, buf)) + } +} + +impl AsyncWrite for CancellableIo { + fn poll_write( + mut self: Pin<&mut Self>, + cx: &mut Context<'_>, + buf: &[u8], + ) -> Poll> { + self.as_mut().poll_trigger_then(cx, |io, cx| io.poll_write(cx, buf)) + } + + fn poll_flush( + mut self: Pin<&mut Self>, + cx: &mut Context<'_> + ) -> Poll> { + self.as_mut().poll_trigger_then(cx, |io, cx| io.poll_flush(cx)) + } + + fn poll_shutdown( + mut self: Pin<&mut Self>, + cx: &mut Context<'_> + ) -> Poll> { + self.as_mut().poll_trigger_then(cx, |io, cx| io.poll_shutdown(cx)) + } + + fn poll_write_vectored( + mut self: Pin<&mut Self>, + cx: &mut Context<'_>, + bufs: &[io::IoSlice<'_>], + ) -> Poll> { + self.as_mut().poll_trigger_then(cx, |io, cx| io.poll_write_vectored(cx, bufs)) + } + + fn is_write_vectored(&self) -> bool { + self.inner().map(|io| io.is_write_vectored()).unwrap_or(false) + } +} + +impl Connection for CancellableIo + where F: Unpin + Send + 'static +{ + fn peer_address(&self) -> io::Result { + self.inner() + .ok_or_else(|| gone()) + .and_then(|io| io.peer_address()) + } + + fn peer_certificates(&self) -> Option> { + self.inner().and_then(|io| io.peer_certificates()) + } +} diff --git a/core/lib/src/listener/connection.rs b/core/lib/src/listener/connection.rs new file mode 100644 index 0000000000..68541109e0 --- /dev/null +++ b/core/lib/src/listener/connection.rs @@ -0,0 +1,93 @@ +use std::io; +use std::borrow::Cow; + +use tokio_util::either::Either; +use tokio::io::{AsyncRead, AsyncWrite}; + +use super::Endpoint; + +/// A collection of raw certificate data. +#[derive(Clone)] +pub struct Certificates<'r>(Cow<'r, [der::CertificateDer<'r>]>); + +pub trait Connection: AsyncRead + AsyncWrite + Send + Unpin { + fn peer_address(&self) -> io::Result; + + /// DER-encoded X.509 certificate chain presented by the client, if any. + /// + /// The certificate order must be as it appears in the TLS protocol: the + /// first certificate relates to the peer, the second certifies the first, + /// the third certifies the second, and so on. + /// + /// Defaults to an empty vector to indicate that no certificates were + /// presented. + fn peer_certificates(&self) -> Option> { None } +} + +impl Connection for Either { + fn peer_address(&self) -> io::Result { + match self { + Either::Left(c) => c.peer_address(), + Either::Right(c) => c.peer_address(), + } + } + + fn peer_certificates(&self) -> Option> { + match self { + Either::Left(c) => c.peer_certificates(), + Either::Right(c) => c.peer_certificates(), + } + } +} + +impl Certificates<'_> { + pub fn into_owned(self) -> Certificates<'static> { + let cow = self.0.into_iter() + .map(|der| der.clone().into_owned()) + .collect::>() + .into(); + + Certificates(cow) + } +} + +#[cfg(feature = "mtls")] +#[cfg_attr(nightly, doc(cfg(feature = "mtls")))] +mod der { + use super::*; + + pub use crate::mtls::CertificateDer; + + impl<'r> Certificates<'r> { + pub(crate) fn inner(&self) -> &[CertificateDer<'r>] { + &self.0 + } + } + + impl<'r> From<&'r [CertificateDer<'r>]> for Certificates<'r> { + fn from(value: &'r [CertificateDer<'r>]) -> Self { + Certificates(value.into()) + } + } + + impl From>> for Certificates<'static> { + fn from(value: Vec>) -> Self { + Certificates(value.into()) + } + } +} + +#[cfg(not(feature = "mtls"))] +mod der { + use std::marker::PhantomData; + + /// A thin wrapper over raw, DER-encoded X.509 client certificate data. + #[derive(Clone)] + pub struct CertificateDer<'r>(PhantomData<&'r [u8]>); + + impl CertificateDer<'_> { + pub fn into_owned(self) -> CertificateDer<'static> { + CertificateDer(PhantomData) + } + } +} diff --git a/core/lib/src/listener/default.rs b/core/lib/src/listener/default.rs new file mode 100644 index 0000000000..32f4a650f0 --- /dev/null +++ b/core/lib/src/listener/default.rs @@ -0,0 +1,61 @@ +use either::Either; + +use crate::listener::{Bindable, Endpoint}; +use crate::error::{Error, ErrorKind}; + +#[derive(serde::Deserialize)] +pub struct DefaultListener { + #[serde(default)] + pub address: Endpoint, + pub port: Option, + pub reuse: Option, + #[cfg(feature = "tls")] + pub tls: Option, +} + +#[cfg(not(unix))] type BaseBindable = Either; +#[cfg(unix)] type BaseBindable = Either; + +#[cfg(not(feature = "tls"))] type TlsBindable = Either; +#[cfg(feature = "tls")] type TlsBindable = Either, T>; + +impl DefaultListener { + pub(crate) fn base_bindable(&self) -> Result { + match &self.address { + Endpoint::Tcp(mut address) => { + self.port.map(|port| address.set_port(port)); + Ok(BaseBindable::Left(address)) + }, + #[cfg(unix)] + Endpoint::Unix(path) => { + let uds = super::unix::UdsConfig { path: path.clone(), reuse: self.reuse, }; + Ok(BaseBindable::Right(uds)) + }, + #[cfg(not(unix))] + Endpoint::Unix(_) => { + let msg = "Unix domain sockets unavailable on non-unix platforms."; + let boxed = Box::::from(msg); + Err(Error::new(ErrorKind::Bind(boxed))) + }, + other => { + let msg = format!("unsupported default listener address: {other}"); + let boxed = Box::::from(msg); + Err(Error::new(ErrorKind::Bind(boxed))) + } + } + } + + pub(crate) fn tls_bindable(&self, inner: T) -> TlsBindable { + #[cfg(feature = "tls")] + if let Some(tls) = self.tls.clone() { + return TlsBindable::Left(super::tls::TlsBindable { inner, tls }); + } + + TlsBindable::Right(inner) + } + + pub fn bindable(&self) -> Result { + self.base_bindable() + .map(|b| b.map_either(|b| self.tls_bindable(b), |b| self.tls_bindable(b))) + } +} diff --git a/core/lib/src/listener/endpoint.rs b/core/lib/src/listener/endpoint.rs new file mode 100644 index 0000000000..26640d1d1c --- /dev/null +++ b/core/lib/src/listener/endpoint.rs @@ -0,0 +1,281 @@ +use std::fmt; +use std::path::{Path, PathBuf}; +use std::any::Any; +use std::net::{SocketAddr as TcpAddr, Ipv4Addr, AddrParseError}; +use std::str::FromStr; +use std::sync::Arc; + +use serde::de; + +use crate::http::uncased::AsUncased; + +pub trait EndpointAddr: fmt::Display + fmt::Debug + Sync + Send + Any { } + +impl EndpointAddr for T {} + +#[cfg(not(feature = "tls"))] type TlsInfo = Option<()>; +#[cfg(feature = "tls")] type TlsInfo = Option; + +/// # Conversions +/// +/// * [`&str`] - parse with [`FromStr`] +/// * [`tokio::net::unix::SocketAddr`] - must be path: [`ListenerAddr::Unix`] +/// * [`std::net::SocketAddr`] - infallibly as [ListenerAddr::Tcp] +/// * [`PathBuf`] - infallibly as [`ListenerAddr::Unix`] +// TODO: Rename to something better. `Endpoint`? +#[derive(Debug)] +pub enum Endpoint { + Tcp(TcpAddr), + Unix(PathBuf), + Tls(Arc, TlsInfo), + Custom(Arc), +} + +impl Endpoint { + pub fn new(value: T) -> Endpoint { + Endpoint::Custom(Arc::new(value)) + } + + pub fn tcp(&self) -> Option { + match self { + Endpoint::Tcp(addr) => Some(*addr), + _ => None, + } + } + + pub fn unix(&self) -> Option<&Path> { + match self { + Endpoint::Unix(addr) => Some(addr), + _ => None, + } + } + + pub fn tls(&self) -> Option<&Endpoint> { + match self { + Endpoint::Tls(addr, _) => Some(addr), + _ => None, + } + } + + #[cfg(feature = "tls")] + pub fn tls_config(&self) -> Option<&crate::tls::TlsConfig> { + match self { + Endpoint::Tls(_, Some(ref config)) => Some(config), + _ => None, + } + } + + #[cfg(feature = "mtls")] + pub fn mtls_config(&self) -> Option<&crate::mtls::MtlsConfig> { + match self { + Endpoint::Tls(_, Some(config)) => config.mutual(), + _ => None, + } + } + + pub fn downcast(&self) -> Option<&T> { + match self { + Endpoint::Tcp(addr) => (&*addr as &dyn Any).downcast_ref(), + Endpoint::Unix(addr) => (&*addr as &dyn Any).downcast_ref(), + Endpoint::Custom(addr) => (&*addr as &dyn Any).downcast_ref(), + Endpoint::Tls(inner, ..) => inner.downcast(), + } + } + + pub fn is_tcp(&self) -> bool { + self.tcp().is_some() + } + + pub fn is_unix(&self) -> bool { + self.unix().is_some() + } + + pub fn is_tls(&self) -> bool { + self.tls().is_some() + } + + #[cfg(feature = "tls")] + pub fn with_tls(self, config: crate::tls::TlsConfig) -> Endpoint { + if self.is_tls() { + return self; + } + + Self::Tls(Arc::new(self), Some(config)) + } + + pub fn assume_tls(self) -> Endpoint { + if self.is_tls() { + return self; + } + + Self::Tls(Arc::new(self), None) + } +} + +impl fmt::Display for Endpoint { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + use Endpoint::*; + + match self { + Tcp(addr) => write!(f, "http://{addr}"), + Unix(addr) => write!(f, "unix:{}", addr.display()), + Custom(inner) => inner.fmt(f), + Tls(inner, c) => match (&**inner, c.as_ref()) { + #[cfg(feature = "mtls")] + (Tcp(i), Some(c)) if c.mutual().is_some() => write!(f, "https://{i} (TLS + MTLS)"), + (Tcp(i), _) => write!(f, "https://{i} (TLS)"), + #[cfg(feature = "mtls")] + (i, Some(c)) if c.mutual().is_some() => write!(f, "{i} (TLS + MTLS)"), + (inner, _) => write!(f, "{inner} (TLS)"), + }, + } + } +} + +impl From for Endpoint { + fn from(value: std::net::SocketAddr) -> Self { + Self::Tcp(value) + } +} + +impl From for Endpoint { + fn from(value: std::net::SocketAddrV4) -> Self { + Self::Tcp(value.into()) + } +} + +impl From for Endpoint { + fn from(value: std::net::SocketAddrV6) -> Self { + Self::Tcp(value.into()) + } +} + +impl From for Endpoint { + fn from(value: PathBuf) -> Self { + Self::Unix(value) + } +} + +#[cfg(unix)] +impl TryFrom for Endpoint { + type Error = std::io::Error; + + fn try_from(v: tokio::net::unix::SocketAddr) -> Result { + v.as_pathname() + .ok_or_else(|| std::io::Error::other("unix socket is not path")) + .map(|path| Endpoint::Unix(path.to_path_buf())) + } +} + +impl TryFrom<&str> for Endpoint { + type Error = AddrParseError; + + fn try_from(value: &str) -> Result { + value.parse() + } +} + +impl Default for Endpoint { + fn default() -> Self { + Endpoint::Tcp(TcpAddr::new(Ipv4Addr::LOCALHOST.into(), 8000)) + } +} + +/// Parses an address into a `ListenerAddr`. +/// +/// The syntax is: +/// +/// ```text +/// listener_addr = 'tcp' ':' tcp_addr | 'unix' ':' unix_addr | tcp_addr +/// tcp_addr := IP_ADDR | SOCKET_ADDR +/// unix_addr := PATH +/// +/// IP_ADDR := `std::net::IpAddr` string as defined by Rust +/// SOCKET_ADDR := `std::net::SocketAddr` string as defined by Rust +/// PATH := `PathBuf` (any UTF-8) string as defined by Rust +/// ``` +/// +/// If `IP_ADDR` is specified, the port defaults to `8000`. +impl FromStr for Endpoint { + type Err = AddrParseError; + + fn from_str(string: &str) -> Result { + fn parse_tcp(string: &str, def_port: u16) -> Result { + string.parse().or_else(|_| string.parse().map(|ip| TcpAddr::new(ip, def_port))) + } + + if let Some((proto, string)) = string.split_once(':') { + if proto.trim().as_uncased() == "tcp" { + return parse_tcp(string.trim(), 8000).map(Self::Tcp); + } else if proto.trim().as_uncased() == "unix" { + return Ok(Self::Unix(PathBuf::from(string.trim()))); + } + } + + parse_tcp(string.trim(), 8000).map(Self::Tcp) + } +} + +impl<'de> de::Deserialize<'de> for Endpoint { + fn deserialize>(de: D) -> Result { + struct Visitor; + + impl<'de> de::Visitor<'de> for Visitor { + type Value = Endpoint; + + fn expecting(&self, formatter: &mut fmt::Formatter<'_>) -> fmt::Result { + formatter.write_str("TCP or Unix address") + } + + fn visit_str(self, v: &str) -> Result { + v.parse::().map_err(|e| E::custom(e.to_string())) + } + } + + de.deserialize_any(Visitor) + } +} + +impl Eq for Endpoint { } + +impl PartialEq for Endpoint { + fn eq(&self, other: &Self) -> bool { + match (self, other) { + (Self::Tcp(l0), Self::Tcp(r0)) => l0 == r0, + (Self::Unix(l0), Self::Unix(r0)) => l0 == r0, + (Self::Tls(l0, _), Self::Tls(r0, _)) => l0 == r0, + (Self::Custom(l0), Self::Custom(r0)) => l0.to_string() == r0.to_string(), + _ => false, + } + } +} + +impl PartialEq for Endpoint { + fn eq(&self, other: &std::net::SocketAddr) -> bool { + self.tcp() == Some(*other) + } +} + +impl PartialEq for Endpoint { + fn eq(&self, other: &std::net::SocketAddrV4) -> bool { + self.tcp() == Some((*other).into()) + } +} + +impl PartialEq for Endpoint { + fn eq(&self, other: &std::net::SocketAddrV6) -> bool { + self.tcp() == Some((*other).into()) + } +} + +impl PartialEq for Endpoint { + fn eq(&self, other: &PathBuf) -> bool { + self.unix() == Some(other.as_path()) + } +} + +impl PartialEq for Endpoint { + fn eq(&self, other: &Path) -> bool { + self.unix() == Some(other) + } +} diff --git a/core/lib/src/listener/listener.rs b/core/lib/src/listener/listener.rs new file mode 100644 index 0000000000..8bdbc08c2b --- /dev/null +++ b/core/lib/src/listener/listener.rs @@ -0,0 +1,65 @@ +use std::io; + +use futures::TryFutureExt; +use tokio_util::either::Either; + +use crate::listener::{Connection, Endpoint}; + +pub trait Listener: Send + Sync { + type Accept: Send; + + type Connection: Connection; + + async fn accept(&self) -> io::Result; + + #[crate::async_bound(Send)] + async fn connect(&self, accept: Self::Accept) -> io::Result; + + fn socket_addr(&self) -> io::Result; +} + +impl Listener for &L { + type Accept = L::Accept; + + type Connection = L::Connection; + + async fn accept(&self) -> io::Result { + ::accept(self).await + } + + async fn connect(&self, accept: Self::Accept) -> io::Result { + ::connect(self, accept).await + } + + fn socket_addr(&self) -> io::Result { + ::socket_addr(self) + } +} + +impl Listener for Either { + type Accept = Either; + + type Connection = Either; + + async fn accept(&self) -> io::Result { + match self { + Either::Left(l) => l.accept().map_ok(Either::Left).await, + Either::Right(l) => l.accept().map_ok(Either::Right).await, + } + } + + async fn connect(&self, accept: Self::Accept) -> io::Result { + match (self, accept) { + (Either::Left(l), Either::Left(a)) => l.connect(a).map_ok(Either::Left).await, + (Either::Right(l), Either::Right(a)) => l.connect(a).map_ok(Either::Right).await, + _ => unreachable!() + } + } + + fn socket_addr(&self) -> io::Result { + match self { + Either::Left(l) => l.socket_addr(), + Either::Right(l) => l.socket_addr(), + } + } +} diff --git a/core/lib/src/listener/mod.rs b/core/lib/src/listener/mod.rs new file mode 100644 index 0000000000..244c36c604 --- /dev/null +++ b/core/lib/src/listener/mod.rs @@ -0,0 +1,24 @@ +mod cancellable; +mod bounced; +mod listener; +mod endpoint; +mod connection; +mod bindable; +mod default; + +#[cfg(unix)] +#[cfg_attr(nightly, doc(cfg(unix)))] +pub mod unix; +#[cfg(feature = "tls")] +#[cfg_attr(nightly, doc(cfg(feature = "tls")))] +pub mod tls; +pub mod tcp; + +pub use endpoint::*; +pub use listener::*; +pub use connection::*; +pub use bindable::*; +pub use default::*; + +pub(crate) use cancellable::*; +pub(crate) use bounced::*; diff --git a/core/lib/src/listener/tcp.rs b/core/lib/src/listener/tcp.rs new file mode 100644 index 0000000000..c2e3fd9f3f --- /dev/null +++ b/core/lib/src/listener/tcp.rs @@ -0,0 +1,43 @@ +use std::io; + +#[doc(inline)] +pub use tokio::net::{TcpListener, TcpStream}; + +use crate::listener::{Listener, Bindable, Connection, Endpoint}; + +impl Bindable for std::net::SocketAddr { + type Listener = TcpListener; + + type Error = io::Error; + + async fn bind(self) -> Result { + TcpListener::bind(self).await + } +} + +impl Listener for TcpListener { + type Accept = Self::Connection; + + type Connection = TcpStream; + + async fn accept(&self) -> io::Result { + let conn = self.accept().await?.0; + let _ = conn.set_nodelay(true); + let _ = conn.set_linger(None); + Ok(conn) + } + + async fn connect(&self, conn: Self::Connection) -> io::Result { + Ok(conn) + } + + fn socket_addr(&self) -> io::Result { + self.local_addr().map(Endpoint::Tcp) + } +} + +impl Connection for TcpStream { + fn peer_address(&self) -> io::Result { + self.peer_addr().map(Endpoint::Tcp) + } +} diff --git a/core/lib/src/listener/tls.rs b/core/lib/src/listener/tls.rs new file mode 100644 index 0000000000..ce2b53ffaf --- /dev/null +++ b/core/lib/src/listener/tls.rs @@ -0,0 +1,116 @@ +use std::io; +use std::sync::Arc; + +use serde::Deserialize; +use rustls::server::{ServerSessionMemoryCache, ServerConfig, WebPkiClientVerifier}; +use tokio_rustls::TlsAcceptor; + +use crate::tls::{TlsConfig, Error}; +use crate::tls::util::{load_cert_chain, load_key, load_ca_certs}; +use crate::listener::{Listener, Bindable, Connection, Certificates, Endpoint}; + +#[doc(inline)] +pub use tokio_rustls::server::TlsStream; + +/// A TLS listener over some listener interface L. +pub struct TlsListener { + listener: L, + acceptor: TlsAcceptor, + config: TlsConfig, +} + +#[derive(Clone, Deserialize)] +pub struct TlsBindable { + #[serde(flatten)] + pub inner: I, + pub tls: TlsConfig, +} + +impl TlsConfig { + pub(crate) fn acceptor(&self) -> Result { + let provider = rustls::crypto::CryptoProvider { + cipher_suites: self.ciphers().map(|c| c.into()).collect(), + ..rustls::crypto::ring::default_provider() + }; + + #[cfg(feature = "mtls")] + let verifier = match self.mutual { + Some(ref mtls) => { + let ca_certs = load_ca_certs(&mut mtls.ca_certs_reader()?)?; + let verifier = WebPkiClientVerifier::builder(Arc::new(ca_certs)); + match mtls.mandatory { + true => verifier.build()?, + false => verifier.allow_unauthenticated().build()?, + } + }, + None => WebPkiClientVerifier::no_client_auth(), + }; + + #[cfg(not(feature = "mtls"))] + let verifier = WebPkiClientVerifier::no_client_auth(); + + let key = load_key(&mut self.key_reader()?)?; + let cert_chain = load_cert_chain(&mut self.certs_reader()?)?; + let mut tls_config = ServerConfig::builder_with_provider(Arc::new(provider)) + .with_safe_default_protocol_versions()? + .with_client_cert_verifier(verifier) + .with_single_cert(cert_chain, key)?; + + tls_config.ignore_client_order = self.prefer_server_cipher_order; + tls_config.session_storage = ServerSessionMemoryCache::new(1024); + tls_config.ticketer = rustls::crypto::ring::Ticketer::new()?; + tls_config.alpn_protocols = vec![b"http/1.1".to_vec()]; + if cfg!(feature = "http2") { + tls_config.alpn_protocols.insert(0, b"h2".to_vec()); + } + + Ok(TlsAcceptor::from(Arc::new(tls_config))) + } +} + +impl Bindable for TlsBindable { + type Listener = TlsListener; + + type Error = Error; + + async fn bind(self) -> Result { + Ok(TlsListener { + acceptor: self.tls.acceptor()?, + listener: self.inner.bind().await.map_err(|e| Error::Bind(Box::new(e)))?, + config: self.tls, + }) + } +} + +impl Listener for TlsListener + where L::Connection: Unpin +{ + type Accept = L::Accept; + + type Connection = TlsStream; + + async fn accept(&self) -> io::Result { + self.listener.accept().await + } + + async fn connect(&self, accept: L::Accept) -> io::Result { + let conn = self.listener.connect(accept).await?; + self.acceptor.accept(conn).await + } + + fn socket_addr(&self) -> io::Result { + Ok(self.listener.socket_addr()?.with_tls(self.config.clone())) + } +} + +impl Connection for TlsStream { + fn peer_address(&self) -> io::Result { + Ok(self.get_ref().0.peer_address()?.assume_tls()) + } + + #[cfg(feature = "mtls")] + fn peer_certificates(&self) -> Option> { + let cert_chain = self.get_ref().1.peer_certificates()?; + Some(Certificates::from(cert_chain)) + } +} diff --git a/core/lib/src/listener/unix.rs b/core/lib/src/listener/unix.rs new file mode 100644 index 0000000000..b6dea5870f --- /dev/null +++ b/core/lib/src/listener/unix.rs @@ -0,0 +1,107 @@ +use std::io; +use std::path::PathBuf; + +use tokio::time::{sleep, Duration}; + +use crate::fs::NamedFile; +use crate::listener::{Listener, Bindable, Connection, Endpoint}; +use crate::util::unix; + +pub use tokio::net::UnixStream; + +#[derive(Debug, Clone)] +pub struct UdsConfig { + /// Socket address. + pub path: PathBuf, + /// Recreate a socket that already exists. + pub reuse: Option, +} + +pub struct UdsListener { + path: PathBuf, + lock: Option, + listener: tokio::net::UnixListener, +} + +impl Bindable for UdsConfig { + type Listener = UdsListener; + + type Error = io::Error; + + async fn bind(self) -> Result { + let lock = if self.reuse.unwrap_or(true) { + let lock_ext = match self.path.extension().and_then(|s| s.to_str()) { + Some(ext) if !ext.is_empty() => format!("{}.lock", ext), + _ => "lock".to_string() + }; + + let mut opts = tokio::fs::File::options(); + opts.create(true).write(true); + let lock_path = self.path.with_extension(lock_ext); + let lock_file = NamedFile::open_with(lock_path, &opts).await?; + + unix::lock_exlusive_nonblocking(lock_file.file())?; + if self.path.exists() { + tokio::fs::remove_file(&self.path).await?; + } + + Some(lock_file) + } else { + None + }; + + // Sometimes, we get `AddrInUse`, even though we've tried deleting the + // socket. If all is well, eventually the socket will _really_ be gone, + // and this will succeed. So let's try a few times. + let mut retries = 5; + let listener = loop { + match tokio::net::UnixListener::bind(&self.path) { + Ok(listener) => break listener, + Err(e) if self.path.exists() && lock.is_none() => return Err(e), + Err(_) if retries > 0 => { + retries -= 1; + sleep(Duration::from_millis(100)).await; + }, + Err(e) => return Err(e), + } + }; + + Ok(UdsListener { lock, listener, path: self.path, }) + } +} + +impl Listener for UdsListener { + type Accept = UnixStream; + + type Connection = Self::Accept; + + async fn accept(&self) -> io::Result { + Ok(self.listener.accept().await?.0) + } + + async fn connect(&self, accept:Self::Accept) -> io::Result { + Ok(accept) + } + + fn socket_addr(&self) -> io::Result { + self.listener.local_addr()?.try_into() + } +} + +impl Connection for UnixStream { + fn peer_address(&self) -> io::Result { + self.local_addr()?.try_into() + } +} + +impl Drop for UdsListener { + fn drop(&mut self) { + if let Some(lock) = &self.lock { + let _ = std::fs::remove_file(&self.path); + let _ = std::fs::remove_file(lock.path()); + let _ = unix::unlock_nonblocking(lock.file()); + } else { + let _ = std::fs::remove_file(&self.path); + } + } +} diff --git a/core/lib/src/local/asynchronous/client.rs b/core/lib/src/local/asynchronous/client.rs index ecec4527cc..2a45a33194 100644 --- a/core/lib/src/local/asynchronous/client.rs +++ b/core/lib/src/local/asynchronous/client.rs @@ -4,7 +4,8 @@ use parking_lot::RwLock; use crate::{Rocket, Phase, Orbit, Ignite, Error}; use crate::local::asynchronous::{LocalRequest, LocalResponse}; -use crate::http::{Method, uri::Origin, private::cookie}; +use crate::http::{Method, uri::Origin}; +use crate::listener::Endpoint; /// An `async` client to construct and dispatch local requests. /// @@ -55,9 +56,15 @@ pub struct Client { impl Client { pub(crate) async fn _new( rocket: Rocket

, - tracked: bool + tracked: bool, + secure: bool, ) -> Result { - let rocket = rocket.local_launch().await?; + let mut listener = Endpoint::new("local client"); + if secure { + listener = listener.assume_tls(); + } + + let rocket = rocket.local_launch(listener).await?; let cookies = RwLock::new(cookie::CookieJar::new()); Ok(Client { rocket, cookies, tracked }) } diff --git a/core/lib/src/local/asynchronous/request.rs b/core/lib/src/local/asynchronous/request.rs index 76ed3f3707..4c85c02024 100644 --- a/core/lib/src/local/asynchronous/request.rs +++ b/core/lib/src/local/asynchronous/request.rs @@ -23,7 +23,7 @@ use super::{Client, LocalResponse}; /// let client = Client::tracked(rocket::build()).await.expect("valid rocket"); /// let req = client.post("/") /// .header(ContentType::JSON) -/// .remote("127.0.0.1:8000".parse().unwrap()) +/// .remote("127.0.0.1:8000") /// .cookie(("name", "value")) /// .body(r#"{ "value": 42 }"#); /// @@ -86,14 +86,14 @@ impl<'c> LocalRequest<'c> { if self.inner().uri() == invalid { error!("invalid request URI: {:?}", invalid.path()); return LocalResponse::new(self.request, move |req| { - rocket.handle_error(Status::BadRequest, req) + rocket.dispatch_error(Status::BadRequest, req) }).await } } // Actually dispatch the request. let mut data = Data::local(self.data); - let token = rocket.preprocess_request(&mut self.request, &mut data).await; + let token = rocket.preprocess(&mut self.request, &mut data).await; let response = LocalResponse::new(self.request, move |req| { rocket.dispatch(token, req, data) }).await; diff --git a/core/lib/src/local/asynchronous/response.rs b/core/lib/src/local/asynchronous/response.rs index f91afeb25c..0c3350be8e 100644 --- a/core/lib/src/local/asynchronous/response.rs +++ b/core/lib/src/local/asynchronous/response.rs @@ -53,9 +53,14 @@ use crate::{Request, Response}; /// /// For more, see [the top-level documentation](../index.html#localresponse). pub struct LocalResponse<'c> { - _request: Box>, + // XXX: SAFETY: This (dependent) field must come first due to drop order! response: Response<'c>, cookies: CookieJar<'c>, + _request: Box>, +} + +impl Drop for LocalResponse<'_> { + fn drop(&mut self) { } } impl<'c> LocalResponse<'c> { @@ -64,7 +69,8 @@ impl<'c> LocalResponse<'c> { O: Future> + Send { // `LocalResponse` is a self-referential structure. In particular, - // `inner` can refer to `_request` and its contents. As such, we must + // `response` and `cookies` can refer to `_request` and its contents. As + // such, we must // 1) Ensure `Request` has a stable address. // // This is done by `Box`ing the `Request`, using only the stable @@ -97,7 +103,7 @@ impl<'c> LocalResponse<'c> { cookies.add_original(cookie.into_owned()); } - LocalResponse { cookies, _request: boxed_req, response, } + LocalResponse { _request: boxed_req, cookies, response, } } } } diff --git a/core/lib/src/local/blocking/client.rs b/core/lib/src/local/blocking/client.rs index d3a8b0ef94..f87df009f2 100644 --- a/core/lib/src/local/blocking/client.rs +++ b/core/lib/src/local/blocking/client.rs @@ -30,7 +30,7 @@ pub struct Client { } impl Client { - fn _new(rocket: Rocket

, tracked: bool) -> Result { + fn _new(rocket: Rocket

, tracked: bool, secure: bool) -> Result { let runtime = tokio::runtime::Builder::new_multi_thread() .thread_name("rocket-local-client-worker-thread") .worker_threads(1) @@ -39,7 +39,7 @@ impl Client { .expect("create tokio runtime"); // Initialize the Rocket instance - let inner = Some(runtime.block_on(asynchronous::Client::_new(rocket, tracked))?); + let inner = Some(runtime.block_on(asynchronous::Client::_new(rocket, tracked, secure))?); Ok(Self { inner, runtime: RefCell::new(runtime) }) } @@ -73,7 +73,7 @@ impl Client { #[inline(always)] pub(crate) fn _with_raw_cookies(&self, f: F) -> T - where F: FnOnce(&crate::http::private::cookie::CookieJar) -> T + where F: FnOnce(&cookie::CookieJar) -> T { self.inner()._with_raw_cookies(f) } diff --git a/core/lib/src/local/blocking/request.rs b/core/lib/src/local/blocking/request.rs index f094c60e44..4d8e35373e 100644 --- a/core/lib/src/local/blocking/request.rs +++ b/core/lib/src/local/blocking/request.rs @@ -21,7 +21,7 @@ use super::{Client, LocalResponse}; /// let client = Client::tracked(rocket::build()).expect("valid rocket"); /// let req = client.post("/") /// .header(ContentType::JSON) -/// .remote("127.0.0.1:8000".parse().unwrap()) +/// .remote("127.0.0.1:8000") /// .cookie(("name", "value")) /// .body(r#"{ "value": 42 }"#); /// diff --git a/core/lib/src/local/client.rs b/core/lib/src/local/client.rs index f2b3b922d0..9b09ae74ff 100644 --- a/core/lib/src/local/client.rs +++ b/core/lib/src/local/client.rs @@ -68,7 +68,12 @@ macro_rules! pub_client_impl { /// ``` #[inline(always)] pub $($prefix)? fn tracked(rocket: Rocket

) -> Result { - Self::_new(rocket, true) $(.$suffix)? + Self::_new(rocket, true, false) $(.$suffix)? + } + + #[inline(always)] + pub $($prefix)? fn tracked_secure(rocket: Rocket

) -> Result { + Self::_new(rocket, true, true) $(.$suffix)? } /// Construct a new `Client` from an instance of `Rocket` _without_ @@ -92,7 +97,11 @@ macro_rules! pub_client_impl { /// let client = Client::untracked(rocket); /// ``` pub $($prefix)? fn untracked(rocket: Rocket

) -> Result { - Self::_new(rocket, false) $(.$suffix)? + Self::_new(rocket, false, false) $(.$suffix)? + } + + pub $($prefix)? fn untracked_secure(rocket: Rocket

) -> Result { + Self::_new(rocket, false, true) $(.$suffix)? } /// Terminates `Client` by initiating a graceful shutdown via @@ -135,15 +144,6 @@ macro_rules! pub_client_impl { Self::tracked(rocket.configure(figment)) $(.$suffix)? } - /// Deprecated alias to [`Client::tracked()`]. - #[deprecated( - since = "0.6.0-dev", - note = "choose between `Client::untracked()` and `Client::tracked()`" - )] - pub $($prefix)? fn new(rocket: Rocket

) -> Result { - Self::tracked(rocket) $(.$suffix)? - } - /// Returns a reference to the `Rocket` this client is creating requests /// for. /// diff --git a/core/lib/src/local/request.rs b/core/lib/src/local/request.rs index 78e975957c..1ec0740050 100644 --- a/core/lib/src/local/request.rs +++ b/core/lib/src/local/request.rs @@ -97,24 +97,40 @@ macro_rules! pub_request_impl { self._request_mut().add_header(header.into()); } - /// Set the remote address of this request. + /// Set the remote address of this request to `address`. + /// + /// `address` may be any type that [can be converted into a `ListenerAddr`]. + /// If `address` fails to convert, the remote is left unchanged. + /// + /// [can be converted into a `ListenerAddr`]: crate::listener::ListenerAddr#conversions /// /// # Examples /// /// Set the remote address to "8.8.8.8:80": /// /// ```rust + /// use std::net::{SocketAddrV4, Ipv4Addr}; + /// #[doc = $import] /// /// # Client::_test(|_, request, _| { /// let request: LocalRequest = request; - /// let address = "8.8.8.8:80".parse().unwrap(); - /// let req = request.remote(address); + /// let req = request.remote("8.8.8.8:80"); + /// + /// let addr = SocketAddrV4::new(Ipv4Addr::new(8, 8, 8, 8).into(), 80); + /// assert_eq!(req.inner().remote().unwrap(), &addr); /// # }); /// ``` #[inline] - pub fn remote(mut self, address: std::net::SocketAddr) -> Self { - self.set_remote(address); + pub fn remote(mut self, endpoint: T) -> Self + where T: TryInto + { + if let Ok(endpoint) = endpoint.try_into() { + self.set_remote(endpoint); + } else { + warn!("remote failed to convert"); + } + self } @@ -228,11 +244,13 @@ macro_rules! pub_request_impl { #[cfg(feature = "mtls")] #[cfg_attr(nightly, doc(cfg(feature = "mtls")))] pub fn identity(mut self, reader: C) -> Self { - use crate::http::{tls::util::load_cert_chain, private::Certificates}; + use std::sync::Arc; + use crate::tls::util::load_cert_chain; + use crate::listener::Certificates; let mut reader = std::io::BufReader::new(reader); let certs = load_cert_chain(&mut reader).map(Certificates::from); - self._request_mut().connection.client_certificates = certs.ok(); + self._request_mut().connection.peer_certs = certs.ok().map(Arc::new); self } diff --git a/core/lib/src/mtls.rs b/core/lib/src/mtls.rs deleted file mode 100644 index 441fffb6f7..0000000000 --- a/core/lib/src/mtls.rs +++ /dev/null @@ -1,25 +0,0 @@ -//! Support for mutual TLS client certificates. -//! -//! For details on how to configure mutual TLS, see -//! [`MutualTls`](crate::config::MutualTls) and the [TLS -//! guide](https://rocket.rs/master/guide/configuration/#tls). See -//! [`Certificate`] for a request guard that validated, verifies, and retrieves -//! client certificates. - -#[doc(inline)] -pub use crate::http::tls::mtls::*; - -use crate::request::{Request, FromRequest, Outcome}; -use crate::outcome::{try_outcome, IntoOutcome}; -use crate::http::Status; - -#[crate::async_trait] -impl<'r> FromRequest<'r> for Certificate<'r> { - type Error = Error; - - async fn from_request(req: &'r Request<'_>) -> Outcome { - let certs = req.connection.client_certificates.as_ref().or_forward(Status::Unauthorized); - let data = try_outcome!(try_outcome!(certs).chain_data().or_forward(Status::Unauthorized)); - Certificate::parse(data).or_error(Status::Unauthorized) - } -} diff --git a/core/http/src/tls/mtls.rs b/core/lib/src/mtls/certificate.rs similarity index 50% rename from core/http/src/tls/mtls.rs rename to core/lib/src/mtls/certificate.rs index 417db2f87d..a430b79fb8 100644 --- a/core/http/src/tls/mtls.rs +++ b/core/lib/src/mtls/certificate.rs @@ -1,51 +1,8 @@ -pub mod oid { - //! Lower-level OID types re-exported from - //! [`oid_registry`](https://docs.rs/oid-registry/0.4) and - //! [`der-parser`](https://docs.rs/der-parser/7). - - pub use x509_parser::oid_registry::*; - pub use x509_parser::objects::*; -} - -pub mod bigint { - //! Signed and unsigned big integer types re-exported from - //! [`num_bigint`](https://docs.rs/num-bigint/0.4). - pub use x509_parser::der_parser::num_bigint::*; -} - -pub mod x509 { - //! Lower-level X.509 types re-exported from - //! [`x509_parser`](https://docs.rs/x509-parser/0.13). - //! - //! Lack of documentation is directly inherited from the source crate. - //! Prefer to use Rocket's wrappers when possible. - - pub use x509_parser::certificate::*; - pub use x509_parser::cri_attributes::*; - pub use x509_parser::error::*; - pub use x509_parser::extensions::*; - pub use x509_parser::revocation_list::*; - pub use x509_parser::time::*; - pub use x509_parser::x509::*; - pub use x509_parser::der_parser::der; - pub use x509_parser::der_parser::ber; - pub use x509_parser::traits::*; -} - -use std::fmt; -use std::ops::Deref; -use std::num::NonZeroUsize; - use ref_cast::RefCast; -use x509_parser::nom; -use x509::{ParsedExtension, X509Name, X509Certificate, TbsCertificate, X509Error, FromDer}; -use oid::OID_X509_EXT_SUBJECT_ALT_NAME as SUBJECT_ALT_NAME; - -use crate::listener::CertificateDer; -/// A type alias for [`Result`](std::result::Result) with the error type set to -/// [`Error`]. -pub type Result = std::result::Result; +use crate::mtls::{x509, oid, bigint, Name, Result, Error}; +use crate::request::{Request, FromRequest, Outcome}; +use crate::http::Status; /// A request guard for validated, verified client certificates. /// @@ -143,60 +100,42 @@ pub type Result = std::result::Result; /// ``` #[derive(Debug, PartialEq)] pub struct Certificate<'a> { - x509: X509Certificate<'a>, - data: &'a CertificateDer, + x509: x509::X509Certificate<'a>, + data: &'a CertificateDer<'a>, } -/// An X.509 Distinguished Name (DN) found in a [`Certificate`]. -/// -/// This type is a wrapper over [`x509::X509Name`] with convenient methods and -/// complete documentation. Should the data exposed by the inherent methods not -/// suffice, this type derefs to [`x509::X509Name`]. -#[repr(transparent)] -#[derive(Debug, PartialEq, RefCast)] -pub struct Name<'a>(X509Name<'a>); +pub use rustls::pki_types::CertificateDer; -/// An error returned by the [`Certificate`] request guard. -/// -/// To retrieve this error in a handler, use an `mtls::Result` -/// guard type: -/// -/// ```rust -/// # extern crate rocket; -/// # use rocket::get; -/// use rocket::mtls::{self, Certificate}; -/// -/// #[get("/auth")] -/// fn auth(cert: mtls::Result>) { -/// match cert { -/// Ok(cert) => { /* do something with the client cert */ }, -/// Err(e) => { /* do something with the error */ }, -/// } -/// } -/// ``` -#[derive(Debug, Clone)] -#[non_exhaustive] -pub enum Error { - /// The certificate chain presented by the client had no certificates. - Empty, - /// The certificate contained neither a subject nor a subjectAlt extension. - NoSubject, - /// There is no subject and the subjectAlt is not marked as critical. - NonCriticalSubjectAlt, - /// An error occurred while parsing the certificate. - Parse(X509Error), - /// The certificate parsed partially but is incomplete. - /// - /// If `Some(n)`, then `n` more bytes were expected. Otherwise, the number - /// of expected bytes is unknown. - Incomplete(Option), - /// The certificate contained `.0` bytes of trailing data. - Trailing(usize), +#[crate::async_trait] +impl<'r> FromRequest<'r> for Certificate<'r> { + type Error = Error; + + async fn from_request(req: &'r Request<'_>) -> Outcome { + use crate::outcome::{try_outcome, IntoOutcome}; + + let certs = req.connection + .peer_certs + .as_ref() + .or_forward(Status::Unauthorized); + + let chain = try_outcome!(certs); + Certificate::parse(chain.inner()).or_error(Status::Unauthorized) + } } impl<'a> Certificate<'a> { - fn parse_one(raw: &[u8]) -> Result> { - let (left, x509) = X509Certificate::from_der(raw)?; + /// PRIVATE: For internal Rocket use only! + fn parse<'r>(chain: &'r [CertificateDer<'r>]) -> Result> { + let data = chain.first().ok_or_else(|| Error::Empty)?; + let x509 = Certificate::parse_one(&*data)?; + Ok(Certificate { x509, data }) + } + + fn parse_one(raw: &[u8]) -> Result> { + use oid::OID_X509_EXT_SUBJECT_ALT_NAME as SUBJECT_ALT_NAME; + use x509_parser::traits::FromDer; + + let (left, x509) = x509::X509Certificate::from_der(raw)?; if !left.is_empty() { return Err(Error::Trailing(left.len())); } @@ -204,7 +143,7 @@ impl<'a> Certificate<'a> { // Ensure we have a subject or a subjectAlt. if x509.subject().as_raw().is_empty() { if let Some(ext) = x509.extensions().iter().find(|e| e.oid == SUBJECT_ALT_NAME) { - if !matches!(ext.parsed_extension(), ParsedExtension::SubjectAlternativeName(..)) { + if let x509::ParsedExtension::SubjectAlternativeName(..) = ext.parsed_extension() { return Err(Error::NoSubject); } else if !ext.critical { return Err(Error::NonCriticalSubjectAlt); @@ -218,18 +157,10 @@ impl<'a> Certificate<'a> { } #[inline(always)] - fn inner(&self) -> &TbsCertificate<'a> { + fn inner(&self) -> &x509::TbsCertificate<'a> { &self.x509.tbs_certificate } - /// PRIVATE: For internal Rocket use only! - #[doc(hidden)] - pub fn parse(chain: &[CertificateDer]) -> Result> { - let data = chain.first().ok_or_else(|| Error::Empty)?; - let x509 = Certificate::parse_one(&data.0)?; - Ok(Certificate { x509, data }) - } - /// Returns the serial number of the X.509 certificate. /// /// # Example @@ -387,176 +318,14 @@ impl<'a> Certificate<'a> { /// } /// ``` pub fn as_bytes(&self) -> &'a [u8] { - &self.data.0 + &*self.data } } -impl<'a> Deref for Certificate<'a> { - type Target = TbsCertificate<'a>; +impl<'a> std::ops::Deref for Certificate<'a> { + type Target = x509::TbsCertificate<'a>; fn deref(&self) -> &Self::Target { self.inner() } } - -impl<'a> Name<'a> { - /// Returns the _first_ UTF-8 _string_ common name, if any. - /// - /// Note that common names need not be UTF-8 strings, or strings at all. - /// This method returns the first common name attribute that is. - /// - /// # Example - /// - /// ```rust - /// # #[macro_use] extern crate rocket; - /// use rocket::mtls::Certificate; - /// - /// #[get("/auth")] - /// fn auth(cert: Certificate<'_>) { - /// if let Some(name) = cert.subject().common_name() { - /// println!("Hello, {}!", name); - /// } - /// } - /// ``` - pub fn common_name(&self) -> Option<&'a str> { - self.common_names().next() - } - - /// Returns an iterator over all of the UTF-8 _string_ common names in - /// `self`. - /// - /// Note that common names need not be UTF-8 strings, or strings at all. - /// This method filters the common names in `self` to those that are. Use - /// the raw [`iter_common_name()`](#method.iter_common_name) to iterate over - /// all value types. - /// - /// # Example - /// - /// ```rust - /// # #[macro_use] extern crate rocket; - /// use rocket::mtls::Certificate; - /// - /// #[get("/auth")] - /// fn auth(cert: Certificate<'_>) { - /// for name in cert.issuer().common_names() { - /// println!("Issued by {}.", name); - /// } - /// } - /// ``` - pub fn common_names(&self) -> impl Iterator + '_ { - self.iter_by_oid(&oid::OID_X509_COMMON_NAME).filter_map(|n| n.as_str().ok()) - } - - /// Returns the _first_ UTF-8 _string_ email address, if any. - /// - /// Note that email addresses need not be UTF-8 strings, or strings at all. - /// This method returns the first email address attribute that is. - /// - /// # Example - /// - /// ```rust - /// # #[macro_use] extern crate rocket; - /// use rocket::mtls::Certificate; - /// - /// #[get("/auth")] - /// fn auth(cert: Certificate<'_>) { - /// if let Some(email) = cert.subject().email() { - /// println!("Hello, {}!", email); - /// } - /// } - /// ``` - pub fn email(&self) -> Option<&'a str> { - self.emails().next() - } - - /// Returns an iterator over all of the UTF-8 _string_ email addresses in - /// `self`. - /// - /// Note that email addresses need not be UTF-8 strings, or strings at all. - /// This method filters the email address in `self` to those that are. Use - /// the raw [`iter_email()`](#method.iter_email) to iterate over all value - /// types. - /// - /// # Example - /// - /// ```rust - /// # #[macro_use] extern crate rocket; - /// use rocket::mtls::Certificate; - /// - /// #[get("/auth")] - /// fn auth(cert: Certificate<'_>) { - /// for email in cert.subject().emails() { - /// println!("Reach me at: {}", email); - /// } - /// } - /// ``` - pub fn emails(&self) -> impl Iterator + '_ { - self.iter_by_oid(&oid::OID_PKCS9_EMAIL_ADDRESS).filter_map(|n| n.as_str().ok()) - } - - /// Returns `true` if `self` has no data. - /// - /// When this is the case for a `subject()`, the subject data can be found - /// in the `subjectAlt` [`extension()`](Certificate::extensions()). - /// - /// # Example - /// - /// ```rust - /// # #[macro_use] extern crate rocket; - /// use rocket::mtls::Certificate; - /// - /// #[get("/auth")] - /// fn auth(cert: Certificate<'_>) { - /// let no_data = cert.subject().is_empty(); - /// } - /// ``` - pub fn is_empty(&self) -> bool { - self.0.as_raw().is_empty() - } -} - -impl<'a> Deref for Name<'a> { - type Target = X509Name<'a>; - - fn deref(&self) -> &Self::Target { - &self.0 - } -} - -impl fmt::Display for Name<'_> { - fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - self.0.fmt(f) - } -} - -impl fmt::Display for Error { - fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - match self { - Error::Parse(e) => write!(f, "parse error: {}", e), - Error::Incomplete(_) => write!(f, "incomplete certificate data"), - Error::Trailing(n) => write!(f, "found {} trailing bytes", n), - Error::Empty => write!(f, "empty certificate chain"), - Error::NoSubject => write!(f, "empty subject without subjectAlt"), - Error::NonCriticalSubjectAlt => write!(f, "empty subject without critical subjectAlt"), - } - } -} - -impl From> for Error { - fn from(e: nom::Err) -> Self { - match e { - nom::Err::Incomplete(nom::Needed::Unknown) => Error::Incomplete(None), - nom::Err::Incomplete(nom::Needed::Size(n)) => Error::Incomplete(Some(n)), - nom::Err::Error(e) | nom::Err::Failure(e) => Error::Parse(e), - } - } -} - -impl std::error::Error for Error { - // fn source(&self) -> Option<&(dyn std::error::Error + 'static)> { - // match self { - // Error::Parse(e) => Some(e), - // _ => None - // } - // } -} diff --git a/core/lib/src/mtls/config.rs b/core/lib/src/mtls/config.rs new file mode 100644 index 0000000000..96fc12c9d1 --- /dev/null +++ b/core/lib/src/mtls/config.rs @@ -0,0 +1,212 @@ +use std::io; + +use figment::value::magic::{RelativePathBuf, Either}; +use serde::{Serialize, Deserialize}; + +/// Mutual TLS configuration. +/// +/// Configuration works in concert with the [`mtls`](crate::mtls) module, which +/// provides a request guard to validate, verify, and retrieve client +/// certificates in routes. +/// +/// By default, mutual TLS is disabled and client certificates are not required, +/// validated or verified. To enable mutual TLS, the `mtls` feature must be +/// enabled and support configured via two `tls.mutual` parameters: +/// +/// * `ca_certs` +/// +/// A required path to a PEM file or raw bytes to a DER-encoded X.509 TLS +/// certificate chain for the certificate authority to verify client +/// certificates against. When a path is configured in a file, such as +/// `Rocket.toml`, relative paths are interpreted as relative to the source +/// file's directory. +/// +/// * `mandatory` +/// +/// An optional boolean that control whether client authentication is +/// required. +/// +/// When `true`, client authentication is required. TLS connections where +/// the client does not present a certificate are immediately terminated. +/// When `false`, the client is not required to present a certificate. In +/// either case, if a certificate _is_ presented, it must be valid or the +/// connection is terminated. +/// +/// In a `Rocket.toml`, configuration might look like: +/// +/// ```toml +/// [default.tls.mutual] +/// ca_certs = "/ssl/ca_cert.pem" +/// mandatory = true # when absent, defaults to false +/// ``` +/// +/// Programmatically, configuration might look like: +/// +/// ```rust +/// # #[macro_use] extern crate rocket; +/// use rocket::mtls::MtlsConfig; +/// use rocket::figment::providers::Serialized; +/// +/// #[launch] +/// fn rocket() -> _ { +/// let mtls = MtlsConfig::from_path("/ssl/ca_cert.pem"); +/// rocket::custom(rocket::Config::figment().merge(("tls.mutual", mtls))) +/// } +/// ``` +/// +/// Once mTLS is configured, the [`mtls::Certificate`](crate::mtls::Certificate) +/// request guard can be used to retrieve client certificates in routes. +#[derive(PartialEq, Debug, Clone, Deserialize, Serialize)] +pub struct MtlsConfig { + /// Path to a PEM file with, or raw bytes for, DER-encoded Certificate + /// Authority certificates which will be used to verify client-presented + /// certificates. + // TODO: Support more than one CA root. + pub(crate) ca_certs: Either>, + /// Whether the client is required to present a certificate. + /// + /// When `true`, the client is required to present a valid certificate to + /// proceed with TLS. When `false`, the client is not required to present a + /// certificate. In either case, if a certificate _is_ presented, it must be + /// valid or the connection is terminated. + #[serde(default)] + #[serde(deserialize_with = "figment::util::bool_from_str_or_int")] + pub mandatory: bool, +} + +impl MtlsConfig { + /// Constructs a `MtlsConfig` from a path to a PEM file with a certificate + /// authority `ca_certs` DER-encoded X.509 TLS certificate chain. This + /// method does no validation; it simply creates a structure suitable for + /// passing into a [`TlsConfig`]. + /// + /// These certificates will be used to verify client-presented certificates + /// in TLS connections. + /// + /// # Example + /// + /// ```rust + /// use rocket::mtls::MtlsConfig; + /// + /// let tls_config = MtlsConfig::from_path("/ssl/ca_certs.pem"); + /// ``` + pub fn from_path>(ca_certs: C) -> Self { + MtlsConfig { + ca_certs: Either::Left(ca_certs.as_ref().to_path_buf().into()), + mandatory: Default::default() + } + } + + /// Constructs a `MtlsConfig` from a byte buffer to a certificate authority + /// `ca_certs` DER-encoded X.509 TLS certificate chain. This method does no + /// validation; it simply creates a structure suitable for passing into a + /// [`TlsConfig`]. + /// + /// These certificates will be used to verify client-presented certificates + /// in TLS connections. + /// + /// # Example + /// + /// ```rust + /// use rocket::mtls::MtlsConfig; + /// + /// # let ca_certs_buf = &[]; + /// let mtls_config = MtlsConfig::from_bytes(ca_certs_buf); + /// ``` + pub fn from_bytes(ca_certs: &[u8]) -> Self { + MtlsConfig { + ca_certs: Either::Right(ca_certs.to_vec()), + mandatory: Default::default() + } + } + + /// Sets whether client authentication is required. Disabled by default. + /// + /// When `true`, client authentication will be required. TLS connections + /// where the client does not present a certificate will be immediately + /// terminated. When `false`, the client is not required to present a + /// certificate. In either case, if a certificate _is_ presented, it must be + /// valid or the connection is terminated. + /// + /// # Example + /// + /// ```rust + /// use rocket::mtls::MtlsConfig; + /// + /// # let ca_certs_buf = &[]; + /// let mtls_config = MtlsConfig::from_bytes(ca_certs_buf).mandatory(true); + /// ``` + pub fn mandatory(mut self, mandatory: bool) -> Self { + self.mandatory = mandatory; + self + } + + /// Returns the value of the `ca_certs` parameter. + /// # Example + /// + /// ```rust + /// use rocket::mtls::MtlsConfig; + /// + /// # let ca_certs_buf = &[]; + /// let mtls_config = MtlsConfig::from_bytes(ca_certs_buf).mandatory(true); + /// assert_eq!(mtls_config.ca_certs().unwrap_right(), ca_certs_buf); + /// ``` + pub fn ca_certs(&self) -> either::Either { + match &self.ca_certs { + Either::Left(path) => either::Either::Left(path.relative()), + Either::Right(bytes) => either::Either::Right(&bytes), + } + } + + #[inline(always)] + pub fn ca_certs_reader(&self) -> io::Result> { + crate::tls::config::to_reader(&self.ca_certs) + } +} + +#[cfg(test)] +mod tests { + use std::path::Path; + use figment::{Figment, providers::{Toml, Format}}; + + use crate::mtls::MtlsConfig; + + #[test] + fn test_mtls_config() { + figment::Jail::expect_with(|jail| { + jail.create_file("MTLS.toml", r#" + certs = "/ssl/cert.pem" + key = "/ssl/key.pem" + "#)?; + + let figment = || Figment::from(Toml::file("MTLS.toml")); + figment().extract::().expect_err("no ca"); + + jail.create_file("MTLS.toml", r#" + ca_certs = "/ssl/ca.pem" + "#)?; + + let mtls: MtlsConfig = figment().extract()?; + assert_eq!(mtls.ca_certs().unwrap_left(), Path::new("/ssl/ca.pem")); + assert!(!mtls.mandatory); + + jail.create_file("MTLS.toml", r#" + ca_certs = "/ssl/ca.pem" + mandatory = true + "#)?; + + let mtls: MtlsConfig = figment().extract()?; + assert_eq!(mtls.ca_certs().unwrap_left(), Path::new("/ssl/ca.pem")); + assert!(mtls.mandatory); + + jail.create_file("MTLS.toml", r#" + ca_certs = "relative/ca.pem" + "#)?; + + let mtls: MtlsConfig = figment().extract()?; + assert_eq!(mtls.ca_certs().unwrap_left(), jail.directory().join("relative/ca.pem")); + + Ok(()) + }); + } +} diff --git a/core/lib/src/mtls/error.rs b/core/lib/src/mtls/error.rs new file mode 100644 index 0000000000..56b8d01ee1 --- /dev/null +++ b/core/lib/src/mtls/error.rs @@ -0,0 +1,74 @@ +use std::fmt; +use std::num::NonZeroUsize; + +use crate::mtls::x509::{self, nom}; + +/// An error returned by the [`Certificate`] request guard. +/// +/// To retrieve this error in a handler, use an `mtls::Result` +/// guard type: +/// +/// ```rust +/// # extern crate rocket; +/// # use rocket::get; +/// use rocket::mtls::{self, Certificate}; +/// +/// #[get("/auth")] +/// fn auth(cert: mtls::Result>) { +/// match cert { +/// Ok(cert) => { /* do something with the client cert */ }, +/// Err(e) => { /* do something with the error */ }, +/// } +/// } +/// ``` +#[derive(Debug, Clone)] +#[non_exhaustive] +pub enum Error { + /// The certificate chain presented by the client had no certificates. + Empty, + /// The certificate contained neither a subject nor a subjectAlt extension. + NoSubject, + /// There is no subject and the subjectAlt is not marked as critical. + NonCriticalSubjectAlt, + /// An error occurred while parsing the certificate. + Parse(x509::X509Error), + /// The certificate parsed partially but is incomplete. + /// + /// If `Some(n)`, then `n` more bytes were expected. Otherwise, the number + /// of expected bytes is unknown. + Incomplete(Option), + /// The certificate contained `.0` bytes of trailing data. + Trailing(usize), +} + +impl fmt::Display for Error { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + match self { + Error::Parse(e) => write!(f, "parse error: {}", e), + Error::Incomplete(_) => write!(f, "incomplete certificate data"), + Error::Trailing(n) => write!(f, "found {} trailing bytes", n), + Error::Empty => write!(f, "empty certificate chain"), + Error::NoSubject => write!(f, "empty subject without subjectAlt"), + Error::NonCriticalSubjectAlt => write!(f, "empty subject without critical subjectAlt"), + } + } +} + +impl From> for Error { + fn from(e: nom::Err) -> Self { + match e { + nom::Err::Incomplete(nom::Needed::Unknown) => Error::Incomplete(None), + nom::Err::Incomplete(nom::Needed::Size(n)) => Error::Incomplete(Some(n)), + nom::Err::Error(e) | nom::Err::Failure(e) => Error::Parse(e), + } + } +} + +impl std::error::Error for Error { + // fn source(&self) -> Option<&(dyn std::error::Error + 'static)> { + // match self { + // Error::Parse(e) => Some(e), + // _ => None + // } + // } +} diff --git a/core/lib/src/mtls/mod.rs b/core/lib/src/mtls/mod.rs new file mode 100644 index 0000000000..10ce6dc055 --- /dev/null +++ b/core/lib/src/mtls/mod.rs @@ -0,0 +1,56 @@ +//! Support for mutual TLS client certificates. +//! +//! For details on how to configure mutual TLS, see +//! [`MutualTls`](crate::config::MutualTls) and the [TLS +//! guide](https://rocket.rs/master/guide/configuration/#tls). See +//! [`Certificate`] for a request guard that validated, verifies, and retrieves +//! client certificates. + +pub mod oid { + //! Lower-level OID types re-exported from + //! [`oid_registry`](https://docs.rs/oid-registry/0.4) and + //! [`der-parser`](https://docs.rs/der-parser/7). + + pub use x509_parser::oid_registry::*; + pub use x509_parser::objects::*; +} + +pub mod bigint { + //! Signed and unsigned big integer types re-exported from + //! [`num_bigint`](https://docs.rs/num-bigint/0.4). + pub use x509_parser::der_parser::num_bigint::*; +} + +pub mod x509 { + //! Lower-level X.509 types re-exported from + //! [`x509_parser`](https://docs.rs/x509-parser/0.13). + //! + //! Lack of documentation is directly inherited from the source crate. + //! Prefer to use Rocket's wrappers when possible. + + pub(crate) use x509_parser::nom; + pub use x509_parser::certificate::*; + pub use x509_parser::cri_attributes::*; + pub use x509_parser::error::*; + pub use x509_parser::extensions::*; + pub use x509_parser::revocation_list::*; + pub use x509_parser::time::*; + pub use x509_parser::x509::*; + pub use x509_parser::der_parser::der; + pub use x509_parser::der_parser::ber; + pub use x509_parser::traits::*; +} + +mod certificate; +mod error; +mod name; +mod config; + +pub use error::Error; +pub use name::Name; +pub use config::MtlsConfig; +pub use certificate::{Certificate, CertificateDer}; + +/// A type alias for [`Result`](std::result::Result) with the error type set to +/// [`Error`]. +pub type Result = std::result::Result; diff --git a/core/lib/src/mtls/name.rs b/core/lib/src/mtls/name.rs new file mode 100644 index 0000000000..c6198ace36 --- /dev/null +++ b/core/lib/src/mtls/name.rs @@ -0,0 +1,146 @@ +use std::fmt; +use std::ops::Deref; + +use ref_cast::RefCast; + +use crate::mtls::x509::X509Name; +use crate::mtls::oid; + +/// An X.509 Distinguished Name (DN) found in a [`Certificate`]. +/// +/// This type is a wrapper over [`x509::X509Name`] with convenient methods and +/// complete documentation. Should the data exposed by the inherent methods not +/// suffice, this type derefs to [`x509::X509Name`]. +#[repr(transparent)] +#[derive(Debug, PartialEq, RefCast)] +pub struct Name<'a>(X509Name<'a>); + +impl<'a> Name<'a> { + /// Returns the _first_ UTF-8 _string_ common name, if any. + /// + /// Note that common names need not be UTF-8 strings, or strings at all. + /// This method returns the first common name attribute that is. + /// + /// # Example + /// + /// ```rust + /// # #[macro_use] extern crate rocket; + /// use rocket::mtls::Certificate; + /// + /// #[get("/auth")] + /// fn auth(cert: Certificate<'_>) { + /// if let Some(name) = cert.subject().common_name() { + /// println!("Hello, {}!", name); + /// } + /// } + /// ``` + pub fn common_name(&self) -> Option<&'a str> { + self.common_names().next() + } + + /// Returns an iterator over all of the UTF-8 _string_ common names in + /// `self`. + /// + /// Note that common names need not be UTF-8 strings, or strings at all. + /// This method filters the common names in `self` to those that are. Use + /// the raw [`iter_common_name()`](#method.iter_common_name) to iterate over + /// all value types. + /// + /// # Example + /// + /// ```rust + /// # #[macro_use] extern crate rocket; + /// use rocket::mtls::Certificate; + /// + /// #[get("/auth")] + /// fn auth(cert: Certificate<'_>) { + /// for name in cert.issuer().common_names() { + /// println!("Issued by {}.", name); + /// } + /// } + /// ``` + pub fn common_names(&self) -> impl Iterator + '_ { + self.iter_by_oid(&oid::OID_X509_COMMON_NAME).filter_map(|n| n.as_str().ok()) + } + + /// Returns the _first_ UTF-8 _string_ email address, if any. + /// + /// Note that email addresses need not be UTF-8 strings, or strings at all. + /// This method returns the first email address attribute that is. + /// + /// # Example + /// + /// ```rust + /// # #[macro_use] extern crate rocket; + /// use rocket::mtls::Certificate; + /// + /// #[get("/auth")] + /// fn auth(cert: Certificate<'_>) { + /// if let Some(email) = cert.subject().email() { + /// println!("Hello, {}!", email); + /// } + /// } + /// ``` + pub fn email(&self) -> Option<&'a str> { + self.emails().next() + } + + /// Returns an iterator over all of the UTF-8 _string_ email addresses in + /// `self`. + /// + /// Note that email addresses need not be UTF-8 strings, or strings at all. + /// This method filters the email address in `self` to those that are. Use + /// the raw [`iter_email()`](#method.iter_email) to iterate over all value + /// types. + /// + /// # Example + /// + /// ```rust + /// # #[macro_use] extern crate rocket; + /// use rocket::mtls::Certificate; + /// + /// #[get("/auth")] + /// fn auth(cert: Certificate<'_>) { + /// for email in cert.subject().emails() { + /// println!("Reach me at: {}", email); + /// } + /// } + /// ``` + pub fn emails(&self) -> impl Iterator + '_ { + self.iter_by_oid(&oid::OID_PKCS9_EMAIL_ADDRESS).filter_map(|n| n.as_str().ok()) + } + + /// Returns `true` if `self` has no data. + /// + /// When this is the case for a `subject()`, the subject data can be found + /// in the `subjectAlt` [`extension()`](Certificate::extensions()). + /// + /// # Example + /// + /// ```rust + /// # #[macro_use] extern crate rocket; + /// use rocket::mtls::Certificate; + /// + /// #[get("/auth")] + /// fn auth(cert: Certificate<'_>) { + /// let no_data = cert.subject().is_empty(); + /// } + /// ``` + pub fn is_empty(&self) -> bool { + self.0.as_raw().is_empty() + } +} + +impl<'a> Deref for Name<'a> { + type Target = X509Name<'a>; + + fn deref(&self) -> &Self::Target { + &self.0 + } +} + +impl fmt::Display for Name<'_> { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + self.0.fmt(f) + } +} diff --git a/core/lib/src/phase.rs b/core/lib/src/phase.rs index 3b6ca870c3..b38deeeaf6 100644 --- a/core/lib/src/phase.rs +++ b/core/lib/src/phase.rs @@ -1,6 +1,7 @@ use state::TypeMap; use figment::Figment; +use crate::listener::Endpoint; use crate::{Catcher, Config, Rocket, Route, Shutdown}; use crate::router::Router; use crate::fairing::Fairings; @@ -113,5 +114,6 @@ phases! { pub(crate) config: Config, pub(crate) state: TypeMap![Send + Sync], pub(crate) shutdown: Shutdown, + pub(crate) endpoint: Endpoint, } } diff --git a/core/lib/src/request/atomic_method.rs b/core/lib/src/request/atomic_method.rs new file mode 100644 index 0000000000..6d49f603d5 --- /dev/null +++ b/core/lib/src/request/atomic_method.rs @@ -0,0 +1,43 @@ +use crate::http::Method; + +pub struct AtomicMethod(ref_swap::RefSwap<'static, Method>); + +#[inline(always)] +const fn makeref(method: Method) -> &'static Method { + match method { + Method::Get => &Method::Get, + Method::Put => &Method::Put, + Method::Post => &Method::Post, + Method::Delete => &Method::Delete, + Method::Options => &Method::Options, + Method::Head => &Method::Head, + Method::Trace => &Method::Trace, + Method::Connect => &Method::Connect, + Method::Patch => &Method::Patch, + } +} + +impl AtomicMethod { + pub fn new(value: Method) -> Self { + Self(ref_swap::RefSwap::new(makeref(value))) + } + + pub fn load(&self) -> Method { + *self.0.load(std::sync::atomic::Ordering::Acquire) + } + + pub fn set(&mut self, new: Method) { + *self = Self::new(new); + } + + pub fn store(&self, new: Method) { + self.0.store(makeref(new), std::sync::atomic::Ordering::Release) + } +} + +impl Clone for AtomicMethod { + fn clone(&self) -> Self { + let inner = self.0.load(std::sync::atomic::Ordering::Acquire); + Self(ref_swap::RefSwap::new(inner)) + } +} diff --git a/core/lib/src/request/from_request.rs b/core/lib/src/request/from_request.rs index c95a427f09..279b35cd00 100644 --- a/core/lib/src/request/from_request.rs +++ b/core/lib/src/request/from_request.rs @@ -1,12 +1,13 @@ use std::convert::Infallible; use std::fmt::Debug; -use std::net::{IpAddr, SocketAddr}; +use std::net::IpAddr; use crate::{Request, Route}; use crate::outcome::{self, IntoOutcome, Outcome::*}; use crate::http::uri::{Host, Origin}; use crate::http::{Status, ContentType, Accept, Method, ProxyProto, CookieJar}; +use crate::listener::Endpoint; /// Type alias for the `Outcome` of a `FromRequest` conversion. pub type Outcome = outcome::Outcome; @@ -486,14 +487,22 @@ impl<'r> FromRequest<'r> for ProxyProto<'r> { } #[crate::async_trait] -impl<'r> FromRequest<'r> for SocketAddr { +impl<'r> FromRequest<'r> for &'r Endpoint { type Error = Infallible; async fn from_request(request: &'r Request<'_>) -> Outcome { - match request.remote() { - Some(addr) => Success(addr), - None => Forward(Status::InternalServerError) - } + request.remote().or_forward(Status::InternalServerError) + } +} + +#[crate::async_trait] +impl<'r> FromRequest<'r> for std::net::SocketAddr { + type Error = Infallible; + + async fn from_request(request: &'r Request<'_>) -> Outcome { + request.remote() + .and_then(|r| r.tcp()) + .or_forward(Status::InternalServerError) } } diff --git a/core/lib/src/request/mod.rs b/core/lib/src/request/mod.rs index fe565b2e2a..0393f96b51 100644 --- a/core/lib/src/request/mod.rs +++ b/core/lib/src/request/mod.rs @@ -3,6 +3,7 @@ mod request; mod from_param; mod from_request; +mod atomic_method; #[cfg(test)] mod tests; @@ -15,6 +16,7 @@ pub use self::from_param::{FromParam, FromSegments}; pub use crate::response::flash::FlashMessage; pub(crate) use self::request::ConnectionMeta; +pub(crate) use self::atomic_method::AtomicMethod; crate::export! { /// Store and immediately retrieve a vector-like value `$v` (`String` or diff --git a/core/lib/src/request/request.rs b/core/lib/src/request/request.rs index 1d380a9767..a02dff3a4e 100644 --- a/core/lib/src/request/request.rs +++ b/core/lib/src/request/request.rs @@ -1,22 +1,24 @@ use std::fmt; use std::ops::RangeFrom; -use std::{future::Future, borrow::Cow, sync::Arc}; -use std::net::{IpAddr, SocketAddr}; +use std::sync::{Arc, atomic::Ordering}; +use std::borrow::Cow; +use std::future::Future; +use std::net::IpAddr; use yansi::Paint; use state::{TypeMap, InitCell}; use futures::future::BoxFuture; -use atomic::{Atomic, Ordering}; +use ref_swap::OptionRefSwap; use crate::{Rocket, Route, Orbit}; -use crate::request::{FromParam, FromSegments, FromRequest, Outcome}; +use crate::request::{FromParam, FromSegments, FromRequest, Outcome, AtomicMethod}; use crate::form::{self, ValueField, FromForm}; use crate::data::Limits; -use crate::http::{hyper, Method, Header, HeaderMap, ProxyProto}; -use crate::http::{ContentType, Accept, MediaType, CookieJar, Cookie}; -use crate::http::private::Certificates; +use crate::http::ProxyProto; +use crate::http::{Method, Header, HeaderMap, ContentType, Accept, MediaType, CookieJar, Cookie}; use crate::http::uri::{fmt::Path, Origin, Segments, Host, Authority}; +use crate::listener::{Certificates, Endpoint, Connection}; /// The type of an incoming web request. /// @@ -24,26 +26,37 @@ use crate::http::uri::{fmt::Path, Origin, Segments, Host, Authority}; /// should likely only be used when writing [`FromRequest`] implementations. It /// contains all of the information for a given web request except for the body /// data. This includes the HTTP method, URI, cookies, headers, and more. +#[derive(Clone)] pub struct Request<'r> { - method: Atomic, + method: AtomicMethod, uri: Origin<'r>, headers: HeaderMap<'r>, + pub(crate) errors: Vec, pub(crate) connection: ConnectionMeta, pub(crate) state: RequestState<'r>, } /// Information derived from an incoming connection, if any. -#[derive(Clone)] +#[derive(Clone, Default)] pub(crate) struct ConnectionMeta { - pub remote: Option, + pub peer_address: Option>, #[cfg_attr(not(feature = "mtls"), allow(dead_code))] - pub client_certificates: Option, + pub peer_certs: Option>>, +} + +impl From<&C> for ConnectionMeta { + fn from(conn: &C) -> Self { + ConnectionMeta { + peer_address: conn.peer_address().ok().map(Arc::new), + peer_certs: conn.peer_certificates().map(|c| c.into_owned()).map(Arc::new), + } + } } /// Information derived from the request. pub(crate) struct RequestState<'r> { pub rocket: &'r Rocket, - pub route: Atomic>, + pub route: OptionRefSwap<'r, Route>, pub cookies: CookieJar<'r>, pub accept: InitCell>, pub content_type: InitCell>, @@ -51,23 +64,11 @@ pub(crate) struct RequestState<'r> { pub host: Option>, } -impl Request<'_> { - pub(crate) fn clone(&self) -> Self { - Request { - method: Atomic::new(self.method()), - uri: self.uri.clone(), - headers: self.headers.clone(), - connection: self.connection.clone(), - state: self.state.clone(), - } - } -} - -impl RequestState<'_> { +impl Clone for RequestState<'_> { fn clone(&self) -> Self { RequestState { rocket: self.rocket, - route: Atomic::new(self.route.load(Ordering::Acquire)), + route: OptionRefSwap::new(self.route.load(Ordering::Acquire)), cookies: self.cookies.clone(), accept: self.accept.clone(), content_type: self.content_type.clone(), @@ -87,15 +88,13 @@ impl<'r> Request<'r> { ) -> Request<'r> { Request { uri, - method: Atomic::new(method), + method: AtomicMethod::new(method), headers: HeaderMap::new(), - connection: ConnectionMeta { - remote: None, - client_certificates: None, - }, + errors: Vec::new(), + connection: ConnectionMeta::default(), state: RequestState { rocket, - route: Atomic::new(None), + route: OptionRefSwap::new(None), cookies: CookieJar::new(None, rocket), accept: InitCell::new(), content_type: InitCell::new(), @@ -120,7 +119,7 @@ impl<'r> Request<'r> { /// ``` #[inline(always)] pub fn method(&self) -> Method { - self.method.load(Ordering::Acquire) + self.method.load() } /// Set the method of `self` to `method`. @@ -140,7 +139,7 @@ impl<'r> Request<'r> { /// ``` #[inline(always)] pub fn set_method(&mut self, method: Method) { - self._set_method(method); + self.method.set(method); } /// Borrow the [`Origin`] URI from `self`. @@ -324,20 +323,20 @@ impl<'r> Request<'r> { /// /// assert_eq!(request.remote(), None); /// - /// let localhost = SocketAddrV4::new(Ipv4Addr::LOCALHOST, 8000).into(); + /// let localhost = SocketAddrV4::new(Ipv4Addr::LOCALHOST, 8000); /// request.set_remote(localhost); - /// assert_eq!(request.remote(), Some(localhost)); + /// assert_eq!(request.remote().unwrap(), &localhost); /// ``` #[inline(always)] - pub fn remote(&self) -> Option { - self.connection.remote + pub fn remote(&self) -> Option<&Endpoint> { + self.connection.peer_address.as_deref() } /// Sets the remote address of `self` to `address`. /// /// # Example /// - /// Set the remote address to be 127.0.0.1:8000: + /// Set the remote address to be 127.0.0.1:8111: /// /// ```rust /// use std::net::{SocketAddrV4, Ipv4Addr}; @@ -347,13 +346,13 @@ impl<'r> Request<'r> { /// /// assert_eq!(request.remote(), None); /// - /// let localhost = SocketAddrV4::new(Ipv4Addr::LOCALHOST, 8000).into(); + /// let localhost = SocketAddrV4::new(Ipv4Addr::LOCALHOST, 8111); /// request.set_remote(localhost); - /// assert_eq!(request.remote(), Some(localhost)); + /// assert_eq!(request.remote().unwrap(), &localhost); /// ``` #[inline(always)] - pub fn set_remote(&mut self, address: SocketAddr) { - self.connection.remote = Some(address); + pub fn set_remote>(&mut self, address: A) { + self.connection.peer_address = Some(Arc::new(address.into())); } /// Returns the IP address of the configured @@ -489,25 +488,26 @@ impl<'r> Request<'r> { /// /// ```rust /// # use rocket::http::Header; - /// # use std::net::{SocketAddr, IpAddr, Ipv4Addr}; /// # let c = rocket::local::blocking::Client::debug_with(vec![]).unwrap(); /// # let mut req = c.get("/"); /// # let request = req.inner_mut(); + /// # use std::net::{SocketAddrV4, Ipv4Addr}; /// /// // starting without an "X-Real-IP" header or remote address /// assert!(request.client_ip().is_none()); /// /// // add a remote address; this is done by Rocket automatically - /// request.set_remote("127.0.0.1:8000".parse().unwrap()); - /// assert_eq!(request.client_ip(), Some("127.0.0.1".parse().unwrap())); + /// let localhost_9190 = SocketAddrV4::new(Ipv4Addr::LOCALHOST, 9190); + /// request.set_remote(localhost_9190); + /// assert_eq!(request.client_ip().unwrap(), Ipv4Addr::LOCALHOST); /// /// // now with an X-Real-IP header, the default value for `ip_header`. /// request.add_header(Header::new("X-Real-IP", "8.8.8.8")); - /// assert_eq!(request.client_ip(), Some("8.8.8.8".parse().unwrap())); + /// assert_eq!(request.client_ip().unwrap(), Ipv4Addr::new(8, 8, 8, 8)); /// ``` #[inline] pub fn client_ip(&self) -> Option { - self.real_ip().or_else(|| self.remote().map(|r| r.ip())) + self.real_ip().or_else(|| Some(self.remote()?.tcp()?.ip())) } /// Returns a wrapped borrow to the cookies in `self`. @@ -691,7 +691,7 @@ impl<'r> Request<'r> { if self.method().supports_payload() { self.content_type().map(|ct| ct.media_type()) } else { - // FIXME: Should we be using `accept_first` or `preferred`? Or + // TODO: Should we be using `accept_first` or `preferred`? Or // should we be checking neither and instead pass things through // where the client accepts the thing at all? self.accept() @@ -1056,11 +1056,9 @@ impl<'r> Request<'r> { self.state.route.store(Some(route), Ordering::Release) } - /// Set the method of `self`, even when `self` is a shared reference. Used - /// during routing to override methods for re-routing. #[inline(always)] pub(crate) fn _set_method(&self, method: Method) { - self.method.store(method, Ordering::Release) + self.method.store(method) } pub(crate) fn cookies_mut(&mut self) -> &mut CookieJar<'r> { @@ -1070,18 +1068,28 @@ impl<'r> Request<'r> { /// Convert from Hyper types into a Rocket Request. pub(crate) fn from_hyp( rocket: &'r Rocket, - hyper: &'r hyper::request::Parts, - connection: Option, - ) -> Result, BadRequest<'r>> { + hyper: &'r hyper::http::request::Parts, + connection: ConnectionMeta, + ) -> Result, Request<'r>> { // Keep track of parsing errors; emit a `BadRequest` if any exist. let mut errors = vec![]; // Ensure that the method is known. TODO: Allow made-up methods? - let method = Method::from_hyp(&hyper.method) - .unwrap_or_else(|| { - errors.push(Kind::BadMethod(&hyper.method)); + let method = match hyper.method { + hyper::Method::GET => Method::Get, + hyper::Method::PUT => Method::Put, + hyper::Method::POST => Method::Post, + hyper::Method::DELETE => Method::Delete, + hyper::Method::OPTIONS => Method::Options, + hyper::Method::HEAD => Method::Head, + hyper::Method::TRACE => Method::Trace, + hyper::Method::CONNECT => Method::Connect, + hyper::Method::PATCH => Method::Patch, + _ => { + errors.push(RequestError::BadMethod(hyper.method.clone())); Method::Get - }); + } + }; // TODO: Keep around not just the path/query, but the rest, if there? let uri = hyper.uri.path_and_query() @@ -1100,20 +1108,20 @@ impl<'r> Request<'r> { Origin::new(uri.path(), uri.query().map(Cow::Borrowed)) }) .unwrap_or_else(|| { - errors.push(Kind::InvalidUri(&hyper.uri)); + errors.push(RequestError::InvalidUri(hyper.uri.clone())); Origin::ROOT }); // Construct the request object; fill in metadata and headers next. let mut request = Request::new(rocket, method, uri); + request.errors = errors; // Set the passed in connection metadata. - if let Some(connection) = connection { - request.connection = connection; - } + request.connection = connection; // Determine + set host. On HTTP < 2, use the `HOST` header. Otherwise, // use the `:authority` pseudo-header which hyper makes part of the URI. + // TODO: Use an `InitCell` to compute this later. request.state.host = if hyper.version < hyper::Version::HTTP_2 { hyper.headers.get("host").and_then(|h| Host::parse_bytes(h.as_bytes()).ok()) } else { @@ -1122,9 +1130,8 @@ impl<'r> Request<'r> { // Set the request cookies, if they exist. for header in hyper.headers.get_all("Cookie") { - let raw_str = match std::str::from_utf8(header.as_bytes()) { - Ok(string) => string, - Err(_) => continue + let Ok(raw_str) = std::str::from_utf8(header.as_bytes()) else { + continue }; for cookie_str in raw_str.split(';').map(|s| s.trim()) { @@ -1137,43 +1144,33 @@ impl<'r> Request<'r> { // Set the rest of the headers. This is rather unfortunate and slow. for (name, value) in hyper.headers.iter() { // FIXME: This is rather unfortunate. Header values needn't be UTF8. - let value = match std::str::from_utf8(value.as_bytes()) { - Ok(value) => value, - Err(_) => { - warn!("Header '{}' contains invalid UTF-8", name); - warn_!("Rocket only supports UTF-8 header values. Dropping header."); - continue; - } + let Ok(value) = std::str::from_utf8(value.as_bytes()) else { + warn!("Header '{}' contains invalid UTF-8", name); + warn_!("Rocket only supports UTF-8 header values. Dropping header."); + continue; }; request.add_header(Header::new(name.as_str(), value)); } - if errors.is_empty() { - Ok(request) - } else { - Err(BadRequest { request, errors }) + match request.errors.is_empty() { + true => Ok(request), + false => Err(request), } } } -#[derive(Debug)] -pub(crate) struct BadRequest<'r> { - pub request: Request<'r>, - pub errors: Vec>, -} - -#[derive(Debug)] -pub(crate) enum Kind<'r> { - InvalidUri(&'r hyper::Uri), - BadMethod(&'r hyper::Method), +#[derive(Debug, Clone)] +pub(crate) enum RequestError { + InvalidUri(hyper::Uri), + BadMethod(hyper::Method), } -impl fmt::Display for Kind<'_> { +impl fmt::Display for RequestError { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { match self { - Kind::InvalidUri(u) => write!(f, "invalid origin URI: {}", u), - Kind::BadMethod(m) => write!(f, "invalid or unrecognized method: {}", m), + RequestError::InvalidUri(u) => write!(f, "invalid origin URI: {}", u), + RequestError::BadMethod(m) => write!(f, "invalid or unrecognized method: {}", m), } } } @@ -1181,8 +1178,8 @@ impl fmt::Display for Kind<'_> { impl fmt::Debug for Request<'_> { fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result { fmt.debug_struct("Request") - .field("method", &self.method) - .field("uri", &self.uri) + .field("method", &self.method()) + .field("uri", &self.uri()) .field("headers", &self.headers()) .field("remote", &self.remote()) .field("cookies", &self.cookies()) diff --git a/core/lib/src/request/tests.rs b/core/lib/src/request/tests.rs index a349aeda39..5af4b8543c 100644 --- a/core/lib/src/request/tests.rs +++ b/core/lib/src/request/tests.rs @@ -1,14 +1,16 @@ use std::collections::HashMap; -use crate::Request; +use crate::request::{Request, ConnectionMeta}; use crate::local::blocking::Client; -use crate::http::hyper; macro_rules! assert_headers { ($($key:expr => [$($value:expr),+]),+) => ({ // Create a new Hyper request. Add all of the passed in headers. let mut req = hyper::Request::get("/test").body(()).unwrap(); - $($(req.headers_mut().append($key, hyper::HeaderValue::from_str($value).unwrap());)+)+ + $($( + req.headers_mut() + .append($key, hyper::header::HeaderValue::from_str($value).unwrap()); + )+)+ // Build up what we expect the headers to actually be. let mut expected = HashMap::new(); @@ -17,7 +19,8 @@ macro_rules! assert_headers { // Create a valid `Rocket` and convert the hyper req to a Rocket one. let client = Client::debug_with(vec![]).unwrap(); let hyper = req.into_parts().0; - let req = Request::from_hyp(client.rocket(), &hyper, None).unwrap(); + let meta = ConnectionMeta::default(); + let req = Request::from_hyp(client.rocket(), &hyper, meta).unwrap(); // Dispatch the request and check that the headers match. let actual_headers = req.headers(); diff --git a/core/lib/src/response/response.rs b/core/lib/src/response/response.rs index 588497e1c2..4d399e9174 100644 --- a/core/lib/src/response/response.rs +++ b/core/lib/src/response/response.rs @@ -1,7 +1,6 @@ use std::{fmt, str}; use std::borrow::Cow; use std::collections::HashMap; -use std::pin::Pin; use tokio::io::{AsyncRead, AsyncSeek}; @@ -146,19 +145,18 @@ impl<'r> Builder<'r> { /// potentially different values to be present in the `Response`. /// /// The type of `header` can be any type that implements `Into

`. - /// This includes `Header` itself, [`ContentType`](crate::http::ContentType) and - /// [hyper::header types](crate::http::hyper::header). + /// This includes `Header` itself, [`ContentType`](crate::http::ContentType) + /// and [`Accept`](crate::http::Accept). /// /// # Example /// /// ```rust /// use rocket::Response; - /// use rocket::http::Header; - /// use rocket::http::hyper::header::ACCEPT; + /// use rocket::http::{Header, Accept}; /// /// let response = Response::build() - /// .header_adjoin(Header::new(ACCEPT.as_str(), "application/json")) - /// .header_adjoin(Header::new(ACCEPT.as_str(), "text/plain")) + /// .header_adjoin(Header::new("Accept", "application/json")) + /// .header_adjoin(Accept::XML) /// .finalize(); /// /// assert_eq!(response.headers().get("Accept").count(), 2); @@ -287,7 +285,7 @@ impl<'r> Builder<'r> { /// /// #[rocket::async_trait] /// impl IoHandler for EchoHandler { - /// async fn io(self: Pin>, io: IoStream) -> io::Result<()> { + /// async fn io(self: Box, io: IoStream) -> io::Result<()> { /// let (mut reader, mut writer) = io::split(io); /// io::copy(&mut reader, &mut writer).await?; /// Ok(()) @@ -488,7 +486,7 @@ pub struct Response<'r> { status: Option, headers: HeaderMap<'r>, body: Body<'r>, - upgrade: HashMap, Pin>>, + upgrade: HashMap, Box>, } impl<'r> Response<'r> { @@ -700,23 +698,22 @@ impl<'r> Response<'r> { /// name `header.name`, another header with the same name and value /// `header.value` is added. The type of `header` can be any type that /// implements `Into
`. This includes `Header` itself, - /// [`ContentType`](crate::http::ContentType) and [`hyper::header` - /// types](crate::http::hyper::header). + /// [`ContentType`](crate::http::ContentType), + /// [`Accept`](crate::http::Accept). /// /// # Example /// /// ```rust /// use rocket::Response; - /// use rocket::http::Header; - /// use rocket::http::hyper::header::ACCEPT; + /// use rocket::http::{Header, Accept}; /// /// let mut response = Response::new(); - /// response.adjoin_header(Header::new(ACCEPT.as_str(), "application/json")); - /// response.adjoin_header(Header::new(ACCEPT.as_str(), "text/plain")); + /// response.adjoin_header(Accept::JSON); + /// response.adjoin_header(Header::new("Accept", "text/plain")); /// /// let mut accept_headers = response.headers().iter(); - /// assert_eq!(accept_headers.next(), Some(Header::new(ACCEPT.as_str(), "application/json"))); - /// assert_eq!(accept_headers.next(), Some(Header::new(ACCEPT.as_str(), "text/plain"))); + /// assert_eq!(accept_headers.next(), Some(Header::new("Accept", "application/json"))); + /// assert_eq!(accept_headers.next(), Some(Header::new("Accept", "text/plain"))); /// assert_eq!(accept_headers.next(), None); /// ``` #[inline(always)] @@ -801,10 +798,10 @@ impl<'r> Response<'r> { /// the comma-separated protocols any of the strings in `I`. Returns /// `Ok(None)` if `self` doesn't support any kind of upgrade. Returns /// `Err(_)` if `protocols` is non-empty but no match was found in `self`. - pub(crate) fn take_upgrade>( + pub(crate) fn search_upgrades<'a, I: Iterator>( &mut self, protocols: I - ) -> Result, Pin>)>, ()> { + ) -> Result, Box)>, ()> { if self.upgrade.is_empty() { return Ok(None); } @@ -839,7 +836,7 @@ impl<'r> Response<'r> { /// /// #[rocket::async_trait] /// impl IoHandler for EchoHandler { - /// async fn io(self: Pin>, io: IoStream) -> io::Result<()> { + /// async fn io(self: Box, io: IoStream) -> io::Result<()> { /// let (mut reader, mut writer) = io::split(io); /// io::copy(&mut reader, &mut writer).await?; /// Ok(()) @@ -854,7 +851,7 @@ impl<'r> Response<'r> { /// assert!(response.upgrade("raw-echo").is_some()); /// # }) /// ``` - pub fn upgrade(&mut self, proto: &str) -> Option> { + pub fn upgrade(&mut self, proto: &str) -> Option<&mut (dyn IoHandler + 'r)> { self.upgrade.get_mut(proto.as_uncased()).map(|h| h.as_mut()) } @@ -972,7 +969,7 @@ impl<'r> Response<'r> { /// /// #[rocket::async_trait] /// impl IoHandler for EchoHandler { - /// async fn io(self: Pin>, io: IoStream) -> io::Result<()> { + /// async fn io(self: Box, io: IoStream) -> io::Result<()> { /// let (mut reader, mut writer) = io::split(io); /// io::copy(&mut reader, &mut writer).await?; /// Ok(()) @@ -990,7 +987,7 @@ impl<'r> Response<'r> { pub fn add_upgrade(&mut self, protocol: N, handler: H) where N: Into>, H: IoHandler + 'r { - self.upgrade.insert(protocol.into(), Box::pin(handler)); + self.upgrade.insert(protocol.into(), Box::new(handler)); } /// Sets the body's maximum chunk size to `size` bytes. diff --git a/core/lib/src/response/stream/sse.rs b/core/lib/src/response/stream/sse.rs index 27044e27af..398596ccbf 100644 --- a/core/lib/src/response/stream/sse.rs +++ b/core/lib/src/response/stream/sse.rs @@ -1,9 +1,9 @@ use std::borrow::Cow; use tokio::io::AsyncRead; -use tokio::time::Duration; -use futures::stream::{self, Stream, StreamExt}; -use futures::future::ready; +use tokio::time::{interval, Duration}; +use futures::{stream::{self, Stream}, future::Either}; +use tokio_stream::{StreamExt, wrappers::IntervalStream}; use crate::request::Request; use crate::response::{self, Response, Responder, stream::{ReaderStream, RawLinedEvent}}; @@ -336,7 +336,7 @@ impl Event { Some(RawLinedEvent::raw("")), ]; - stream::iter(events).filter_map(ready) + stream::iter(events).filter_map(|x| x) } } @@ -528,25 +528,19 @@ impl> EventStream { self } - fn heartbeat_stream(&self) -> Option> { - use tokio::time::interval; - use tokio_stream::wrappers::IntervalStream; - + fn heartbeat_stream(&self) -> impl Stream { self.heartbeat .map(|beat| IntervalStream::new(interval(beat))) .map(|stream| stream.map(|_| RawLinedEvent::raw(":"))) + .map_or_else(|| Either::Right(stream::empty()), Either::Left) } fn into_stream(self) -> impl Stream { - use futures::future::Either; - use crate::ext::StreamExt; - - let heartbeat_stream = self.heartbeat_stream(); - let raw_events = self.stream.map(|e| e.into_stream()).flatten(); - match heartbeat_stream { - Some(heartbeat) => Either::Left(raw_events.join(heartbeat)), - None => Either::Right(raw_events) - } + use futures::StreamExt; + + let heartbeats = self.heartbeat_stream(); + let events = StreamExt::map(self.stream, |e| e.into_stream()).flatten(); + crate::util::join(events, heartbeats) } fn into_reader(self) -> impl AsyncRead { @@ -621,10 +615,11 @@ mod sse_tests { impl> EventStream { fn into_string(self) -> String { + use std::pin::pin; + crate::async_test(async move { let mut string = String::new(); - let reader = self.into_reader(); - tokio::pin!(reader); + let mut reader = pin!(self.into_reader()); reader.read_to_string(&mut string).await.expect("event stream -> string"); string }) diff --git a/core/lib/src/rocket.rs b/core/lib/src/rocket.rs index ef9fdfecfa..40570e7b45 100644 --- a/core/lib/src/rocket.rs +++ b/core/lib/src/rocket.rs @@ -1,14 +1,14 @@ use std::fmt; use std::ops::{Deref, DerefMut}; -use std::net::SocketAddr; use yansi::Paint; use either::Either; use figment::{Figment, Provider}; use crate::{Catcher, Config, Route, Shutdown, sentinel, shield::Shield}; +use crate::listener::{Endpoint, Bindable, DefaultListener}; use crate::router::Router; -use crate::trip_wire::TripWire; +use crate::util::TripWire; use crate::fairing::{Fairing, Fairings}; use crate::phase::{Phase, Build, Building, Ignite, Igniting, Orbit, Orbiting}; use crate::phase::{Stateful, StateRef, State}; @@ -203,35 +203,31 @@ impl Rocket { /// # Example /// /// ```rust - /// use rocket::Config; + /// use rocket::config::{Config, Ident}; /// # use std::net::Ipv4Addr; /// # use std::path::{Path, PathBuf}; /// # type Result = std::result::Result<(), rocket::Error>; /// /// let config = Config { - /// port: 7777, - /// address: Ipv4Addr::new(18, 127, 0, 1).into(), + /// ident: Ident::try_new("MyServer").expect("valid ident"), /// temp_dir: "/tmp/config-example".into(), /// ..Config::debug_default() /// }; /// /// # let _: Result = rocket::async_test(async move { /// let rocket = rocket::custom(&config).ignite().await?; - /// assert_eq!(rocket.config().port, 7777); - /// assert_eq!(rocket.config().address, Ipv4Addr::new(18, 127, 0, 1)); + /// assert_eq!(rocket.config().ident.as_str(), Some("MyServer")); /// assert_eq!(rocket.config().temp_dir.relative(), Path::new("/tmp/config-example")); /// /// // Create a new figment which modifies _some_ keys the existing figment: /// let figment = rocket.figment().clone() - /// .merge((Config::PORT, 8888)) - /// .merge((Config::ADDRESS, "171.64.200.10")); + /// .merge((Config::IDENT, "Example")); /// /// let rocket = rocket::custom(&config) /// .configure(figment) /// .ignite().await?; /// - /// assert_eq!(rocket.config().port, 8888); - /// assert_eq!(rocket.config().address, Ipv4Addr::new(171, 64, 200, 10)); + /// assert_eq!(rocket.config().ident.as_str(), Some("Example")); /// assert_eq!(rocket.config().temp_dir.relative(), Path::new("/tmp/config-example")); /// # Ok(()) /// # }); @@ -664,8 +660,9 @@ impl Rocket { self.shutdown.clone() } - fn into_orbit(self) -> Rocket { + pub(crate) fn into_orbit(self, address: Endpoint) -> Rocket { Rocket(Orbiting { + endpoint: address, router: self.0.router, fairings: self.0.fairings, figment: self.0.figment, @@ -675,28 +672,24 @@ impl Rocket { }) } - async fn _local_launch(self) -> Rocket { - let rocket = self.into_orbit(); - rocket.fairings.handle_liftoff(&rocket).await; - launch_info!("{}{}", "🚀 ".emoji(), "Rocket has launched locally".primary().bold()); + async fn _local_launch(self, addr: Endpoint) -> Rocket { + let rocket = self.into_orbit(addr); + Rocket::liftoff(&rocket).await; rocket } async fn _launch(self) -> Result, Error> { - self.into_orbit() - .default_tcp_http_server(|rkt| Box::pin(async move { - rkt.fairings.handle_liftoff(&rkt).await; - - let proto = rkt.config.tls_enabled().then(|| "https").unwrap_or("http"); - let socket_addr = SocketAddr::new(rkt.config.address, rkt.config.port); - let addr = format!("{}://{}", proto, socket_addr); - launch_info!("{}{} {}", - "🚀 ".emoji(), - "Rocket has launched from".bold().primary().linger(), - addr.underline()); - })) - .await - .map(|rocket| rocket.into_ignite()) + let config = self.figment().extract::()?; + either::for_both!(config.base_bindable()?, base => { + either::for_both!(config.tls_bindable(base), bindable => { + self._launch_on(bindable).await + }) + }) + } + + async fn _launch_on(self, bindable: B) -> Result, Error> { + let listener = bindable.bind().await.map_err(|e| ErrorKind::Bind(Box::new(e)))?; + self.serve(listener).await } } @@ -712,6 +705,21 @@ impl Rocket { }) } + pub(crate) async fn liftoff>(rocket: R) { + let rocket = rocket.deref(); + rocket.fairings.handle_liftoff(rocket).await; + + if !crate::running_within_rocket_async_rt().await { + warn!("Rocket is executing inside of a custom runtime."); + info_!("Rocket's runtime is enabled via `#[rocket::main]` or `#[launch]`."); + info_!("Forced shutdown is disabled. Runtime settings may be suboptimal."); + } + + launch_info!("{}{} {}", "🚀 ".emoji(), + "Rocket has launched on".bold().primary().linger(), + rocket.endpoint().underline()); + } + /// Returns the finalized, active configuration. This is guaranteed to /// remain stable after [`Rocket::ignite()`], through ignition and into /// orbit. @@ -734,6 +742,10 @@ impl Rocket { &self.config } + pub fn endpoint(&self) -> &Endpoint { + &self.endpoint + } + /// Returns a handle which can be used to trigger a shutdown and detect a /// triggered shutdown. /// @@ -867,10 +879,10 @@ impl Rocket

{ } } - pub(crate) async fn local_launch(self) -> Result, Error> { + pub(crate) async fn local_launch(self, l: Endpoint) -> Result, Error> { let rocket = match self.0.into_state() { - State::Build(s) => Rocket::from(s).ignite().await?._local_launch().await, - State::Ignite(s) => Rocket::from(s)._local_launch().await, + State::Build(s) => Rocket::from(s).ignite().await?._local_launch(l).await, + State::Ignite(s) => Rocket::from(s)._local_launch(l).await, State::Orbit(s) => Rocket::from(s) }; @@ -928,6 +940,14 @@ impl Rocket

{ State::Orbit(s) => Ok(Rocket::from(s).into_ignite()) } } + + pub async fn launch_on(self, bindable: B) -> Result, Error> { + match self.0.into_state() { + State::Build(s) => Rocket::from(s).ignite().await?._launch_on(bindable).await, + State::Ignite(s) => Rocket::from(s)._launch_on(bindable).await, + State::Orbit(s) => Ok(Rocket::from(s).into_ignite()) + } + } } #[doc(hidden)] diff --git a/core/lib/src/route/handler.rs b/core/lib/src/route/handler.rs index e29be6d570..b42d81e0fc 100644 --- a/core/lib/src/route/handler.rs +++ b/core/lib/src/route/handler.rs @@ -167,7 +167,6 @@ impl Handler for F } } -// FIXME! impl<'r, 'o: 'r> Outcome<'o> { /// Return the `Outcome` of response to `req` from `responder`. /// diff --git a/core/lib/src/server.rs b/core/lib/src/server.rs index da2626e327..3fbe2ae702 100644 --- a/core/lib/src/server.rs +++ b/core/lib/src/server.rs @@ -1,540 +1,142 @@ use std::io; +use std::pin::pin; use std::sync::Arc; use std::time::Duration; -use std::pin::Pin; -use yansi::Paint; -use tokio::sync::oneshot; +use hyper::service::service_fn; +use hyper_util::rt::{TokioExecutor, TokioIo, TokioTimer}; +use hyper_util::server::conn::auto::Builder; +use futures::{Future, TryFutureExt, future::{select, Either::*}}; use tokio::time::sleep; -use futures::stream::StreamExt; -use futures::future::{FutureExt, Future, BoxFuture}; -use crate::{route, Rocket, Orbit, Request, Response, Data, Config}; -use crate::form::Form; -use crate::outcome::Outcome; -use crate::error::{Error, ErrorKind}; -use crate::ext::{AsyncReadExt, CancellableListener, CancellableIo}; +use crate::{Request, Rocket, Orbit, Data, Ignite}; use crate::request::ConnectionMeta; -use crate::data::IoHandler; - -use crate::http::{hyper, uncased, Method, Status, Header}; -use crate::http::private::{TcpListener, Listener, Connection, Incoming}; - -// A token returned to force the execution of one method before another. -pub(crate) struct RequestToken; - -async fn handle(name: Option<&str>, run: F) -> Option - where F: FnOnce() -> Fut, Fut: Future, -{ - use std::panic::AssertUnwindSafe; - - macro_rules! panic_info { - ($name:expr, $e:expr) => {{ - match $name { - Some(name) => error_!("Handler {} panicked.", name.primary()), - None => error_!("A handler panicked.") - }; - - info_!("This is an application bug."); - info_!("A panic in Rust must be treated as an exceptional event."); - info_!("Panicking is not a suitable error handling mechanism."); - info_!("Unwinding, the result of a panic, is an expensive operation."); - info_!("Panics will degrade application performance."); - info_!("Instead of panicking, return `Option` and/or `Result`."); - info_!("Values of either type can be returned directly from handlers."); - warn_!("A panic is treated as an internal server error."); - $e - }} - } - - let run = AssertUnwindSafe(run); - let fut = std::panic::catch_unwind(move || run()) - .map_err(|e| panic_info!(name, e)) - .ok()?; - - AssertUnwindSafe(fut) - .catch_unwind() - .await - .map_err(|e| panic_info!(name, e)) - .ok() -} - -// This function tries to hide all of the Hyper-ness from Rocket. It essentially -// converts Hyper types into Rocket types, then calls the `dispatch` function, -// which knows nothing about Hyper. Because responding depends on the -// `HyperResponse` type, this function does the actual response processing. -async fn hyper_service_fn( - rocket: Arc>, - conn: ConnectionMeta, - mut hyp_req: hyper::Request, -) -> Result, io::Error> { - // This future must return a hyper::Response, but the response body might - // borrow from the request. Instead, write the body in another future that - // sends the response metadata (and a body channel) prior. - let (tx, rx) = oneshot::channel(); - - #[cfg(not(broken_fmt))] - debug!("received request: {:#?}", hyp_req); - - tokio::spawn(async move { - // We move the request next, so get the upgrade future now. - let pending_upgrade = hyper::upgrade::on(&mut hyp_req); - - // Convert a Hyper request into a Rocket request. - let (h_parts, mut h_body) = hyp_req.into_parts(); - match Request::from_hyp(&rocket, &h_parts, Some(conn)) { - Ok(mut req) => { - // Convert into Rocket `Data`, dispatch request, write response. - let mut data = Data::from(&mut h_body); - let token = rocket.preprocess_request(&mut req, &mut data).await; - let mut response = rocket.dispatch(token, &req, data).await; - let upgrade = response.take_upgrade(req.headers().get("upgrade")); - if let Ok(Some((proto, handler))) = upgrade { - rocket.handle_upgrade(response, proto, handler, pending_upgrade, tx).await; - } else { - if upgrade.is_err() { - warn_!("Request wants upgrade but no I/O handler matched."); - info_!("Request is not being upgraded."); - } - - rocket.send_response(response, tx).await; - } - }, - Err(e) => { - warn!("Bad incoming HTTP request."); - e.errors.iter().for_each(|e| warn_!("Error: {}.", e)); - warn_!("Dispatching salvaged request to catcher: {}.", e.request); - - let response = rocket.handle_error(Status::BadRequest, &e.request).await; - rocket.send_response(response, tx).await; - } - } - }); - - // Receive the response written to `tx` by the task above. - rx.await.map_err(|e| io::Error::new(io::ErrorKind::BrokenPipe, e)) -} +use crate::erased::{ErasedRequest, ErasedResponse, ErasedIoHandler}; +use crate::listener::{Listener, CancellableExt, BouncedExt}; +use crate::error::{Error, ErrorKind}; +use crate::data::IoStream; +use crate::util::ReaderStream; +use crate::http::Status; impl Rocket { - /// Wrapper around `_send_response` to log a success or error. - #[inline] - async fn send_response( - &self, - response: Response<'_>, - tx: oneshot::Sender>, - ) { - let remote_hungup = |e: &io::Error| match e.kind() { - | io::ErrorKind::BrokenPipe - | io::ErrorKind::ConnectionReset - | io::ErrorKind::ConnectionAborted => true, - _ => false, - }; - - match self._send_response(response, tx).await { - Ok(()) => info_!("{}", "Response succeeded.".green()), - Err(e) if remote_hungup(&e) => warn_!("Remote left: {}.", e), - Err(e) => warn_!("Failed to write response: {}.", e), - } - } - - /// Attempts to create a hyper response from `response` and send it to `tx`. - #[inline] - async fn _send_response( - &self, - mut response: Response<'_>, - tx: oneshot::Sender>, - ) -> io::Result<()> { - let mut hyp_res = hyper::Response::builder(); - - hyp_res = hyp_res.status(response.status().code); - for header in response.headers().iter() { - let name = header.name.as_str(); - let value = header.value.as_bytes(); - hyp_res = hyp_res.header(name, value); - } - - let body = response.body_mut(); - if let Some(n) = body.size().await { - hyp_res = hyp_res.header(hyper::header::CONTENT_LENGTH, n); - } - - let (mut sender, hyp_body) = hyper::Body::channel(); - let hyp_response = hyp_res.body(hyp_body) - .map_err(|e| io::Error::new(io::ErrorKind::InvalidData, e))?; - - #[cfg(not(broken_fmt))] - debug!("sending response: {:#?}", hyp_response); - - tx.send(hyp_response).map_err(|_| { - let msg = "client disconnect before response started"; - io::Error::new(io::ErrorKind::BrokenPipe, msg) - })?; - - let max_chunk_size = body.max_chunk_size(); - let mut stream = body.into_bytes_stream(max_chunk_size); - while let Some(next) = stream.next().await { - sender.send_data(next?).await - .map_err(|e| io::Error::new(io::ErrorKind::BrokenPipe, e))?; - } - - Ok(()) - } - - async fn handle_upgrade<'r>( - &self, - mut response: Response<'r>, - proto: uncased::Uncased<'r>, - io_handler: Pin>, - pending_upgrade: hyper::upgrade::OnUpgrade, - tx: oneshot::Sender>, - ) { - info_!("Upgrading connection to {}.", Paint::white(&proto).bold()); - response.set_status(Status::SwitchingProtocols); - response.set_raw_header("Connection", "Upgrade"); - response.set_raw_header("Upgrade", proto.clone().into_cow()); - self.send_response(response, tx).await; + async fn service( + self: Arc, + mut req: hyper::Request, + connection: ConnectionMeta, + ) -> Result>, http::Error> { + let upgrade = hyper::upgrade::on(&mut req); + let (parts, incoming) = req.into_parts(); + let request = ErasedRequest::new(self, parts, |rocket, parts| { + Request::from_hyp(rocket, parts, connection).unwrap_or_else(|e| e) + }); - match pending_upgrade.await { - Ok(io_stream) => { - info_!("Upgrade successful."); - if let Err(e) = io_handler.io(io_stream.into()).await { - if e.kind() == io::ErrorKind::BrokenPipe { - warn!("Upgraded {} I/O handler was closed.", proto); - } else { - error!("Upgraded {} I/O handler failed: {}", proto, e); - } + let mut response = request.into_response( + incoming, + |incoming| Data::from(incoming), + |rocket, request, data| Box::pin(rocket.preprocess(request, data)), + |token, rocket, request, data| Box::pin(async move { + if !request.errors.is_empty() { + return rocket.dispatch_error(Status::BadRequest, request).await; } - }, - Err(e) => { - warn!("Response indicated upgrade, but upgrade failed."); - warn_!("Upgrade error: {}", e); - } - } - } - /// Preprocess the request for Rocket things. Currently, this means: - /// - /// * Rewriting the method in the request if _method form field exists. - /// * Run the request fairings. - /// - /// Keep this in-sync with derive_form when preprocessing form fields. - pub(crate) async fn preprocess_request( - &self, - req: &mut Request<'_>, - data: &mut Data<'_> - ) -> RequestToken { - // Check if this is a form and if the form contains the special _method - // field which we use to reinterpret the request's method. - let (min_len, max_len) = ("_method=get".len(), "_method=delete".len()); - let peek_buffer = data.peek(max_len).await; - let is_form = req.content_type().map_or(false, |ct| ct.is_form()); + let mut response = rocket.dispatch(token, request, data).await; + response.body_mut().size().await; + response + }) + ).await; - if is_form && req.method() == Method::Post && peek_buffer.len() >= min_len { - let method = std::str::from_utf8(peek_buffer).ok() - .and_then(|raw_form| Form::values(raw_form).next()) - .filter(|field| field.name == "_method") - .and_then(|field| field.value.parse().ok()); - - if let Some(method) = method { - req._set_method(method); - } + let io_handler = response.to_io_handler(Rocket::extract_io_handler); + if let Some(handler) = io_handler { + let upgrade = upgrade.map_ok(IoStream::from).map_err(io::Error::other); + tokio::task::spawn(io_handler_task(upgrade, handler)); } - // Run request fairings. - self.fairings.handle_request(req, data).await; - - RequestToken - } - - #[inline] - pub(crate) async fn dispatch<'s, 'r: 's>( - &'s self, - _token: RequestToken, - request: &'r Request<'s>, - data: Data<'r> - ) -> Response<'r> { - info!("{}:", request); - - // Remember if the request is `HEAD` for later body stripping. - let was_head_request = request.method() == Method::Head; - - // Route the request and run the user's handlers. - let mut response = self.route_and_process(request, data).await; - - // Add a default 'Server' header if it isn't already there. - // TODO: If removing Hyper, write out `Date` header too. - if let Some(ident) = request.rocket().config.ident.as_str() { - if !response.headers().contains("Server") { - response.set_header(Header::new("Server", ident)); - } + let mut builder = hyper::Response::builder(); + builder = builder.status(response.inner().status().code); + for header in response.inner().headers().iter() { + builder = builder.header(header.name().as_str(), header.value()); } - // Run the response fairings. - self.fairings.handle_response(request, &mut response).await; - - // Strip the body if this is a `HEAD` request. - if was_head_request { - response.strip_body(); + if let Some(size) = response.inner().body().preset_size() { + builder = builder.header("Content-Length", size); } - response - } - - async fn route_and_process<'s, 'r: 's>( - &'s self, - request: &'r Request<'s>, - data: Data<'r> - ) -> Response<'r> { - let mut response = match self.route(request, data).await { - Outcome::Success(response) => response, - Outcome::Forward((data, _)) if request.method() == Method::Head => { - info_!("Autohandling {} request.", "HEAD".primary().bold()); - - // Dispatch the request again with Method `GET`. - request._set_method(Method::Get); - match self.route(request, data).await { - Outcome::Success(response) => response, - Outcome::Error(status) => self.handle_error(status, request).await, - Outcome::Forward((_, status)) => self.handle_error(status, request).await, - } - } - Outcome::Forward((_, status)) => self.handle_error(status, request).await, - Outcome::Error(status) => self.handle_error(status, request).await, - }; - - // Set the cookies. Note that error responses will only include cookies - // set by the error handler. See `handle_error` for more. - let delta_jar = request.cookies().take_delta_jar(); - for cookie in delta_jar.delta() { - response.adjoin_header(cookie); - } - - response - } - - /// Tries to find a `Responder` for a given `request`. It does this by - /// routing the request and calling the handler for each matching route - /// until one of the handlers returns success or error, or there are no - /// additional routes to try (forward). The corresponding outcome for each - /// condition is returned. - #[inline] - async fn route<'s, 'r: 's>( - &'s self, - request: &'r Request<'s>, - mut data: Data<'r>, - ) -> route::Outcome<'r> { - // Go through all matching routes until we fail or succeed or run out of - // routes to try, in which case we forward with the last status. - let mut status = Status::NotFound; - for route in self.router.route(request) { - // Retrieve and set the requests parameters. - info_!("Matched: {}", route); - request.set_route(route); - - let name = route.name.as_deref(); - let outcome = handle(name, || route.handler.handle(request, data)).await - .unwrap_or(Outcome::Error(Status::InternalServerError)); - - // Check if the request processing completed (Some) or if the - // request needs to be forwarded. If it does, continue the loop - // (None) to try again. - info_!("{}", outcome.log_display()); - match outcome { - o@Outcome::Success(_) | o@Outcome::Error(_) => return o, - Outcome::Forward(forwarded) => (data, status) = forwarded, - } - } - - error_!("No matching routes for {}.", request); - Outcome::Forward((data, status)) - } - - /// Invokes the handler with `req` for catcher with status `status`. - /// - /// In order of preference, invoked handler is: - /// * the user's registered handler for `status` - /// * the user's registered `default` handler - /// * Rocket's default handler for `status` - /// - /// Return `Ok(result)` if the handler succeeded. Returns `Ok(Some(Status))` - /// if the handler ran to completion but failed. Returns `Ok(None)` if the - /// handler panicked while executing. - async fn invoke_catcher<'s, 'r: 's>( - &'s self, - status: Status, - req: &'r Request<'s> - ) -> Result, Option> { - // For now, we reset the delta state to prevent any modifications - // from earlier, unsuccessful paths from being reflected in error - // response. We may wish to relax this in the future. - req.cookies().reset_delta(); - - if let Some(catcher) = self.router.catch(status, req) { - warn_!("Responding with registered {} catcher.", catcher); - let name = catcher.name.as_deref(); - handle(name, || catcher.handler.handle(status, req)).await - .map(|result| result.map_err(Some)) - .unwrap_or_else(|| Err(None)) - } else { - let code = status.code.blue().bold(); - warn_!("No {} catcher registered. Using Rocket default.", code); - Ok(crate::catcher::default_handler(status, req)) - } + let chunk_size = response.inner().body().max_chunk_size(); + builder.body(ReaderStream::with_capacity(response, chunk_size)) } +} - // Invokes the catcher for `status`. Returns the response on success. - // - // On catcher error, the 500 error catcher is attempted. If _that_ errors, - // the (infallible) default 500 error cather is used. - pub(crate) async fn handle_error<'s, 'r: 's>( - &'s self, - mut status: Status, - req: &'r Request<'s> - ) -> Response<'r> { - // Dispatch to the `status` catcher. - if let Ok(r) = self.invoke_catcher(status, req).await { - return r; - } +async fn io_handler_task(stream: S, mut handler: ErasedIoHandler) + where S: Future> +{ + let stream = match stream.await { + Ok(stream) => stream, + Err(e) => return warn_!("Upgrade failed: {e}"), + }; - // If it fails and it's not a 500, try the 500 catcher. - if status != Status::InternalServerError { - error_!("Catcher failed. Attempting 500 error catcher."); - status = Status::InternalServerError; - if let Ok(r) = self.invoke_catcher(status, req).await { - return r; - } + info_!("Upgrade succeeded."); + if let Err(e) = handler.take().io(stream).await { + match e.kind() { + io::ErrorKind::BrokenPipe => warn!("Upgrade I/O handler was closed."), + e => error!("Upgrade I/O handler failed: {e}"), } - - // If it failed again or if it was already a 500, use Rocket's default. - error_!("{} catcher failed. Using Rocket default 500.", status.code); - crate::catcher::default_handler(Status::InternalServerError, req) } +} - pub(crate) async fn default_tcp_http_server(mut self, ready: C) -> Result - where C: for<'a> Fn(&'a Self) -> BoxFuture<'a, ()> +impl Rocket { + pub(crate) async fn serve(self, listener: L) -> Result + where L: Listener + 'static { - use std::net::ToSocketAddrs; - - // Determine the address we're going to serve on. - let addr = format!("{}:{}", self.config.address, self.config.port); - let mut addr = addr.to_socket_addrs() - .map(|mut addrs| addrs.next().expect(">= 1 socket addr")) - .map_err(|e| Error::new(ErrorKind::Io(e)))?; - - #[cfg(feature = "tls")] - if self.config.tls_enabled() { - if let Some(ref config) = self.config.tls { - use crate::http::tls::TlsListener; - - let conf = config.to_native_config().map_err(ErrorKind::Io)?; - let l = TlsListener::bind(addr, conf).await.map_err(ErrorKind::TlsBind)?; - addr = l.local_addr().unwrap_or(addr); - self.config.address = addr.ip(); - self.config.port = addr.port(); - ready(&mut self).await; - return self.http_server(l).await; + let mut builder = Builder::new(TokioExecutor::new()); + let keep_alive = Duration::from_secs(self.config.keep_alive.into()); + builder.http1() + .half_close(true) + .timer(TokioTimer::new()) + .keep_alive(keep_alive > Duration::ZERO) + .preserve_header_case(true) + .header_read_timeout(Duration::from_secs(15)); + + #[cfg(feature = "http2")] { + builder.http2().timer(TokioTimer::new()); + if keep_alive > Duration::ZERO { + builder.http2() + .timer(TokioTimer::new()) + .keep_alive_interval(keep_alive / 4) + .keep_alive_timeout(keep_alive); } } - let l = TcpListener::bind(addr).await.map_err(ErrorKind::Bind)?; - addr = l.local_addr().unwrap_or(addr); - self.config.address = addr.ip(); - self.config.port = addr.port(); - ready(&mut self).await; - self.http_server(l).await - } - - // TODO.async: Solidify the Listener APIs and make this function public - pub(crate) async fn http_server(self, listener: L) -> Result - where L: Listener + Send, ::Connection: Send + Unpin + 'static - { - // Emit a warning if we're not running inside of Rocket's async runtime. - if self.config.profile == Config::DEBUG_PROFILE { - tokio::task::spawn_blocking(|| { - let this = std::thread::current(); - if !this.name().map_or(false, |s| s.starts_with("rocket-worker")) { - warn!("Rocket is executing inside of a custom runtime."); - info_!("Rocket's runtime is enabled via `#[rocket::main]` or `#[launch]`."); - info_!("Forced shutdown is disabled. Runtime settings may be suboptimal."); - } - }); - } - - // Set up cancellable I/O from the given listener. Shutdown occurs when - // `Shutdown` (`TripWire`) resolves. This can occur directly through a - // notification or indirectly through an external signal which, when - // received, results in triggering the notify. - let shutdown = self.shutdown(); - let sig_stream = self.config.shutdown.signal_stream(); - let grace = self.config.shutdown.grace as u64; - let mercy = self.config.shutdown.mercy as u64; - - // Start a task that listens for external signals and notifies shutdown. - if let Some(mut stream) = sig_stream { - let shutdown = shutdown.clone(); - tokio::spawn(async move { - while let Some(sig) = stream.next().await { - if shutdown.0.tripped() { - warn!("Received {}. Shutdown already in progress.", sig); - } else { - warn!("Received {}. Requesting shutdown.", sig); + let listener = listener.bounced().cancellable(self.shutdown(), &self.config.shutdown); + let rocket = Arc::new(self.into_orbit(listener.socket_addr()?)); + let _ = tokio::spawn(Rocket::liftoff(rocket.clone())).await; + + let (server, listener) = (Arc::new(builder), Arc::new(listener)); + while let Some(accept) = listener.accept_next().await { + let (listener, rocket, server) = (listener.clone(), rocket.clone(), server.clone()); + tokio::spawn({ + let result = async move { + let conn = TokioIo::new(listener.connect(accept).await?); + let meta = ConnectionMeta::from(conn.inner()); + let service = service_fn(|req| rocket.clone().service(req, meta.clone())); + let serve = pin!(server.serve_connection_with_upgrades(conn, service)); + match select(serve, rocket.shutdown()).await { + Left((result, _)) => result, + Right((_, mut conn)) => { + conn.as_mut().graceful_shutdown(); + conn.await + } } + }; - shutdown.0.trip(); - } + result.inspect_err(crate::error::log_server_error) }); } - // Save the keep-alive value for later use; we're about to move `self`. - let keep_alive = self.config.keep_alive; - - // Create the Hyper `Service`. - let rocket = Arc::new(self); - let service_fn = |conn: &CancellableIo<_, L::Connection>| { - let rocket = rocket.clone(); - let connection = ConnectionMeta { - remote: conn.peer_address(), - client_certificates: conn.peer_certificates(), - }; - - async move { - Ok::<_, std::convert::Infallible>(hyper::service::service_fn(move |req| { - hyper_service_fn(rocket.clone(), connection.clone(), req) - })) - } - }; - - // NOTE: `hyper` uses `tokio::spawn()` as the default executor. - let listener = CancellableListener::new(shutdown.clone(), listener, grace, mercy); - let builder = hyper::server::Server::builder(Incoming::new(listener).nodelay(true)); - - #[cfg(feature = "http2")] - let builder = builder.http2_keep_alive_interval(match keep_alive { - 0 => None, - n => Some(Duration::from_secs(n as u64)) - }); - - let server = builder - .http1_keepalive(keep_alive != 0) - .http1_preserve_header_case(true) - .serve(hyper::service::make_service_fn(service_fn)) - .with_graceful_shutdown(shutdown.clone()); - - // This deserves some explanation. - // - // This is largely to deal with Hyper's dreadful and largely nonexistent - // handling of shutdown, in general, nevermind graceful. - // - // When Hyper receives a "graceful shutdown" request, it stops accepting - // new requests. That's it. It continues to process existing requests - // and outgoing responses forever and never cancels them. As a result, - // Rocket must take it upon itself to cancel any existing I/O. - // - // To do so, Rocket wraps all connections in a `CancellableIo` struct, - // an internal structure that gracefully closes I/O when it receives a - // signal. That signal is the `shutdown` future. When the future - // resolves, `CancellableIo` begins to terminate in grace, mercy, and - // finally force close phases. Since all connections are wrapped in + // Rocket wraps all connections in a `CancellableIo` struct, an internal + // structure that gracefully closes I/O when it receives a signal. That + // signal is the `shutdown` future. When the future resolves, + // `CancellableIo` begins to terminate in grace, mercy, and finally + // force close phases. Since all connections are wrapped in // `CancellableIo`, this eventually ends all I/O. // // At that point, unless a user spawned an infinite, stand-alone task @@ -543,69 +145,35 @@ impl Rocket { // we can return the owned instance of `Rocket`. // // Unfortunately, the Hyper `server` future resolves as soon as it has - // finishes processing requests without respect for ongoing responses. + // finished processing requests without respect for ongoing responses. // That is, `server` resolves even when there are running tasks that are // generating a response. So, `server` resolving implies little to // nothing about the state of connections. As a result, we depend on the // timing of grace + mercy + some buffer to determine when all // connections should be closed, thus all tasks should be complete, thus - // all references to `Arc` should be dropped and we can get a - // unique reference. - tokio::pin!(server); - tokio::select! { - biased; - - _ = shutdown => { - // Run shutdown fairings. We compute `sleep()` for grace periods - // beforehand to ensure we don't add shutdown fairing completion - // time, which is arbitrary, to these periods. - info!("Shutdown requested. Waiting for pending I/O..."); - let grace_timer = sleep(Duration::from_secs(grace)); - let mercy_timer = sleep(Duration::from_secs(grace + mercy)); - let shutdown_timer = sleep(Duration::from_secs(grace + mercy + 1)); - rocket.fairings.handle_shutdown(&*rocket).await; - - tokio::pin!(grace_timer, mercy_timer, shutdown_timer); - tokio::select! { - biased; + // all references to `Arc` should be dropped and we can get back + // a unique reference. + info!("Shutting down. Waiting for shutdown fairings and pending I/O..."); + tokio::spawn({ + let rocket = rocket.clone(); + async move { rocket.fairings.handle_shutdown(&*rocket).await } + }); - result = &mut server => { - if let Err(e) = result { - warn!("Server failed while shutting down: {}", e); - return Err(Error::shutdown(rocket.clone(), e)); - } + let config = &rocket.config.shutdown; + let wait = Duration::from_micros(250); + for period in [wait, config.grace(), wait, config.mercy(), wait * 4] { + if Arc::strong_count(&rocket) == 1 { break } + sleep(period).await; + } - if Arc::strong_count(&rocket) != 1 { grace_timer.await; } - if Arc::strong_count(&rocket) != 1 { mercy_timer.await; } - if Arc::strong_count(&rocket) != 1 { shutdown_timer.await; } - match Arc::try_unwrap(rocket) { - Ok(rocket) => { - info!("Graceful shutdown completed successfully."); - Ok(rocket) - } - Err(rocket) => { - warn!("Shutdown failed: outstanding background I/O."); - Err(Error::shutdown(rocket, None)) - } - } - } - _ = &mut shutdown_timer => { - warn!("Shutdown failed: server executing after timeouts."); - return Err(Error::shutdown(rocket.clone(), None)); - }, - } + match Arc::try_unwrap(rocket) { + Ok(rocket) => { + info!("Graceful shutdown completed successfully."); + Ok(rocket.into_ignite()) } - result = &mut server => { - match result { - Ok(()) => { - info!("Server shutdown nominally."); - Ok(Arc::try_unwrap(rocket).map_err(|r| Error::shutdown(r, None))?) - } - Err(e) => { - info!("Server failed prior to shutdown: {}:", e); - Err(Error::shutdown(rocket.clone(), e)) - } - } + Err(rocket) => { + warn!("Shutdown failed: outstanding background I/O."); + Err(Error::new(ErrorKind::Shutdown(rocket))) } } } diff --git a/core/lib/src/shield/shield.rs b/core/lib/src/shield/shield.rs index ea44814484..f3a3aeb241 100644 --- a/core/lib/src/shield/shield.rs +++ b/core/lib/src/shield/shield.rs @@ -198,7 +198,7 @@ impl Fairing for Shield { } async fn on_liftoff(&self, rocket: &Rocket) { - let force_hsts = rocket.config().tls_enabled() + let force_hsts = rocket.endpoint().is_tls() && rocket.figment().profile() != Config::DEBUG_PROFILE && !self.is_enabled::(); diff --git a/core/lib/src/shutdown.rs b/core/lib/src/shutdown.rs index 490114f51c..43a667af0a 100644 --- a/core/lib/src/shutdown.rs +++ b/core/lib/src/shutdown.rs @@ -5,7 +5,7 @@ use std::pin::Pin; use futures::FutureExt; use crate::request::{FromRequest, Outcome, Request}; -use crate::trip_wire::TripWire; +use crate::util::TripWire; /// A request guard and future for graceful shutdown. /// diff --git a/core/lib/src/config/tls.rs b/core/lib/src/tls/config.rs similarity index 56% rename from core/lib/src/config/tls.rs rename to core/lib/src/tls/config.rs index 12face0015..3131e16d5c 100644 --- a/core/lib/src/config/tls.rs +++ b/core/lib/src/tls/config.rs @@ -1,3 +1,5 @@ +use std::io; + use figment::value::magic::{Either, RelativePathBuf}; use serde::{Deserialize, Serialize}; use indexmap::IndexSet; @@ -29,7 +31,7 @@ use indexmap::IndexSet; /// /// Additionally, the `mutual` parameter controls if and how the server /// authenticates clients via mutual TLS. It works in concert with the -/// [`mtls`](crate::mtls) module. See [`MutualTls`] for configuration details. +/// [`mtls`](crate::mtls) module. See [`MtlsConfig`] for configuration details. /// /// In `Rocket.toml`, configuration might look like: /// @@ -43,41 +45,36 @@ use indexmap::IndexSet; /// /// ```rust /// # #[macro_use] extern crate rocket; -/// use rocket::config::{Config, TlsConfig, CipherSuite}; +/// use rocket::tls::{TlsConfig, CipherSuite}; +/// use rocket::figment::providers::Serialized; /// /// #[launch] /// fn rocket() -> _ { -/// let tls_config = TlsConfig::from_paths("/ssl/certs.pem", "/ssl/key.pem") +/// let tls = TlsConfig::from_paths("/ssl/certs.pem", "/ssl/key.pem") /// .with_ciphers(CipherSuite::TLS_V13_SET) /// .with_preferred_server_cipher_order(true); /// -/// let config = Config { -/// tls: Some(tls_config), -/// ..Default::default() -/// }; -/// -/// rocket::custom(config) +/// rocket::custom(rocket::Config::figment().merge(("tls", tls))) /// } /// ``` /// /// Or by creating a custom figment: /// /// ```rust -/// use rocket::config::Config; +/// use rocket::figment::Figment; +/// use rocket::tls::TlsConfig; /// -/// let figment = Config::figment() -/// .merge(("tls.certs", "path/to/certs.pem")) -/// .merge(("tls.key", vec![0; 32])); +/// let figment = Figment::new() +/// .merge(("certs", "path/to/certs.pem")) +/// .merge(("key", vec![0; 32])); /// # -/// # let config = rocket::Config::from(figment); -/// # let tls_config = config.tls.as_ref().unwrap(); +/// # let tls_config: TlsConfig = figment.extract().unwrap(); /// # assert!(tls_config.certs().is_left()); /// # assert!(tls_config.key().is_right()); /// # assert_eq!(tls_config.ciphers().count(), 9); /// # assert!(!tls_config.prefer_server_cipher_order()); /// ``` #[derive(PartialEq, Debug, Clone, Deserialize, Serialize)] -#[cfg_attr(nightly, doc(cfg(feature = "tls")))] pub struct TlsConfig { /// Path to a PEM file with, or raw bytes for, a DER-encoded X.509 TLS /// certificate chain. @@ -95,92 +92,12 @@ pub struct TlsConfig { #[serde(default)] #[cfg(feature = "mtls")] #[cfg_attr(nightly, doc(cfg(feature = "mtls")))] - pub(crate) mutual: Option, -} - -/// Mutual TLS configuration. -/// -/// Configuration works in concert with the [`mtls`](crate::mtls) module, which -/// provides a request guard to validate, verify, and retrieve client -/// certificates in routes. -/// -/// By default, mutual TLS is disabled and client certificates are not required, -/// validated or verified. To enable mutual TLS, the `mtls` feature must be -/// enabled and support configured via two `tls.mutual` parameters: -/// -/// * `ca_certs` -/// -/// A required path to a PEM file or raw bytes to a DER-encoded X.509 TLS -/// certificate chain for the certificate authority to verify client -/// certificates against. When a path is configured in a file, such as -/// `Rocket.toml`, relative paths are interpreted as relative to the source -/// file's directory. -/// -/// * `mandatory` -/// -/// An optional boolean that control whether client authentication is -/// required. -/// -/// When `true`, client authentication is required. TLS connections where -/// the client does not present a certificate are immediately terminated. -/// When `false`, the client is not required to present a certificate. In -/// either case, if a certificate _is_ presented, it must be valid or the -/// connection is terminated. -/// -/// In a `Rocket.toml`, configuration might look like: -/// -/// ```toml -/// [default.tls.mutual] -/// ca_certs = "/ssl/ca_cert.pem" -/// mandatory = true # when absent, defaults to false -/// ``` -/// -/// Programmatically, configuration might look like: -/// -/// ```rust -/// # #[macro_use] extern crate rocket; -/// use rocket::config::{Config, TlsConfig, MutualTls}; -/// -/// #[launch] -/// fn rocket() -> _ { -/// let tls_config = TlsConfig::from_paths("/ssl/certs.pem", "/ssl/key.pem") -/// .with_mutual(MutualTls::from_path("/ssl/ca_cert.pem")); -/// -/// let config = Config { -/// tls: Some(tls_config), -/// ..Default::default() -/// }; -/// -/// rocket::custom(config) -/// } -/// ``` -/// -/// Once mTLS is configured, the [`mtls::Certificate`](crate::mtls::Certificate) -/// request guard can be used to retrieve client certificates in routes. -#[derive(PartialEq, Debug, Clone, Deserialize, Serialize)] -#[cfg(feature = "mtls")] -#[cfg_attr(nightly, doc(cfg(feature = "mtls")))] -pub struct MutualTls { - /// Path to a PEM file with, or raw bytes for, DER-encoded Certificate - /// Authority certificates which will be used to verify client-presented - /// certificates. - // TODO: We should support more than one root. - pub(crate) ca_certs: Either>, - /// Whether the client is required to present a certificate. - /// - /// When `true`, the client is required to present a valid certificate to - /// proceed with TLS. When `false`, the client is not required to present a - /// certificate. In either case, if a certificate _is_ presented, it must be - /// valid or the connection is terminated. - #[serde(default)] - #[serde(deserialize_with = "figment::util::bool_from_str_or_int")] - pub mandatory: bool, + pub(crate) mutual: Option, } /// A supported TLS cipher suite. #[allow(non_camel_case_types)] #[derive(PartialEq, Eq, Debug, Copy, Clone, Hash, Deserialize, Serialize)] -#[cfg_attr(nightly, doc(cfg(feature = "tls")))] #[non_exhaustive] pub enum CipherSuite { /// The TLS 1.3 `TLS_CHACHA20_POLY1305_SHA256` cipher suite. @@ -204,50 +121,7 @@ pub enum CipherSuite { TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256, } -impl CipherSuite { - /// The default set and order of cipher suites. These are all of the - /// variants in [`CipherSuite`] in their declaration order. - pub const DEFAULT_SET: [CipherSuite; 9] = [ - // TLS v1.3 suites... - CipherSuite::TLS_CHACHA20_POLY1305_SHA256, - CipherSuite::TLS_AES_256_GCM_SHA384, - CipherSuite::TLS_AES_128_GCM_SHA256, - - // TLS v1.2 suites... - CipherSuite::TLS_ECDHE_ECDSA_WITH_CHACHA20_POLY1305_SHA256, - CipherSuite::TLS_ECDHE_RSA_WITH_CHACHA20_POLY1305_SHA256, - CipherSuite::TLS_ECDHE_ECDSA_WITH_AES_256_GCM_SHA384, - CipherSuite::TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256, - CipherSuite::TLS_ECDHE_RSA_WITH_AES_256_GCM_SHA384, - CipherSuite::TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256, - ]; - - /// The default set and order of cipher suites. These are the TLS 1.3 - /// variants in [`CipherSuite`] in their declaration order. - pub const TLS_V13_SET: [CipherSuite; 3] = [ - CipherSuite::TLS_CHACHA20_POLY1305_SHA256, - CipherSuite::TLS_AES_256_GCM_SHA384, - CipherSuite::TLS_AES_128_GCM_SHA256, - ]; - - /// The default set and order of cipher suites. These are the TLS 1.2 - /// variants in [`CipherSuite`] in their declaration order. - pub const TLS_V12_SET: [CipherSuite; 6] = [ - CipherSuite::TLS_ECDHE_ECDSA_WITH_CHACHA20_POLY1305_SHA256, - CipherSuite::TLS_ECDHE_RSA_WITH_CHACHA20_POLY1305_SHA256, - CipherSuite::TLS_ECDHE_ECDSA_WITH_AES_256_GCM_SHA384, - CipherSuite::TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256, - CipherSuite::TLS_ECDHE_RSA_WITH_AES_256_GCM_SHA384, - CipherSuite::TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256, - ]; - - /// Used as the `serde` default for `ciphers`. - fn default_set() -> IndexSet { - Self::DEFAULT_SET.iter().copied().collect() - } -} - -impl TlsConfig { +impl Default for TlsConfig { fn default() -> Self { TlsConfig { certs: Either::Right(vec![]), @@ -258,7 +132,9 @@ impl TlsConfig { mutual: None, } } +} +impl TlsConfig { /// Constructs a `TlsConfig` from paths to a `certs` certificate chain /// a `key` private-key. This method does no validation; it simply creates a /// structure suitable for passing into a [`Config`](crate::Config). @@ -266,12 +142,12 @@ impl TlsConfig { /// # Example /// /// ```rust - /// use rocket::config::TlsConfig; + /// use rocket::tls::TlsConfig; /// /// let tls_config = TlsConfig::from_paths("/ssl/certs.pem", "/ssl/key.pem"); /// ``` pub fn from_paths(certs: C, key: K) -> Self - where C: AsRef, K: AsRef + where C: AsRef, K: AsRef, { TlsConfig { certs: Either::Left(certs.as_ref().to_path_buf().into()), @@ -288,7 +164,7 @@ impl TlsConfig { /// # Example /// /// ```rust - /// use rocket::config::TlsConfig; + /// use rocket::tls::TlsConfig; /// /// # let certs_buf = &[]; /// # let key_buf = &[]; @@ -315,7 +191,7 @@ impl TlsConfig { /// Disable TLS v1.2 by selecting only TLS v1.3 cipher suites: /// /// ```rust - /// use rocket::config::{TlsConfig, CipherSuite}; + /// use rocket::tls::{TlsConfig, CipherSuite}; /// /// # let certs_buf = &[]; /// # let key_buf = &[]; @@ -326,7 +202,7 @@ impl TlsConfig { /// Enable only ChaCha20-Poly1305 based TLS v1.2 and TLS v1.3 cipher suites: /// /// ```rust - /// use rocket::config::{TlsConfig, CipherSuite}; + /// use rocket::tls::{TlsConfig, CipherSuite}; /// /// # let certs_buf = &[]; /// # let key_buf = &[]; @@ -341,7 +217,7 @@ impl TlsConfig { /// Later duplicates are ignored. /// /// ```rust - /// use rocket::config::{TlsConfig, CipherSuite}; + /// use rocket::tls::{TlsConfig, CipherSuite}; /// /// # let certs_buf = &[]; /// # let key_buf = &[]; @@ -361,8 +237,8 @@ impl TlsConfig { /// CipherSuite::TLS_ECDHE_RSA_WITH_CHACHA20_POLY1305_SHA256, /// ]); /// ``` - pub fn with_ciphers(mut self, ciphers: I) -> Self - where I: IntoIterator + pub fn with_ciphers(mut self, ciphers: C) -> Self + where C: IntoIterator { self.ciphers = ciphers.into_iter().collect(); self @@ -385,7 +261,7 @@ impl TlsConfig { /// # Example /// /// ```rust - /// use rocket::config::{TlsConfig, CipherSuite}; + /// use rocket::tls::{TlsConfig, CipherSuite}; /// /// # let certs_buf = &[]; /// # let key_buf = &[]; @@ -398,22 +274,24 @@ impl TlsConfig { self } - /// Configures mutual TLS. See [`MutualTls`] for details. + /// Set mutual TLS configuration. See + /// [`MtlsConfig`](crate::mtls::MtlsConfig) for details. /// /// # Example /// /// ```rust - /// use rocket::config::{TlsConfig, MutualTls}; + /// use rocket::tls::TlsConfig; + /// use rocket::mtls::MtlsConfig; /// /// # let certs = &[]; /// # let key = &[]; - /// let mtls_config = MutualTls::from_path("path/to/cert.pem").mandatory(true); + /// let mtls_config = MtlsConfig::from_path("path/to/cert.pem").mandatory(true); /// let tls_config = TlsConfig::from_bytes(certs, key).with_mutual(mtls_config); /// assert!(tls_config.mutual().is_some()); /// ``` #[cfg(feature = "mtls")] #[cfg_attr(nightly, doc(cfg(feature = "mtls")))] - pub fn with_mutual(mut self, config: MutualTls) -> Self { + pub fn with_mutual(mut self, config: crate::mtls::MtlsConfig) -> Self { self.mutual = Some(config); self } @@ -423,16 +301,17 @@ impl TlsConfig { /// # Example /// /// ```rust - /// use rocket::Config; + /// # use std::path::Path; + /// use rocket::tls::TlsConfig; + /// use rocket::figment::Figment; /// - /// let figment = Config::figment() - /// .merge(("tls.certs", vec![0; 32])) - /// .merge(("tls.key", "/etc/ssl/key.pem")); + /// let figment = Figment::new() + /// .merge(("certs", "/path/to/certs.pem")) + /// .merge(("key", vec![0; 32])); /// - /// let config = rocket::Config::from(figment); - /// let tls_config = config.tls.as_ref().unwrap(); - /// let cert_bytes = tls_config.certs().right().unwrap(); - /// assert!(cert_bytes.iter().all(|&b| b == 0)); + /// let tls_config: TlsConfig = figment.extract().unwrap(); + /// let cert_path = tls_config.certs().left().unwrap(); + /// assert_eq!(cert_path, Path::new("/path/to/certs.pem")); /// ``` pub fn certs(&self) -> either::Either { match &self.certs { @@ -441,20 +320,24 @@ impl TlsConfig { } } + pub fn certs_reader(&self) -> io::Result> { + to_reader(&self.certs) + } + /// Returns the value of the `key` parameter. /// /// # Example /// /// ```rust - /// use std::path::Path; - /// use rocket::Config; + /// # use std::path::Path; + /// use rocket::tls::TlsConfig; + /// use rocket::figment::Figment; /// - /// let figment = Config::figment() - /// .merge(("tls.certs", vec![0; 32])) - /// .merge(("tls.key", "/etc/ssl/key.pem")); + /// let figment = Figment::new() + /// .merge(("certs", vec![0; 32])) + /// .merge(("key", "/etc/ssl/key.pem")); /// - /// let config = rocket::Config::from(figment); - /// let tls_config = config.tls.as_ref().unwrap(); + /// let tls_config: TlsConfig = figment.extract().unwrap(); /// let key_path = tls_config.key().left().unwrap(); /// assert_eq!(key_path, Path::new("/etc/ssl/key.pem")); /// ``` @@ -465,13 +348,17 @@ impl TlsConfig { } } + pub fn key_reader(&self) -> io::Result> { + to_reader(&self.key) + } + /// Returns an iterator over the enabled cipher suites in their order of /// preference from most to least preferred. /// /// # Example /// /// ```rust - /// use rocket::config::{TlsConfig, CipherSuite}; + /// use rocket::tls::{TlsConfig, CipherSuite}; /// /// # let certs_buf = &[]; /// # let key_buf = &[]; @@ -496,7 +383,7 @@ impl TlsConfig { /// # Example /// /// ```rust - /// use rocket::config::TlsConfig; + /// use rocket::tls::TlsConfig; /// /// # let certs_buf = &[]; /// # let key_buf = &[]; @@ -520,11 +407,13 @@ impl TlsConfig { /// /// ```rust /// use std::path::Path; - /// use rocket::config::{TlsConfig, MutualTls}; + /// + /// use rocket::tls::TlsConfig; + /// use rocket::mtls::MtlsConfig; /// /// # let certs = &[]; /// # let key = &[]; - /// let mtls_config = MutualTls::from_path("path/to/cert.pem").mandatory(true); + /// let mtls_config = MtlsConfig::from_path("path/to/cert.pem").mandatory(true); /// let tls_config = TlsConfig::from_bytes(certs, key).with_mutual(mtls_config); /// /// let mtls = tls_config.mutual().unwrap(); @@ -533,171 +422,232 @@ impl TlsConfig { /// ``` #[cfg(feature = "mtls")] #[cfg_attr(nightly, doc(cfg(feature = "mtls")))] - pub fn mutual(&self) -> Option<&MutualTls> { + pub fn mutual(&self) -> Option<&crate::mtls::MtlsConfig> { self.mutual.as_ref() } -} -#[cfg(feature = "mtls")] -impl MutualTls { - /// Constructs a `MutualTls` from a path to a PEM file with a certificate - /// authority `ca_certs` DER-encoded X.509 TLS certificate chain. This - /// method does no validation; it simply creates a structure suitable for - /// passing into a [`TlsConfig`]. - /// - /// These certificates will be used to verify client-presented certificates - /// in TLS connections. - /// - /// # Example - /// - /// ```rust - /// use rocket::config::MutualTls; - /// - /// let tls_config = MutualTls::from_path("/ssl/ca_certs.pem"); - /// ``` - pub fn from_path>(ca_certs: C) -> Self { - MutualTls { - ca_certs: Either::Left(ca_certs.as_ref().to_path_buf().into()), - mandatory: Default::default() - } + pub fn validate(&self) -> Result<(), crate::tls::Error> { + self.acceptor().map(|_| ()) } +} - /// Constructs a `MutualTls` from a byte buffer to a certificate authority - /// `ca_certs` DER-encoded X.509 TLS certificate chain. This method does no - /// validation; it simply creates a structure suitable for passing into a - /// [`TlsConfig`]. - /// - /// These certificates will be used to verify client-presented certificates - /// in TLS connections. - /// - /// # Example - /// - /// ```rust - /// use rocket::config::MutualTls; - /// - /// # let ca_certs_buf = &[]; - /// let mtls_config = MutualTls::from_bytes(ca_certs_buf); - /// ``` - pub fn from_bytes(ca_certs: &[u8]) -> Self { - MutualTls { - ca_certs: Either::Right(ca_certs.to_vec()), - mandatory: Default::default() - } - } +impl CipherSuite { + /// The default set and order of cipher suites. These are all of the + /// variants in [`CipherSuite`] in their declaration order. + pub const DEFAULT_SET: [CipherSuite; 9] = [ + // TLS v1.3 suites... + CipherSuite::TLS_CHACHA20_POLY1305_SHA256, + CipherSuite::TLS_AES_256_GCM_SHA384, + CipherSuite::TLS_AES_128_GCM_SHA256, - /// Sets whether client authentication is required. Disabled by default. - /// - /// When `true`, client authentication will be required. TLS connections - /// where the client does not present a certificate will be immediately - /// terminated. When `false`, the client is not required to present a - /// certificate. In either case, if a certificate _is_ presented, it must be - /// valid or the connection is terminated. - /// - /// # Example - /// - /// ```rust - /// use rocket::config::MutualTls; - /// - /// # let ca_certs_buf = &[]; - /// let mtls_config = MutualTls::from_bytes(ca_certs_buf).mandatory(true); - /// ``` - pub fn mandatory(mut self, mandatory: bool) -> Self { - self.mandatory = mandatory; - self + // TLS v1.2 suites... + CipherSuite::TLS_ECDHE_ECDSA_WITH_CHACHA20_POLY1305_SHA256, + CipherSuite::TLS_ECDHE_RSA_WITH_CHACHA20_POLY1305_SHA256, + CipherSuite::TLS_ECDHE_ECDSA_WITH_AES_256_GCM_SHA384, + CipherSuite::TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256, + CipherSuite::TLS_ECDHE_RSA_WITH_AES_256_GCM_SHA384, + CipherSuite::TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256, + ]; + + /// The default set and order of cipher suites. These are the TLS 1.3 + /// variants in [`CipherSuite`] in their declaration order. + pub const TLS_V13_SET: [CipherSuite; 3] = [ + CipherSuite::TLS_CHACHA20_POLY1305_SHA256, + CipherSuite::TLS_AES_256_GCM_SHA384, + CipherSuite::TLS_AES_128_GCM_SHA256, + ]; + + /// The default set and order of cipher suites. These are the TLS 1.2 + /// variants in [`CipherSuite`] in their declaration order. + pub const TLS_V12_SET: [CipherSuite; 6] = [ + CipherSuite::TLS_ECDHE_ECDSA_WITH_CHACHA20_POLY1305_SHA256, + CipherSuite::TLS_ECDHE_RSA_WITH_CHACHA20_POLY1305_SHA256, + CipherSuite::TLS_ECDHE_ECDSA_WITH_AES_256_GCM_SHA384, + CipherSuite::TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256, + CipherSuite::TLS_ECDHE_RSA_WITH_AES_256_GCM_SHA384, + CipherSuite::TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256, + ]; + + /// Used as the `serde` default for `ciphers`. + fn default_set() -> IndexSet { + Self::DEFAULT_SET.iter().copied().collect() } +} - /// Returns the value of the `ca_certs` parameter. - /// # Example - /// - /// ```rust - /// use rocket::config::MutualTls; - /// - /// # let ca_certs_buf = &[]; - /// let mtls_config = MutualTls::from_bytes(ca_certs_buf).mandatory(true); - /// assert_eq!(mtls_config.ca_certs().unwrap_right(), ca_certs_buf); - /// ``` - pub fn ca_certs(&self) -> either::Either { - match &self.ca_certs { - Either::Left(path) => either::Either::Left(path.relative()), - Either::Right(bytes) => either::Either::Right(&bytes), +impl From for rustls::SupportedCipherSuite { + fn from(cipher: CipherSuite) -> Self { + use rustls::crypto::ring::cipher_suite; + + match cipher { + CipherSuite::TLS_CHACHA20_POLY1305_SHA256 => + cipher_suite::TLS13_CHACHA20_POLY1305_SHA256, + CipherSuite::TLS_AES_256_GCM_SHA384 => + cipher_suite::TLS13_AES_256_GCM_SHA384, + CipherSuite::TLS_AES_128_GCM_SHA256 => + cipher_suite::TLS13_AES_128_GCM_SHA256, + CipherSuite::TLS_ECDHE_ECDSA_WITH_CHACHA20_POLY1305_SHA256 => + cipher_suite::TLS_ECDHE_ECDSA_WITH_CHACHA20_POLY1305_SHA256, + CipherSuite::TLS_ECDHE_RSA_WITH_CHACHA20_POLY1305_SHA256 => + cipher_suite::TLS_ECDHE_RSA_WITH_CHACHA20_POLY1305_SHA256, + CipherSuite::TLS_ECDHE_ECDSA_WITH_AES_256_GCM_SHA384 => + cipher_suite::TLS_ECDHE_ECDSA_WITH_AES_256_GCM_SHA384, + CipherSuite::TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256 => + cipher_suite::TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256, + CipherSuite::TLS_ECDHE_RSA_WITH_AES_256_GCM_SHA384 => + cipher_suite::TLS_ECDHE_RSA_WITH_AES_256_GCM_SHA384, + CipherSuite::TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256 => + cipher_suite::TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256, } } } -#[cfg(feature = "tls")] -mod with_tls_feature { - use std::fs; - use std::io::{self, Error}; - - use crate::http::tls::Config; - use crate::http::tls::rustls::SupportedCipherSuite as RustlsCipher; - use crate::http::tls::rustls::crypto::ring::cipher_suite; - - use yansi::Paint; - - use super::{Either, RelativePathBuf, TlsConfig, CipherSuite}; - - type Reader = Box; - - fn to_reader(value: &Either>) -> io::Result { - match value { - Either::Left(path) => { - let path = path.relative(); - let file = fs::File::open(&path).map_err(move |e| { +pub(crate) fn to_reader( + value: &Either> +) -> io::Result> { + match value { + Either::Left(path) => { + let path = path.relative(); + let file = std::fs::File::open(&path) + .map_err(move |e| { let source = figment::Source::File(path); - let msg = format!("error reading TLS file `{}`: {}", source.primary(), e); - Error::new(e.kind(), msg) + let msg = format!("error reading TLS file `{source}`: {e}"); + io::Error::new(e.kind(), msg) })?; - Ok(Box::new(io::BufReader::new(file))) - } - Either::Right(vec) => Ok(Box::new(io::Cursor::new(vec.clone()))), + Ok(Box::new(io::BufReader::new(file))) } + Either::Right(vec) => Ok(Box::new(io::Cursor::new(vec.clone()))), } +} - impl TlsConfig { - /// This is only called when TLS is enabled. - pub(crate) fn to_native_config(&self) -> io::Result> { - Ok(Config { - cert_chain: to_reader(&self.certs)?, - private_key: to_reader(&self.key)?, - ciphersuites: self.rustls_ciphers().collect(), - prefer_server_order: self.prefer_server_cipher_order, - #[cfg(not(feature = "mtls"))] - mandatory_mtls: false, - #[cfg(not(feature = "mtls"))] - ca_certs: None, - #[cfg(feature = "mtls")] - mandatory_mtls: self.mutual.as_ref().map_or(false, |m| m.mandatory), - #[cfg(feature = "mtls")] - ca_certs: match self.mutual { - Some(ref mtls) => Some(to_reader(&mtls.ca_certs)?), - None => None - }, - }) - } - - fn rustls_ciphers(&self) -> impl Iterator + '_ { - self.ciphers().map(|ciphersuite| match ciphersuite { - CipherSuite::TLS_CHACHA20_POLY1305_SHA256 => - cipher_suite::TLS13_CHACHA20_POLY1305_SHA256, - CipherSuite::TLS_AES_256_GCM_SHA384 => - cipher_suite::TLS13_AES_256_GCM_SHA384, - CipherSuite::TLS_AES_128_GCM_SHA256 => - cipher_suite::TLS13_AES_128_GCM_SHA256, - CipherSuite::TLS_ECDHE_ECDSA_WITH_CHACHA20_POLY1305_SHA256 => - cipher_suite::TLS_ECDHE_ECDSA_WITH_CHACHA20_POLY1305_SHA256, - CipherSuite::TLS_ECDHE_RSA_WITH_CHACHA20_POLY1305_SHA256 => - cipher_suite::TLS_ECDHE_RSA_WITH_CHACHA20_POLY1305_SHA256, - CipherSuite::TLS_ECDHE_ECDSA_WITH_AES_256_GCM_SHA384 => - cipher_suite::TLS_ECDHE_ECDSA_WITH_AES_256_GCM_SHA384, - CipherSuite::TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256 => - cipher_suite::TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256, - CipherSuite::TLS_ECDHE_RSA_WITH_AES_256_GCM_SHA384 => - cipher_suite::TLS_ECDHE_RSA_WITH_AES_256_GCM_SHA384, - CipherSuite::TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256 => - cipher_suite::TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256, - }) - } +#[cfg(test)] +mod tests { + use figment::{Figment, providers::{Toml, Format}}; + + #[test] + fn test_tls_config_from_file() { + use crate::tls::{TlsConfig, CipherSuite}; + use pretty_assertions::assert_eq; + + figment::Jail::expect_with(|jail| { + jail.create_file("Rocket.toml", r#" + [global] + shutdown.ctrlc = 0 + ident = false + + [global.tls] + certs = "/ssl/cert.pem" + key = "/ssl/key.pem" + + [global.limits] + forms = "1mib" + json = "10mib" + stream = "50kib" + "#)?; + + let config: TlsConfig = crate::Config::figment().extract_inner("tls")?; + assert_eq!(config, TlsConfig::from_paths("/ssl/cert.pem", "/ssl/key.pem")); + + jail.create_file("Rocket.toml", r#" + [global.tls] + certs = "cert.pem" + key = "key.pem" + "#)?; + + let config: TlsConfig = crate::Config::figment().extract_inner("tls")?; + assert_eq!(config, TlsConfig::from_paths( + jail.directory().join("cert.pem"), + jail.directory().join("key.pem") + )); + + jail.create_file("TLS.toml", r#" + certs = "cert.pem" + key = "key.pem" + prefer_server_cipher_order = true + ciphers = [ + "TLS_CHACHA20_POLY1305_SHA256", + "TLS_AES_256_GCM_SHA384", + "TLS_AES_128_GCM_SHA256", + "TLS_ECDHE_ECDSA_WITH_CHACHA20_POLY1305_SHA256", + "TLS_ECDHE_ECDSA_WITH_AES_256_GCM_SHA384", + "TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256", + ] + "#)?; + + let config: TlsConfig = Figment::from(Toml::file("TLS.toml")).extract()?; + let cert_path = jail.directory().join("cert.pem"); + let key_path = jail.directory().join("key.pem"); + assert_eq!(config, TlsConfig::from_paths(cert_path, key_path) + .with_preferred_server_cipher_order(true) + .with_ciphers([ + CipherSuite::TLS_CHACHA20_POLY1305_SHA256, + CipherSuite::TLS_AES_256_GCM_SHA384, + CipherSuite::TLS_AES_128_GCM_SHA256, + CipherSuite::TLS_ECDHE_ECDSA_WITH_CHACHA20_POLY1305_SHA256, + CipherSuite::TLS_ECDHE_ECDSA_WITH_AES_256_GCM_SHA384, + CipherSuite::TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256, + ])); + + jail.create_file("Rocket.toml", r#" + [global] + shutdown.ctrlc = 0 + ident = false + + [global.tls] + certs = "/ssl/cert.pem" + key = "/ssl/key.pem" + + [global.limits] + forms = "1mib" + json = "10mib" + stream = "50kib" + "#)?; + + let config: TlsConfig = crate::Config::figment().extract_inner("tls")?; + assert_eq!(config, TlsConfig::from_paths("/ssl/cert.pem", "/ssl/key.pem")); + + jail.create_file("Rocket.toml", r#" + [global.tls] + certs = "cert.pem" + key = "key.pem" + "#)?; + + let config: TlsConfig = crate::Config::figment().extract_inner("tls")?; + assert_eq!(config, TlsConfig::from_paths( + jail.directory().join("cert.pem"), + jail.directory().join("key.pem") + )); + + jail.create_file("Rocket.toml", r#" + [global.tls] + certs = "cert.pem" + key = "key.pem" + prefer_server_cipher_order = true + ciphers = [ + "TLS_CHACHA20_POLY1305_SHA256", + "TLS_AES_256_GCM_SHA384", + "TLS_AES_128_GCM_SHA256", + "TLS_ECDHE_ECDSA_WITH_CHACHA20_POLY1305_SHA256", + "TLS_ECDHE_ECDSA_WITH_AES_256_GCM_SHA384", + "TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256", + ] + "#)?; + + let config: TlsConfig = crate::Config::figment().extract_inner("tls")?; + let cert_path = jail.directory().join("cert.pem"); + let key_path = jail.directory().join("key.pem"); + assert_eq!(config, TlsConfig::from_paths(cert_path, key_path) + .with_preferred_server_cipher_order(true) + .with_ciphers([ + CipherSuite::TLS_CHACHA20_POLY1305_SHA256, + CipherSuite::TLS_AES_256_GCM_SHA384, + CipherSuite::TLS_AES_128_GCM_SHA256, + CipherSuite::TLS_ECDHE_ECDSA_WITH_CHACHA20_POLY1305_SHA256, + CipherSuite::TLS_ECDHE_ECDSA_WITH_AES_256_GCM_SHA384, + CipherSuite::TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256, + ])); + + Ok(()) + }); } } diff --git a/core/http/src/tls/error.rs b/core/lib/src/tls/error.rs similarity index 94% rename from core/http/src/tls/error.rs rename to core/lib/src/tls/error.rs index 429f4a9d1f..ab2faf3e77 100644 --- a/core/http/src/tls/error.rs +++ b/core/lib/src/tls/error.rs @@ -11,6 +11,7 @@ pub enum KeyError { #[derive(Debug)] pub enum Error { Io(std::io::Error), + Bind(Box), Tls(rustls::Error), Mtls(rustls::server::VerifierBuilderError), CertChain(std::io::Error), @@ -29,6 +30,7 @@ impl std::fmt::Display for Error { CertChain(e) => write!(f, "failed to process certificate chain: {e}"), PrivKey(e) => write!(f, "failed to process private key: {e}"), CertAuth(e) => write!(f, "failed to process certificate authority: {e}"), + Bind(e) => write!(f, "failed to bind to network interface: {e}"), } } } @@ -66,6 +68,7 @@ impl std::error::Error for Error { Error::CertChain(e) => Some(e), Error::PrivKey(e) => Some(e), Error::CertAuth(e) => Some(e), + Error::Bind(e) => Some(&**e), } } } diff --git a/core/lib/src/tls/mod.rs b/core/lib/src/tls/mod.rs new file mode 100644 index 0000000000..d6128e3b71 --- /dev/null +++ b/core/lib/src/tls/mod.rs @@ -0,0 +1,7 @@ +mod error; +pub(crate) mod config; +pub(crate) mod util; + +pub use error::Result; +pub use config::{TlsConfig, CipherSuite}; +pub use error::Error; diff --git a/core/http/src/tls/util.rs b/core/lib/src/tls/util.rs similarity index 100% rename from core/http/src/tls/util.rs rename to core/lib/src/tls/util.rs diff --git a/core/lib/src/util/chain.rs b/core/lib/src/util/chain.rs new file mode 100644 index 0000000000..c60a193b0c --- /dev/null +++ b/core/lib/src/util/chain.rs @@ -0,0 +1,52 @@ +use std::io; +use std::task::{Poll, Context}; +use std::pin::Pin; + +use pin_project_lite::pin_project; +use tokio::io::{AsyncRead, ReadBuf}; + +pin_project! { + /// Stream for the [`chain`](super::AsyncReadExt::chain) method. + #[must_use = "streams do nothing unless polled"] + pub struct Chain { + #[pin] + first: Option, + #[pin] + second: U, + } +} + +impl Chain { + pub(crate) fn new(first: T, second: U) -> Self { + Self { first: Some(first), second } + } +} + +impl Chain { + /// Gets references to the underlying readers in this `Chain`. + pub fn get_ref(&self) -> (Option<&T>, &U) { + (self.first.as_ref(), &self.second) + } +} + +impl AsyncRead for Chain { + fn poll_read( + mut self: Pin<&mut Self>, + cx: &mut Context<'_>, + buf: &mut ReadBuf<'_>, + ) -> Poll> { + let me = self.as_mut().project(); + if let Some(first) = me.first.as_pin_mut() { + let init_rem = buf.remaining(); + futures::ready!(first.poll_read(cx, buf))?; + if buf.remaining() == init_rem { + self.as_mut().project().first.set(None); + } else { + return Poll::Ready(Ok(())); + } + } + + let me = self.as_mut().project(); + me.second.poll_read(cx, buf) + } +} diff --git a/core/lib/src/util/join.rs b/core/lib/src/util/join.rs new file mode 100644 index 0000000000..d8ffc9d2d3 --- /dev/null +++ b/core/lib/src/util/join.rs @@ -0,0 +1,77 @@ +use std::pin::Pin; +use std::task::{Poll, Context}; + +use pin_project_lite::pin_project; + +use futures::stream::Stream; +use futures::ready; + +/// Join two streams, `a` and `b`, into a new `Join` stream that returns items +/// from both streams, biased to `a`, until `a` finishes. The joined stream +/// completes when `a` completes, irrespective of `b`. If `b` stops producing +/// values, then the joined stream acts exactly like a fused `a`. +/// +/// Values are biased to those of `a`: if `a` provides a value, it is always +/// emitted before a value provided by `b`. In other words, values from `b` are +/// emitted when and only when `a` is not producing a value. +pub fn join(a: A, b: B) -> Join { + Join { a, b: Some(b), done: false, } +} + +pin_project! { + /// Stream returned by [`join`]. + pub struct Join { + #[pin] + a: T, + #[pin] + b: Option, + // Set when `a` returns `None`. + done: bool, + } +} + +impl Stream for Join + where T: Stream, + U: Stream, +{ + type Item = T::Item; + + fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + if self.done { + return Poll::Ready(None); + } + + let me = self.as_mut().project(); + match me.a.poll_next(cx) { + Poll::Ready(opt) => { + *me.done = opt.is_none(); + Poll::Ready(opt) + }, + Poll::Pending => match me.b.as_pin_mut() { + None => Poll::Pending, + Some(b) => match ready!(b.poll_next(cx)) { + Some(value) => Poll::Ready(Some(value)), + None => { + self.as_mut().project().b.set(None); + Poll::Pending + } + } + } + } + } + + fn size_hint(&self) -> (usize, Option) { + let (left_low, left_high) = self.a.size_hint(); + let (right_low, right_high) = self.b.as_ref() + .map(|b| b.size_hint()) + .unwrap_or_default(); + + let low = left_low.saturating_add(right_low); + let high = match (left_high, right_high) { + (Some(h1), Some(h2)) => h1.checked_add(h2), + _ => None, + }; + + (low, high) + } +} diff --git a/core/lib/src/util/mod.rs b/core/lib/src/util/mod.rs new file mode 100644 index 0000000000..d3055f36ce --- /dev/null +++ b/core/lib/src/util/mod.rs @@ -0,0 +1,12 @@ +mod chain; +mod tripwire; +mod reader_stream; +mod join; + +#[cfg(unix)] +pub mod unix; + +pub use chain::Chain; +pub use tripwire::TripWire; +pub use reader_stream::ReaderStream; +pub use join::join; diff --git a/core/lib/src/util/reader_stream.rs b/core/lib/src/util/reader_stream.rs new file mode 100644 index 0000000000..da0b1ab049 --- /dev/null +++ b/core/lib/src/util/reader_stream.rs @@ -0,0 +1,124 @@ +use std::pin::Pin; +use std::task::{Context, Poll}; + +use bytes::{Bytes, BytesMut}; +use futures::stream::Stream; +use pin_project_lite::pin_project; +use tokio::io::AsyncRead; + +pin_project! { + /// Convert an [`AsyncRead`] into a [`Stream`] of byte chunks. + /// + /// This stream is fused. It performs the inverse operation of + /// [`StreamReader`]. + /// + /// # Example + /// + /// ``` + /// # #[tokio::main] + /// # async fn main() -> std::io::Result<()> { + /// use tokio_stream::StreamExt; + /// use tokio_util::io::ReaderStream; + /// + /// // Create a stream of data. + /// let data = b"hello, world!"; + /// let mut stream = ReaderStream::new(&data[..]); + /// + /// // Read all of the chunks into a vector. + /// let mut stream_contents = Vec::new(); + /// while let Some(chunk) = stream.next().await { + /// stream_contents.extend_from_slice(&chunk?); + /// } + /// + /// // Once the chunks are concatenated, we should have the + /// // original data. + /// assert_eq!(stream_contents, data); + /// # Ok(()) + /// # } + /// ``` + /// + /// [`AsyncRead`]: tokio::io::AsyncRead + /// [`StreamReader`]: crate::io::StreamReader + /// [`Stream`]: futures_core::Stream + #[derive(Debug)] + pub struct ReaderStream { + // Reader itself. + // + // This value is `None` if the stream has terminated. + #[pin] + reader: R, + // Working buffer, used to optimize allocations. + buf: BytesMut, + capacity: usize, + done: bool, + } +} + +impl ReaderStream { + /// Convert an [`AsyncRead`] into a [`Stream`] with item type + /// `Result`, + /// with a specific read buffer initial capacity. + /// + /// [`AsyncRead`]: tokio::io::AsyncRead + /// [`Stream`]: futures_core::Stream + pub fn with_capacity(reader: R, capacity: usize) -> Self { + ReaderStream { + reader: reader, + buf: BytesMut::with_capacity(capacity), + capacity, + done: false, + } + } +} + +impl Stream for ReaderStream { + type Item = std::io::Result; + + fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + use tokio_util::io::poll_read_buf; + + let mut this = self.as_mut().project(); + + if *this.done { + return Poll::Ready(None); + } + + if this.buf.capacity() == 0 { + this.buf.reserve(*this.capacity); + } + + let reader = this.reader; + match poll_read_buf(reader, cx, &mut this.buf) { + Poll::Pending => Poll::Pending, + Poll::Ready(Err(err)) => { + *this.done = true; + Poll::Ready(Some(Err(err))) + } + Poll::Ready(Ok(0)) => { + *this.done = true; + Poll::Ready(None) + } + Poll::Ready(Ok(_)) => { + let chunk = this.buf.split(); + Poll::Ready(Some(Ok(chunk.freeze()))) + } + } + } + + // fn size_hint(&self) -> (usize, Option) { + // self.reader.size_hint() + // } +} + +impl hyper::body::Body for ReaderStream { + type Data = bytes::Bytes; + + type Error = std::io::Error; + + fn poll_frame( + self: Pin<&mut Self>, + cx: &mut Context<'_>, + ) -> Poll, Self::Error>>> { + self.poll_next(cx).map_ok(hyper::body::Frame::data) + } +} diff --git a/core/lib/src/trip_wire.rs b/core/lib/src/util/tripwire.rs similarity index 100% rename from core/lib/src/trip_wire.rs rename to core/lib/src/util/tripwire.rs diff --git a/core/lib/src/util/unix.rs b/core/lib/src/util/unix.rs new file mode 100644 index 0000000000..8889b78d14 --- /dev/null +++ b/core/lib/src/util/unix.rs @@ -0,0 +1,25 @@ +use std::io; +use std::os::fd::AsRawFd; + +pub fn lock_exlusive_nonblocking(file: &T) -> io::Result<()> { + let raw_fd = file.as_raw_fd(); + let res = unsafe { + libc::flock(raw_fd, libc::LOCK_EX | libc::LOCK_NB) + }; + + match res { + 0 => Ok(()), + _ => Err(io::Error::last_os_error()), + } +} + +pub fn unlock_nonblocking(file: &T) -> io::Result<()> { + let res = unsafe { + libc::flock(file.as_raw_fd(), libc::LOCK_UN | libc::LOCK_NB) + }; + + match res { + 0 => Ok(()), + _ => Err(io::Error::last_os_error()), + } +} diff --git a/core/lib/tests/can-launch-tls.rs b/core/lib/tests/can-launch-tls.rs index fdd864b192..1fc4cfbef0 100644 --- a/core/lib/tests/can-launch-tls.rs +++ b/core/lib/tests/can-launch-tls.rs @@ -1,8 +1,9 @@ #![cfg(feature = "tls")] use rocket::fs::relative; -use rocket::config::{Config, TlsConfig, CipherSuite}; use rocket::local::asynchronous::Client; +use rocket::tls::{TlsConfig, CipherSuite}; +use rocket::figment::providers::Serialized; #[rocket::async_test] async fn can_launch_tls() { @@ -15,9 +16,8 @@ async fn can_launch_tls() { CipherSuite::TLS_ECDHE_RSA_WITH_CHACHA20_POLY1305_SHA256, ]); - let rocket = rocket::custom(Config { tls: Some(tls), ..Config::debug_default() }); - let client = Client::debug(rocket).await.unwrap(); - + let config = rocket::Config::figment().merge(Serialized::defaults(tls)); + let client = Client::debug(rocket::custom(config)).await.unwrap(); client.rocket().shutdown().notify(); client.rocket().shutdown().await; } diff --git a/core/lib/tests/on_launch_fairing_can_inspect_port.rs b/core/lib/tests/on_launch_fairing_can_inspect_port.rs index a0e35572f2..8060128027 100644 --- a/core/lib/tests/on_launch_fairing_can_inspect_port.rs +++ b/core/lib/tests/on_launch_fairing_can_inspect_port.rs @@ -1,3 +1,5 @@ +use std::net::{SocketAddr, Ipv4Addr}; + use rocket::config::Config; use rocket::fairing::AdHoc; use rocket::futures::channel::oneshot; @@ -5,13 +7,13 @@ use rocket::futures::channel::oneshot; #[rocket::async_test] async fn on_ignite_fairing_can_inspect_port() { let (tx, rx) = oneshot::channel(); - let rocket = rocket::custom(Config { port: 0, ..Config::debug_default() }) + let rocket = rocket::custom(Config::debug_default()) .attach(AdHoc::on_liftoff("Send Port -> Channel", move |rocket| { Box::pin(async move { - tx.send(rocket.config().port).unwrap(); + tx.send(rocket.endpoint().tcp().unwrap().port()).unwrap(); }) })); - rocket::tokio::spawn(rocket.launch()); + rocket::tokio::spawn(rocket.launch_on(SocketAddr::from((Ipv4Addr::LOCALHOST, 0)))); assert_ne!(rx.await.unwrap(), 0); } diff --git a/core/lib/tests/sentinel.rs b/core/lib/tests/sentinel.rs index f29f829d77..89f7cb5467 100644 --- a/core/lib/tests/sentinel.rs +++ b/core/lib/tests/sentinel.rs @@ -155,7 +155,7 @@ fn inner_sentinels_detected() { impl<'r, 'o: 'r> response::Responder<'r, 'o> for ResponderSentinel { fn respond_to(self, _: &'r Request<'_>) -> response::Result<'o> { - todo!() + unimplemented!() } } diff --git a/core/lib/tests/tls-config-from-source-1503.rs b/core/lib/tests/tls-config-from-source-1503.rs index 92085f6861..3db9073140 100644 --- a/core/lib/tests/tls-config-from-source-1503.rs +++ b/core/lib/tests/tls-config-from-source-1503.rs @@ -8,19 +8,14 @@ macro_rules! relative { #[test] fn tls_config_from_source() { - use rocket::config::{Config, TlsConfig}; - use rocket::figment::Figment; + use rocket::tls::TlsConfig; + use rocket::figment::{Figment, providers::Serialized}; let cert_path = relative!("examples/tls/private/cert.pem"); let key_path = relative!("examples/tls/private/key.pem"); + let config = TlsConfig::from_paths(cert_path, key_path); - let rocket_config = Config { - tls: Some(TlsConfig::from_paths(cert_path, key_path)), - ..Default::default() - }; - - let config: Config = Figment::from(rocket_config).extract().unwrap(); - let tls = config.tls.expect("have TLS config"); + let tls: TlsConfig = Figment::from(Serialized::globals(config)).extract().unwrap(); assert_eq!(tls.certs().unwrap_left(), cert_path); assert_eq!(tls.key().unwrap_left(), key_path); } diff --git a/examples/config/src/tests.rs b/examples/config/src/tests.rs index 7cabb9dc51..e774f7ec05 100644 --- a/examples/config/src/tests.rs +++ b/examples/config/src/tests.rs @@ -6,15 +6,11 @@ async fn test_config(profile: &str) { let config = rocket.config(); match &*profile { "debug" => { - assert_eq!(config.address, std::net::Ipv4Addr::LOCALHOST); - assert_eq!(config.port, 8000); assert_eq!(config.workers, 1); assert_eq!(config.keep_alive, 0); assert_eq!(config.log_level, LogLevel::Normal); } "release" => { - assert_eq!(config.address, std::net::Ipv4Addr::LOCALHOST); - assert_eq!(config.port, 8000); assert_eq!(config.workers, 12); assert_eq!(config.keep_alive, 5); assert_eq!(config.log_level, LogLevel::Critical); diff --git a/examples/hello/src/main.rs b/examples/hello/src/main.rs index 79b4588aca..0f8c55cb1a 100644 --- a/examples/hello/src/main.rs +++ b/examples/hello/src/main.rs @@ -74,19 +74,8 @@ fn hello(lang: Option, opt: Options<'_>) -> String { #[launch] fn rocket() -> _ { - use rocket::fairing::AdHoc; - rocket::build() .mount("/", routes![hello]) .mount("/hello", routes![world, mir]) .mount("/wave", routes![wave]) - .attach(AdHoc::on_request("Compatibility Normalizer", |req, _| Box::pin(async move { - if !req.uri().is_normalized_nontrailing() { - let normal = req.uri().clone().into_normalized_nontrailing(); - warn!("Incoming request URI was normalized for compatibility."); - info_!("{} -> {}", req.uri(), normal); - req.set_uri(normal); - } - }))) - } diff --git a/examples/tls/src/redirector.rs b/examples/tls/src/redirector.rs index aeffe9ad3a..e490ee1b1e 100644 --- a/examples/tls/src/redirector.rs +++ b/examples/tls/src/redirector.rs @@ -1,33 +1,38 @@ //! Redirect all HTTP requests to HTTPs. -use std::sync::OnceLock; +use std::net::SocketAddr; use rocket::http::Status; use rocket::log::LogLevel; -use rocket::{route, Error, Request, Data, Route, Orbit, Rocket, Ignite, Config}; +use rocket::{route, Error, Request, Data, Route, Orbit, Rocket, Ignite}; use rocket::fairing::{Fairing, Info, Kind}; use rocket::response::Redirect; +use yansi::Paint; + +#[derive(Debug, Clone, Copy, Default)] +pub struct Redirector(u16); + #[derive(Debug, Clone)] -pub struct Redirector { - pub listen_port: u16, - pub tls_port: OnceLock, +pub struct Config { + server: rocket::Config, + tls_addr: SocketAddr, } impl Redirector { pub fn on(port: u16) -> Self { - Redirector { listen_port: port, tls_port: OnceLock::new() } + Redirector(port) } // Route function that gets called on every single request. fn redirect<'r>(req: &'r Request, _: Data<'r>) -> route::BoxFuture<'r> { // FIXME: Check the host against a whitelist! - let redirector = req.rocket().state::().expect("managed Self"); + let config = req.rocket().state::().expect("managed Self"); if let Some(host) = req.host() { let domain = host.domain(); - let https_uri = match redirector.tls_port.get() { - Some(443) | None => format!("https://{domain}{}", req.uri()), - Some(port) => format!("https://{domain}:{port}{}", req.uri()), + let https_uri = match config.tls_addr.port() { + 443 => format!("https://{domain}{}", req.uri()), + port => format!("https://{domain}:{port}{}", req.uri()), }; route::Outcome::from(req, Redirect::permanent(https_uri)).pin() @@ -37,21 +42,12 @@ impl Redirector { } // Launch an instance of Rocket than handles redirection on `self.port`. - pub async fn try_launch(self, mut config: Config) -> Result, Error> { - use yansi::Paint; + pub async fn try_launch(self, config: Config) -> Result, Error> { use rocket::http::Method::*; - // Determine the port TLS is being served on. - let tls_port = self.tls_port.get_or_init(|| config.port); - - // Adjust config for redirector: disable TLS, set port, disable logging. - config.tls = None; - config.port = self.listen_port; - config.log_level = LogLevel::Critical; - info!("{}{}", "🔒 ".mask(), "HTTP -> HTTPS Redirector:".magenta()); - info_!("redirecting on insecure port {} to TLS port {}", - self.listen_port.yellow(), tls_port.green()); + info_!("redirecting insecure port {} to TLS port {}", + self.0.yellow(), config.tls_addr.port().green()); // Build a vector of routes to `redirect` on `` for each method. let redirects = [Get, Put, Post, Delete, Options, Head, Trace, Connect, Patch] @@ -59,10 +55,11 @@ impl Redirector { .map(|m| Route::new(m, "/", Self::redirect)) .collect::>(); - rocket::custom(config) - .manage(self) + let addr = SocketAddr::new(config.tls_addr.ip(), self.0); + rocket::custom(&config.server) + .manage(config) .mount("/", redirects) - .launch() + .launch_on(addr) .await } } @@ -76,8 +73,24 @@ impl Fairing for Redirector { } } - async fn on_liftoff(&self, rkt: &Rocket) { - let (this, shutdown, config) = (self.clone(), rkt.shutdown(), rkt.config().clone()); + async fn on_liftoff(&self, rocket: &Rocket) { + let Some(tls_addr) = rocket.endpoint().tls().and_then(|tls| tls.tcp()) else { + info!("{}{}", "🔒 ".mask(), "HTTP -> HTTPS Redirector:".magenta()); + warn_!("Main instance is not being served over TLS/TCP."); + warn_!("Redirector refusing to start."); + return; + }; + + let config = Config { + tls_addr, + server: rocket::Config { + log_level: LogLevel::Critical, + ..rocket.config().clone() + }, + }; + + let this = *self; + let shutdown = rocket.shutdown(); let _ = rocket::tokio::spawn(async move { if let Err(e) = this.try_launch(config).await { error!("Failed to start HTTP -> HTTPS redirector."); diff --git a/examples/tls/src/tests.rs b/examples/tls/src/tests.rs index 2629e3c487..61efbec9ff 100644 --- a/examples/tls/src/tests.rs +++ b/examples/tls/src/tests.rs @@ -1,11 +1,21 @@ use std::fs::{self, File}; +use rocket::http::{CookieJar, Cookie}; use rocket::local::blocking::Client; use rocket::fs::relative; +#[get("/cookie")] +fn cookie(jar: &CookieJar<'_>) { + jar.add(("k1", "v1")); + jar.add_private(("k2", "v2")); + + jar.add(Cookie::build(("k1u", "v1u")).secure(false)); + jar.add_private(Cookie::build(("k2u", "v2u")).secure(false)); +} + #[test] fn hello_mutual() { - let client = Client::tracked(super::rocket()).unwrap(); + let client = Client::tracked_secure(super::rocket()).unwrap(); let cert_paths = fs::read_dir(relative!("private")).unwrap() .map(|entry| entry.unwrap().path().to_string_lossy().into_owned()) .filter(|path| path.ends_with("_cert.pem") && !path.ends_with("ca_cert.pem")); @@ -23,35 +33,43 @@ fn hello_mutual() { #[test] fn secure_cookies() { - use rocket::http::{CookieJar, Cookie}; - - #[get("/cookie")] - fn cookie(jar: &CookieJar<'_>) { - jar.add(("k1", "v1")); - jar.add_private(("k2", "v2")); - - jar.add(Cookie::build(("k1u", "v1u")).secure(false)); - jar.add_private(Cookie::build(("k2u", "v2u")).secure(false)); - } + let rocket = super::rocket().mount("/", routes![cookie]); + let client = Client::tracked_secure(rocket).unwrap(); - let client = Client::tracked(super::rocket().mount("/", routes![cookie])).unwrap(); let response = client.get("/cookie").dispatch(); - let c1 = response.cookies().get("k1").unwrap(); - assert_eq!(c1.secure(), Some(true)); - let c2 = response.cookies().get_private("k2").unwrap(); + let c3 = response.cookies().get("k1u").unwrap(); + let c4 = response.cookies().get_private("k2u").unwrap(); + + assert_eq!(c1.secure(), Some(true)); assert_eq!(c2.secure(), Some(true)); + assert_ne!(c3.secure(), Some(true)); + assert_ne!(c4.secure(), Some(true)); +} - let c1 = response.cookies().get("k1u").unwrap(); - assert_ne!(c1.secure(), Some(true)); +#[test] +fn insecure_cookies() { + let rocket = super::rocket().mount("/", routes![cookie]); + let client = Client::tracked(rocket).unwrap(); + + let response = client.get("/cookie").dispatch(); + let c1 = response.cookies().get("k1").unwrap(); + let c2 = response.cookies().get_private("k2").unwrap(); + let c3 = response.cookies().get("k1u").unwrap(); + let c4 = response.cookies().get_private("k2u").unwrap(); - let c2 = response.cookies().get_private("k2u").unwrap(); - assert_ne!(c2.secure(), Some(true)); + assert_eq!(c1.secure(), None); + assert_eq!(c2.secure(), None); + assert_eq!(c3.secure(), None); + assert_eq!(c4.secure(), None); } #[test] fn hello_world() { + use rocket::listener::DefaultListener; + use rocket::config::{Config, SecretKey}; + let profiles = [ "rsa_sha256", "ecdsa_nistp256_sha256_pkcs8", @@ -61,11 +79,20 @@ fn hello_world() { "ed25519", ]; - // TODO: Testing doesn't actually read keys since we don't do TLS locally. for profile in profiles { - let config = rocket::Config::figment().select(profile); - let client = Client::tracked(super::rocket().configure(config)).unwrap(); + let config = Config { + secret_key: SecretKey::generate().unwrap(), + ..Config::debug_default() + }; + + let figment = Config::figment().merge(config).select(profile); + let client = Client::tracked_secure(super::rocket().configure(figment)).unwrap(); let response = client.get("/").dispatch(); assert_eq!(response.into_string().unwrap(), "Hello, world!"); + + let figment = client.rocket().figment(); + let listener: DefaultListener = figment.extract().unwrap(); + assert_eq!(figment.profile(), profile); + listener.tls.as_ref().unwrap().validate().expect("valid TLS config"); } } diff --git a/examples/upgrade/static/index.html b/examples/upgrade/static/index.html index f4bf469445..78c1add1b4 100644 --- a/examples/upgrade/static/index.html +++ b/examples/upgrade/static/index.html @@ -14,7 +14,7 @@

WebSocket Client Test