Skip to content

Commit

Permalink
Add channels::io::{ChannelTx, ChannelRx} and implement Channel::into_…
Browse files Browse the repository at this point in the history
…io_parts
  • Loading branch information
lowlevl authored and Eugeny committed Sep 22, 2023
1 parent fc77c53 commit acd744a
Show file tree
Hide file tree
Showing 4 changed files with 223 additions and 0 deletions.
7 changes: 7 additions & 0 deletions russh/src/channels/io/mod.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
use super::ChannelMsg;

mod rx;
pub use rx::ChannelRx;

mod tx;
pub use tx::ChannelTx;
100 changes: 100 additions & 0 deletions russh/src/channels/io/rx.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,100 @@
use std::{
io,
pin::Pin,
sync::{Arc, Mutex, TryLockError},
task::{Context, Poll},
};

use tokio::{
io::AsyncRead,
sync::mpsc::{self, error::TryRecvError},
};

use super::ChannelMsg;

pub struct ChannelRx {
receiver: mpsc::UnboundedReceiver<ChannelMsg>,
buffer: Option<ChannelMsg>,

window_size: Arc<Mutex<u32>>,
}

impl ChannelRx {
pub fn new(
receiver: mpsc::UnboundedReceiver<ChannelMsg>,
window_size: Arc<Mutex<u32>>,
) -> Self {
Self {
receiver,
buffer: None,
window_size,
}
}
}

impl AsyncRead for ChannelRx {
fn poll_read(
mut self: Pin<&mut Self>,
cx: &mut Context<'_>,
buf: &mut tokio::io::ReadBuf<'_>,
) -> Poll<io::Result<()>> {
let msg = match self.buffer.take() {
Some(msg) => msg,
None => match self.receiver.try_recv() {
Ok(msg) => msg,
Err(TryRecvError::Empty) => {
cx.waker().wake_by_ref();
return Poll::Pending;
}
Err(TryRecvError::Disconnected) => {
return Poll::Ready(Ok(()));
}
},
};

match &msg {
ChannelMsg::Data { data } => {
if buf.remaining() >= data.len() {
buf.put_slice(data);

Poll::Ready(Ok(()))
} else {
self.buffer = Some(msg);

cx.waker().wake_by_ref();
Poll::Pending
}
}
ChannelMsg::WindowAdjusted { new_size } => {
let buffer = match self.window_size.try_lock() {
Ok(mut window_size) => {
*window_size = *new_size;

None
}
Err(TryLockError::WouldBlock) => Some(msg),
Err(TryLockError::Poisoned(err)) => {
return Poll::Ready(Err(io::Error::new(
io::ErrorKind::Other,
err.to_string(),
)))
}
};

self.buffer = buffer;

cx.waker().wake_by_ref();
Poll::Pending
}
ChannelMsg::Eof => {
self.receiver.close();

Poll::Ready(Ok(()))
}
_ => {
cx.waker().wake_by_ref();
Poll::Pending
}
}
}
}
96 changes: 96 additions & 0 deletions russh/src/channels/io/tx.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,96 @@
use std::{
io,
pin::Pin,
sync::{Arc, Mutex, TryLockError},
task::{Context, Poll},
};

use tokio::{
io::AsyncWrite,
sync::mpsc::{self, error::TrySendError},
};

use russh_cryptovec::CryptoVec;

use super::ChannelMsg;
use crate::ChannelId;

pub struct ChannelTx<S> {
sender: mpsc::Sender<S>,
id: ChannelId,

window_size: Arc<Mutex<u32>>,
max_packet_size: u32,
}

impl<S> ChannelTx<S> {
pub fn new(
sender: mpsc::Sender<S>,
id: ChannelId,
window_size: Arc<Mutex<u32>>,
max_packet_size: u32,
) -> Self {
Self {
sender,
id,
window_size,
max_packet_size,
}
}
}

impl<S> AsyncWrite for ChannelTx<S>
where
S: From<(ChannelId, ChannelMsg)> + 'static,
{
fn poll_write(
self: Pin<&mut Self>,
cx: &mut Context<'_>,
buf: &[u8],
) -> Poll<Result<usize, io::Error>> {
let mut window_size = match self.window_size.try_lock() {
Ok(window_size) => window_size,
Err(TryLockError::WouldBlock) => {
cx.waker().wake_by_ref();
return Poll::Pending;
}
Err(TryLockError::Poisoned(err)) => {
return Poll::Ready(Err(io::Error::new(io::ErrorKind::Other, err.to_string())))
}
};

let writable = self.max_packet_size.min(*window_size).min(buf.len() as u32) as usize;
if writable == 0 {
cx.waker().wake_by_ref();
return Poll::Pending;
}

let mut data = CryptoVec::new_zeroed(writable);
#[allow(clippy::indexing_slicing)] // Clamped to maximum `buf.len()` with `.min`
data.copy_from_slice(&buf[..writable]);
data.resize(writable);

*window_size -= writable as u32;
drop(window_size);

match self
.sender
.try_send((self.id, ChannelMsg::Data { data }).into())
{
Ok(_) => Poll::Ready(Ok(writable)),
Err(TrySendError::Closed(_)) => Poll::Ready(Ok(0)),
Err(TrySendError::Full(_)) => {
cx.waker().wake_by_ref();
Poll::Pending
}
}
}

fn poll_flush(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<Result<(), io::Error>> {
Poll::Ready(Ok(()))
}

fn poll_shutdown(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), io::Error>> {
self.poll_flush(cx)
}
}
20 changes: 20 additions & 0 deletions russh/src/channels/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@ use log::debug;

use crate::{ChannelId, ChannelOpenFailure, ChannelStream, Error, Pty, Sig};

pub mod io;

#[derive(Debug)]
#[non_exhaustive]
/// Possible messages that [Channel::wait] can receive.
Expand Down Expand Up @@ -410,4 +412,22 @@ impl<S: From<(ChannelId, ChannelMsg)> + Send + 'static> Channel<S> {
});
stream
}

/// Setup the [`Channel`] to be able to send messages through [`io::ChannelTx`],
/// and receiving them through [`io::ChannelRx`].
pub fn into_io_parts(self) -> (io::ChannelTx<S>, io::ChannelRx) {
use std::sync::{Arc, Mutex};

let window_size = Arc::new(Mutex::new(self.window_size));

(
io::ChannelTx::new(
self.sender,
self.id,
window_size.clone(),
self.max_packet_size,
),
io::ChannelRx::new(self.receiver, window_size),
)
}
}

0 comments on commit acd744a

Please sign in to comment.