diff --git a/Cargo.toml b/Cargo.toml index 68799b5ad..69701f5f6 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -68,7 +68,7 @@ once_cell = "1.16.0" opentelemetry = "0.21.0" opentelemetry_sdk = { version = "0.21.0", features = ["rt-tokio", "logs"] } opentelemetry-http = "0.10.0" -opentelemetry-otlp = { version = "0.14.0", features = ["logs", "grpc-tonic"] } +opentelemetry-otlp = { version = "0.14.0", features = ["logs", "grpc-tonic"] } opentelemetry-proto = "0.4.0" opentelemetry-contrib = { version = "0.4.0", features = ["datadog"] } opentelemetry-appender-tracing = "0.2.0" @@ -104,7 +104,7 @@ tracing-core = { version = "0.1.32", default-features = false } tracing-opentelemetry = "0.22.0" tracing-subscriber = { version = "0.3.16", default-features = false, features = [ "registry", - "json" + "json", ] } ttl_cache = "0.5.1" ulid = "1.0.0" diff --git a/docker-compose.yml b/docker-compose.yml index 6c09021a4..41a5ceff5 100644 --- a/docker-compose.yml +++ b/docker-compose.yml @@ -54,7 +54,7 @@ services: - "--stripe-secret-key=${STRIPE_SECRET_KEY}" - "--jwt-signing-private-key=${AUTH_JWTSIGNING_PRIVATE_KEY}" healthcheck: - test: curl --fail http://localhost:8000/ || exit 1 + test: curl -f -s http://localhost:8000 interval: 1m timeout: 10s retries: 5 @@ -150,7 +150,7 @@ services: - "--use-tls=${USE_TLS}" - "--admin-key=${GATEWAY_ADMIN_KEY}" healthcheck: - test: ["CMD", "curl", "-f", "http://localhost:8001"] + test: curl -f -s http://localhost:8001 interval: 1m timeout: 15s retries: 15 diff --git a/gateway/src/acme.rs b/gateway/src/acme.rs index 5e543c2f9..1587489e8 100644 --- a/gateway/src/acme.rs +++ b/gateway/src/acme.rs @@ -1,14 +1,8 @@ use std::collections::HashMap; use std::sync::Arc; -use std::task::{Context, Poll}; use std::time::Duration; -use axum::body::boxed; -use axum::response::Response; use fqdn::FQDN; -use futures::future::BoxFuture; -use hyper::server::conn::AddrStream; -use hyper::{Body, Request}; use instant_acme::{ Account, AccountCredentials, Authorization, AuthorizationStatus, Challenge, ChallengeType, Identifier, KeyAuthorization, LetsEncrypt, NewAccount, NewOrder, Order, OrderStatus, @@ -17,12 +11,8 @@ use rcgen::{Certificate, CertificateParams, DistinguishedName}; use shuttle_common::models::project::ProjectName; use tokio::sync::Mutex; use tokio::time::sleep; -use tower::{Layer, Service}; use tracing::{error, trace, warn}; -use crate::proxy::AsResponderTo; -use crate::Error; - const MAX_RETRIES: usize = 15; const MAX_RETRIES_CERTIFICATE_FETCHING: usize = 5; @@ -49,7 +39,7 @@ impl AcmeClient { self.0.lock().await.insert(token, key); } - async fn get_http01_challenge_authorization(&self, token: &str) -> Option { + pub async fn get_http01_challenge_authorization(&self, token: &str) -> Option { self.0 .lock() .await @@ -328,97 +318,3 @@ pub enum AcmeClientError { } impl std::error::Error for AcmeClientError {} - -pub struct ChallengeResponderLayer { - client: AcmeClient, -} - -impl ChallengeResponderLayer { - pub fn new(client: AcmeClient) -> Self { - Self { client } - } -} - -impl Layer for ChallengeResponderLayer { - type Service = ChallengeResponder; - - fn layer(&self, inner: S) -> Self::Service { - ChallengeResponder { - client: self.client.clone(), - inner, - } - } -} - -pub struct ChallengeResponder { - client: AcmeClient, - inner: S, -} - -impl<'r, S> AsResponderTo<&'r AddrStream> for ChallengeResponder -where - S: AsResponderTo<&'r AddrStream>, -{ - fn as_responder_to(&self, req: &'r AddrStream) -> Self { - Self { - client: self.client.clone(), - inner: self.inner.as_responder_to(req), - } - } -} - -impl Service> for ChallengeResponder -where - S: Service, Response = Response, Error = Error> + Send + 'static, - S::Future: Send + 'static, -{ - type Response = S::Response; - type Error = S::Error; - type Future = BoxFuture<'static, Result>; - - fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll> { - self.inner.poll_ready(cx) - } - - fn call(&mut self, req: Request) -> Self::Future { - if !req.uri().path().starts_with("/.well-known/acme-challenge/") { - let future = self.inner.call(req); - return Box::pin(async move { - let response: Response = future.await?; - Ok(response) - }); - } - - let token = match req - .uri() - .path() - .strip_prefix("/.well-known/acme-challenge/") - { - Some(token) => token.to_string(), - None => { - return Box::pin(async { - Ok(Response::builder() - .status(404) - .body(boxed(Body::empty())) - .unwrap()) - }) - } - }; - - trace!(token, "responding to certificate challenge"); - - let client = self.client.clone(); - - Box::pin(async move { - let (status, body) = match client.get_http01_challenge_authorization(&token).await { - Some(key) => (200, Body::from(key)), - None => (404, Body::empty()), - }; - - Ok(Response::builder() - .status(status) - .body(boxed(body)) - .unwrap()) - }) - } -} diff --git a/gateway/src/proxy.rs b/gateway/src/proxy.rs index ca6c02481..9785baba6 100644 --- a/gateway/src/proxy.rs +++ b/gateway/src/proxy.rs @@ -1,38 +1,36 @@ -use std::convert::Infallible; use std::future::Future; use std::io; -use std::net::SocketAddr; -use std::pin::Pin; +use std::net::{IpAddr, SocketAddr}; use std::sync::Arc; -use std::task::{Context, Poll}; +use axum::extract::{ConnectInfo, Path, State}; use axum::headers::{HeaderMapExt, Host}; -use axum::response::{IntoResponse, Response}; +use axum::response::Response; +use axum::routing::any; use axum_server::accept::DefaultAcceptor; use axum_server::tls_rustls::RustlsAcceptor; use fqdn::{fqdn, FQDN}; -use futures::future::{ready, Ready}; use futures::prelude::*; use http::header::SERVER; -use http::HeaderValue; +use http::{HeaderValue, StatusCode}; use hyper::body::{Body, HttpBody}; use hyper::client::connect::dns::GaiResolver; use hyper::client::HttpConnector; -use hyper::server::conn::AddrStream; use hyper::{Client, Request}; use hyper_reverse_proxy::ReverseProxy; use once_cell::sync::Lazy; use opentelemetry::global; use opentelemetry_http::HeaderInjector; +use shuttle_common::backends::cache::{CacheManagement, CacheManager}; use shuttle_common::backends::headers::XShuttleProject; use shuttle_common::models::error::InvalidProjectName; +use shuttle_common::models::project::ProjectName; use tokio::sync::mpsc::Sender; -use tower::{Service, ServiceBuilder}; use tower_sanitize_path::SanitizePath; use tracing::{debug_span, error, field, trace}; use tracing_opentelemetry::OpenTelemetrySpanExt; -use crate::acme::{AcmeClient, ChallengeResponderLayer, CustomDomain}; +use crate::acme::AcmeClient; use crate::service::GatewayService; use crate::task::BoxedTask; use crate::{Error, ErrorKind}; @@ -41,159 +39,100 @@ static PROXY_CLIENT: Lazy>> = Lazy::new(|| ReverseProxy::new(Client::new())); static SERVER_HEADER: Lazy = Lazy::new(|| "shuttle.rs".parse().unwrap()); -pub trait AsResponderTo { - fn as_responder_to(&self, req: R) -> Self; - - fn into_make_service(self) -> ResponderMakeService - where - Self: Sized, - { - ResponderMakeService { inner: self } - } -} - -pub struct ResponderMakeService { - inner: S, -} - -impl<'r, S> Service<&'r AddrStream> for ResponderMakeService -where - S: AsResponderTo<&'r AddrStream>, -{ - type Response = S; - type Error = Infallible; - type Future = Ready>; - - fn poll_ready(&mut self, _cx: &mut Context<'_>) -> Poll> { - Poll::Ready(Ok(())) - } - - fn call(&mut self, req: &'r AddrStream) -> Self::Future { - ready(Ok(self.inner.as_responder_to(req))) - } -} - -#[derive(Clone)] -pub struct UserProxy { +pub struct ProxyState { gateway: Arc, task_sender: Sender, - remote_addr: SocketAddr, public: FQDN, + project_cache: CacheManager, + domain_cache: CacheManager, } -impl<'r> AsResponderTo<&'r AddrStream> for UserProxy { - fn as_responder_to(&self, addr_stream: &'r AddrStream) -> Self { - let mut responder = self.clone(); - responder.remote_addr = addr_stream.remote_addr(); - responder - } -} +async fn proxy( + ConnectInfo(addr): ConnectInfo, + State(state): State>, + mut req: Request, +) -> Result { + let span = debug_span!("proxy", http.method = %req.method(), http.host = field::Empty, http.uri = %req.uri(), http.status_code = field::Empty, shuttle.project.name = field::Empty); + trace!(?req, "serving proxy request"); -impl AsResponderTo for SanitizePath -where - S: AsResponderTo + Clone, -{ - fn as_responder_to(&self, req: R) -> Self { - let responder = self.clone(); - responder.inner().as_responder_to(req); + let fqdn = req + .headers() + .typed_get::() + .map(|host| fqdn!(host.hostname())) + .ok_or_else(|| Error::from_kind(ErrorKind::BadHost))?; - responder - } -} + span.record("http.host", fqdn.to_string()); -impl UserProxy { - async fn proxy( - self, - task_sender: Sender, - mut req: Request, - ) -> Result { - let span = debug_span!("proxy", http.method = %req.method(), http.host = field::Empty, http.uri = %req.uri(), http.status_code = field::Empty, shuttle.project.name = field::Empty); - trace!(?req, "serving proxy request"); - - let fqdn = req - .headers() - .typed_get::() - .map(|host| fqdn!(host.hostname())) - .ok_or_else(|| Error::from_kind(ErrorKind::BadHost))?; - - span.record("http.host", fqdn.to_string()); - - let project_name = if fqdn.is_subdomain_of(&self.public) - && fqdn.depth() - self.public.depth() == 1 - { + let project_name = + if fqdn.is_subdomain_of(&state.public) && fqdn.depth() - state.public.depth() == 1 { fqdn.labels() .next() .unwrap() .to_owned() .parse() .map_err(|_| Error::from_kind(ErrorKind::InvalidProjectName(InvalidProjectName)))? - } else if let Ok(CustomDomain { project_name, .. }) = - self.gateway.project_details_for_custom_domain(&fqdn).await - { - project_name + } else if let Some(project) = { state.domain_cache.get(fqdn.to_string().as_str()) } { + project } else { - return Err(Error::from_kind(ErrorKind::CustomDomainNotFound)); + let project_name = state + .gateway + .project_details_for_custom_domain(&fqdn) + .await? + .project_name; + state.domain_cache.insert( + fqdn.to_string().as_str(), + project_name.clone(), + std::time::Duration::from_millis(5000), + ); + project_name }; - req.headers_mut() - .typed_insert(XShuttleProject(project_name.to_string())); + // Record current project for tracing purposes + span.record("shuttle.project.name", &project_name.to_string()); - let project = self - .gateway - .find_or_start_project(&project_name, task_sender) - .await?; + req.headers_mut() + .typed_insert(XShuttleProject(project_name.to_string())); - // Record current project for tracing purposes - span.record("shuttle.project.name", &project_name.to_string()); - - let target_ip = project + // cache project ip lookups to not overload the db during rapid requests + let target_ip = if let Some(ip) = { state.project_cache.get(project_name.as_str()) } { + ip + } else { + let ip = state + .gateway + .find_or_start_project(&project_name, state.task_sender.clone()) + .await? .state .target_ip()? .ok_or_else(|| Error::from_kind(ErrorKind::ProjectNotReady))?; - - let target_url = format!("http://{}:{}", target_ip, 8000); - - let cx = span.context(); - - global::get_text_map_propagator(|propagator| { - propagator.inject_context(&cx, &mut HeaderInjector(req.headers_mut())) - }); - - let mut res = PROXY_CLIENT - .call(self.remote_addr.ip(), &target_url, req) - .await - .map_err(|e| { - error!(error = ?e, "gateway proxy client error"); - Error::from_kind(ErrorKind::ProjectUnavailable) - })?; - - res.headers_mut().insert(SERVER, SERVER_HEADER.clone()); - let (parts, body) = res.into_parts(); - let body = ::map_err(body, axum::Error::new).boxed_unsync(); - - span.record("http.status_code", parts.status.as_u16()); - - Ok(Response::from_parts(parts, body)) - } -} - -impl Service> for UserProxy { - type Response = Response; - type Error = Error; - type Future = - Pin> + Send + 'static>>; - - fn poll_ready(&mut self, _cx: &mut Context<'_>) -> Poll> { - Poll::Ready(Ok(())) - } - - fn call(&mut self, req: Request) -> Self::Future { - let task_sender = self.task_sender.clone(); - self.clone() - .proxy(task_sender, req) - .or_else(|err: Error| future::ready(Ok(err.into_response()))) - .boxed() - } + state.project_cache.insert( + project_name.as_str(), + ip, + std::time::Duration::from_millis(1000), + ); + ip + }; + let target_url = format!("http://{}:{}", target_ip, 8000); + + let cx = span.context(); + global::get_text_map_propagator(|propagator| { + propagator.inject_context(&cx, &mut HeaderInjector(req.headers_mut())) + }); + + let mut res = PROXY_CLIENT + .call(addr.ip(), &target_url, req) + .await + .map_err(|e| { + error!(error = ?e, "gateway proxy client error"); + Error::from_kind(ErrorKind::ProjectUnavailable) + })?; + + res.headers_mut().insert(SERVER, SERVER_HEADER.clone()); + let (parts, body) = res.into_parts(); + let body = ::map_err(body, axum::Error::new).boxed_unsync(); + + span.record("http.status_code", parts.status.as_u16()); + + Ok(Response::from_parts(parts, body)) } #[derive(Clone)] @@ -202,55 +141,32 @@ pub struct Bouncer { public: FQDN, } -impl<'r> AsResponderTo<&'r AddrStream> for Bouncer { - fn as_responder_to(&self, _req: &'r AddrStream) -> Self { - self.clone() - } -} - -impl Bouncer { - async fn bounce(self, req: Request) -> Result { - let mut resp = Response::builder(); +async fn bounce(State(state): State>, req: Request) -> Result { + let mut resp = Response::builder(); - let host = req.headers().typed_get::().unwrap(); - let hostname = host.hostname(); - let fqdn = fqdn!(hostname); + let host = req.headers().typed_get::().unwrap(); + let hostname = host.hostname(); + let fqdn = fqdn!(hostname); - let path = req.uri(); + let path = req.uri(); - if fqdn.is_subdomain_of(&self.public) - || self - .gateway - .project_details_for_custom_domain(&fqdn) - .await - .is_ok() - { - resp = resp - .status(301) - .header("Location", format!("https://{hostname}{path}")); - } else { - resp = resp.status(404); - } - - let body = ::map_err(Body::empty(), axum::Error::new).boxed_unsync(); - - Ok(resp.body(body).unwrap()) + if fqdn.is_subdomain_of(&state.public) + || state + .gateway + .project_details_for_custom_domain(&fqdn) + .await + .is_ok() + { + resp = resp + .status(301) + .header("Location", format!("https://{hostname}{path}")); + } else { + resp = resp.status(404); } -} -impl Service> for Bouncer { - type Response = Response; - type Error = Error; - type Future = - Pin> + Send + 'static>>; - - fn poll_ready(&mut self, _cx: &mut Context<'_>) -> Poll> { - Poll::Ready(Ok(())) - } + let body = ::map_err(Body::empty(), axum::Error::new).boxed_unsync(); - fn call(&mut self, req: Request) -> Self::Future { - self.clone().bounce(req).boxed() - } + Ok(resp.body(body).unwrap()) } pub struct UserServiceBuilder { @@ -325,17 +241,26 @@ impl UserServiceBuilder { .user_binds_to .expect("a socket address to bind to is required"); - let user_proxy = SanitizePath::sanitize_paths(UserProxy { - gateway: service.clone(), - task_sender, - remote_addr: "127.0.0.1:80".parse().unwrap(), - public: public.clone(), - }) - .into_make_service(); - - let bouncer = self.bouncer_binds_to.as_ref().map(|_| Bouncer { - gateway: service.clone(), - public: public.clone(), + let san = SanitizePath::sanitize_paths( + axum::Router::new() + .fallback(proxy) // catch all routes + .with_state(Arc::new(ProxyState { + gateway: service.clone(), + task_sender, + public: public.clone(), + project_cache: CacheManager::new(1024), + domain_cache: CacheManager::new(256), + })), + ); + let user_proxy = axum::ServiceExt::into_make_service_with_connect_info::(san); + + let bouncer = self.bouncer_binds_to.as_ref().map(|_| { + axum::Router::new() + .fallback(bounce) // catch all routes + .with_state(Arc::new(Bouncer { + gateway: service.clone(), + public: public.clone(), + })) }); let mut futs = Vec::new(); @@ -348,9 +273,21 @@ impl UserServiceBuilder { .acme .expect("TLS cannot be enabled without an ACME client"); - let bouncer = ServiceBuilder::new() - .layer(ChallengeResponderLayer::new(acme)) - .service(bouncer); + let bouncer = axum::Router::new() + .route( + "/.well-known/acme-challenge/*rest", + any( + |Path(token): Path, State(client): State| async move { + trace!(token, "responding to certificate challenge"); + match client.get_http01_challenge_authorization(&token).await { + Some(key) => Ok(key), + None => Err(StatusCode::NOT_FOUND), + } + }, + ), + ) + .with_state(acme) + .merge(bouncer); let bouncer = axum_server::Server::bind(bouncer_binds_to) .serve(bouncer.into_make_service())