@@ -68,6 +68,23 @@ pub(crate) struct CommonSession<Config> {
6868 pub received_data : bool ,
6969}
7070
71+ #[ derive( Debug , Clone , Copy ) ]
72+ pub ( crate ) enum ChannelFlushResult {
73+ Incomplete { wrote : usize , } ,
74+ Complete { wrote : usize , pending_eof : bool , pending_close : bool , }
75+ }
76+ impl ChannelFlushResult {
77+ pub ( crate ) fn wrote ( & self ) -> usize {
78+ match self {
79+ ChannelFlushResult :: Incomplete { wrote } => * wrote,
80+ ChannelFlushResult :: Complete { wrote, .. } => * wrote,
81+ }
82+ }
83+ pub ( crate ) fn complete ( wrote : usize , channel : & ChannelParams ) -> Self {
84+ ChannelFlushResult :: Complete { wrote, pending_eof : channel. pending_eof , pending_close : channel. pending_close }
85+ }
86+ }
87+
7188impl < C > CommonSession < C > {
7289 pub fn newkeys ( & mut self , newkeys : NewKeys ) {
7390 if let Some ( ref mut enc) = self . encrypted {
@@ -158,12 +175,20 @@ impl Encrypted {
158175 */
159176
160177 pub fn eof ( & mut self , channel : ChannelId ) {
161- self . byte ( channel, msg:: CHANNEL_EOF ) ;
178+ if let Some ( channel) = self . has_pending_data_mut ( channel) {
179+ channel. pending_eof = true ;
180+ } else {
181+ self . byte ( channel, msg:: CHANNEL_EOF ) ;
182+ }
162183 }
163184
164185 pub fn close ( & mut self , channel : ChannelId ) {
165- self . byte ( channel, msg:: CHANNEL_CLOSE ) ;
166- self . channels . remove ( & channel) ;
186+ if let Some ( channel) = self . has_pending_data_mut ( channel) {
187+ channel. pending_close = true ;
188+ } else {
189+ self . byte ( channel, msg:: CHANNEL_CLOSE ) ;
190+ self . channels . remove ( & channel) ;
191+ }
167192 }
168193
169194 pub fn sender_window_size ( & self , channel : ChannelId ) -> usize {
@@ -203,33 +228,55 @@ impl Encrypted {
203228 false
204229 }
205230
231+ fn flush_channel ( write : & mut CryptoVec , channel : & mut ChannelParams ) -> ChannelFlushResult {
232+ let mut pending_size = 0 ;
233+ while let Some ( ( buf, a, from) ) = channel. pending_data . pop_front ( ) {
234+ let size = Self :: data_noqueue ( write, channel, & buf, a, from) ;
235+ pending_size += size;
236+ if from + size < buf. len ( ) {
237+ channel. pending_data . push_front ( ( buf, a, from + size) ) ;
238+ return ChannelFlushResult :: Incomplete { wrote : pending_size } ;
239+ }
240+ }
241+ ChannelFlushResult :: complete ( pending_size, channel)
242+ }
243+
244+ fn handle_flushed_channel ( & mut self , channel : ChannelId , flush_result : ChannelFlushResult ) {
245+ if let ChannelFlushResult :: Complete { wrote : _, pending_eof, pending_close } = flush_result {
246+ if pending_eof {
247+ self . eof ( channel) ;
248+ }
249+ if pending_close {
250+ self . close ( channel) ;
251+ }
252+ }
253+ }
254+
206255 pub fn flush_pending ( & mut self , channel : ChannelId ) -> usize {
207256 let mut pending_size = 0 ;
257+ let mut maybe_flush_result = Option :: < ChannelFlushResult > :: None ;
258+
208259 if let Some ( channel) = self . channels . get_mut ( & channel) {
209- while let Some ( ( buf, a, from) ) = channel. pending_data . pop_front ( ) {
210- let size = Self :: data_noqueue ( & mut self . write , channel, & buf, from) ;
211- pending_size += size;
212- if from + size < buf. len ( ) {
213- channel. pending_data . push_front ( ( buf, a, from + size) ) ;
214- break ;
215- }
216- }
260+ let flush_result = Self :: flush_channel ( & mut self . write , channel) ;
261+ pending_size += flush_result. wrote ( ) ;
262+ maybe_flush_result = Some ( flush_result) ;
263+ }
264+ if let Some ( flush_result) = maybe_flush_result {
265+ self . handle_flushed_channel ( channel, flush_result)
217266 }
218267 pending_size
219268 }
220269
221270 pub fn flush_all_pending ( & mut self ) {
222- for ( _, channel) in self . channels . iter_mut ( ) {
223- while let Some ( ( buf, a, from) ) = channel. pending_data . pop_front ( ) {
224- let size = Self :: data_noqueue ( & mut self . write , channel, & buf, from) ;
225- if from + size < buf. len ( ) {
226- channel. pending_data . push_front ( ( buf, a, from + size) ) ;
227- break ;
228- }
229- }
271+ for channel in self . channels . values_mut ( ) {
272+ Self :: flush_channel ( & mut self . write , channel) ;
230273 }
231274 }
232275
276+ fn has_pending_data_mut ( & mut self , channel : ChannelId ) -> Option < & mut ChannelParams > {
277+ self . channels . get_mut ( & channel) . filter ( |c| !c. pending_data . is_empty ( ) )
278+ }
279+
233280 pub fn has_pending_data ( & self , channel : ChannelId ) -> bool {
234281 if let Some ( channel) = self . channels . get ( & channel) {
235282 !channel. pending_data . is_empty ( )
@@ -245,6 +292,7 @@ impl Encrypted {
245292 write : & mut CryptoVec ,
246293 channel : & mut ChannelParams ,
247294 buf0 : & [ u8 ] ,
295+ a : Option < u32 > ,
248296 from : usize ,
249297 ) -> usize {
250298 if from >= buf0. len ( ) {
@@ -262,12 +310,21 @@ impl Encrypted {
262310 while !buf. is_empty ( ) {
263311 // Compute the length we're allowed to send.
264312 let off = std:: cmp:: min ( buf. len ( ) , channel. recipient_maximum_packet_size as usize ) ;
265- push_packet ! ( write, {
266- write. push( msg:: CHANNEL_DATA ) ;
267- write. push_u32_be( channel. recipient_channel) ;
268- #[ allow( clippy:: indexing_slicing) ] // length checked
269- write. extend_ssh_string( & buf[ ..off] ) ;
270- } ) ;
313+ match a {
314+ None => push_packet ! ( write, {
315+ write. push( msg:: CHANNEL_DATA ) ;
316+ write. push_u32_be( channel. recipient_channel) ;
317+ #[ allow( clippy:: indexing_slicing) ] // length checked
318+ write. extend_ssh_string( & buf[ ..off] ) ;
319+ } ) ,
320+ Some ( ext) => push_packet ! ( write, {
321+ write. push( msg:: CHANNEL_EXTENDED_DATA ) ;
322+ write. push_u32_be( channel. recipient_channel) ;
323+ write. push_u32_be( ext) ;
324+ #[ allow( clippy:: indexing_slicing) ] // length checked
325+ write. extend_ssh_string( & buf[ ..off] ) ;
326+ } ) ,
327+ }
271328 trace ! (
272329 "buffer: {:?} {:?}" ,
273330 write. len( ) ,
@@ -290,7 +347,7 @@ impl Encrypted {
290347 channel. pending_data . push_back ( ( buf0, None , 0 ) ) ;
291348 return ;
292349 }
293- let buf_len = Self :: data_noqueue ( & mut self . write , channel, & buf0, 0 ) ;
350+ let buf_len = Self :: data_noqueue ( & mut self . write , channel, & buf0, None , 0 ) ;
294351 if buf_len < buf0. len ( ) {
295352 channel. pending_data . push_back ( ( buf0, None , buf_len) )
296353 }
@@ -300,39 +357,13 @@ impl Encrypted {
300357 }
301358
302359 pub fn extended_data ( & mut self , channel : ChannelId , ext : u32 , buf0 : CryptoVec ) {
303- use std:: ops:: Deref ;
304360 if let Some ( channel) = self . channels . get_mut ( & channel) {
305361 assert ! ( channel. confirmed) ;
306362 if !channel. pending_data . is_empty ( ) {
307363 channel. pending_data . push_back ( ( buf0, Some ( ext) , 0 ) ) ;
308364 return ;
309365 }
310- let mut buf = if buf0. len ( ) as u32 > channel. recipient_window_size {
311- #[ allow( clippy:: indexing_slicing) ] // length checked
312- & buf0[ 0 ..channel. recipient_window_size as usize ]
313- } else {
314- & buf0
315- } ;
316- let buf_len = buf. len ( ) ;
317-
318- while !buf. is_empty ( ) {
319- // Compute the length we're allowed to send.
320- let off = std:: cmp:: min ( buf. len ( ) , channel. recipient_maximum_packet_size as usize ) ;
321- push_packet ! ( self . write, {
322- self . write. push( msg:: CHANNEL_EXTENDED_DATA ) ;
323- self . write. push_u32_be( channel. recipient_channel) ;
324- self . write. push_u32_be( ext) ;
325- #[ allow( clippy:: indexing_slicing) ] // length checked
326- self . write. extend_ssh_string( & buf[ ..off] ) ;
327- } ) ;
328- trace ! ( "buffer: {:?}" , self . write. deref( ) . len( ) ) ;
329- channel. recipient_window_size -= off as u32 ;
330- #[ allow( clippy:: indexing_slicing) ] // length checked
331- {
332- buf = & buf[ off..]
333- }
334- }
335- trace ! ( "buf.len() = {:?}, buf_len = {:?}" , buf. len( ) , buf_len) ;
366+ let buf_len = Self :: data_noqueue ( & mut self . write , channel, & buf0, Some ( ext) , 0 ) ;
336367 if buf_len < buf0. len ( ) {
337368 channel. pending_data . push_back ( ( buf0, Some ( ext) , buf_len) )
338369 }
@@ -402,6 +433,8 @@ impl Encrypted {
402433 confirmed : false ,
403434 wants_reply : false ,
404435 pending_data : std:: collections:: VecDeque :: new ( ) ,
436+ pending_eof : false ,
437+ pending_close : false ,
405438 } ) ;
406439 return ChannelId ( self . last_channel_id . 0 ) ;
407440 }
0 commit comments