Skip to content

Commit c99f49c

Browse files
committed
fixed Error::Disconnect getting returned from connect instead of the more specific error type when connection fails during kex phase
1 parent 8b88465 commit c99f49c

File tree

1 file changed

+24
-13
lines changed

1 file changed

+24
-13
lines changed

russh/src/client/mod.rs

Lines changed: 24 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -778,10 +778,12 @@ where
778778
session_sender,
779779
);
780780
session.read_ssh_id(sshid)?;
781-
let (encrypted_signal, encrypted_recv) = tokio::sync::oneshot::channel();
782-
let join = tokio::spawn(session.run(stream, handler, Some(encrypted_signal)));
781+
let (kex_done_signal, kex_done_signal_rx) = oneshot::channel();
782+
let join = tokio::spawn(session.run(stream, handler, Some(kex_done_signal)));
783783

784-
if encrypted_recv.await.is_err() {
784+
if kex_done_signal_rx.await.is_err() {
785+
// kex_done_signal Sender is dropped when the session
786+
// fails before a succesful key exchange
785787
join.await.map_err(crate::Error::Join)??;
786788
return Err(H::Error::from(crate::Error::Disconnect));
787789
}
@@ -829,15 +831,15 @@ impl Session {
829831
mut self,
830832
stream: SshRead<R>,
831833
mut handler: H,
832-
encrypted_signal: Option<tokio::sync::oneshot::Sender<()>>,
834+
mut kex_done_signal: Option<oneshot::Sender<()>>,
833835
) -> Result<(), H::Error> {
834836
let (stream_read, mut stream_write) = stream.split();
835837
let result = self
836838
.run_inner(
837839
stream_read,
838840
&mut stream_write,
839841
&mut handler,
840-
encrypted_signal,
842+
&mut kex_done_signal,
841843
)
842844
.await;
843845
trace!("disconnected");
@@ -852,9 +854,18 @@ impl Session {
852854
Ok(())
853855
}
854856
Err(e) => {
855-
handler.disconnected(DisconnectReason::Error(e)).await?;
856-
//Err(e)
857-
Ok(())
857+
if kex_done_signal.is_some() {
858+
// The kex signal has not been consumed yet,
859+
// so we can send return the concrete error to be propagated
860+
// into the JoinHandle and returned from `connect_stream`
861+
Err(e)
862+
} else {
863+
// The kex signal has been consumed, so no one is
864+
// awaiting the result of this coroutine
865+
// We're better off passing the error into the Handler
866+
handler.disconnected(DisconnectReason::Error(e)).await?;
867+
Err(H::Error::from(crate::Error::Disconnect))
868+
}
858869
}
859870
}
860871
}
@@ -864,7 +875,7 @@ impl Session {
864875
stream_read: SshRead<ReadHalf<R>>,
865876
stream_write: &mut WriteHalf<R>,
866877
handler: &mut H,
867-
mut encrypted_signal: Option<tokio::sync::oneshot::Sender<()>>,
878+
kex_done_signal: &mut Option<tokio::sync::oneshot::Sender<()>>,
868879
) -> Result<RemoteDisconnectInfo, H::Error> {
869880
let mut result: Result<RemoteDisconnectInfo, H::Error> =
870881
Err(crate::Error::Disconnect.into());
@@ -934,7 +945,7 @@ impl Session {
934945
result = self.process_disconnect(buf);
935946
} else {
936947
self.common.received_data = true;
937-
reply( self,handler, &mut encrypted_signal, &mut buffer.seqn, buf).await?;
948+
reply(self, handler, kex_done_signal, &mut buffer.seqn, buf).await?;
938949
}
939950
}
940951

@@ -1349,7 +1360,7 @@ impl KexDhDone {
13491360
async fn reply<H: Handler>(
13501361
session: &mut Session,
13511362
handler: &mut H,
1352-
sender: &mut Option<tokio::sync::oneshot::Sender<()>>,
1363+
kex_done_signal: &mut Option<tokio::sync::oneshot::Sender<()>>,
13531364
seqn: &mut Wrapping<u32>,
13541365
buf: &[u8],
13551366
) -> Result<(), H::Error> {
@@ -1392,7 +1403,7 @@ async fn reply<H: Handler>(
13921403
done.compute_keys(CryptoVec::new(), false)?,
13931404
);
13941405

1395-
if let Some(sender) = sender.take() {
1406+
if let Some(sender) = kex_done_signal.take() {
13961407
sender.send(()).unwrap_or(());
13971408
}
13981409
} else {
@@ -1430,7 +1441,7 @@ async fn reply<H: Handler>(
14301441
if buf.first() != Some(&msg::NEWKEYS) {
14311442
return Err(crate::Error::Kex.into());
14321443
}
1433-
if let Some(sender) = sender.take() {
1444+
if let Some(sender) = kex_done_signal.take() {
14341445
sender.send(()).unwrap_or(());
14351446
}
14361447
session

0 commit comments

Comments
 (0)