Skip to content

Commit

Permalink
Add ws_proxy option for connecting gateway through a proxy.
Browse files Browse the repository at this point in the history
  • Loading branch information
zzzzRuby committed Jan 20, 2024
1 parent d931786 commit 15d9f25
Show file tree
Hide file tree
Showing 6 changed files with 32 additions and 55 deletions.
1 change: 1 addition & 0 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ arrayvec = { version = "0.7.4", features = ["serde"] }
small-fixed-array = { version = "0.2", features = ["serde"] }
bool_to_bitflags = { version = "0.1.0" }
nonmax = { version = "0.5.5", features = ["serde"] }
async-http-proxy = { version = "1.2.5", features = ["runtime-tokio"] }
# Optional dependencies
fxhash = { version = "0.2.1", optional = true }
simd-json = { version = "0.13.4", optional = true }
Expand Down
22 changes: 16 additions & 6 deletions src/client/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,10 @@ use futures::channel::mpsc::UnboundedReceiver as Receiver;
use futures::future::BoxFuture;
use futures::StreamExt as _;
use tracing::debug;
<<<<<<< HEAD
=======
use url::Url;
>>>>>>> 782a2539cf (Add ws_proxy option for connecting gateway through a proxy.)

pub use self::context::Context;
pub use self::error::Error as ClientError;
Expand Down Expand Up @@ -283,12 +287,6 @@ impl ClientBuilder {
self
}

/// Remove websocket proxy.
pub fn no_ws_proxy(mut self) -> Self {
self.ws_proxy = None;
self
}

/// Gets the websocket proxy. See [`Self::ws_proxy`] for more info.
pub fn get_ws_proxy(&self) -> Option<&str> {
self.ws_proxy.as_deref()
Expand Down Expand Up @@ -346,6 +344,18 @@ impl IntoFuture for ClientBuilder {
},
};

let ws_proxy = match ws_proxy {
Some(proxy) => {
let parsed_proxy = Url::parse(&proxy).map_err(|why| {
tracing::warn!("Error building proxy URL with base `{}`: {:?}", proxy, why);

Error::Gateway(GatewayError::BuildingUrl)
})?;
Some(Arc::new(parsed_proxy))
}
None => None
};

#[cfg(feature = "framework")]
let framework_cell = Arc::new(OnceLock::new());
let (shard_manager, shard_manager_ret_value) = ShardManager::new(ShardManagerOptions {
Expand Down
3 changes: 2 additions & 1 deletion src/gateway/bridge/shard_manager.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ use futures::{SinkExt, StreamExt};
use tokio::sync::Mutex;
use tokio::time::timeout;
use tracing::{info, warn};
use url::Url;

#[cfg(feature = "voice")]
use super::VoiceGatewayManager;
Expand Down Expand Up @@ -371,7 +372,7 @@ pub struct ShardManagerOptions {
#[cfg(feature = "voice")]
pub voice_manager: Option<Arc<dyn VoiceGatewayManager>>,
pub ws_url: Arc<str>,
pub ws_proxy: Option<String>,
pub ws_proxy: Option<Arc<Url>>,
pub shard_total: NonZeroU16,
#[cfg(feature = "cache")]
pub cache: Arc<Cache>,
Expand Down
5 changes: 3 additions & 2 deletions src/gateway/bridge/shard_queuer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ use futures::StreamExt;
use tokio::sync::Mutex;
use tokio::time::{sleep, timeout, Duration, Instant};
use tracing::{debug, info, warn};
use url::Url;

#[cfg(feature = "voice")]
use super::VoiceGatewayManager;
Expand Down Expand Up @@ -71,7 +72,7 @@ pub struct ShardQueuer {
pub voice_manager: Option<Arc<dyn VoiceGatewayManager + 'static>>,
/// A copy of the URL to use to connect to the gateway.
pub ws_url: Arc<str>,
pub ws_proxy: Option<String>,
pub ws_proxy: Option<Arc<Url>>,
/// The total amount of shards to start.
pub shard_total: NonZeroU16,
#[cfg(feature = "cache")]
Expand Down Expand Up @@ -215,7 +216,7 @@ impl ShardQueuer {
async fn start(&mut self, shard_id: ShardId) -> Result<()> {
let mut shard = Shard::new(
Arc::clone(&self.ws_url),
self.ws_proxy.as_deref(),
self.ws_proxy.clone(),
self.http.token(),
ShardInfo::new(shard_id, self.shard_total),
self.intents,
Expand Down
21 changes: 4 additions & 17 deletions src/gateway/shard.rs
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,7 @@ pub struct Shard {
pub started: Instant,
pub token: String,
ws_url: Arc<str>,
ws_proxy: Option<String>,
ws_proxy: Option<Arc<Url>>,
pub intents: GatewayIntents,
}

Expand Down Expand Up @@ -121,13 +121,13 @@ impl Shard {
/// TLS error.
pub async fn new(
ws_url: Arc<str>,
ws_proxy: Option<&str>,
ws_proxy: Option<Arc<Url>>,
token: &str,
shard_info: ShardInfo,
intents: GatewayIntents,
presence: Option<PresenceData>,
) -> Result<Shard> {
let client = connect(&ws_url, ws_proxy).await?;
let client = connect(&ws_url, ws_proxy.as_deref()).await?;

let presence = presence.unwrap_or_default();
let last_heartbeat_sent = None;
Expand Down Expand Up @@ -745,26 +745,13 @@ impl Shard {
}
}

async fn connect(base_url: &str, proxy: Option<&str>) -> Result<WsClient> {
async fn connect(base_url: &str, proxy: Option<&Url>) -> Result<WsClient> {
let url =
Url::parse(&format!("{base_url}?v={}", constants::GATEWAY_VERSION)).map_err(|why| {
warn!("Error building gateway URL with base `{}`: {:?}", base_url, why);

Error::Gateway(GatewayError::BuildingUrl)
})?;

let proxy_url = proxy.map(|proxy| {
Url::parse(proxy).map_err(|why| {
warn!("Error building proxy URL with base `{}`: {:?}", base_url, why);

Error::Gateway(GatewayError::BuildingUrl)
})
});

let proxy = match proxy_url {
Some(result) => Some(result?),
None => None,
};

WsClient::connect(url, proxy).await
}
35 changes: 6 additions & 29 deletions src/gateway/ws.rs
Original file line number Diff line number Diff line change
Expand Up @@ -130,37 +130,14 @@ impl WsClient {
.ok_or_else(|| std::io::Error::new(ErrorKind::Unsupported, "unknown target scheme"))?;
let mut tcp_stream = TcpStream::connect(proxy_addr).await?;

let buf =
format!("CONNECT {host}:{port} HTTP/1.1\r\nHost: {host}:{port}\r\n\r\n").into_bytes();

tcp_stream.write_all(&buf).await?;

let mut all_buf = Vec::new();

loop {
let mut buf = [0; 1024];
let n = tcp_stream.read(&mut buf).await?;
if n == 0 {
return Err(std::io::Error::new(ErrorKind::UnexpectedEof, "no bytes in tunnel"));
}
all_buf.extend_from_slice(&buf[..n]);

if !all_buf.starts_with(b"HTTP/1.1 200") && !all_buf.starts_with(b"HTTP/1.0 200") {
return Err(std::io::Error::new(ErrorKind::Other, "tunnel error"));
}
if all_buf.ends_with(b"\r\n\r\n") {
return Ok(tcp_stream);
}
if all_buf.len() > 4096 {
return Err(std::io::Error::new(
ErrorKind::UnexpectedEof,
"too many bytes in tunnel",
));
}
}
async_http_proxy::http_connect_tokio(&mut tcp_stream, host, port)
.await
.map_err(|_| std::io::Error::new(ErrorKind::Unsupported, "unsupported proxy"))?;

Ok(tcp_stream)
}

pub(crate) async fn connect(url: Url, proxy: Option<Url>) -> Result<Self> {
pub(crate) async fn connect(url: Url, proxy: Option<&Url>) -> Result<Self> {
let config = WebSocketConfig {
max_message_size: None,
max_frame_size: None,
Expand Down

0 comments on commit 15d9f25

Please sign in to comment.