Skip to content

Commit a5c4adc

Browse files
EugenyEpicEric
andauthored
An attempt at #401 - removing TX busywait (#408)
Co-authored-by: Eric Rodrigues Pires <eric@eric.dev.br>
1 parent ac441a6 commit a5c4adc

File tree

9 files changed

+228
-52
lines changed

9 files changed

+228
-52
lines changed

russh/src/channels/channel_ref.rs

Lines changed: 4 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,27 +1,25 @@
1-
use std::sync::Arc;
2-
31
use tokio::sync::mpsc::UnboundedSender;
4-
use tokio::sync::Mutex;
52

3+
use super::WindowSizeRef;
64
use crate::ChannelMsg;
75

86
/// A handle to the [`super::Channel`]'s to be able to transmit messages
97
/// to it and update it's `window_size`.
108
#[derive(Debug)]
119
pub struct ChannelRef {
1210
pub(super) sender: UnboundedSender<ChannelMsg>,
13-
pub(super) window_size: Arc<Mutex<u32>>,
11+
pub(super) window_size: WindowSizeRef,
1412
}
1513

1614
impl ChannelRef {
1715
pub fn new(sender: UnboundedSender<ChannelMsg>) -> Self {
1816
Self {
1917
sender,
20-
window_size: Default::default(),
18+
window_size: WindowSizeRef::new(0),
2119
}
2220
}
2321

24-
pub fn window_size(&self) -> &Arc<Mutex<u32>> {
22+
pub(crate) fn window_size(&self) -> &WindowSizeRef {
2523
&self.window_size
2624
}
2725
}

russh/src/channels/io/tx.rs

Lines changed: 74 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,8 @@
1+
use std::convert::TryFrom;
2+
use std::future::Future;
13
use std::io;
4+
use std::num::NonZero;
5+
use std::ops::DerefMut;
26
use std::pin::Pin;
37
use std::sync::Arc;
48
use std::task::{ready, Context, Poll};
@@ -7,7 +11,7 @@ use futures::FutureExt;
711
use tokio::io::AsyncWrite;
812
use tokio::sync::mpsc::error::SendError;
913
use tokio::sync::mpsc::{self, OwnedPermit};
10-
use tokio::sync::{Mutex, OwnedMutexGuard};
14+
use tokio::sync::{Mutex, Notify, OwnedMutexGuard};
1115

1216
use super::ChannelMsg;
1317
use crate::{ChannelId, CryptoVec};
@@ -16,13 +20,34 @@ type BoxedThreadsafeFuture<T> = Pin<Box<dyn Sync + Send + std::future::Future<Ou
1620
type OwnedPermitFuture<S> =
1721
BoxedThreadsafeFuture<Result<(OwnedPermit<S>, ChannelMsg, usize), SendError<()>>>;
1822

23+
struct WatchNotification(Pin<Box<dyn Sync + Send + Future<Output = ()>>>);
24+
25+
/// A single future that becomes ready once the window size
26+
/// changes to a positive value
27+
impl WatchNotification {
28+
fn new(n: Arc<Notify>) -> Self {
29+
Self(Box::pin(async move { n.notified().await }))
30+
}
31+
}
32+
33+
impl Future for WatchNotification {
34+
type Output = ();
35+
36+
fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
37+
let inner = self.deref_mut().0.as_mut();
38+
ready!(inner.poll(cx));
39+
Poll::Ready(())
40+
}
41+
}
42+
1943
pub struct ChannelTx<S> {
2044
sender: mpsc::Sender<S>,
2145
send_fut: Option<OwnedPermitFuture<S>>,
2246
id: ChannelId,
23-
2447
window_size_fut: Option<BoxedThreadsafeFuture<OwnedMutexGuard<u32>>>,
2548
window_size: Arc<Mutex<u32>>,
49+
notify: Arc<Notify>,
50+
window_size_notication: WatchNotification,
2651
max_packet_size: u32,
2752
ext: Option<u32>,
2853
}
@@ -35,43 +60,62 @@ where
3560
sender: mpsc::Sender<S>,
3661
id: ChannelId,
3762
window_size: Arc<Mutex<u32>>,
63+
window_size_notification: Arc<Notify>,
3864
max_packet_size: u32,
3965
ext: Option<u32>,
4066
) -> Self {
4167
Self {
4268
sender,
4369
send_fut: None,
4470
id,
71+
notify: Arc::clone(&window_size_notification),
72+
window_size_notication: WatchNotification::new(window_size_notification),
4573
window_size,
4674
window_size_fut: None,
4775
max_packet_size,
4876
ext,
4977
}
5078
}
5179

52-
fn poll_mk_msg(&mut self, cx: &mut Context<'_>, buf: &[u8]) -> Poll<(ChannelMsg, usize)> {
80+
fn poll_writable(&mut self, cx: &mut Context<'_>, buf_len: usize) -> Poll<NonZero<usize>> {
5381
let window_size = self.window_size.clone();
5482
let window_size_fut = self
5583
.window_size_fut
5684
.get_or_insert_with(|| Box::pin(window_size.lock_owned()));
5785
let mut window_size = ready!(window_size_fut.poll_unpin(cx));
5886
self.window_size_fut.take();
5987

60-
let writable = (self.max_packet_size)
61-
.min(*window_size)
62-
.min(buf.len() as u32) as usize;
63-
if writable == 0 {
64-
// TODO fix this busywait
65-
cx.waker().wake_by_ref();
66-
return Poll::Pending;
88+
let writable = (self.max_packet_size).min(*window_size).min(buf_len as u32) as usize;
89+
90+
match NonZero::try_from(writable) {
91+
Ok(w) => {
92+
*window_size -= writable as u32;
93+
if *window_size > 0 {
94+
self.notify.notify_one();
95+
}
96+
Poll::Ready(w)
97+
}
98+
Err(_) => {
99+
drop(window_size);
100+
ready!(self.window_size_notication.poll_unpin(cx));
101+
self.window_size_notication = WatchNotification::new(Arc::clone(&self.notify));
102+
cx.waker().wake_by_ref();
103+
Poll::Pending
104+
}
67105
}
68-
let mut data = CryptoVec::new_zeroed(writable);
69-
#[allow(clippy::indexing_slicing)] // Clamped to maximum `buf.len()` with `.min`
70-
data.copy_from_slice(&buf[..writable]);
71-
data.resize(writable);
106+
}
107+
108+
fn poll_mk_msg(
109+
&mut self,
110+
cx: &mut Context<'_>,
111+
buf: &[u8],
112+
) -> Poll<(ChannelMsg, NonZero<usize>)> {
113+
let writable = ready!(self.poll_writable(cx, buf.len()));
72114

73-
*window_size -= writable as u32;
74-
drop(window_size);
115+
let mut data = CryptoVec::new_zeroed(writable.into());
116+
#[allow(clippy::indexing_slicing)] // Clamped to maximum `buf.len()` with `.poll_writable`
117+
data.copy_from_slice(&buf[..writable.into()]);
118+
data.resize(writable.into());
75119

76120
let msg = match self.ext {
77121
None => ChannelMsg::Data { data },
@@ -116,11 +160,17 @@ where
116160
cx: &mut Context<'_>,
117161
buf: &[u8],
118162
) -> Poll<Result<usize, io::Error>> {
163+
if buf.is_empty() {
164+
return Poll::Ready(Err(io::Error::new(
165+
io::ErrorKind::WriteZero,
166+
"cannot send empty buffer",
167+
)));
168+
}
119169
let send_fut = if let Some(x) = self.send_fut.as_mut() {
120170
x
121171
} else {
122172
let (msg, writable) = ready!(self.poll_mk_msg(cx, buf));
123-
self.activate(msg, writable)
173+
self.activate(msg, writable.into())
124174
};
125175
let r = ready!(send_fut.as_mut().poll_unpin(cx));
126176
Poll::Ready(self.handle_write_result(r))
@@ -143,3 +193,10 @@ where
143193
Poll::Ready(self.handle_write_result(r).map(drop))
144194
}
145195
}
196+
197+
impl<S> Drop for ChannelTx<S> {
198+
fn drop(&mut self) {
199+
// Allow other writers to make progress
200+
self.notify.notify_one();
201+
}
202+
}

russh/src/channels/mod.rs

Lines changed: 34 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@ use std::sync::Arc;
22

33
use tokio::io::{AsyncRead, AsyncWrite};
44
use tokio::sync::mpsc::{Sender, UnboundedReceiver};
5-
use tokio::sync::Mutex;
5+
use tokio::sync::{Mutex, Notify};
66

77
use crate::{ChannelId, ChannelOpenFailure, CryptoVec, Error, Pty, Sig};
88

@@ -112,6 +112,31 @@ pub enum ChannelMsg {
112112
OpenFailure(ChannelOpenFailure),
113113
}
114114

115+
#[derive(Clone, Debug)]
116+
pub(crate) struct WindowSizeRef {
117+
value: Arc<Mutex<u32>>,
118+
notifier: Arc<Notify>,
119+
}
120+
121+
impl WindowSizeRef {
122+
pub(crate) fn new(initial: u32) -> Self {
123+
let notifier = Arc::new(Notify::new());
124+
Self {
125+
value: Arc::new(Mutex::new(initial)),
126+
notifier,
127+
}
128+
}
129+
130+
pub(crate) async fn update(&self, value: u32) {
131+
*self.value.lock().await = value;
132+
self.notifier.notify_one();
133+
}
134+
135+
pub(crate) fn subscribe(&self) -> Arc<Notify> {
136+
Arc::clone(&self.notifier)
137+
}
138+
}
139+
115140
/// A handle to a session channel.
116141
///
117142
/// Allows you to read and write from a channel without borrowing the session
@@ -120,7 +145,7 @@ pub struct Channel<Send: From<(ChannelId, ChannelMsg)>> {
120145
pub(crate) sender: Sender<Send>,
121146
pub(crate) receiver: UnboundedReceiver<ChannelMsg>,
122147
pub(crate) max_packet_size: u32,
123-
pub(crate) window_size: Arc<Mutex<u32>>,
148+
pub(crate) window_size: WindowSizeRef,
124149
}
125150

126151
impl<T: From<(ChannelId, ChannelMsg)>> std::fmt::Debug for Channel<T> {
@@ -137,7 +162,7 @@ impl<S: From<(ChannelId, ChannelMsg)> + Send + Sync + 'static> Channel<S> {
137162
window_size: u32,
138163
) -> (Self, ChannelRef) {
139164
let (tx, rx) = tokio::sync::mpsc::unbounded_channel();
140-
let window_size = Arc::new(Mutex::new(window_size));
165+
let window_size = WindowSizeRef::new(window_size);
141166

142167
(
143168
Self {
@@ -157,7 +182,8 @@ impl<S: From<(ChannelId, ChannelMsg)> + Send + Sync + 'static> Channel<S> {
157182
/// Returns the min between the maximum packet size and the
158183
/// remaining window size in the channel.
159184
pub async fn writable_packet_size(&self) -> usize {
160-
self.max_packet_size.min(*self.window_size.lock().await) as usize
185+
self.max_packet_size
186+
.min(*self.window_size.value.lock().await) as usize
161187
}
162188

163189
pub fn id(&self) -> ChannelId {
@@ -337,7 +363,8 @@ impl<S: From<(ChannelId, ChannelMsg)> + Send + Sync + 'static> Channel<S> {
337363
io::ChannelTx::new(
338364
self.sender.clone(),
339365
self.id,
340-
self.window_size.clone(),
366+
self.window_size.value.clone(),
367+
self.window_size.subscribe(),
341368
self.max_packet_size,
342369
None,
343370
),
@@ -369,7 +396,8 @@ impl<S: From<(ChannelId, ChannelMsg)> + Send + Sync + 'static> Channel<S> {
369396
io::ChannelTx::new(
370397
self.sender.clone(),
371398
self.id,
372-
self.window_size.clone(),
399+
self.window_size.value.clone(),
400+
self.window_size.subscribe(),
373401
self.max_packet_size,
374402
ext,
375403
)

russh/src/client/encrypted.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -630,7 +630,7 @@ impl Session {
630630
new_size -= enc.flush_pending(channel_num)? as u32;
631631
}
632632
if let Some(chan) = self.channels.get(&channel_num) {
633-
*chan.window_size().lock().await = new_size;
633+
chan.window_size().update(new_size).await;
634634

635635
let _ = chan.send(ChannelMsg::WindowAdjusted { new_size });
636636
}

russh/src/client/mod.rs

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -54,9 +54,9 @@ use tokio::pin;
5454
use tokio::sync::mpsc::{
5555
channel, unbounded_channel, Receiver, Sender, UnboundedReceiver, UnboundedSender,
5656
};
57-
use tokio::sync::{oneshot, Mutex};
57+
use tokio::sync::oneshot;
5858

59-
use crate::channels::{Channel, ChannelMsg, ChannelRef};
59+
use crate::channels::{Channel, ChannelMsg, ChannelRef, WindowSizeRef};
6060
use crate::cipher::{self, clear, CipherPair, OpeningKey};
6161
use crate::keys::key::parse_public_key;
6262
use crate::session::{
@@ -428,7 +428,7 @@ impl<H: Handler> Handle<H> {
428428
async fn wait_channel_confirmation(
429429
&self,
430430
mut receiver: UnboundedReceiver<ChannelMsg>,
431-
window_size_ref: Arc<Mutex<u32>>,
431+
window_size_ref: WindowSizeRef,
432432
) -> Result<Channel<Msg>, crate::Error> {
433433
loop {
434434
match receiver.recv().await {
@@ -437,7 +437,7 @@ impl<H: Handler> Handle<H> {
437437
max_packet_size,
438438
window_size,
439439
}) => {
440-
*window_size_ref.lock().await = window_size;
440+
window_size_ref.update(window_size).await;
441441

442442
return Ok(Channel {
443443
id,

russh/src/server/encrypted.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -763,7 +763,7 @@ impl Session {
763763
enc.flush_pending(channel_num)?;
764764
}
765765
if let Some(chan) = self.channels.get(&channel_num) {
766-
*chan.window_size().lock().await = new_size;
766+
chan.window_size().update(new_size).await;
767767

768768
chan.send(ChannelMsg::WindowAdjusted { new_size })
769769
.unwrap_or(())

russh/src/server/session.rs

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,14 @@
11
use std::collections::{HashMap, VecDeque};
22
use std::sync::Arc;
33

4+
use channels::WindowSizeRef;
45
use log::debug;
56
use negotiation::parse_kex_algo_list;
67
use russh_keys::helpers::NameList;
78
use russh_keys::map_err;
89
use tokio::io::{AsyncRead, AsyncWrite, AsyncWriteExt};
910
use tokio::sync::mpsc::{unbounded_channel, Receiver, Sender, UnboundedReceiver};
10-
use tokio::sync::{oneshot, Mutex};
11+
use tokio::sync::oneshot;
1112

1213
use super::*;
1314
use crate::channels::{Channel, ChannelMsg, ChannelRef};
@@ -346,7 +347,7 @@ impl Handle {
346347
async fn wait_channel_confirmation(
347348
&self,
348349
mut receiver: UnboundedReceiver<ChannelMsg>,
349-
window_size_ref: Arc<Mutex<u32>>,
350+
window_size_ref: WindowSizeRef,
350351
) -> Result<Channel<Msg>, Error> {
351352
loop {
352353
match receiver.recv().await {
@@ -355,7 +356,7 @@ impl Handle {
355356
max_packet_size,
356357
window_size,
357358
}) => {
358-
*window_size_ref.lock().await = window_size;
359+
window_size_ref.update(window_size).await;
359360

360361
return Ok(Channel {
361362
id,

russh/src/session.rs

Lines changed: 0 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -192,13 +192,6 @@ impl Encrypted {
192192
Ok(())
193193
}
194194

195-
/*
196-
pub fn authenticated(&mut self) {
197-
self.server_compression.init_compress(&mut self.compress);
198-
self.state = EncryptedState::Authenticated;
199-
}
200-
*/
201-
202195
pub fn eof(&mut self, channel: ChannelId) -> Result<(), crate::Error> {
203196
if let Some(channel) = self.has_pending_data_mut(channel) {
204197
channel.pending_eof = true;

0 commit comments

Comments
 (0)