Skip to content

Commit

Permalink
Use ChannelMsg::WindowAdjusted during data transfer
Browse files Browse the repository at this point in the history
OpenSSH server sends `CHANNEL_WINDOW_ADJUST` messages before window_size is 0.

Handle these message at each turn of the loop within `Channel.send_data`

Signed-off-by: Joe Grund <jgrund@whamcloud.io>
  • Loading branch information
jgrund committed Sep 12, 2023
1 parent 8fa6892 commit 52e5eaa
Show file tree
Hide file tree
Showing 3 changed files with 41 additions and 18 deletions.
40 changes: 29 additions & 11 deletions russh/src/channels.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
use russh_cryptovec::CryptoVec;
use tokio::sync::mpsc::{Sender, UnboundedReceiver};
use log::debug;
use russh_cryptovec::CryptoVec;
use tokio::sync::mpsc::{error::TryRecvError, Sender, UnboundedReceiver};

use crate::{ChannelId, ChannelOpenFailure, ChannelStream, Error, Pty, Sig};

Expand Down Expand Up @@ -290,23 +290,38 @@ impl<S: From<(ChannelId, ChannelMsg)> + Send + 'static> Channel<S> {
while self.window_size == 0 {
match self.receiver.recv().await {
Some(ChannelMsg::WindowAdjusted { new_size }) => {
debug!("window adjusted: {:?}", new_size);
debug!("channel {} => window adjusted: {new_size}", self.id);
self.window_size = new_size;
break;
}
Some(msg) => {
debug!("unexpected channel msg: {:?}", msg);
debug!("channel {} => unexpected channel msg: {msg:?}", self.id);
}
None => break,
}
}

// Some implementations send CHANNEL_WINDOW_ADJUST prior to
// window size being 0. Process those at each turn of the loop here.
match self.receiver.try_recv() {
Ok(ChannelMsg::WindowAdjusted { new_size }) => {
debug!("channel {} => window adjusted: {new_size}", self.id);
self.window_size = new_size;
}
Ok(msg) => {
debug!("channel {} => unexpected channel msg: {msg:?}", self.id);
}
Err(TryRecvError::Empty | TryRecvError::Disconnected) => {}
}

debug!(
"sending data, self.window_size = {:?}, self.max_packet_size = {:?}, total = {:?}",
self.window_size, self.max_packet_size, total
"channel {} => sending data, self.window_size = {}, self.max_packet_size = {}, total = {total}",
self.id, self.window_size, self.max_packet_size
);
let sendable = self.window_size.min(self.max_packet_size) as usize;

debug!("sendable {:?}", sendable);
let sendable = self.writable_packet_size();

debug!("channel {} => sendable {sendable}", self.id);

// If we can not send anymore, continue
// and wait for server window adjustment
Expand All @@ -318,14 +333,16 @@ impl<S: From<(ChannelId, ChannelMsg)> + Send + 'static> Channel<S> {
let n = data.read(&mut c[..]).await?;
total += n;
c.resize(n);
self.window_size -= n as u32;

self.window_size = self.window_size.saturating_sub(n as u32);

self.send_data_packet(ext, c).await?;

if n == 0 {
break;
} else if self.window_size > 0 {
continue;
}
}

Ok(())
}

Expand All @@ -349,6 +366,7 @@ impl<S: From<(ChannelId, ChannelMsg)> + Send + 'static> Channel<S> {
match self.receiver.recv().await {
Some(ChannelMsg::WindowAdjusted { new_size }) => {
self.window_size = new_size;

Some(ChannelMsg::WindowAdjusted { new_size })
}
Some(msg) => Some(msg),
Expand Down
15 changes: 10 additions & 5 deletions russh/src/client/encrypted.rs
Original file line number Diff line number Diff line change
Expand Up @@ -613,26 +613,31 @@ impl Session {
}
Some(&msg::CHANNEL_WINDOW_ADJUST) => {
debug!("channel_window_adjust");

let mut r = buf.reader(1);
let channel_num = ChannelId(r.read_u32().map_err(crate::Error::from)?);
let amount = r.read_u32().map_err(crate::Error::from)?;
let mut new_size = 0;

debug!("amount: {:?}", amount);

if let Some(ref mut enc) = self.common.encrypted {
if let Some(ref mut channel) = enc.channels.get_mut(&channel_num) {
channel.recipient_window_size += amount;
channel.recipient_window_size =
channel.recipient_window_size.saturating_add(amount);

new_size = channel.recipient_window_size;
} else {
return Err(crate::Error::WrongChannel.into());
}
}

if let Some(ref mut enc) = self.common.encrypted {
new_size -= enc.flush_pending(channel_num) as u32;
new_size = new_size.saturating_sub(enc.flush_pending(channel_num) as u32);
}

if let Some(chan) = self.channels.get(&channel_num) {
let _ = chan.send(ChannelMsg::WindowAdjusted { new_size });
_ = chan.send(ChannelMsg::WindowAdjusted { new_size });
}

client.window_adjusted(channel_num, new_size, self).await
}
Some(&msg::GLOBAL_REQUEST) => {
Expand Down
4 changes: 2 additions & 2 deletions russh/src/server/encrypted.rs
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ impl Session {
// Either this packet is a KEXINIT, in which case we start a key re-exchange.

#[allow(clippy::unwrap_used)]
let mut enc = self.common.encrypted.as_mut().unwrap();
let enc = self.common.encrypted.as_mut().unwrap();
if buf.first() == Some(&msg::KEXINIT) {
debug!("Received rekeying request");
// If we're not currently rekeying, but `buf` is a rekey request
Expand Down Expand Up @@ -143,7 +143,7 @@ impl Session {
};

#[allow(clippy::unwrap_used)]
let mut enc = self.common.encrypted.as_mut().unwrap();
let enc = self.common.encrypted.as_mut().unwrap();
// If we've successfully read a packet.
match enc.state {
EncryptedState::WaitingAuthServiceRequest {
Expand Down

0 comments on commit 52e5eaa

Please sign in to comment.