Skip to content

Commit

Permalink
websocket: upgrade protocol using the exising connection
Browse files Browse the repository at this point in the history
  • Loading branch information
picoHz committed Apr 20, 2024
1 parent e06c4d9 commit 8af9ccf
Show file tree
Hide file tree
Showing 4 changed files with 40 additions and 122 deletions.
4 changes: 0 additions & 4 deletions taxy/src/proxy/http/error.rs
Expand Up @@ -8,9 +8,6 @@ pub enum ProxyError {
#[error("domain fronting detected")]
DomainFrontingDetected,

#[error("dns lookup failed")]
DnsLookupFailed,

#[error("no route found")]
NoRouteFound,
}
Expand All @@ -20,7 +17,6 @@ impl ProxyError {
match self {
Self::DomainFrontingDetected => StatusCode::MISDIRECTED_REQUEST,
Self::NoRouteFound => StatusCode::BAD_GATEWAY,
Self::DnsLookupFailed => StatusCode::from_u16(523).unwrap(),
}
}
}
Expand Down
1 change: 0 additions & 1 deletion taxy/src/proxy/http/mod.rs
Expand Up @@ -42,7 +42,6 @@ mod header;
mod hyper_tls;
mod pool;
mod route;
mod upgrade;

const MAX_BUFFER_SIZE: usize = 4096;
const HTTP2_MAX_FRAME_SIZE: usize = 16384;
Expand Down
102 changes: 40 additions & 62 deletions taxy/src/proxy/http/pool.rs
Expand Up @@ -2,23 +2,15 @@ use super::{
compression::{is_compressed, CompressionStream},
error::map_response,
};
use crate::proxy::http::{
error::ProxyError, hyper_tls::client::HttpsConnector, upgrade, IoStream, HTTP2_MAX_FRAME_SIZE,
};
use crate::proxy::http::{hyper_tls::client::HttpsConnector, HTTP2_MAX_FRAME_SIZE};
use hyper::{
client::HttpConnector,
header::UPGRADE,
http::{uri::Scheme, HeaderValue},
Body, Client, Request, Response,
client::HttpConnector, header::UPGRADE, http::HeaderValue, Body, Client, Request, Response,
};
use std::sync::Arc;
use tokio::net::{self, TcpSocket};
use tokio_rustls::{rustls::ClientConfig, TlsConnector};
use tracing::{debug, error};
use warp::host::Authority;
use tokio_rustls::rustls::ClientConfig;
use tracing::error;

pub struct ConnectionPool {
tls_client_config: Arc<ClientConfig>,
client: Client<HttpsConnector<HttpConnector>>,
}

Expand All @@ -28,23 +20,20 @@ impl ConnectionPool {
let client = Client::builder()
.http2_max_frame_size(Some(HTTP2_MAX_FRAME_SIZE as u32))
.build::<_, hyper::Body>(https);

Self {
tls_client_config,
client,
}
Self { client }
}

pub async fn request(&self, mut req: Request<Body>) -> Result<Response<Body>, anyhow::Error> {
let conn = Conn {
scheme: req.uri().scheme().unwrap().clone(),
authority: req.uri().authority().unwrap().clone(),
let upgrading_req = if req.headers().contains_key(UPGRADE) {
let mut cloned_req = Request::builder().uri(req.uri()).body(Body::empty())?;
cloned_req.headers_mut().clone_from(req.headers());
let mut cloned_req = Some(cloned_req);
req = cloned_req.replace(req).unwrap();
cloned_req
} else {
None
};

if req.headers().contains_key(UPGRADE) {
return start_upgrading_connection(conn, req, self.tls_client_config.clone()).await;
}

let accept_brotli = req
.headers()
.get(hyper::header::ACCEPT_ENCODING)
Expand All @@ -53,9 +42,24 @@ impl ConnectionPool {

*req.version_mut() = hyper::Version::HTTP_11;

let result: Result<_, anyhow::Error> =
let mut result: Result<_, anyhow::Error> =
self.client.request(req).await.map_err(|err| err.into());

match (&result, upgrading_req) {
(Ok(res), Some(upgrading_req))
if res.status() == hyper::StatusCode::SWITCHING_PROTOCOLS =>
{
let mut cloned_res = Response::builder().status(res.status());
cloned_res.headers_mut().unwrap().clone_from(res.headers());
let upgrading_res =
std::mem::replace(&mut result, Ok(cloned_res.body(Body::empty())?)).unwrap();
tokio::spawn(async move {
upgrade_connection(upgrading_req, upgrading_res).await;
});
}
_ => (),
}

let http2 = result
.as_ref()
.map(|res| res.version() == hyper::Version::HTTP_2)
Expand Down Expand Up @@ -95,41 +99,15 @@ impl ConnectionPool {
}
}

#[derive(Debug, Clone, PartialEq, Eq, Hash)]
struct Conn {
scheme: Scheme,
authority: Authority,
}

async fn start_upgrading_connection(
conn: Conn,
req: Request<Body>,
tls_client_config: Arc<ClientConfig>,
) -> Result<Response<Body>, anyhow::Error> {
let resolved = net::lookup_host(conn.authority.as_str())
.await
.map_err(|_| ProxyError::DnsLookupFailed)?
.next()
.ok_or(ProxyError::DnsLookupFailed)?;
debug!(authority = %conn.authority, %resolved);

let sock = if resolved.is_ipv4() {
TcpSocket::new_v4()
} else {
TcpSocket::new_v6()
}?;

let stream = sock.connect(resolved).await?;
debug!(%resolved, "connected");

let mut stream: Box<dyn IoStream> = Box::new(stream);
if conn.scheme == Scheme::HTTPS {
debug!(%resolved, "client: tls handshake");
let tls = TlsConnector::from(tls_client_config);
let host = conn.authority.host().to_string();
let tls_stream = tls.connect(host.try_into().unwrap(), stream).await?;
stream = Box::new(tls_stream);
}

upgrade::connect(req, stream).await
async fn upgrade_connection(req: Request<Body>, res: Response<Body>) {
match tokio::try_join!(hyper::upgrade::on(req), hyper::upgrade::on(res)) {
Ok((mut req, mut res)) => {
if let Err(err) = tokio::io::copy_bidirectional(&mut req, &mut res).await {
error!("upgraded io error: {}", err);
}
}
Err(err) => {
error!("upgrading io error: {}", err);
}
};
}
55 changes: 0 additions & 55 deletions taxy/src/proxy/http/upgrade.rs

This file was deleted.

0 comments on commit 8af9ccf

Please sign in to comment.