Skip to content

Commit

Permalink
Additional Disconnect Handling
Browse files Browse the repository at this point in the history
- Allow Server Handle to trigger Disconnect
- Allow Client Handler to read Disconnect reason

Rename Handler disconnect event

Include all types of disconnect in disconnected callback

Remove no longer valid comment
  • Loading branch information
amtelekom authored and Eugeny committed Feb 27, 2024
1 parent 1c43473 commit 1d7dab8
Show file tree
Hide file tree
Showing 3 changed files with 151 additions and 15 deletions.
111 changes: 97 additions & 14 deletions russh/src/client/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@

use std::cell::RefCell;
use std::collections::{HashMap, VecDeque};
use std::convert::TryInto;
use std::num::Wrapping;
use std::pin::Pin;
use std::sync::Arc;
Expand All @@ -49,7 +50,7 @@ use russh_keys::encoding::Reader;
#[cfg(feature = "openssl")]
use russh_keys::key::SignatureHash;
use russh_keys::key::{self, parse_public_key, PublicKey};
use tokio::io::{AsyncRead, AsyncWrite, AsyncWriteExt};
use tokio::io::{AsyncRead, AsyncWrite, AsyncWriteExt, ReadHalf, WriteHalf};
use tokio::net::{TcpStream, ToSocketAddrs};
use tokio::pin;
use tokio::sync::mpsc::{
Expand Down Expand Up @@ -195,6 +196,19 @@ pub struct Prompt {
pub echo: bool,
}

#[derive(Debug)]
pub struct RemoteDisconnectInfo {
pub reason_code: crate::Disconnect,
pub message: String,
pub lang_tag: String,
}

#[derive(Debug)]
pub enum DisconnectReason<E: From<crate::Error> + Send> {
ReceivedDisconnect(RemoteDisconnectInfo),
Error(E),
}

/// Handle to a session, used to send messages to a client outside of
/// the request/response cycle.
pub struct Handle<H: Handler> {
Expand Down Expand Up @@ -737,23 +751,59 @@ impl Session {

async fn run<H: Handler + Send, R: AsyncRead + AsyncWrite + Unpin + Send>(
mut self,
mut stream: SshRead<R>,
stream: SshRead<R>,
mut handler: H,
mut encrypted_signal: Option<tokio::sync::oneshot::Sender<()>>,
encrypted_signal: Option<tokio::sync::oneshot::Sender<()>>,
) -> Result<(), H::Error> {
let (stream_read, mut stream_write) = stream.split();
let result = self
.run_inner(
stream_read,
&mut stream_write,
&mut handler,
encrypted_signal,
)
.await;
trace!("disconnected");
self.receiver.close();
self.inbound_channel_receiver.close();
stream_write.shutdown().await.map_err(crate::Error::from)?;
match result {
Ok(v) => {
handler
.disconnected(DisconnectReason::ReceivedDisconnect(v))
.await?;
Ok(())
}
Err(e) => {
handler.disconnected(DisconnectReason::Error(e)).await?;
//Err(e)
Ok(())
}
}
}

async fn run_inner<H: Handler + Send, R: AsyncRead + AsyncWrite + Unpin + Send>(
&mut self,
stream_read: SshRead<ReadHalf<R>>,
stream_write: &mut WriteHalf<R>,
handler: &mut H,
mut encrypted_signal: Option<tokio::sync::oneshot::Sender<()>>,
) -> Result<RemoteDisconnectInfo, H::Error> {
let mut result: Result<RemoteDisconnectInfo, H::Error> =
Err(crate::Error::Disconnect.into());
self.flush()?;
if !self.common.write_buffer.buffer.is_empty() {
debug!("writing {:?} bytes", self.common.write_buffer.buffer.len());
stream
stream_write
.write_all(&self.common.write_buffer.buffer)
.await
.map_err(crate::Error::from)?;
stream.flush().await.map_err(crate::Error::from)?;
stream_write.flush().await.map_err(crate::Error::from)?;
}
self.common.write_buffer.buffer.clear();
let mut decomp = CryptoVec::new();

let (stream_read, mut stream_write) = stream.split();
let buffer = SSHBuffer::new();

// Allow handing out references to the cipher
Expand Down Expand Up @@ -805,10 +855,10 @@ impl Session {
if !buf.is_empty() {
#[allow(clippy::indexing_slicing)] // length checked
if buf[0] == crate::msg::DISCONNECT {
break;
result = self.process_disconnect(buf);
} else {
self.common.received_data = true;
reply( &mut self,&mut handler, &mut encrypted_signal, &mut buffer.seqn, buf).await?;
reply( self,handler, &mut encrypted_signal, &mut buffer.seqn, buf).await?;
}
}

Expand Down Expand Up @@ -906,12 +956,30 @@ impl Session {
}
}
}
debug!("disconnected");
self.receiver.close();
self.inbound_channel_receiver.close();
stream_write.shutdown().await.map_err(crate::Error::from)?;

Ok(())
result
}

fn process_disconnect<E: From<crate::Error> + Send>(
&mut self,
buf: &[u8],
) -> Result<RemoteDisconnectInfo, E> {
self.common.disconnected = true;
let mut reader = buf.reader(1);

let reason_code = reader.read_u32().map_err(crate::Error::from)?.try_into()?;
let message = std::str::from_utf8(reader.read_string().map_err(crate::Error::from)?)
.map_err(crate::Error::from)?
.to_owned();
let lang_tag = std::str::from_utf8(reader.read_string().map_err(crate::Error::from)?)
.map_err(crate::Error::from)?
.to_owned();

Ok(RemoteDisconnectInfo {
reason_code,
message,
lang_tag,
})
}

fn handle_msg(&mut self, msg: Msg) -> Result<(), crate::Error> {
Expand Down Expand Up @@ -1360,7 +1428,7 @@ impl Default for Config {

#[async_trait]
pub trait Handler: Sized + Send {
type Error: From<crate::Error> + Send;
type Error: From<crate::Error> + Send + core::fmt::Debug;

/// Called when the server sends us an authentication banner. This
/// is usually meant to be shown to the user, see
Expand Down Expand Up @@ -1620,4 +1688,19 @@ pub trait Handler: Sized + Send {
debug!("openssh_ext_hostkeys_announced: {:?}", keys);
Ok(())
}

/// Called when the server sent a disconnect message
///
/// If reason is an Error, this function should re-return the error so the join can also evaluate it
#[allow(unused_variables)]
async fn disconnected(
&mut self,
reason: DisconnectReason<Self::Error>,
) -> Result<(), Self::Error> {
debug!("disconnected: {:?}", reason);
match reason {
DisconnectReason::ReceivedDisconnect(_) => Ok(()),
DisconnectReason::Error(e) => Err(e),
}
}
}
30 changes: 29 additions & 1 deletion russh/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -94,7 +94,10 @@
//! messages sent through a `server::Handle` are processed when there
//! is no incoming packet to read.

use std::fmt::{Debug, Display, Formatter};
use std::{
convert::TryFrom,
fmt::{Debug, Display, Formatter},
};

use log::debug;
use parsing::ChannelOpenConfirmation;
Expand Down Expand Up @@ -379,6 +382,31 @@ pub enum Disconnect {
IllegalUserName = 15,
}

impl TryFrom<u32> for Disconnect {
type Error = crate::Error;

fn try_from(value: u32) -> Result<Self, Self::Error> {
Ok(match value {
1 => Self::HostNotAllowedToConnect,
2 => Self::ProtocolError,
3 => Self::KeyExchangeFailed,
4 => Self::Reserved,
5 => Self::MACError,
6 => Self::CompressionError,
7 => Self::ServiceNotAvailable,
8 => Self::ProtocolVersionNotSupported,
9 => Self::HostKeyNotVerifiable,
10 => Self::ConnectionLost,
11 => Self::ByApplication,
12 => Self::TooManyConnections,
13 => Self::AuthCancelledByUser,
14 => Self::NoMoreAuthMethodsAvailable,
15 => Self::IllegalUserName,
_ => return Err(crate::Error::Inconsistent),
})
}
}

/// The type of signals that can be sent to a remote process. If you
/// plan to use custom signals, read [the
/// RFC](https://tools.ietf.org/html/rfc4254#section-6.10) to
Expand Down
25 changes: 25 additions & 0 deletions russh/src/server/session.rs
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,11 @@ pub enum Msg {
address: String,
port: u32,
},
Disconnect {
reason: crate::Disconnect,
description: String,
language_tag: String,
},
Channel(ChannelId, ChannelMsg),
}

Expand Down Expand Up @@ -348,6 +353,23 @@ impl Handle {
.await
.map_err(|_| ())
}

/// Allows a server to disconnect a client session
pub async fn disconnect(
&self,
reason: Disconnect,
description: String,
language_tag: String,
) -> Result<(), Error> {
self.sender
.send(Msg::Disconnect {
reason,
description,
language_tag,
})
.await
.map_err(|_| Error::SendError)
}
}

impl Session {
Expand Down Expand Up @@ -511,6 +533,9 @@ impl Session {
Some(Msg::CancelTcpIpForward { address, port, reply_channel }) => {
self.cancel_tcpip_forward(&address, port, reply_channel);
}
Some(Msg::Disconnect {reason, description, language_tag}) => {
self.common.disconnect(reason, &description, &language_tag);
}
Some(_) => {
// should be unreachable, since the receiver only gets
// messages from methods implemented within russh
Expand Down

0 comments on commit 1d7dab8

Please sign in to comment.