Skip to content

Commit 7f08d95

Browse files
Gelbpunktvilgotf
andauthored
fix(gateway): Add a connection timeout of 10s (#2448)
Fixes an issue where shards become stuck awaiting a connection attempt that never completes. Signed-off-by: Jens Reidel <adrian@travitia.xyz> Co-authored-by: Tim Vilgot Mikael Fredenberg <26655508+vilgotf@users.noreply.github.com>
1 parent 4aa06ca commit 7f08d95

File tree

1 file changed

+48
-11
lines changed

1 file changed

+48
-11
lines changed

twilight-gateway/src/shard.rs

Lines changed: 48 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@ use futures_sink::Sink;
2929
use serde::{de::DeserializeOwned, Deserialize};
3030
use std::{
3131
env::consts::OS,
32+
error::Error,
3233
fmt,
3334
future::Future,
3435
io,
@@ -39,7 +40,7 @@ use std::{
3940
use tokio::{
4041
net::TcpStream,
4142
sync::oneshot,
42-
time::{self, Duration, Instant, Interval, MissedTickBehavior},
43+
time::{self, error::Elapsed, timeout, Duration, Instant, Interval, MissedTickBehavior},
4344
};
4445
use tokio_websockets::{ClientBuilder, Error as WebsocketError, Limits, MaybeTlsStream};
4546
use twilight_model::gateway::{
@@ -66,11 +67,44 @@ const COMPRESSION_FEATURES: &str = if cfg!(feature = "zstd") {
6667
""
6768
};
6869

70+
/// Timeout for connecting to the gateway.
71+
const CONNECT_TIMEOUT: Duration = Duration::from_secs(10);
72+
6973
/// [`tokio_websockets`] library Websocket connection.
7074
type Connection = tokio_websockets::WebSocketStream<MaybeTlsStream<TcpStream>>;
7175

76+
/// Wrapper enum around [`WebsocketError`] with a timeout case.
77+
enum ConnectionError {
78+
/// Connection attempt timed out.
79+
Timeout(Elapsed),
80+
/// Error from the websocket library, [`tokio_websockets`].
81+
Websocket(WebsocketError),
82+
}
83+
84+
impl ConnectionError {
85+
/// Returns the boxed wrapped error.
86+
fn into_boxed_error(self) -> Box<dyn Error + Send + Sync> {
87+
match self {
88+
Self::Websocket(e) => Box::new(e),
89+
Self::Timeout(e) => Box::new(e),
90+
}
91+
}
92+
}
93+
94+
impl From<WebsocketError> for ConnectionError {
95+
fn from(value: WebsocketError) -> Self {
96+
Self::Websocket(value)
97+
}
98+
}
99+
100+
impl From<Elapsed> for ConnectionError {
101+
fn from(value: Elapsed) -> Self {
102+
Self::Timeout(value)
103+
}
104+
}
105+
72106
/// Wrapper struct around an `async fn` with a `Debug` implementation.
73-
struct ConnectionFuture(Pin<Box<dyn Future<Output = Result<Connection, WebsocketError>> + Send>>);
107+
struct ConnectionFuture(Pin<Box<dyn Future<Output = Result<Connection, ConnectionError>> + Send>>);
74108

75109
impl fmt::Debug for ConnectionFuture {
76110
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
@@ -858,14 +892,17 @@ impl<Q: Queue + Unpin> Stream for Shard<Q> {
858892
let secs = 2u8.saturating_pow(reconnect_attempts.into());
859893
time::sleep(Duration::from_secs(secs.into())).await;
860894

861-
Ok(ClientBuilder::new()
862-
.uri(&uri)
863-
.expect("URL should be valid")
864-
.limits(Limits::unlimited())
865-
.connector(&tls)
866-
.connect()
867-
.await?
868-
.0)
895+
Ok(timeout(
896+
CONNECT_TIMEOUT,
897+
ClientBuilder::new()
898+
.uri(&uri)
899+
.expect("URL should be valid")
900+
.limits(Limits::unlimited())
901+
.connector(&tls)
902+
.connect(),
903+
)
904+
.await??
905+
.0)
869906
})));
870907
}
871908

@@ -893,7 +930,7 @@ impl<Q: Queue + Unpin> Stream for Shard<Q> {
893930

894931
return Poll::Ready(Some(Err(ReceiveMessageError {
895932
kind: ReceiveMessageErrorType::Reconnect,
896-
source: Some(Box::new(source)),
933+
source: Some(source.into_boxed_error()),
897934
})));
898935
}
899936
}

0 commit comments

Comments
 (0)