Skip to content

Commit

Permalink
fixed #100 - allow overriding Handler methods without losing Channel …
Browse files Browse the repository at this point in the history
…functionality
  • Loading branch information
Eugeny committed Jun 4, 2023
1 parent 30c401e commit 359fa3c
Show file tree
Hide file tree
Showing 5 changed files with 296 additions and 200 deletions.
99 changes: 99 additions & 0 deletions russh/examples/test.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,99 @@
use async_trait::async_trait;
use env_logger;
use log::debug;
use russh::server::{Auth, Msg, Session};
use russh::*;
use russh_keys::*;
use std::collections::HashMap;
use std::sync::{Arc, Mutex};

#[tokio::main]
async fn main() -> anyhow::Result<()> {
env_logger::init();
let mut config = russh::server::Config::default();
config.auth_rejection_time = std::time::Duration::from_secs(3);
config
.keys
.push(russh_keys::key::KeyPair::generate_ed25519().unwrap());
let config = Arc::new(config);
let sh = Server {
clients: Arc::new(Mutex::new(HashMap::new())),
id: 0,
};
tokio::time::timeout(
std::time::Duration::from_secs(60),
russh::server::run(config, ("0.0.0.0", 2222), sh),
)
.await
.unwrap_or(Ok(()))?;

Ok(())
}

#[derive(Clone)]
struct Server {
clients: Arc<Mutex<HashMap<(usize, ChannelId), Channel<Msg>>>>,
id: usize,
}

impl server::Server for Server {
type Handler = Self;
fn new_client(&mut self, _: Option<std::net::SocketAddr>) -> Self {
debug!("new client");
let s = self.clone();
self.id += 1;
s
}
}

#[async_trait]
impl server::Handler for Server {
type Error = anyhow::Error;

async fn channel_open_session(
self,
channel: Channel<Msg>,
session: Session,
) -> Result<(Self, bool, Session), Self::Error> {
{
debug!("channel open session");
let mut clients = self.clients.lock().unwrap();
clients.insert((self.id, channel.id()), channel);
}
Ok((self, true, session))
}

/// The client requests a shell.
#[allow(unused_variables)]
async fn shell_request(
self,
channel: ChannelId,
mut session: Session,
) -> Result<(Self, Session), Self::Error> {
session.request_success();
Ok((self, session))
}

async fn auth_publickey(
self,
_: &str,
_: &key::PublicKey,
) -> Result<(Self, Auth), Self::Error> {
Ok((self, server::Auth::Accept))
}
async fn data(
self,
_channel: ChannelId,
data: &[u8],
mut session: Session,
) -> Result<(Self, Session), Self::Error> {
debug!("data: {data:?}");
{
let mut clients = self.clients.lock().unwrap();
for ((_, _channel_id), ref mut channel) in clients.iter_mut() {
session.data(channel.id(), CryptoVec::from(data.to_vec()));
}
}
Ok((self, session))
}
}
97 changes: 80 additions & 17 deletions russh/src/client/encrypted.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,18 +15,20 @@
use std::cell::RefCell;
use std::convert::TryInto;

use log::{debug, error, info, trace, warn};
use russh_cryptovec::CryptoVec;
use russh_keys::encoding::{Encoding, Reader};
use russh_keys::key::parse_public_key;
use tokio::sync::mpsc::unbounded_channel;
use log::{debug, error, info, trace, warn};

use crate::client::{Handler, Msg, Prompt, Reply, Session};
use crate::key::PubKey;
use crate::negotiation::{Named, Select};
use crate::parsing::{ChannelOpenConfirmation, ChannelType, OpenChannelMessage};
use crate::session::{Encrypted, EncryptedState, Kex, KexInit};
use crate::{auth, msg, negotiation, Channel, ChannelId, ChannelOpenFailure, ChannelParams, Sig};
use crate::{
auth, msg, negotiation, Channel, ChannelId, ChannelMsg, ChannelOpenFailure, ChannelParams, Sig,
};

thread_local! {
static SIGNATURE_BUFFER: RefCell<CryptoVec> = RefCell::new(CryptoVec::new());
Expand Down Expand Up @@ -184,7 +186,6 @@ impl Session {
current: None,
rejection_count: 0,
},

};
let len = enc.write.len();
#[allow(clippy::indexing_slicing)] // length checked
Expand Down Expand Up @@ -246,15 +247,14 @@ impl Session {
if no_more_methods {
return Err(crate::Error::NoAuthMethod.into());
}

} else if buf.first() == Some(&msg::USERAUTH_INFO_REQUEST_OR_USERAUTH_PK_OK) {
if let Some(auth::CurrentRequest::PublicKey {
ref mut sent_pk_ok, ..
}) = auth_request.current
{
debug!("userauth_pk_ok");
*sent_pk_ok = true;
} else if let Some(auth::CurrentRequest::KeyboardInteractive { .. }) =
} else if let Some(auth::CurrentRequest::KeyboardInteractive { .. }) =
auth_request.current
{
debug!("keyboard_interactive");
Expand Down Expand Up @@ -307,7 +307,8 @@ impl Session {
// write responses
enc.client_send_auth_response(&responses)?;
return Ok((client, self));
} else {}
} else {
}

// continue with userauth_pk_ok
match self.common.auth_method.take() {
Expand Down Expand Up @@ -396,6 +397,18 @@ impl Session {
return Err(crate::Error::Inconsistent.into());
};

if let Some(channel) = self.channels.get(&local_id) {
channel
.send(ChannelMsg::Open {
id: local_id,
max_packet_size: msg.maximum_packet_size,
window_size: msg.initial_window_size,
})
.unwrap_or(());
} else {
error!("no channel for id {local_id:?}");
}

client
.channel_open_confirmation(
local_id,
Expand All @@ -414,12 +427,16 @@ impl Session {
// will not be released.
enc.close(channel_num);
}
self.channels.remove(&channel_num);
client.channel_close(channel_num, self).await
}
Some(&msg::CHANNEL_EOF) => {
debug!("channel_eof");
let mut r = buf.reader(1);
let channel_num = ChannelId(r.read_u32().map_err(crate::Error::from)?);
if let Some(chan) = self.channels.get(&channel_num) {
let _ = chan.send(ChannelMsg::Eof);
}
client.channel_eof(channel_num, self).await
}
Some(&msg::CHANNEL_OPEN_FAILURE) => {
Expand All @@ -436,6 +453,13 @@ impl Session {
if let Some(ref mut enc) = self.common.encrypted {
enc.channels.remove(&channel_num);
}

if let Some(sender) = self.channels.remove(&channel_num) {
let _ = sender.send(ChannelMsg::OpenFailure(reason_code));
}

let _ = self.sender.send(Reply::ChannelOpenFailure);

client
.channel_open_failure(channel_num, reason_code, descr, language, self)
.await
Expand All @@ -455,6 +479,13 @@ impl Session {
}
}
}

if let Some(chan) = self.channels.get(&channel_num) {
let _ = chan.send(ChannelMsg::Data {
data: CryptoVec::from_slice(data),
});
}

client.data(channel_num, data, self).await
}
Some(&msg::CHANNEL_EXTENDED_DATA) => {
Expand All @@ -473,6 +504,14 @@ impl Session {
}
}
}

if let Some(chan) = self.channels.get(&channel_num) {
let _ = chan.send(ChannelMsg::ExtendedData {
ext: extended_code,
data: CryptoVec::from_slice(data),
});
}

client
.extended_data(channel_num, extended_code, data, self)
.await
Expand All @@ -489,30 +528,44 @@ impl Session {
match req {
b"xon-xoff" => {
r.read_byte().map_err(crate::Error::from)?; // should be 0.
let client_can_do = r.read_byte().map_err(crate::Error::from)?;
client.xon_xoff(channel_num, client_can_do != 0, self).await
let client_can_do = r.read_byte().map_err(crate::Error::from)? != 0;
if let Some(chan) = self.channels.get(&channel_num) {
let _ = chan.send(ChannelMsg::XonXoff { client_can_do });
}
client.xon_xoff(channel_num, client_can_do, self).await
}
b"exit-status" => {
r.read_byte().map_err(crate::Error::from)?; // should be 0.
let exit_status = r.read_u32().map_err(crate::Error::from)?;
if let Some(chan) = self.channels.get(&channel_num) {
let _ = chan.send(ChannelMsg::ExitStatus { exit_status });
}
client.exit_status(channel_num, exit_status, self).await
}
b"exit-signal" => {
r.read_byte().map_err(crate::Error::from)?; // should be 0.
let signal_name =
Sig::from_name(r.read_string().map_err(crate::Error::from)?)?;
let core_dumped = r.read_byte().map_err(crate::Error::from)?;
let core_dumped = r.read_byte().map_err(crate::Error::from)? != 0;
let error_message =
std::str::from_utf8(r.read_string().map_err(crate::Error::from)?)
.map_err(crate::Error::from)?;
let lang_tag =
std::str::from_utf8(r.read_string().map_err(crate::Error::from)?)
.map_err(crate::Error::from)?;
if let Some(chan) = self.channels.get(&channel_num) {
let _ = chan.send(ChannelMsg::ExitSignal {
signal_name: signal_name.clone(),
core_dumped,
error_message: error_message.to_string(),
lang_tag: lang_tag.to_string(),
});
}
client
.exit_signal(
channel_num,
signal_name,
core_dumped != 0,
core_dumped,
error_message,
lang_tag,
self,
Expand Down Expand Up @@ -563,17 +616,24 @@ impl Session {
let mut r = buf.reader(1);
let channel_num = ChannelId(r.read_u32().map_err(crate::Error::from)?);
let amount = r.read_u32().map_err(crate::Error::from)?;
let mut new_value = 0;
let mut new_size = 0;
debug!("amount: {:?}", amount);
if let Some(ref mut enc) = self.common.encrypted {
if let Some(ref mut channel) = enc.channels.get_mut(&channel_num) {
channel.recipient_window_size += amount;
new_value = channel.recipient_window_size;
new_size = channel.recipient_window_size;
} else {
return Err(crate::Error::WrongChannel.into());
}
}
client.window_adjusted(channel_num, new_value, self).await

if let Some(ref mut enc) = self.common.encrypted {
new_size -= enc.flush_pending(channel_num) as u32;
}
if let Some(chan) = self.channels.get(&channel_num) {
let _ = chan.send(ChannelMsg::WindowAdjusted { new_size });
}
client.window_adjusted(channel_num, new_size, self).await
}
Some(&msg::GLOBAL_REQUEST) => {
let mut r = buf.reader(1);
Expand Down Expand Up @@ -634,11 +694,17 @@ impl Session {
Some(&msg::CHANNEL_SUCCESS) => {
let mut r = buf.reader(1);
let channel_num = ChannelId(r.read_u32().map_err(crate::Error::from)?);
if let Some(chan) = self.channels.get(&channel_num) {
let _ = chan.send(ChannelMsg::Success);
}
client.channel_success(channel_num, self).await
}
Some(&msg::CHANNEL_FAILURE) => {
let mut r = buf.reader(1);
let channel_num = ChannelId(r.read_u32().map_err(crate::Error::from)?);
if let Some(chan) = self.channels.get(&channel_num) {
let _ = chan.send(ChannelMsg::Failure);
}
client.channel_failure(channel_num, self).await
}
Some(&msg::CHANNEL_OPEN) => {
Expand Down Expand Up @@ -891,10 +957,7 @@ impl Encrypted {
Ok(())
}

fn client_send_auth_response(
&mut self,
responses: &[String]
) -> Result<(), crate::Error> {
fn client_send_auth_response(&mut self, responses: &[String]) -> Result<(), crate::Error> {
push_packet!(self.write, {
self.write.push(msg::USERAUTH_INFO_RESPONSE);
self.write
Expand Down

0 comments on commit 359fa3c

Please sign in to comment.