Skip to content

Commit 1d7dab8

Browse files
amtelekomEugeny
authored andcommitted
Additional Disconnect Handling
- Allow Server Handle to trigger Disconnect - Allow Client Handler to read Disconnect reason Rename Handler disconnect event Include all types of disconnect in disconnected callback Remove no longer valid comment
1 parent 1c43473 commit 1d7dab8

File tree

3 files changed

+151
-15
lines changed

3 files changed

+151
-15
lines changed

russh/src/client/mod.rs

Lines changed: 97 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,7 @@
3636
3737
use std::cell::RefCell;
3838
use std::collections::{HashMap, VecDeque};
39+
use std::convert::TryInto;
3940
use std::num::Wrapping;
4041
use std::pin::Pin;
4142
use std::sync::Arc;
@@ -49,7 +50,7 @@ use russh_keys::encoding::Reader;
4950
#[cfg(feature = "openssl")]
5051
use russh_keys::key::SignatureHash;
5152
use russh_keys::key::{self, parse_public_key, PublicKey};
52-
use tokio::io::{AsyncRead, AsyncWrite, AsyncWriteExt};
53+
use tokio::io::{AsyncRead, AsyncWrite, AsyncWriteExt, ReadHalf, WriteHalf};
5354
use tokio::net::{TcpStream, ToSocketAddrs};
5455
use tokio::pin;
5556
use tokio::sync::mpsc::{
@@ -195,6 +196,19 @@ pub struct Prompt {
195196
pub echo: bool,
196197
}
197198

199+
#[derive(Debug)]
200+
pub struct RemoteDisconnectInfo {
201+
pub reason_code: crate::Disconnect,
202+
pub message: String,
203+
pub lang_tag: String,
204+
}
205+
206+
#[derive(Debug)]
207+
pub enum DisconnectReason<E: From<crate::Error> + Send> {
208+
ReceivedDisconnect(RemoteDisconnectInfo),
209+
Error(E),
210+
}
211+
198212
/// Handle to a session, used to send messages to a client outside of
199213
/// the request/response cycle.
200214
pub struct Handle<H: Handler> {
@@ -737,23 +751,59 @@ impl Session {
737751

738752
async fn run<H: Handler + Send, R: AsyncRead + AsyncWrite + Unpin + Send>(
739753
mut self,
740-
mut stream: SshRead<R>,
754+
stream: SshRead<R>,
741755
mut handler: H,
742-
mut encrypted_signal: Option<tokio::sync::oneshot::Sender<()>>,
756+
encrypted_signal: Option<tokio::sync::oneshot::Sender<()>>,
743757
) -> Result<(), H::Error> {
758+
let (stream_read, mut stream_write) = stream.split();
759+
let result = self
760+
.run_inner(
761+
stream_read,
762+
&mut stream_write,
763+
&mut handler,
764+
encrypted_signal,
765+
)
766+
.await;
767+
trace!("disconnected");
768+
self.receiver.close();
769+
self.inbound_channel_receiver.close();
770+
stream_write.shutdown().await.map_err(crate::Error::from)?;
771+
match result {
772+
Ok(v) => {
773+
handler
774+
.disconnected(DisconnectReason::ReceivedDisconnect(v))
775+
.await?;
776+
Ok(())
777+
}
778+
Err(e) => {
779+
handler.disconnected(DisconnectReason::Error(e)).await?;
780+
//Err(e)
781+
Ok(())
782+
}
783+
}
784+
}
785+
786+
async fn run_inner<H: Handler + Send, R: AsyncRead + AsyncWrite + Unpin + Send>(
787+
&mut self,
788+
stream_read: SshRead<ReadHalf<R>>,
789+
stream_write: &mut WriteHalf<R>,
790+
handler: &mut H,
791+
mut encrypted_signal: Option<tokio::sync::oneshot::Sender<()>>,
792+
) -> Result<RemoteDisconnectInfo, H::Error> {
793+
let mut result: Result<RemoteDisconnectInfo, H::Error> =
794+
Err(crate::Error::Disconnect.into());
744795
self.flush()?;
745796
if !self.common.write_buffer.buffer.is_empty() {
746797
debug!("writing {:?} bytes", self.common.write_buffer.buffer.len());
747-
stream
798+
stream_write
748799
.write_all(&self.common.write_buffer.buffer)
749800
.await
750801
.map_err(crate::Error::from)?;
751-
stream.flush().await.map_err(crate::Error::from)?;
802+
stream_write.flush().await.map_err(crate::Error::from)?;
752803
}
753804
self.common.write_buffer.buffer.clear();
754805
let mut decomp = CryptoVec::new();
755806

756-
let (stream_read, mut stream_write) = stream.split();
757807
let buffer = SSHBuffer::new();
758808

759809
// Allow handing out references to the cipher
@@ -805,10 +855,10 @@ impl Session {
805855
if !buf.is_empty() {
806856
#[allow(clippy::indexing_slicing)] // length checked
807857
if buf[0] == crate::msg::DISCONNECT {
808-
break;
858+
result = self.process_disconnect(buf);
809859
} else {
810860
self.common.received_data = true;
811-
reply( &mut self,&mut handler, &mut encrypted_signal, &mut buffer.seqn, buf).await?;
861+
reply( self,handler, &mut encrypted_signal, &mut buffer.seqn, buf).await?;
812862
}
813863
}
814864

@@ -906,12 +956,30 @@ impl Session {
906956
}
907957
}
908958
}
909-
debug!("disconnected");
910-
self.receiver.close();
911-
self.inbound_channel_receiver.close();
912-
stream_write.shutdown().await.map_err(crate::Error::from)?;
913959

914-
Ok(())
960+
result
961+
}
962+
963+
fn process_disconnect<E: From<crate::Error> + Send>(
964+
&mut self,
965+
buf: &[u8],
966+
) -> Result<RemoteDisconnectInfo, E> {
967+
self.common.disconnected = true;
968+
let mut reader = buf.reader(1);
969+
970+
let reason_code = reader.read_u32().map_err(crate::Error::from)?.try_into()?;
971+
let message = std::str::from_utf8(reader.read_string().map_err(crate::Error::from)?)
972+
.map_err(crate::Error::from)?
973+
.to_owned();
974+
let lang_tag = std::str::from_utf8(reader.read_string().map_err(crate::Error::from)?)
975+
.map_err(crate::Error::from)?
976+
.to_owned();
977+
978+
Ok(RemoteDisconnectInfo {
979+
reason_code,
980+
message,
981+
lang_tag,
982+
})
915983
}
916984

917985
fn handle_msg(&mut self, msg: Msg) -> Result<(), crate::Error> {
@@ -1360,7 +1428,7 @@ impl Default for Config {
13601428
13611429
#[async_trait]
13621430
pub trait Handler: Sized + Send {
1363-
type Error: From<crate::Error> + Send;
1431+
type Error: From<crate::Error> + Send + core::fmt::Debug;
13641432

13651433
/// Called when the server sends us an authentication banner. This
13661434
/// is usually meant to be shown to the user, see
@@ -1620,4 +1688,19 @@ pub trait Handler: Sized + Send {
16201688
debug!("openssh_ext_hostkeys_announced: {:?}", keys);
16211689
Ok(())
16221690
}
1691+
1692+
/// Called when the server sent a disconnect message
1693+
///
1694+
/// If reason is an Error, this function should re-return the error so the join can also evaluate it
1695+
#[allow(unused_variables)]
1696+
async fn disconnected(
1697+
&mut self,
1698+
reason: DisconnectReason<Self::Error>,
1699+
) -> Result<(), Self::Error> {
1700+
debug!("disconnected: {:?}", reason);
1701+
match reason {
1702+
DisconnectReason::ReceivedDisconnect(_) => Ok(()),
1703+
DisconnectReason::Error(e) => Err(e),
1704+
}
1705+
}
16231706
}

russh/src/lib.rs

Lines changed: 29 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -94,7 +94,10 @@
9494
//! messages sent through a `server::Handle` are processed when there
9595
//! is no incoming packet to read.
9696
97-
use std::fmt::{Debug, Display, Formatter};
97+
use std::{
98+
convert::TryFrom,
99+
fmt::{Debug, Display, Formatter},
100+
};
98101

99102
use log::debug;
100103
use parsing::ChannelOpenConfirmation;
@@ -379,6 +382,31 @@ pub enum Disconnect {
379382
IllegalUserName = 15,
380383
}
381384

385+
impl TryFrom<u32> for Disconnect {
386+
type Error = crate::Error;
387+
388+
fn try_from(value: u32) -> Result<Self, Self::Error> {
389+
Ok(match value {
390+
1 => Self::HostNotAllowedToConnect,
391+
2 => Self::ProtocolError,
392+
3 => Self::KeyExchangeFailed,
393+
4 => Self::Reserved,
394+
5 => Self::MACError,
395+
6 => Self::CompressionError,
396+
7 => Self::ServiceNotAvailable,
397+
8 => Self::ProtocolVersionNotSupported,
398+
9 => Self::HostKeyNotVerifiable,
399+
10 => Self::ConnectionLost,
400+
11 => Self::ByApplication,
401+
12 => Self::TooManyConnections,
402+
13 => Self::AuthCancelledByUser,
403+
14 => Self::NoMoreAuthMethodsAvailable,
404+
15 => Self::IllegalUserName,
405+
_ => return Err(crate::Error::Inconsistent),
406+
})
407+
}
408+
}
409+
382410
/// The type of signals that can be sent to a remote process. If you
383411
/// plan to use custom signals, read [the
384412
/// RFC](https://tools.ietf.org/html/rfc4254#section-6.10) to

russh/src/server/session.rs

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -59,6 +59,11 @@ pub enum Msg {
5959
address: String,
6060
port: u32,
6161
},
62+
Disconnect {
63+
reason: crate::Disconnect,
64+
description: String,
65+
language_tag: String,
66+
},
6267
Channel(ChannelId, ChannelMsg),
6368
}
6469

@@ -348,6 +353,23 @@ impl Handle {
348353
.await
349354
.map_err(|_| ())
350355
}
356+
357+
/// Allows a server to disconnect a client session
358+
pub async fn disconnect(
359+
&self,
360+
reason: Disconnect,
361+
description: String,
362+
language_tag: String,
363+
) -> Result<(), Error> {
364+
self.sender
365+
.send(Msg::Disconnect {
366+
reason,
367+
description,
368+
language_tag,
369+
})
370+
.await
371+
.map_err(|_| Error::SendError)
372+
}
351373
}
352374

353375
impl Session {
@@ -511,6 +533,9 @@ impl Session {
511533
Some(Msg::CancelTcpIpForward { address, port, reply_channel }) => {
512534
self.cancel_tcpip_forward(&address, port, reply_channel);
513535
}
536+
Some(Msg::Disconnect {reason, description, language_tag}) => {
537+
self.common.disconnect(reason, &description, &language_tag);
538+
}
514539
Some(_) => {
515540
// should be unreachable, since the receiver only gets
516541
// messages from methods implemented within russh

0 commit comments

Comments
 (0)