Skip to content

Commit bd13e95

Browse files
mmirateEugeny
authored andcommitted
Avert the race between sending data and sending EOF
1 parent 2a4b5a0 commit bd13e95

File tree

4 files changed

+92
-53
lines changed

4 files changed

+92
-53
lines changed

russh/src/client/encrypted.rs

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -740,6 +740,8 @@ impl Session {
740740
confirmed: true,
741741
wants_reply: false,
742742
pending_data: std::collections::VecDeque::new(),
743+
pending_eof: false,
744+
pending_close: false,
743745
};
744746

745747
let confirm = || {

russh/src/lib.rs

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -473,6 +473,8 @@ pub(crate) struct ChannelParams {
473473
pub confirmed: bool,
474474
wants_reply: bool,
475475
pending_data: std::collections::VecDeque<(CryptoVec, Option<u32>, usize)>,
476+
pending_eof: bool,
477+
pending_close: bool,
476478
}
477479

478480
impl ChannelParams {

russh/src/server/encrypted.rs

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1093,6 +1093,8 @@ impl Session {
10931093
confirmed: true,
10941094
wants_reply: false,
10951095
pending_data: std::collections::VecDeque::new(),
1096+
pending_eof: false,
1097+
pending_close: false,
10961098
};
10971099

10981100
let (channel, reference) = Channel::new(

russh/src/session.rs

Lines changed: 86 additions & 53 deletions
Original file line numberDiff line numberDiff line change
@@ -68,6 +68,23 @@ pub(crate) struct CommonSession<Config> {
6868
pub received_data: bool,
6969
}
7070

71+
#[derive(Debug, Clone, Copy)]
72+
pub(crate) enum ChannelFlushResult {
73+
Incomplete { wrote: usize, },
74+
Complete { wrote: usize, pending_eof: bool, pending_close: bool, }
75+
}
76+
impl ChannelFlushResult {
77+
pub(crate) fn wrote(&self) -> usize {
78+
match self {
79+
ChannelFlushResult::Incomplete { wrote } => *wrote,
80+
ChannelFlushResult::Complete { wrote, .. } => *wrote,
81+
}
82+
}
83+
pub(crate) fn complete(wrote: usize, channel: &ChannelParams) -> Self {
84+
ChannelFlushResult::Complete { wrote, pending_eof: channel.pending_eof, pending_close: channel.pending_close }
85+
}
86+
}
87+
7188
impl<C> CommonSession<C> {
7289
pub fn newkeys(&mut self, newkeys: NewKeys) {
7390
if let Some(ref mut enc) = self.encrypted {
@@ -158,12 +175,20 @@ impl Encrypted {
158175
*/
159176

160177
pub fn eof(&mut self, channel: ChannelId) {
161-
self.byte(channel, msg::CHANNEL_EOF);
178+
if let Some(channel) = self.has_pending_data_mut(channel) {
179+
channel.pending_eof = true;
180+
} else {
181+
self.byte(channel, msg::CHANNEL_EOF);
182+
}
162183
}
163184

164185
pub fn close(&mut self, channel: ChannelId) {
165-
self.byte(channel, msg::CHANNEL_CLOSE);
166-
self.channels.remove(&channel);
186+
if let Some(channel) = self.has_pending_data_mut(channel) {
187+
channel.pending_close = true;
188+
} else {
189+
self.byte(channel, msg::CHANNEL_CLOSE);
190+
self.channels.remove(&channel);
191+
}
167192
}
168193

169194
pub fn sender_window_size(&self, channel: ChannelId) -> usize {
@@ -203,33 +228,55 @@ impl Encrypted {
203228
false
204229
}
205230

231+
fn flush_channel(write: &mut CryptoVec, channel: &mut ChannelParams) -> ChannelFlushResult {
232+
let mut pending_size = 0;
233+
while let Some((buf, a, from)) = channel.pending_data.pop_front() {
234+
let size = Self::data_noqueue(write, channel, &buf, a, from);
235+
pending_size += size;
236+
if from + size < buf.len() {
237+
channel.pending_data.push_front((buf, a, from + size));
238+
return ChannelFlushResult::Incomplete { wrote: pending_size };
239+
}
240+
}
241+
ChannelFlushResult::complete(pending_size, channel)
242+
}
243+
244+
fn handle_flushed_channel(&mut self, channel: ChannelId, flush_result: ChannelFlushResult) {
245+
if let ChannelFlushResult::Complete { wrote: _, pending_eof, pending_close } = flush_result {
246+
if pending_eof {
247+
self.eof(channel);
248+
}
249+
if pending_close {
250+
self.close(channel);
251+
}
252+
}
253+
}
254+
206255
pub fn flush_pending(&mut self, channel: ChannelId) -> usize {
207256
let mut pending_size = 0;
257+
let mut maybe_flush_result = Option::<ChannelFlushResult>::None;
258+
208259
if let Some(channel) = self.channels.get_mut(&channel) {
209-
while let Some((buf, a, from)) = channel.pending_data.pop_front() {
210-
let size = Self::data_noqueue(&mut self.write, channel, &buf, from);
211-
pending_size += size;
212-
if from + size < buf.len() {
213-
channel.pending_data.push_front((buf, a, from + size));
214-
break;
215-
}
216-
}
260+
let flush_result = Self::flush_channel(&mut self.write, channel);
261+
pending_size += flush_result.wrote();
262+
maybe_flush_result = Some(flush_result);
263+
}
264+
if let Some(flush_result) = maybe_flush_result {
265+
self.handle_flushed_channel(channel, flush_result)
217266
}
218267
pending_size
219268
}
220269

221270
pub fn flush_all_pending(&mut self) {
222-
for (_, channel) in self.channels.iter_mut() {
223-
while let Some((buf, a, from)) = channel.pending_data.pop_front() {
224-
let size = Self::data_noqueue(&mut self.write, channel, &buf, from);
225-
if from + size < buf.len() {
226-
channel.pending_data.push_front((buf, a, from + size));
227-
break;
228-
}
229-
}
271+
for channel in self.channels.values_mut() {
272+
Self::flush_channel(&mut self.write, channel);
230273
}
231274
}
232275

276+
fn has_pending_data_mut(&mut self, channel: ChannelId) -> Option<&mut ChannelParams> {
277+
self.channels.get_mut(&channel).filter(|c| !c.pending_data.is_empty())
278+
}
279+
233280
pub fn has_pending_data(&self, channel: ChannelId) -> bool {
234281
if let Some(channel) = self.channels.get(&channel) {
235282
!channel.pending_data.is_empty()
@@ -245,6 +292,7 @@ impl Encrypted {
245292
write: &mut CryptoVec,
246293
channel: &mut ChannelParams,
247294
buf0: &[u8],
295+
a: Option<u32>,
248296
from: usize,
249297
) -> usize {
250298
if from >= buf0.len() {
@@ -262,12 +310,21 @@ impl Encrypted {
262310
while !buf.is_empty() {
263311
// Compute the length we're allowed to send.
264312
let off = std::cmp::min(buf.len(), channel.recipient_maximum_packet_size as usize);
265-
push_packet!(write, {
266-
write.push(msg::CHANNEL_DATA);
267-
write.push_u32_be(channel.recipient_channel);
268-
#[allow(clippy::indexing_slicing)] // length checked
269-
write.extend_ssh_string(&buf[..off]);
270-
});
313+
match a {
314+
None => push_packet!(write, {
315+
write.push(msg::CHANNEL_DATA);
316+
write.push_u32_be(channel.recipient_channel);
317+
#[allow(clippy::indexing_slicing)] // length checked
318+
write.extend_ssh_string(&buf[..off]);
319+
}),
320+
Some(ext) => push_packet!(write, {
321+
write.push(msg::CHANNEL_EXTENDED_DATA);
322+
write.push_u32_be(channel.recipient_channel);
323+
write.push_u32_be(ext);
324+
#[allow(clippy::indexing_slicing)] // length checked
325+
write.extend_ssh_string(&buf[..off]);
326+
}),
327+
}
271328
trace!(
272329
"buffer: {:?} {:?}",
273330
write.len(),
@@ -290,7 +347,7 @@ impl Encrypted {
290347
channel.pending_data.push_back((buf0, None, 0));
291348
return;
292349
}
293-
let buf_len = Self::data_noqueue(&mut self.write, channel, &buf0, 0);
350+
let buf_len = Self::data_noqueue(&mut self.write, channel, &buf0, None, 0);
294351
if buf_len < buf0.len() {
295352
channel.pending_data.push_back((buf0, None, buf_len))
296353
}
@@ -300,39 +357,13 @@ impl Encrypted {
300357
}
301358

302359
pub fn extended_data(&mut self, channel: ChannelId, ext: u32, buf0: CryptoVec) {
303-
use std::ops::Deref;
304360
if let Some(channel) = self.channels.get_mut(&channel) {
305361
assert!(channel.confirmed);
306362
if !channel.pending_data.is_empty() {
307363
channel.pending_data.push_back((buf0, Some(ext), 0));
308364
return;
309365
}
310-
let mut buf = if buf0.len() as u32 > channel.recipient_window_size {
311-
#[allow(clippy::indexing_slicing)] // length checked
312-
&buf0[0..channel.recipient_window_size as usize]
313-
} else {
314-
&buf0
315-
};
316-
let buf_len = buf.len();
317-
318-
while !buf.is_empty() {
319-
// Compute the length we're allowed to send.
320-
let off = std::cmp::min(buf.len(), channel.recipient_maximum_packet_size as usize);
321-
push_packet!(self.write, {
322-
self.write.push(msg::CHANNEL_EXTENDED_DATA);
323-
self.write.push_u32_be(channel.recipient_channel);
324-
self.write.push_u32_be(ext);
325-
#[allow(clippy::indexing_slicing)] // length checked
326-
self.write.extend_ssh_string(&buf[..off]);
327-
});
328-
trace!("buffer: {:?}", self.write.deref().len());
329-
channel.recipient_window_size -= off as u32;
330-
#[allow(clippy::indexing_slicing)] // length checked
331-
{
332-
buf = &buf[off..]
333-
}
334-
}
335-
trace!("buf.len() = {:?}, buf_len = {:?}", buf.len(), buf_len);
366+
let buf_len = Self::data_noqueue(&mut self.write, channel, &buf0, Some(ext), 0);
336367
if buf_len < buf0.len() {
337368
channel.pending_data.push_back((buf0, Some(ext), buf_len))
338369
}
@@ -402,6 +433,8 @@ impl Encrypted {
402433
confirmed: false,
403434
wants_reply: false,
404435
pending_data: std::collections::VecDeque::new(),
436+
pending_eof: false,
437+
pending_close: false,
405438
});
406439
return ChannelId(self.last_channel_id.0);
407440
}

0 commit comments

Comments
 (0)