From b6226fb7f8d216bff3eefb7043a3c40db894551b Mon Sep 17 00:00:00 2001 From: zzzzRuby Date: Wed, 13 Dec 2023 15:33:20 +0800 Subject: [PATCH] Add ws_proxy option for connecting gateway through a proxy. --- Cargo.toml | 1 + src/client/mod.rs | 29 +++++++++++++++++ src/gateway/bridge/shard_manager.rs | 4 +++ src/gateway/bridge/shard_queuer.rs | 3 ++ src/gateway/shard.rs | 14 +++++--- src/gateway/ws.rs | 50 ++++++++++++++++++++++++++--- 6 files changed, 92 insertions(+), 9 deletions(-) diff --git a/Cargo.toml b/Cargo.toml index bd8993ae953..4aa7297e7eb 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -37,6 +37,7 @@ nonmax = { version = "0.5.5", features = ["serde"] } strum = { version = "0.26", features = ["derive"] } to-arraystring = "0.1.0" extract_map = { version = "0.1.0", features = ["serde", "iter_mut"] } +async-http-proxy = { version = "1.2.5", features = ["runtime-tokio"] } # Optional dependencies fxhash = { version = "0.2.1", optional = true } chrono = { version = "0.4.31", default-features = false, features = ["clock", "serde"], optional = true } diff --git a/src/client/mod.rs b/src/client/mod.rs index 865cbfd7e12..4e740d12559 100644 --- a/src/client/mod.rs +++ b/src/client/mod.rs @@ -32,6 +32,7 @@ use futures::channel::mpsc::UnboundedReceiver as Receiver; use futures::future::BoxFuture; use futures::StreamExt as _; use tracing::debug; +use url::Url; pub use self::context::Context; pub use self::error::Error as ClientError; @@ -76,6 +77,7 @@ pub struct ClientBuilder { event_handlers: Vec>, raw_event_handlers: Vec>, presence: PresenceData, + ws_proxy: Option, } #[cfg(feature = "gateway")] @@ -110,6 +112,7 @@ impl ClientBuilder { event_handlers: vec![], raw_event_handlers: vec![], presence: PresenceData::default(), + ws_proxy: None, } } @@ -276,6 +279,18 @@ impl ClientBuilder { pub fn get_presence(&self) -> &PresenceData { &self.presence } + + /// Sets a http proxy for the websocket connection. + pub fn ws_proxy>(mut self, proxy: T) -> Self { + self.ws_proxy = Some(proxy.into()); + self + } + + /// Gets the websocket proxy. See [`Self::ws_proxy`] for more info. + #[must_use] + pub fn get_ws_proxy(&self) -> Option<&str> { + self.ws_proxy.as_deref() + } } #[cfg(feature = "gateway")] @@ -294,6 +309,7 @@ impl IntoFuture for ClientBuilder { let intents = self.intents; let presence = self.presence; let http = self.http; + let ws_proxy = self.ws_proxy; if let Some(ratelimiter) = &http.ratelimiter { let event_handlers_clone = event_handlers.clone(); @@ -325,6 +341,18 @@ impl IntoFuture for ClientBuilder { }, }; + let ws_proxy = match ws_proxy { + Some(proxy) => { + let parsed_proxy = Url::parse(&proxy).map_err(|why| { + tracing::warn!("Error building proxy URL with base `{}`: {:?}", proxy, why); + + Error::Gateway(GatewayError::BuildingUrl) + })?; + Some(Arc::new(parsed_proxy)) + }, + None => None, + }; + #[cfg(feature = "framework")] let framework_cell = Arc::new(OnceLock::new()); let (shard_manager, shard_manager_ret_value) = ShardManager::new(ShardManagerOptions { @@ -336,6 +364,7 @@ impl IntoFuture for ClientBuilder { #[cfg(feature = "voice")] voice_manager: voice_manager.clone(), ws_url: Arc::clone(&ws_url), + ws_proxy, shard_total, #[cfg(feature = "cache")] cache: Arc::clone(&cache), diff --git a/src/gateway/bridge/shard_manager.rs b/src/gateway/bridge/shard_manager.rs index 1d045b617a5..9966de02b51 100644 --- a/src/gateway/bridge/shard_manager.rs +++ b/src/gateway/bridge/shard_manager.rs @@ -10,6 +10,7 @@ use futures::{SinkExt, StreamExt}; use tokio::sync::Mutex; use tokio::time::timeout; use tracing::{info, warn}; +use url::Url; #[cfg(feature = "voice")] use super::VoiceGatewayManager; @@ -78,6 +79,7 @@ use crate::model::gateway::GatewayIntents; /// # #[cfg(feature = "voice")] /// # voice_manager: None, /// ws_url, +/// ws_proxy: None, /// shard_total, /// # #[cfg(feature = "cache")] /// # cache: unimplemented!(), @@ -141,6 +143,7 @@ impl ShardManager { #[cfg(feature = "voice")] voice_manager: opt.voice_manager, ws_url: opt.ws_url, + ws_proxy: opt.ws_proxy, shard_total: opt.shard_total, #[cfg(feature = "cache")] cache: opt.cache, @@ -366,6 +369,7 @@ pub struct ShardManagerOptions { #[cfg(feature = "voice")] pub voice_manager: Option>, pub ws_url: Arc, + pub ws_proxy: Option>, pub shard_total: NonZeroU16, #[cfg(feature = "cache")] pub cache: Arc, diff --git a/src/gateway/bridge/shard_queuer.rs b/src/gateway/bridge/shard_queuer.rs index 278bbb273c8..e7460f64205 100644 --- a/src/gateway/bridge/shard_queuer.rs +++ b/src/gateway/bridge/shard_queuer.rs @@ -9,6 +9,7 @@ use futures::StreamExt; use tokio::sync::Mutex; use tokio::time::{sleep, timeout, Duration, Instant}; use tracing::{debug, info, warn}; +use url::Url; #[cfg(feature = "voice")] use super::VoiceGatewayManager; @@ -71,6 +72,7 @@ pub struct ShardQueuer { pub voice_manager: Option>, /// A copy of the URL to use to connect to the gateway. pub ws_url: Arc, + pub ws_proxy: Option>, /// The total amount of shards to start. pub shard_total: NonZeroU16, #[cfg(feature = "cache")] @@ -214,6 +216,7 @@ impl ShardQueuer { async fn start(&mut self, shard_id: ShardId) -> Result<()> { let mut shard = Shard::new( Arc::clone(&self.ws_url), + self.ws_proxy.clone(), Arc::clone(self.http.token()), ShardInfo::new(shard_id, self.shard_total), self.intents, diff --git a/src/gateway/shard.rs b/src/gateway/shard.rs index e4c39339f13..6f17e554a1d 100644 --- a/src/gateway/shard.rs +++ b/src/gateway/shard.rs @@ -75,6 +75,7 @@ pub struct Shard { pub started: Instant, token: Secret, ws_url: Arc, + ws_proxy: Option>, pub intents: GatewayIntents, } @@ -108,7 +109,7 @@ impl Shard { /// /// // retrieve the gateway response, which contains the URL to connect to /// let gateway = Arc::from(http.get_gateway().await?.url); - /// let shard = Shard::new(gateway, token, shard_info, GatewayIntents::all(), None).await?; + /// let shard = Shard::new(gateway, None, token, shard_info, GatewayIntents::all(), None).await?; /// /// // at this point, you can create a `loop`, and receive events and match /// // their variants @@ -122,12 +123,13 @@ impl Shard { /// TLS error. pub async fn new( ws_url: Arc, + ws_proxy: Option>, token: Arc, shard_info: ShardInfo, intents: GatewayIntents, presence: Option, ) -> Result { - let client = connect(&ws_url).await?; + let client = connect(&ws_url, ws_proxy.as_deref()).await?; let presence = presence.unwrap_or_default(); let last_heartbeat_sent = None; @@ -153,6 +155,7 @@ impl Shard { session_id, shard_info, ws_url, + ws_proxy, intents, }) } @@ -703,7 +706,8 @@ impl Shard { // Hello is received. self.stage = ConnectionStage::Connecting; self.started = Instant::now(); - let client = connect(&self.ws_url).await?; + let proxy = self.ws_proxy.as_deref(); + let client = connect(&self.ws_url, proxy).await?; self.stage = ConnectionStage::Handshake; Ok(client) @@ -762,7 +766,7 @@ impl Shard { } } -async fn connect(base_url: &str) -> Result { +async fn connect(base_url: &str, proxy: Option<&Url>) -> Result { let url = Url::parse(&format!("{base_url}?v={}", constants::GATEWAY_VERSION)).map_err(|why| { warn!("Error building gateway URL with base `{}`: {:?}", base_url, why); @@ -770,5 +774,5 @@ async fn connect(base_url: &str) -> Result { Error::Gateway(GatewayError::BuildingUrl) })?; - WsClient::connect(url).await + WsClient::connect(url, proxy).await } diff --git a/src/gateway/ws.rs b/src/gateway/ws.rs index d479d8b6962..eb9c8f92ba5 100644 --- a/src/gateway/ws.rs +++ b/src/gateway/ws.rs @@ -1,4 +1,5 @@ use std::env::consts; +use std::io::ErrorKind; #[cfg(feature = "client")] use std::io::Read; use std::time::SystemTime; @@ -19,11 +20,16 @@ use tokio_tungstenite::tungstenite::protocol::WebSocketConfig; #[cfg(feature = "client")] use tokio_tungstenite::tungstenite::Error as WsError; use tokio_tungstenite::tungstenite::Message; -use tokio_tungstenite::{connect_async_with_config, MaybeTlsStream, WebSocketStream}; +use tokio_tungstenite::{ + client_async_tls_with_config, + connect_async_with_config, + MaybeTlsStream, + WebSocketStream, +}; #[cfg(feature = "client")] use tracing::warn; use tracing::{debug, trace}; -use url::Url; +use url::{Position, Url}; use super::{ActivityData, ChunkGuildFilter, PresenceData}; use crate::constants::{self, Opcode}; @@ -100,13 +106,49 @@ const TIMEOUT: Duration = Duration::from_millis(500); const DECOMPRESSION_MULTIPLIER: usize = 3; impl WsClient { - pub(crate) async fn connect(url: Url) -> Result { + async fn connect_with_proxy_async( + target_url: &Url, + proxy_url: &Url, + ) -> std::result::Result { + let proxy_addr = &proxy_url[Position::BeforeHost..Position::AfterPort]; + if proxy_url.scheme() != "http" && proxy_url.scheme() != "https" { + return Err(std::io::Error::new(ErrorKind::Unsupported, "unknown proxy scheme")); + } + + let host = target_url + .host_str() + .ok_or_else(|| std::io::Error::new(ErrorKind::Unsupported, "unknown target host"))?; + let port = target_url + .port() + .or_else(|| match target_url.scheme() { + "wss" => Some(443), + "ws" => Some(80), + _ => None, + }) + .ok_or_else(|| std::io::Error::new(ErrorKind::Unsupported, "unknown target scheme"))?; + let mut tcp_stream = TcpStream::connect(proxy_addr).await?; + + async_http_proxy::http_connect_tokio(&mut tcp_stream, host, port) + .await + .map_err(|_| std::io::Error::new(ErrorKind::Unsupported, "unsupported proxy"))?; + + Ok(tcp_stream) + } + + pub(crate) async fn connect(url: Url, proxy: Option<&Url>) -> Result { let config = WebSocketConfig { max_message_size: None, max_frame_size: None, ..Default::default() }; - let (stream, _) = connect_async_with_config(url, Some(config), false).await?; + let (stream, _) = match proxy { + None => connect_async_with_config(url, Some(config), false).await?, + Some(proxy) => { + let tls_stream = Self::connect_with_proxy_async(&url, proxy).await?; + tls_stream.set_nodelay(true)?; + client_async_tls_with_config(url, tls_stream, Some(config), None).await? + }, + }; Ok(Self(stream)) }