Skip to content

Commit

Permalink
Add ws_proxy option for connecting gateway through a proxy.
Browse files Browse the repository at this point in the history
  • Loading branch information
zzzzRuby committed Mar 19, 2024
1 parent 3ab58dc commit b6226fb
Show file tree
Hide file tree
Showing 6 changed files with 92 additions and 9 deletions.
1 change: 1 addition & 0 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@ nonmax = { version = "0.5.5", features = ["serde"] }
strum = { version = "0.26", features = ["derive"] }
to-arraystring = "0.1.0"
extract_map = { version = "0.1.0", features = ["serde", "iter_mut"] }
async-http-proxy = { version = "1.2.5", features = ["runtime-tokio"] }
# Optional dependencies
fxhash = { version = "0.2.1", optional = true }
chrono = { version = "0.4.31", default-features = false, features = ["clock", "serde"], optional = true }
Expand Down
29 changes: 29 additions & 0 deletions src/client/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ use futures::channel::mpsc::UnboundedReceiver as Receiver;
use futures::future::BoxFuture;
use futures::StreamExt as _;
use tracing::debug;
use url::Url;

pub use self::context::Context;
pub use self::error::Error as ClientError;
Expand Down Expand Up @@ -76,6 +77,7 @@ pub struct ClientBuilder {
event_handlers: Vec<Arc<dyn EventHandler>>,
raw_event_handlers: Vec<Arc<dyn RawEventHandler>>,
presence: PresenceData,
ws_proxy: Option<String>,
}

#[cfg(feature = "gateway")]
Expand Down Expand Up @@ -110,6 +112,7 @@ impl ClientBuilder {
event_handlers: vec![],
raw_event_handlers: vec![],
presence: PresenceData::default(),
ws_proxy: None,
}
}

Expand Down Expand Up @@ -276,6 +279,18 @@ impl ClientBuilder {
pub fn get_presence(&self) -> &PresenceData {
&self.presence
}

/// Sets a http proxy for the websocket connection.
pub fn ws_proxy<T: Into<String>>(mut self, proxy: T) -> Self {
self.ws_proxy = Some(proxy.into());
self
}

/// Gets the websocket proxy. See [`Self::ws_proxy`] for more info.
#[must_use]
pub fn get_ws_proxy(&self) -> Option<&str> {
self.ws_proxy.as_deref()
}
}

#[cfg(feature = "gateway")]
Expand All @@ -294,6 +309,7 @@ impl IntoFuture for ClientBuilder {
let intents = self.intents;
let presence = self.presence;
let http = self.http;
let ws_proxy = self.ws_proxy;

if let Some(ratelimiter) = &http.ratelimiter {
let event_handlers_clone = event_handlers.clone();
Expand Down Expand Up @@ -325,6 +341,18 @@ impl IntoFuture for ClientBuilder {
},
};

let ws_proxy = match ws_proxy {
Some(proxy) => {
let parsed_proxy = Url::parse(&proxy).map_err(|why| {
tracing::warn!("Error building proxy URL with base `{}`: {:?}", proxy, why);

Error::Gateway(GatewayError::BuildingUrl)
})?;
Some(Arc::new(parsed_proxy))
},
None => None,
};

#[cfg(feature = "framework")]
let framework_cell = Arc::new(OnceLock::new());
let (shard_manager, shard_manager_ret_value) = ShardManager::new(ShardManagerOptions {
Expand All @@ -336,6 +364,7 @@ impl IntoFuture for ClientBuilder {
#[cfg(feature = "voice")]
voice_manager: voice_manager.clone(),
ws_url: Arc::clone(&ws_url),
ws_proxy,
shard_total,
#[cfg(feature = "cache")]
cache: Arc::clone(&cache),
Expand Down
4 changes: 4 additions & 0 deletions src/gateway/bridge/shard_manager.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ use futures::{SinkExt, StreamExt};
use tokio::sync::Mutex;
use tokio::time::timeout;
use tracing::{info, warn};
use url::Url;

#[cfg(feature = "voice")]
use super::VoiceGatewayManager;
Expand Down Expand Up @@ -78,6 +79,7 @@ use crate::model::gateway::GatewayIntents;
/// # #[cfg(feature = "voice")]
/// # voice_manager: None,
/// ws_url,
/// ws_proxy: None,
/// shard_total,
/// # #[cfg(feature = "cache")]
/// # cache: unimplemented!(),
Expand Down Expand Up @@ -141,6 +143,7 @@ impl ShardManager {
#[cfg(feature = "voice")]
voice_manager: opt.voice_manager,
ws_url: opt.ws_url,
ws_proxy: opt.ws_proxy,
shard_total: opt.shard_total,
#[cfg(feature = "cache")]
cache: opt.cache,
Expand Down Expand Up @@ -366,6 +369,7 @@ pub struct ShardManagerOptions {
#[cfg(feature = "voice")]
pub voice_manager: Option<Arc<dyn VoiceGatewayManager>>,
pub ws_url: Arc<str>,
pub ws_proxy: Option<Arc<Url>>,
pub shard_total: NonZeroU16,
#[cfg(feature = "cache")]
pub cache: Arc<Cache>,
Expand Down
3 changes: 3 additions & 0 deletions src/gateway/bridge/shard_queuer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ use futures::StreamExt;
use tokio::sync::Mutex;
use tokio::time::{sleep, timeout, Duration, Instant};
use tracing::{debug, info, warn};
use url::Url;

#[cfg(feature = "voice")]
use super::VoiceGatewayManager;
Expand Down Expand Up @@ -71,6 +72,7 @@ pub struct ShardQueuer {
pub voice_manager: Option<Arc<dyn VoiceGatewayManager + 'static>>,
/// A copy of the URL to use to connect to the gateway.
pub ws_url: Arc<str>,
pub ws_proxy: Option<Arc<Url>>,
/// The total amount of shards to start.
pub shard_total: NonZeroU16,
#[cfg(feature = "cache")]
Expand Down Expand Up @@ -214,6 +216,7 @@ impl ShardQueuer {
async fn start(&mut self, shard_id: ShardId) -> Result<()> {
let mut shard = Shard::new(
Arc::clone(&self.ws_url),
self.ws_proxy.clone(),
Arc::clone(self.http.token()),
ShardInfo::new(shard_id, self.shard_total),
self.intents,
Expand Down
14 changes: 9 additions & 5 deletions src/gateway/shard.rs
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,7 @@ pub struct Shard {
pub started: Instant,
token: Secret<Token>,
ws_url: Arc<str>,
ws_proxy: Option<Arc<Url>>,
pub intents: GatewayIntents,
}

Expand Down Expand Up @@ -108,7 +109,7 @@ impl Shard {
///
/// // retrieve the gateway response, which contains the URL to connect to
/// let gateway = Arc::from(http.get_gateway().await?.url);
/// let shard = Shard::new(gateway, token, shard_info, GatewayIntents::all(), None).await?;
/// let shard = Shard::new(gateway, None, token, shard_info, GatewayIntents::all(), None).await?;
///
/// // at this point, you can create a `loop`, and receive events and match
/// // their variants
Expand All @@ -122,12 +123,13 @@ impl Shard {
/// TLS error.
pub async fn new(
ws_url: Arc<str>,
ws_proxy: Option<Arc<Url>>,
token: Arc<str>,
shard_info: ShardInfo,
intents: GatewayIntents,
presence: Option<PresenceData>,
) -> Result<Shard> {
let client = connect(&ws_url).await?;
let client = connect(&ws_url, ws_proxy.as_deref()).await?;

let presence = presence.unwrap_or_default();
let last_heartbeat_sent = None;
Expand All @@ -153,6 +155,7 @@ impl Shard {
session_id,
shard_info,
ws_url,
ws_proxy,
intents,
})
}
Expand Down Expand Up @@ -703,7 +706,8 @@ impl Shard {
// Hello is received.
self.stage = ConnectionStage::Connecting;
self.started = Instant::now();
let client = connect(&self.ws_url).await?;
let proxy = self.ws_proxy.as_deref();
let client = connect(&self.ws_url, proxy).await?;
self.stage = ConnectionStage::Handshake;

Ok(client)
Expand Down Expand Up @@ -762,13 +766,13 @@ impl Shard {
}
}

async fn connect(base_url: &str) -> Result<WsClient> {
async fn connect(base_url: &str, proxy: Option<&Url>) -> Result<WsClient> {
let url =
Url::parse(&format!("{base_url}?v={}", constants::GATEWAY_VERSION)).map_err(|why| {
warn!("Error building gateway URL with base `{}`: {:?}", base_url, why);

Error::Gateway(GatewayError::BuildingUrl)
})?;

WsClient::connect(url).await
WsClient::connect(url, proxy).await
}
50 changes: 46 additions & 4 deletions src/gateway/ws.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
use std::env::consts;
use std::io::ErrorKind;
#[cfg(feature = "client")]
use std::io::Read;
use std::time::SystemTime;
Expand All @@ -19,11 +20,16 @@ use tokio_tungstenite::tungstenite::protocol::WebSocketConfig;
#[cfg(feature = "client")]
use tokio_tungstenite::tungstenite::Error as WsError;
use tokio_tungstenite::tungstenite::Message;
use tokio_tungstenite::{connect_async_with_config, MaybeTlsStream, WebSocketStream};
use tokio_tungstenite::{
client_async_tls_with_config,
connect_async_with_config,
MaybeTlsStream,
WebSocketStream,
};
#[cfg(feature = "client")]
use tracing::warn;
use tracing::{debug, trace};
use url::Url;
use url::{Position, Url};

use super::{ActivityData, ChunkGuildFilter, PresenceData};
use crate::constants::{self, Opcode};
Expand Down Expand Up @@ -100,13 +106,49 @@ const TIMEOUT: Duration = Duration::from_millis(500);
const DECOMPRESSION_MULTIPLIER: usize = 3;

impl WsClient {
pub(crate) async fn connect(url: Url) -> Result<Self> {
async fn connect_with_proxy_async(
target_url: &Url,
proxy_url: &Url,
) -> std::result::Result<TcpStream, std::io::Error> {
let proxy_addr = &proxy_url[Position::BeforeHost..Position::AfterPort];
if proxy_url.scheme() != "http" && proxy_url.scheme() != "https" {
return Err(std::io::Error::new(ErrorKind::Unsupported, "unknown proxy scheme"));
}

let host = target_url
.host_str()
.ok_or_else(|| std::io::Error::new(ErrorKind::Unsupported, "unknown target host"))?;
let port = target_url
.port()
.or_else(|| match target_url.scheme() {
"wss" => Some(443),
"ws" => Some(80),
_ => None,
})
.ok_or_else(|| std::io::Error::new(ErrorKind::Unsupported, "unknown target scheme"))?;
let mut tcp_stream = TcpStream::connect(proxy_addr).await?;

async_http_proxy::http_connect_tokio(&mut tcp_stream, host, port)
.await
.map_err(|_| std::io::Error::new(ErrorKind::Unsupported, "unsupported proxy"))?;

Ok(tcp_stream)
}

pub(crate) async fn connect(url: Url, proxy: Option<&Url>) -> Result<Self> {
let config = WebSocketConfig {
max_message_size: None,
max_frame_size: None,
..Default::default()
};
let (stream, _) = connect_async_with_config(url, Some(config), false).await?;
let (stream, _) = match proxy {
None => connect_async_with_config(url, Some(config), false).await?,
Some(proxy) => {
let tls_stream = Self::connect_with_proxy_async(&url, proxy).await?;
tls_stream.set_nodelay(true)?;
client_async_tls_with_config(url, tls_stream, Some(config), None).await?
},
};

Ok(Self(stream))
}
Expand Down

0 comments on commit b6226fb

Please sign in to comment.