Skip to content

Commit

Permalink
Avert the race between sending data and sending EOF
Browse files Browse the repository at this point in the history
  • Loading branch information
mmirate authored and Eugeny committed Jan 18, 2024
1 parent 2a4b5a0 commit bd13e95
Show file tree
Hide file tree
Showing 4 changed files with 92 additions and 53 deletions.
2 changes: 2 additions & 0 deletions russh/src/client/encrypted.rs
Original file line number Diff line number Diff line change
Expand Up @@ -740,6 +740,8 @@ impl Session {
confirmed: true,
wants_reply: false,
pending_data: std::collections::VecDeque::new(),
pending_eof: false,
pending_close: false,
};

let confirm = || {
Expand Down
2 changes: 2 additions & 0 deletions russh/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -473,6 +473,8 @@ pub(crate) struct ChannelParams {
pub confirmed: bool,
wants_reply: bool,
pending_data: std::collections::VecDeque<(CryptoVec, Option<u32>, usize)>,
pending_eof: bool,
pending_close: bool,
}

impl ChannelParams {
Expand Down
2 changes: 2 additions & 0 deletions russh/src/server/encrypted.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1093,6 +1093,8 @@ impl Session {
confirmed: true,
wants_reply: false,
pending_data: std::collections::VecDeque::new(),
pending_eof: false,
pending_close: false,
};

let (channel, reference) = Channel::new(
Expand Down
139 changes: 86 additions & 53 deletions russh/src/session.rs
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,23 @@ pub(crate) struct CommonSession<Config> {
pub received_data: bool,
}

#[derive(Debug, Clone, Copy)]
pub(crate) enum ChannelFlushResult {
Incomplete { wrote: usize, },
Complete { wrote: usize, pending_eof: bool, pending_close: bool, }
}
impl ChannelFlushResult {
pub(crate) fn wrote(&self) -> usize {
match self {
ChannelFlushResult::Incomplete { wrote } => *wrote,
ChannelFlushResult::Complete { wrote, .. } => *wrote,
}
}
pub(crate) fn complete(wrote: usize, channel: &ChannelParams) -> Self {
ChannelFlushResult::Complete { wrote, pending_eof: channel.pending_eof, pending_close: channel.pending_close }
}
}

impl<C> CommonSession<C> {
pub fn newkeys(&mut self, newkeys: NewKeys) {
if let Some(ref mut enc) = self.encrypted {
Expand Down Expand Up @@ -158,12 +175,20 @@ impl Encrypted {
*/

pub fn eof(&mut self, channel: ChannelId) {
self.byte(channel, msg::CHANNEL_EOF);
if let Some(channel) = self.has_pending_data_mut(channel) {
channel.pending_eof = true;
} else {
self.byte(channel, msg::CHANNEL_EOF);
}
}

pub fn close(&mut self, channel: ChannelId) {
self.byte(channel, msg::CHANNEL_CLOSE);
self.channels.remove(&channel);
if let Some(channel) = self.has_pending_data_mut(channel) {
channel.pending_close = true;
} else {
self.byte(channel, msg::CHANNEL_CLOSE);
self.channels.remove(&channel);
}
}

pub fn sender_window_size(&self, channel: ChannelId) -> usize {
Expand Down Expand Up @@ -203,33 +228,55 @@ impl Encrypted {
false
}

fn flush_channel(write: &mut CryptoVec, channel: &mut ChannelParams) -> ChannelFlushResult {
let mut pending_size = 0;
while let Some((buf, a, from)) = channel.pending_data.pop_front() {
let size = Self::data_noqueue(write, channel, &buf, a, from);
pending_size += size;
if from + size < buf.len() {
channel.pending_data.push_front((buf, a, from + size));
return ChannelFlushResult::Incomplete { wrote: pending_size };
}
}
ChannelFlushResult::complete(pending_size, channel)
}

fn handle_flushed_channel(&mut self, channel: ChannelId, flush_result: ChannelFlushResult) {
if let ChannelFlushResult::Complete { wrote: _, pending_eof, pending_close } = flush_result {
if pending_eof {
self.eof(channel);
}
if pending_close {
self.close(channel);
}
}
}

pub fn flush_pending(&mut self, channel: ChannelId) -> usize {
let mut pending_size = 0;
let mut maybe_flush_result = Option::<ChannelFlushResult>::None;

if let Some(channel) = self.channels.get_mut(&channel) {
while let Some((buf, a, from)) = channel.pending_data.pop_front() {
let size = Self::data_noqueue(&mut self.write, channel, &buf, from);
pending_size += size;
if from + size < buf.len() {
channel.pending_data.push_front((buf, a, from + size));
break;
}
}
let flush_result = Self::flush_channel(&mut self.write, channel);
pending_size += flush_result.wrote();
maybe_flush_result = Some(flush_result);
}
if let Some(flush_result) = maybe_flush_result {
self.handle_flushed_channel(channel, flush_result)
}
pending_size
}

pub fn flush_all_pending(&mut self) {
for (_, channel) in self.channels.iter_mut() {
while let Some((buf, a, from)) = channel.pending_data.pop_front() {
let size = Self::data_noqueue(&mut self.write, channel, &buf, from);
if from + size < buf.len() {
channel.pending_data.push_front((buf, a, from + size));
break;
}
}
for channel in self.channels.values_mut() {
Self::flush_channel(&mut self.write, channel);
}
}

fn has_pending_data_mut(&mut self, channel: ChannelId) -> Option<&mut ChannelParams> {
self.channels.get_mut(&channel).filter(|c| !c.pending_data.is_empty())
}

pub fn has_pending_data(&self, channel: ChannelId) -> bool {
if let Some(channel) = self.channels.get(&channel) {
!channel.pending_data.is_empty()
Expand All @@ -245,6 +292,7 @@ impl Encrypted {
write: &mut CryptoVec,
channel: &mut ChannelParams,
buf0: &[u8],
a: Option<u32>,
from: usize,
) -> usize {
if from >= buf0.len() {
Expand All @@ -262,12 +310,21 @@ impl Encrypted {
while !buf.is_empty() {
// Compute the length we're allowed to send.
let off = std::cmp::min(buf.len(), channel.recipient_maximum_packet_size as usize);
push_packet!(write, {
write.push(msg::CHANNEL_DATA);
write.push_u32_be(channel.recipient_channel);
#[allow(clippy::indexing_slicing)] // length checked
write.extend_ssh_string(&buf[..off]);
});
match a {
None => push_packet!(write, {
write.push(msg::CHANNEL_DATA);
write.push_u32_be(channel.recipient_channel);
#[allow(clippy::indexing_slicing)] // length checked
write.extend_ssh_string(&buf[..off]);
}),
Some(ext) => push_packet!(write, {
write.push(msg::CHANNEL_EXTENDED_DATA);
write.push_u32_be(channel.recipient_channel);
write.push_u32_be(ext);
#[allow(clippy::indexing_slicing)] // length checked
write.extend_ssh_string(&buf[..off]);
}),
}
trace!(
"buffer: {:?} {:?}",
write.len(),
Expand All @@ -290,7 +347,7 @@ impl Encrypted {
channel.pending_data.push_back((buf0, None, 0));
return;
}
let buf_len = Self::data_noqueue(&mut self.write, channel, &buf0, 0);
let buf_len = Self::data_noqueue(&mut self.write, channel, &buf0, None, 0);
if buf_len < buf0.len() {
channel.pending_data.push_back((buf0, None, buf_len))
}
Expand All @@ -300,39 +357,13 @@ impl Encrypted {
}

pub fn extended_data(&mut self, channel: ChannelId, ext: u32, buf0: CryptoVec) {
use std::ops::Deref;
if let Some(channel) = self.channels.get_mut(&channel) {
assert!(channel.confirmed);
if !channel.pending_data.is_empty() {
channel.pending_data.push_back((buf0, Some(ext), 0));
return;
}
let mut buf = if buf0.len() as u32 > channel.recipient_window_size {
#[allow(clippy::indexing_slicing)] // length checked
&buf0[0..channel.recipient_window_size as usize]
} else {
&buf0
};
let buf_len = buf.len();

while !buf.is_empty() {
// Compute the length we're allowed to send.
let off = std::cmp::min(buf.len(), channel.recipient_maximum_packet_size as usize);
push_packet!(self.write, {
self.write.push(msg::CHANNEL_EXTENDED_DATA);
self.write.push_u32_be(channel.recipient_channel);
self.write.push_u32_be(ext);
#[allow(clippy::indexing_slicing)] // length checked
self.write.extend_ssh_string(&buf[..off]);
});
trace!("buffer: {:?}", self.write.deref().len());
channel.recipient_window_size -= off as u32;
#[allow(clippy::indexing_slicing)] // length checked
{
buf = &buf[off..]
}
}
trace!("buf.len() = {:?}, buf_len = {:?}", buf.len(), buf_len);
let buf_len = Self::data_noqueue(&mut self.write, channel, &buf0, Some(ext), 0);
if buf_len < buf0.len() {
channel.pending_data.push_back((buf0, Some(ext), buf_len))
}
Expand Down Expand Up @@ -402,6 +433,8 @@ impl Encrypted {
confirmed: false,
wants_reply: false,
pending_data: std::collections::VecDeque::new(),
pending_eof: false,
pending_close: false,
});
return ChannelId(self.last_channel_id.0);
}
Expand Down

0 comments on commit bd13e95

Please sign in to comment.