diff --git a/taxy/src/proxy/http/error.rs b/taxy/src/proxy/http/error.rs index 537ba4e..97d1f43 100644 --- a/taxy/src/proxy/http/error.rs +++ b/taxy/src/proxy/http/error.rs @@ -8,9 +8,6 @@ pub enum ProxyError { #[error("domain fronting detected")] DomainFrontingDetected, - #[error("dns lookup failed")] - DnsLookupFailed, - #[error("no route found")] NoRouteFound, } @@ -20,7 +17,6 @@ impl ProxyError { match self { Self::DomainFrontingDetected => StatusCode::MISDIRECTED_REQUEST, Self::NoRouteFound => StatusCode::BAD_GATEWAY, - Self::DnsLookupFailed => StatusCode::from_u16(523).unwrap(), } } } diff --git a/taxy/src/proxy/http/mod.rs b/taxy/src/proxy/http/mod.rs index 3df68d5..c51e7be 100644 --- a/taxy/src/proxy/http/mod.rs +++ b/taxy/src/proxy/http/mod.rs @@ -42,7 +42,6 @@ mod header; mod hyper_tls; mod pool; mod route; -mod upgrade; const MAX_BUFFER_SIZE: usize = 4096; const HTTP2_MAX_FRAME_SIZE: usize = 16384; diff --git a/taxy/src/proxy/http/pool.rs b/taxy/src/proxy/http/pool.rs index f022a88..966b177 100644 --- a/taxy/src/proxy/http/pool.rs +++ b/taxy/src/proxy/http/pool.rs @@ -2,23 +2,15 @@ use super::{ compression::{is_compressed, CompressionStream}, error::map_response, }; -use crate::proxy::http::{ - error::ProxyError, hyper_tls::client::HttpsConnector, upgrade, IoStream, HTTP2_MAX_FRAME_SIZE, -}; +use crate::proxy::http::{hyper_tls::client::HttpsConnector, HTTP2_MAX_FRAME_SIZE}; use hyper::{ - client::HttpConnector, - header::UPGRADE, - http::{uri::Scheme, HeaderValue}, - Body, Client, Request, Response, + client::HttpConnector, header::UPGRADE, http::HeaderValue, Body, Client, Request, Response, }; use std::sync::Arc; -use tokio::net::{self, TcpSocket}; -use tokio_rustls::{rustls::ClientConfig, TlsConnector}; -use tracing::{debug, error}; -use warp::host::Authority; +use tokio_rustls::rustls::ClientConfig; +use tracing::error; pub struct ConnectionPool { - tls_client_config: Arc, client: Client>, } @@ -28,23 +20,20 @@ impl ConnectionPool { let client = Client::builder() .http2_max_frame_size(Some(HTTP2_MAX_FRAME_SIZE as u32)) .build::<_, hyper::Body>(https); - - Self { - tls_client_config, - client, - } + Self { client } } pub async fn request(&self, mut req: Request) -> Result, anyhow::Error> { - let conn = Conn { - scheme: req.uri().scheme().unwrap().clone(), - authority: req.uri().authority().unwrap().clone(), + let upgrading_req = if req.headers().contains_key(UPGRADE) { + let mut cloned_req = Request::builder().uri(req.uri()).body(Body::empty())?; + cloned_req.headers_mut().clone_from(req.headers()); + let mut cloned_req = Some(cloned_req); + req = cloned_req.replace(req).unwrap(); + cloned_req + } else { + None }; - if req.headers().contains_key(UPGRADE) { - return start_upgrading_connection(conn, req, self.tls_client_config.clone()).await; - } - let accept_brotli = req .headers() .get(hyper::header::ACCEPT_ENCODING) @@ -53,9 +42,24 @@ impl ConnectionPool { *req.version_mut() = hyper::Version::HTTP_11; - let result: Result<_, anyhow::Error> = + let mut result: Result<_, anyhow::Error> = self.client.request(req).await.map_err(|err| err.into()); + match (&result, upgrading_req) { + (Ok(res), Some(upgrading_req)) + if res.status() == hyper::StatusCode::SWITCHING_PROTOCOLS => + { + let mut cloned_res = Response::builder().status(res.status()); + cloned_res.headers_mut().unwrap().clone_from(res.headers()); + let upgrading_res = + std::mem::replace(&mut result, Ok(cloned_res.body(Body::empty())?)).unwrap(); + tokio::spawn(async move { + upgrade_connection(upgrading_req, upgrading_res).await; + }); + } + _ => (), + } + let http2 = result .as_ref() .map(|res| res.version() == hyper::Version::HTTP_2) @@ -95,41 +99,15 @@ impl ConnectionPool { } } -#[derive(Debug, Clone, PartialEq, Eq, Hash)] -struct Conn { - scheme: Scheme, - authority: Authority, -} - -async fn start_upgrading_connection( - conn: Conn, - req: Request, - tls_client_config: Arc, -) -> Result, anyhow::Error> { - let resolved = net::lookup_host(conn.authority.as_str()) - .await - .map_err(|_| ProxyError::DnsLookupFailed)? - .next() - .ok_or(ProxyError::DnsLookupFailed)?; - debug!(authority = %conn.authority, %resolved); - - let sock = if resolved.is_ipv4() { - TcpSocket::new_v4() - } else { - TcpSocket::new_v6() - }?; - - let stream = sock.connect(resolved).await?; - debug!(%resolved, "connected"); - - let mut stream: Box = Box::new(stream); - if conn.scheme == Scheme::HTTPS { - debug!(%resolved, "client: tls handshake"); - let tls = TlsConnector::from(tls_client_config); - let host = conn.authority.host().to_string(); - let tls_stream = tls.connect(host.try_into().unwrap(), stream).await?; - stream = Box::new(tls_stream); - } - - upgrade::connect(req, stream).await +async fn upgrade_connection(req: Request, res: Response) { + match tokio::try_join!(hyper::upgrade::on(req), hyper::upgrade::on(res)) { + Ok((mut req, mut res)) => { + if let Err(err) = tokio::io::copy_bidirectional(&mut req, &mut res).await { + error!("upgraded io error: {}", err); + } + } + Err(err) => { + error!("upgrading io error: {}", err); + } + }; } diff --git a/taxy/src/proxy/http/upgrade.rs b/taxy/src/proxy/http/upgrade.rs deleted file mode 100644 index d7dbb30..0000000 --- a/taxy/src/proxy/http/upgrade.rs +++ /dev/null @@ -1,55 +0,0 @@ -use super::IoStream; -use hyper::{client, Body, Request, Response, StatusCode}; -use tokio::io::AsyncWriteExt; -use tracing::error; - -pub async fn connect( - req: Request, - stream: Box, -) -> anyhow::Result> { - let mut client_req = Request::builder().uri(req.uri()).body(Body::empty())?; - client_req.headers_mut().clone_from(req.headers()); - - let (mut sender, conn) = client::conn::Builder::new() - .handshake::<_, Body>(stream) - .await?; - - tokio::task::spawn(async move { - if let Err(err) = conn.await { - error!("Connection failed: {:?}", err); - } - }); - - let mut res = sender.send_request(client_req).await?; - if res.status() != StatusCode::SWITCHING_PROTOCOLS { - return Ok(res); - } - - let mut upgraded_client = match hyper::upgrade::on(&mut res).await { - Ok(upgraded) => upgraded, - Err(e) => { - error!("client upgrade error: {}", e); - return Ok(Response::builder() - .status(StatusCode::INTERNAL_SERVER_ERROR) - .body(Body::empty()) - .unwrap()); - } - }; - - tokio::spawn(async move { - match hyper::upgrade::on(req).await { - Ok(mut upgraded) => { - if let Err(err) = - tokio::io::copy_bidirectional(&mut upgraded_client, &mut upgraded).await - { - error!("upgraded io error: {}", err); - } - let _ = upgraded.shutdown().await; - let _ = upgraded_client.shutdown().await; - } - Err(e) => error!("server upgrade error: {}", e), - } - }); - - Ok(res) -}