Skip to content

Commit

Permalink
Move run and run_on_socket to Server trait
Browse files Browse the repository at this point in the history
  • Loading branch information
ricott1 committed Feb 8, 2024
1 parent 0763767 commit a592366
Show file tree
Hide file tree
Showing 5 changed files with 81 additions and 78 deletions.
5 changes: 2 additions & 3 deletions russh/examples/echoserver.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ use std::collections::HashMap;
use std::sync::Arc;

use async_trait::async_trait;
use russh::server::Server as _;
use russh::server::{Msg, Session};
use russh::*;
use russh_keys::*;
Expand All @@ -25,9 +26,7 @@ async fn main() {
clients: Arc::new(Mutex::new(HashMap::new())),
id: 0,
};
russh::server::run(config, ("0.0.0.0", 2222), &mut sh)
.await
.unwrap();
sh.run_on_address(config, ("0.0.0.0", 2222)).await.unwrap();
}

#[derive(Clone)]
Expand Down
28 changes: 15 additions & 13 deletions russh/examples/sftp_server.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,9 @@ use std::time::Duration;

use async_trait::async_trait;
use log::{error, info, LevelFilter};
use russh::server::Server as _;
use russh::server::{Auth, Msg, Session};

use russh::{Channel, ChannelId};
use russh_keys::key::KeyPair;
use russh_sftp::protocol::{File, FileAttributes, Handle, Name, Status, StatusCode, Version};
Expand Down Expand Up @@ -195,17 +197,17 @@ async fn main() {

let mut server = Server;

russh::server::run(
Arc::new(config),
(
"0.0.0.0",
std::env::var("PORT")
.unwrap_or("22".to_string())
.parse()
.unwrap(),
),
&mut server,
)
.await
.unwrap();
server
.run_on_address(
Arc::new(config),
(
"0.0.0.0",
std::env::var("PORT")
.unwrap_or("22".to_string())
.parse()
.unwrap(),
),
)
.await
.unwrap();
}
3 changes: 2 additions & 1 deletion russh/examples/test.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ use std::sync::{Arc, Mutex};

use async_trait::async_trait;
use log::debug;
use russh::server::Server as _;
use russh::server::{Auth, Msg, Session};
use russh::*;
use russh_keys::*;
Expand All @@ -22,7 +23,7 @@ async fn main() -> anyhow::Result<()> {
};
tokio::time::timeout(
std::time::Duration::from_secs(60),
russh::server::run(config, ("0.0.0.0", 2222), &mut sh),
sh.run_on_address(config, ("0.0.0.0", 2222)),
)
.await
.unwrap_or(Ok(()))?;
Expand Down
120 changes: 60 additions & 60 deletions russh/src/server/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,6 @@ use crate::*;

mod kex;
mod session;
pub use self::kex::*;
pub use self::session::*;
mod encrypted;

Expand Down Expand Up @@ -525,80 +524,81 @@ pub trait Handler: Sized {
}
}

#[async_trait]
/// Trait used to create new handlers when clients connect.
pub trait Server {
/// The type of handlers.
type Handler: Handler + Send;
type Handler: Handler + Send + 'static;
/// Called when a new client connects.
fn new_client(&mut self, peer_addr: Option<std::net::SocketAddr>) -> Self::Handler;
/// Called when an active connection fails.
fn handle_session_error(&mut self, _error: <Self::Handler as Handler>::Error) {}
}

/// Run a server on a specified `tokio::net::TcpListener`. Useful when dropping
/// privileges immediately after socket binding, for example.
pub async fn run_on_socket<H: Server + Send + 'static>(
config: Arc<Config>,
socket: &TcpListener,
server: &mut H,
) -> Result<(), std::io::Error> {
if config.maximum_packet_size > 65535 {
error!(
"Maximum packet size ({:?}) should not larger than a TCP packet (65535)",
config.maximum_packet_size
);
}

let (error_tx, mut error_rx) = tokio::sync::mpsc::unbounded_channel();

loop {
tokio::select! {
accept_result = socket.accept() => {
match accept_result {
Ok((socket, _)) => {
let config = config.clone();
let handler = server.new_client(socket.peer_addr().ok());
let error_tx = error_tx.clone();
tokio::spawn(async move {
let session = match run_stream(config, socket, handler).await {
Ok(s) => s,
Err(e) => {
debug!("Connection setup failed");
let _ = error_tx.send(e);
return
}
};
match session.await {
Ok(_) => debug!("Connection closed"),
Err(e) => {
debug!("Connection closed with error");
let _ = error_tx.send(e);
/// Run a server on a specified `tokio::net::TcpListener`. Useful when dropping
/// privileges immediately after socket binding, for example.
async fn run_on_socket(
&mut self,
config: Arc<Config>,
socket: &TcpListener,
) -> Result<(), std::io::Error> {
if config.maximum_packet_size > 65535 {
error!(
"Maximum packet size ({:?}) should not larger than a TCP packet (65535)",
config.maximum_packet_size
);
}

let (error_tx, mut error_rx) = tokio::sync::mpsc::unbounded_channel();

loop {
tokio::select! {
accept_result = socket.accept() => {
match accept_result {
Ok((socket, _)) => {
let config = config.clone();
let handler = self.new_client(socket.peer_addr().ok());
let error_tx = error_tx.clone();
tokio::spawn(async move {
let session = match run_stream(config, socket, handler).await {
Ok(s) => s,
Err(e) => {
debug!("Connection setup failed");
let _ = error_tx.send(e);
return
}
};
match session.await {
Ok(_) => debug!("Connection closed"),
Err(e) => {
debug!("Connection closed with error");
let _ = error_tx.send(e);
}
}
}
});
});
}
_ => break,
}
_ => break,
},
Some(error) = error_rx.recv() => {
self.handle_session_error(error);
}
},
Some(error) = error_rx.recv() => {
server.handle_session_error(error);
}
}
}

Ok(())
}
Ok(())
}

/// Run a server.
/// Create a new `Connection` from the server's configuration, a
/// stream and a [`Handler`](trait.Handler.html).
pub async fn run<H: Server + Send + 'static, A: ToSocketAddrs>(
config: Arc<Config>,
addrs: A,
server: &mut H,
) -> Result<(), std::io::Error> {
let socket = TcpListener::bind(addrs).await?;
run_on_socket(config, &socket, server).await
/// Run a server.
/// Create a new `Connection` from the server's configuration, a
/// stream and a [`Handler`](trait.Handler.html).
async fn run_on_address<A: ToSocketAddrs + Send>(
&mut self,
config: Arc<Config>,
addrs: A,
) -> Result<(), std::io::Error> {
let socket = TcpListener::bind(addrs).await?;
self.run_on_socket(config, &socket).await
}
}

use std::cell::RefCell;
Expand Down
3 changes: 2 additions & 1 deletion russh/tests/test_data_stream.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ use std::net::{SocketAddr, TcpListener, TcpStream};
use std::sync::Arc;

use rand::RngCore;
use russh::server::Server as _;
use russh::server::{self, Auth, Msg, Session};
use russh::{client, Channel};
use russh_keys::key;
Expand Down Expand Up @@ -90,7 +91,7 @@ impl Server {
});
let mut sh = Server {};

russh::server::run(config, addr, &mut sh).await.unwrap();
sh.run_on_address(config, addr).await.unwrap();
}
}

Expand Down

0 comments on commit a592366

Please sign in to comment.