@@ -29,6 +29,7 @@ use futures_sink::Sink;
29
29
use serde:: { de:: DeserializeOwned , Deserialize } ;
30
30
use std:: {
31
31
env:: consts:: OS ,
32
+ error:: Error ,
32
33
fmt,
33
34
future:: Future ,
34
35
io,
@@ -39,7 +40,7 @@ use std::{
39
40
use tokio:: {
40
41
net:: TcpStream ,
41
42
sync:: oneshot,
42
- time:: { self , Duration , Instant , Interval , MissedTickBehavior } ,
43
+ time:: { self , error :: Elapsed , timeout , Duration , Instant , Interval , MissedTickBehavior } ,
43
44
} ;
44
45
use tokio_websockets:: { ClientBuilder , Error as WebsocketError , Limits , MaybeTlsStream } ;
45
46
use twilight_model:: gateway:: {
@@ -66,11 +67,44 @@ const COMPRESSION_FEATURES: &str = if cfg!(feature = "zstd") {
66
67
""
67
68
} ;
68
69
70
+ /// Timeout for connecting to the gateway.
71
+ const CONNECT_TIMEOUT : Duration = Duration :: from_secs ( 10 ) ;
72
+
69
73
/// [`tokio_websockets`] library Websocket connection.
70
74
type Connection = tokio_websockets:: WebSocketStream < MaybeTlsStream < TcpStream > > ;
71
75
76
+ /// Wrapper enum around [`WebsocketError`] with a timeout case.
77
+ enum ConnectionError {
78
+ /// Connection attempt timed out.
79
+ Timeout ( Elapsed ) ,
80
+ /// Error from the websocket library, [`tokio_websockets`].
81
+ Websocket ( WebsocketError ) ,
82
+ }
83
+
84
+ impl ConnectionError {
85
+ /// Returns the boxed wrapped error.
86
+ fn into_boxed_error ( self ) -> Box < dyn Error + Send + Sync > {
87
+ match self {
88
+ Self :: Websocket ( e) => Box :: new ( e) ,
89
+ Self :: Timeout ( e) => Box :: new ( e) ,
90
+ }
91
+ }
92
+ }
93
+
94
+ impl From < WebsocketError > for ConnectionError {
95
+ fn from ( value : WebsocketError ) -> Self {
96
+ Self :: Websocket ( value)
97
+ }
98
+ }
99
+
100
+ impl From < Elapsed > for ConnectionError {
101
+ fn from ( value : Elapsed ) -> Self {
102
+ Self :: Timeout ( value)
103
+ }
104
+ }
105
+
72
106
/// Wrapper struct around an `async fn` with a `Debug` implementation.
73
- struct ConnectionFuture ( Pin < Box < dyn Future < Output = Result < Connection , WebsocketError > > + Send > > ) ;
107
+ struct ConnectionFuture ( Pin < Box < dyn Future < Output = Result < Connection , ConnectionError > > + Send > > ) ;
74
108
75
109
impl fmt:: Debug for ConnectionFuture {
76
110
fn fmt ( & self , f : & mut fmt:: Formatter < ' _ > ) -> fmt:: Result {
@@ -858,14 +892,17 @@ impl<Q: Queue + Unpin> Stream for Shard<Q> {
858
892
let secs = 2u8 . saturating_pow ( reconnect_attempts. into ( ) ) ;
859
893
time:: sleep ( Duration :: from_secs ( secs. into ( ) ) ) . await ;
860
894
861
- Ok ( ClientBuilder :: new ( )
862
- . uri ( & uri)
863
- . expect ( "URL should be valid" )
864
- . limits ( Limits :: unlimited ( ) )
865
- . connector ( & tls)
866
- . connect ( )
867
- . await ?
868
- . 0 )
895
+ Ok ( timeout (
896
+ CONNECT_TIMEOUT ,
897
+ ClientBuilder :: new ( )
898
+ . uri ( & uri)
899
+ . expect ( "URL should be valid" )
900
+ . limits ( Limits :: unlimited ( ) )
901
+ . connector ( & tls)
902
+ . connect ( ) ,
903
+ )
904
+ . await ??
905
+ . 0 )
869
906
} ) ) ) ;
870
907
}
871
908
@@ -893,7 +930,7 @@ impl<Q: Queue + Unpin> Stream for Shard<Q> {
893
930
894
931
return Poll :: Ready ( Some ( Err ( ReceiveMessageError {
895
932
kind : ReceiveMessageErrorType :: Reconnect ,
896
- source : Some ( Box :: new ( source) ) ,
933
+ source : Some ( source. into_boxed_error ( ) ) ,
897
934
} ) ) ) ;
898
935
}
899
936
}
0 commit comments