diff --git a/components/tikv_util/Cargo.toml b/components/tikv_util/Cargo.toml index d8964cf0301..5b508a4a4d4 100644 --- a/components/tikv_util/Cargo.toml +++ b/components/tikv_util/Cargo.toml @@ -73,3 +73,8 @@ regex = "1.0" tempfile = "3.0" toml = "0.5" utime = "0.2" + +[[bench]] +name = "channel" +path = "benches/channel/mod.rs" +test = true diff --git a/tests/benches/channel/bench_channel.rs b/components/tikv_util/benches/channel/bench_channel.rs similarity index 87% rename from tests/benches/channel/bench_channel.rs rename to components/tikv_util/benches/channel/bench_channel.rs index eb69412046d..6867aab0f56 100644 --- a/tests/benches/channel/bench_channel.rs +++ b/components/tikv_util/benches/channel/bench_channel.rs @@ -113,8 +113,8 @@ fn bench_crossbeam_channel(b: &mut Bencher) { } #[bench] -fn bench_receiver_stream_batch(b: &mut Bencher) { - let (tx, rx) = mpsc::batch::bounded::(128, 8); +fn bench_receiver_stream_unbounded_batch(b: &mut Bencher) { + let (tx, rx) = mpsc::future::unbounded::(mpsc::future::WakePolicy::TillReach(8)); for _ in 0..1 { let tx1 = tx.clone(); thread::spawn(move || { @@ -124,12 +124,9 @@ fn bench_receiver_stream_batch(b: &mut Bencher) { }); } - let mut rx = Some(mpsc::batch::BatchReceiver::new( - rx, - 32, - Vec::new, - mpsc::batch::VecCollector, - )); + let rx = mpsc::future::BatchReceiver::new(rx, 32, Vec::new, Vec::push); + + let mut rx = Some(block_on(rx.into_future()).1); b.iter(|| { let mut count = 0; @@ -150,8 +147,8 @@ fn bench_receiver_stream_batch(b: &mut Bencher) { } #[bench] -fn bench_receiver_stream(b: &mut Bencher) { - let (tx, rx) = mpsc::batch::bounded::(128, 1); +fn bench_receiver_stream_unbounded_nobatch(b: &mut Bencher) { + let (tx, rx) = mpsc::future::unbounded::(mpsc::future::WakePolicy::Immediately); for _ in 0..1 { let tx1 = tx.clone(); thread::spawn(move || { @@ -161,7 +158,7 @@ fn bench_receiver_stream(b: &mut Bencher) { }); } - let mut rx = Some(rx); + let mut rx = Some(block_on(rx.into_future()).1); b.iter(|| { let mut count = 0; let mut rx1 = rx.take().unwrap(); diff --git a/tests/benches/channel/mod.rs b/components/tikv_util/benches/channel/mod.rs similarity index 100% rename from tests/benches/channel/mod.rs rename to components/tikv_util/benches/channel/mod.rs diff --git a/components/tikv_util/src/mpsc/batch.rs b/components/tikv_util/src/mpsc/batch.rs deleted file mode 100644 index 0415f9376af..00000000000 --- a/components/tikv_util/src/mpsc/batch.rs +++ /dev/null @@ -1,509 +0,0 @@ -// Copyright 2018 TiKV Project Authors. Licensed under Apache-2.0. - -use std::{ - pin::Pin, - ptr::null_mut, - sync::{ - atomic::{AtomicBool, AtomicPtr, AtomicUsize, Ordering}, - Arc, - }, - time::Duration, -}; - -use crossbeam::channel::{ - self, RecvError, RecvTimeoutError, SendError, TryRecvError, TrySendError, -}; -use futures::{ - stream::Stream, - task::{Context, Poll, Waker}, -}; - -struct State { - // If the receiver can't get any messages temporarily in `poll` context, it will put its - // current task here. - recv_task: AtomicPtr, - notify_size: usize, - // How many messages are sent without notify. - pending: AtomicUsize, - notifier_registered: AtomicBool, -} - -impl State { - fn new(notify_size: usize) -> State { - State { - // Any pointer that is put into `recv_task` must be a valid and owned - // pointer (it must not be dropped). When a pointer is retrieved from - // `recv_task`, the user is responsible for its proper destruction. - recv_task: AtomicPtr::new(null_mut()), - notify_size, - pending: AtomicUsize::new(0), - notifier_registered: AtomicBool::new(false), - } - } - - #[inline] - fn try_notify_post_send(&self) { - let old_pending = self.pending.fetch_add(1, Ordering::AcqRel); - if old_pending >= self.notify_size - 1 { - self.notify(); - } - } - - #[inline] - fn notify(&self) { - let t = self.recv_task.swap(null_mut(), Ordering::AcqRel); - if !t.is_null() { - self.pending.store(0, Ordering::Release); - // Safety: see comment on `recv_task`. - let t = unsafe { Box::from_raw(t) }; - t.wake(); - } - } - - /// When the `Receiver` that holds the `State` is running on an `Executor`, - /// the `Receiver` calls this to yield from the current `poll` context, - /// and puts the current task handle to `recv_task`, so that the `Sender` - /// respectively can notify it after sending some messages into the channel. - #[inline] - fn yield_poll(&self, waker: Waker) -> bool { - let t = Box::into_raw(Box::new(waker)); - let origin = self.recv_task.swap(t, Ordering::AcqRel); - if !origin.is_null() { - // Safety: see comment on `recv_task`. - unsafe { drop(Box::from_raw(origin)) }; - return true; - } - false - } -} - -impl Drop for State { - fn drop(&mut self) { - let t = self.recv_task.swap(null_mut(), Ordering::AcqRel); - if !t.is_null() { - // Safety: see comment on `recv_task`. - unsafe { drop(Box::from_raw(t)) }; - } - } -} - -/// `Notifier` is used to notify receiver whenever you want. -pub struct Notifier(Arc); -impl Notifier { - #[inline] - pub fn notify(self) { - drop(self); - } -} - -impl Drop for Notifier { - #[inline] - fn drop(&mut self) { - let notifier_registered = &self.0.notifier_registered; - if notifier_registered - .compare_exchange(true, false, Ordering::AcqRel, Ordering::Acquire) - .is_err() - { - unreachable!("notifier_registered must be true"); - } - self.0.notify(); - } -} - -pub struct Sender { - sender: Option>, - state: Arc, -} - -impl Clone for Sender { - #[inline] - fn clone(&self) -> Sender { - Sender { - sender: self.sender.clone(), - state: Arc::clone(&self.state), - } - } -} - -impl Drop for Sender { - #[inline] - fn drop(&mut self) { - drop(self.sender.take()); - self.state.notify(); - } -} - -pub struct Receiver { - receiver: channel::Receiver, - state: Arc, -} - -impl Sender { - pub fn is_empty(&self) -> bool { - // When there is no sender references, it can't be known whether - // it's empty or not. - self.sender.as_ref().map_or(false, |s| s.is_empty()) - } - - #[inline] - pub fn send(&self, t: T) -> Result<(), SendError> { - self.sender.as_ref().unwrap().send(t)?; - self.state.try_notify_post_send(); - Ok(()) - } - - #[inline] - pub fn send_and_notify(&self, t: T) -> Result<(), SendError> { - self.sender.as_ref().unwrap().send(t)?; - self.state.notify(); - Ok(()) - } - - #[inline] - pub fn try_send(&self, t: T) -> Result<(), TrySendError> { - self.sender.as_ref().unwrap().try_send(t)?; - self.state.try_notify_post_send(); - Ok(()) - } - - #[inline] - pub fn get_notifier(&self) -> Option { - let notifier_registered = &self.state.notifier_registered; - if notifier_registered - .compare_exchange(false, true, Ordering::AcqRel, Ordering::Acquire) - .is_ok() - { - return Some(Notifier(Arc::clone(&self.state))); - } - None - } -} - -impl Receiver { - #[inline] - pub fn recv(&self) -> Result { - self.receiver.recv() - } - - #[inline] - pub fn try_recv(&self) -> Result { - self.receiver.try_recv() - } - - #[inline] - pub fn recv_timeout(&self, timeout: Duration) -> Result { - self.receiver.recv_timeout(timeout) - } -} - -/// Creates a unbounded channel with a given `notify_size`, which means if there -/// are more pending messages in the channel than `notify_size`, the `Sender` -/// will auto notify the `Receiver`. -/// -/// # Panics -/// if `notify_size` equals to 0. -#[inline] -pub fn unbounded(notify_size: usize) -> (Sender, Receiver) { - assert!(notify_size > 0); - let state = Arc::new(State::new(notify_size)); - let (sender, receiver) = channel::unbounded(); - ( - Sender { - sender: Some(sender), - state: state.clone(), - }, - Receiver { receiver, state }, - ) -} - -/// Creates a bounded channel with a given `notify_size`, which means if there -/// are more pending messages in the channel than `notify_size`, the `Sender` -/// will auto notify the `Receiver`. -/// -/// # Panics -/// if `notify_size` equals to 0. -#[inline] -pub fn bounded(cap: usize, notify_size: usize) -> (Sender, Receiver) { - assert!(notify_size > 0); - let state = Arc::new(State::new(notify_size)); - let (sender, receiver) = channel::bounded(cap); - ( - Sender { - sender: Some(sender), - state: state.clone(), - }, - Receiver { receiver, state }, - ) -} - -impl Stream for Receiver { - type Item = T; - - fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { - match self.try_recv() { - Ok(m) => Poll::Ready(Some(m)), - Err(TryRecvError::Empty) => { - if self.state.yield_poll(cx.waker().clone()) { - Poll::Pending - } else { - // For the case that all senders are dropped before the current task is saved. - self.poll_next(cx) - } - } - Err(TryRecvError::Disconnected) => Poll::Ready(None), - } - } -} - -/// A Collector Used in `BatchReceiver`. -pub trait BatchCollector { - /// If `elem` is collected into `collection` successfully, return `None`. - /// Otherwise return `elem` back, and `collection` should be spilled out. - fn collect(&mut self, collection: &mut Collection, elem: Elem) -> Option; -} - -pub struct VecCollector; - -impl BatchCollector, E> for VecCollector { - fn collect(&mut self, v: &mut Vec, e: E) -> Option { - v.push(e); - None - } -} - -/// `BatchReceiver` is a `futures::Stream`, which returns a batched type. -pub struct BatchReceiver { - rx: Receiver, - max_batch_size: usize, - elem: Option, - initializer: I, - collector: C, -} - -impl BatchReceiver -where - T: Unpin, - E: Unpin, - I: Fn() -> E + Unpin, - C: BatchCollector + Unpin, -{ - /// Creates a new `BatchReceiver` with given `initializer` and `collector`. - /// `initializer` is used to generate a initial value, and `collector` - /// will collect every (at most `max_batch_size`) raw items into the - /// batched value. - pub fn new(rx: Receiver, max_batch_size: usize, initializer: I, collector: C) -> Self { - BatchReceiver { - rx, - max_batch_size, - elem: None, - initializer, - collector, - } - } -} - -impl Stream for BatchReceiver -where - T: Unpin, - E: Unpin, - I: Fn() -> E + Unpin, - C: BatchCollector + Unpin, -{ - type Item = E; - - fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { - let ctx = self.get_mut(); - let (mut count, mut received) = (0, None); - let finished = loop { - match ctx.rx.try_recv() { - Ok(m) => { - let collection = ctx.elem.get_or_insert_with(&ctx.initializer); - if let Some(m) = ctx.collector.collect(collection, m) { - received = Some(m); - break false; - } - count += 1; - if count >= ctx.max_batch_size { - break false; - } - } - Err(TryRecvError::Disconnected) => break true, - Err(TryRecvError::Empty) => { - if ctx.rx.state.yield_poll(cx.waker().clone()) { - break false; - } - } - } - }; - - if ctx.elem.is_none() && finished { - return Poll::Ready(None); - } else if ctx.elem.is_none() { - return Poll::Pending; - } - let elem = ctx.elem.take(); - if let Some(m) = received { - let collection = ctx.elem.get_or_insert_with(&ctx.initializer); - let _received = ctx.collector.collect(collection, m); - debug_assert!(_received.is_none()); - } - Poll::Ready(elem) - } -} - -#[cfg(test)] -mod tests { - use std::{ - sync::{mpsc, Mutex}, - thread, time, - }; - - use futures::{ - future::{self, BoxFuture, FutureExt}, - stream::{self, StreamExt}, - task::{self, ArcWake, Poll}, - }; - use tokio::runtime::Builder; - - use super::*; - - #[test] - fn test_receiver() { - let (tx, rx) = unbounded::(4); - - let msg_counter = Arc::new(AtomicUsize::new(0)); - let msg_counter1 = Arc::clone(&msg_counter); - let pool = Builder::new_multi_thread() - .worker_threads(1) - .build() - .unwrap(); - let _res = pool.spawn(rx.for_each(move |_| { - msg_counter1.fetch_add(1, Ordering::AcqRel); - future::ready(()) - })); - - // Wait until the receiver is suspended. - loop { - thread::sleep(time::Duration::from_millis(10)); - if !tx.state.recv_task.load(Ordering::SeqCst).is_null() { - break; - } - } - - // Send without notify, the receiver can't get batched messages. - tx.send(0).unwrap(); - thread::sleep(time::Duration::from_millis(10)); - assert_eq!(msg_counter.load(Ordering::Acquire), 0); - - // Send with notify. - let notifier = tx.get_notifier().unwrap(); - assert!(tx.get_notifier().is_none()); - notifier.notify(); - thread::sleep(time::Duration::from_millis(10)); - assert_eq!(msg_counter.load(Ordering::Acquire), 1); - - // Auto notify with more sendings. - for _ in 0..4 { - tx.send(0).unwrap(); - } - thread::sleep(time::Duration::from_millis(10)); - assert_eq!(msg_counter.load(Ordering::Acquire), 5); - } - - #[test] - fn test_batch_receiver() { - let (tx, rx) = unbounded::(4); - - let rx = BatchReceiver::new(rx, 8, || Vec::with_capacity(4), VecCollector); - let msg_counter = Arc::new(AtomicUsize::new(0)); - let msg_counter_spawned = Arc::clone(&msg_counter); - let (nty, polled) = mpsc::sync_channel(1); - let pool = Builder::new_multi_thread() - .worker_threads(1) - .build() - .unwrap(); - let _res = pool.spawn( - stream::select( - rx, - stream::poll_fn(move |_| -> Poll>> { - nty.send(()).unwrap(); - Poll::Ready(None) - }), - ) - .for_each(move |v| { - let len = v.len(); - assert!(len <= 8); - msg_counter_spawned.fetch_add(len, Ordering::AcqRel); - future::ready(()) - }), - ); - - // Wait until the receiver has been polled in the spawned thread. - polled.recv().unwrap(); - - // Send without notify, the receiver can't get batched messages. - tx.send(0).unwrap(); - thread::sleep(time::Duration::from_millis(10)); - assert_eq!(msg_counter.load(Ordering::Acquire), 0); - - // Send with notify. - let notifier = tx.get_notifier().unwrap(); - assert!(tx.get_notifier().is_none()); - notifier.notify(); - thread::sleep(time::Duration::from_millis(10)); - assert_eq!(msg_counter.load(Ordering::Acquire), 1); - - // Auto notify with more sendings. - for _ in 0..16 { - tx.send(0).unwrap(); - } - thread::sleep(time::Duration::from_millis(10)); - assert_eq!(msg_counter.load(Ordering::Acquire), 17); - } - - #[test] - fn test_switch_between_sender_and_receiver() { - let (tx, mut rx) = unbounded::(4); - let future = async move { rx.next().await }; - let task = Task { - future: Arc::new(Mutex::new(Some(future.boxed()))), - }; - // Receiver has not received any messages, so the future is not be finished - // in this tick. - task.tick(); - assert!(task.future.lock().unwrap().is_some()); - // After sender is dropped, the task will be waked and then it tick self - // again to advance the progress. - drop(tx); - assert!(task.future.lock().unwrap().is_none()); - } - - #[derive(Clone)] - struct Task { - future: Arc>>>>, - } - - impl Task { - fn tick(&self) { - let task = Arc::new(self.clone()); - let mut future_slot = self.future.lock().unwrap(); - if let Some(mut future) = future_slot.take() { - let waker = task::waker_ref(&task); - let cx = &mut Context::from_waker(&waker); - match future.as_mut().poll(cx) { - Poll::Pending => { - *future_slot = Some(future); - } - Poll::Ready(None) => {} - _ => unimplemented!(), - } - } - } - } - - impl ArcWake for Task { - fn wake_by_ref(arc_self: &Arc) { - arc_self.tick(); - } - } -} diff --git a/components/tikv_util/src/mpsc/future.rs b/components/tikv_util/src/mpsc/future.rs new file mode 100644 index 00000000000..c38dc8c1492 --- /dev/null +++ b/components/tikv_util/src/mpsc/future.rs @@ -0,0 +1,431 @@ +// Copyright 2022 TiKV Project Authors. Licensed under Apache-2.0. + +//! A module provides the implementation of receiver that supports async/await. + +use std::{ + pin::Pin, + sync::atomic::{self, AtomicUsize, Ordering}, + task::{Context, Poll}, +}; + +use crossbeam::{ + channel::{SendError, TryRecvError}, + queue::SegQueue, +}; +use futures::{task::AtomicWaker, Stream, StreamExt}; + +#[derive(Clone, Copy)] +pub enum WakePolicy { + Immediately, + TillReach(usize), +} + +struct Queue { + queue: SegQueue, + waker: AtomicWaker, + liveness: AtomicUsize, + policy: WakePolicy, +} + +impl Queue { + #[inline] + fn wake(&self, policy: WakePolicy) { + match policy { + WakePolicy::Immediately => self.waker.wake(), + WakePolicy::TillReach(n) => { + if self.queue.len() < n { + return; + } + self.waker.wake(); + } + } + } +} + +const SENDER_COUNT_BASE: usize = 1 << 1; +const RECEIVER_COUNT_BASE: usize = 1; + +pub struct Sender { + queue: *mut Queue, +} + +impl Sender { + /// Sends the message with predefined wake policy. + #[inline] + pub fn send(&self, t: T) -> Result<(), SendError> { + let policy = unsafe { (*self.queue).policy }; + self.send_with(t, policy) + } + + /// Sends the message with the specified wake policy. + #[inline] + pub fn send_with(&self, t: T, policy: WakePolicy) -> Result<(), SendError> { + let queue = unsafe { &*self.queue }; + if queue.liveness.load(Ordering::Acquire) & RECEIVER_COUNT_BASE != 0 { + queue.queue.push(t); + queue.wake(policy); + return Ok(()); + } + Err(SendError(t)) + } +} + +impl Clone for Sender { + fn clone(&self) -> Self { + let queue = unsafe { &*self.queue }; + queue + .liveness + .fetch_add(SENDER_COUNT_BASE, Ordering::Relaxed); + Self { queue: self.queue } + } +} + +impl Drop for Sender { + #[inline] + fn drop(&mut self) { + let queue = unsafe { &*self.queue }; + let previous = queue + .liveness + .fetch_sub(SENDER_COUNT_BASE, Ordering::Release); + if previous == SENDER_COUNT_BASE | RECEIVER_COUNT_BASE { + // The last sender is dropped, we need to wake up the receiver. + queue.waker.wake(); + } else if previous == SENDER_COUNT_BASE { + atomic::fence(Ordering::Acquire); + drop(unsafe { Box::from_raw(self.queue) }); + } + } +} + +unsafe impl Send for Sender {} +unsafe impl Sync for Sender {} + +pub struct Receiver { + queue: *mut Queue, +} + +impl Stream for Receiver { + type Item = T; + + #[inline] + fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + let queue = unsafe { &*self.queue }; + if let Some(t) = queue.queue.pop() { + return Poll::Ready(Some(t)); + } + queue.waker.register(cx.waker()); + // In case the message is pushed right before registering waker. + if let Some(t) = queue.queue.pop() { + return Poll::Ready(Some(t)); + } + if queue.liveness.load(Ordering::Acquire) & !RECEIVER_COUNT_BASE != 0 { + return Poll::Pending; + } + Poll::Ready(None) + } +} + +impl Receiver { + #[inline] + pub fn try_recv(&mut self) -> Result { + let queue = unsafe { &*self.queue }; + if let Some(t) = queue.queue.pop() { + return Ok(t); + } + if queue.liveness.load(Ordering::Acquire) & !RECEIVER_COUNT_BASE != 0 { + return Err(TryRecvError::Empty); + } + Err(TryRecvError::Disconnected) + } +} + +impl Drop for Receiver { + #[inline] + fn drop(&mut self) { + let queue = unsafe { &*self.queue }; + if RECEIVER_COUNT_BASE + == queue + .liveness + .fetch_sub(RECEIVER_COUNT_BASE, Ordering::Release) + { + atomic::fence(Ordering::Acquire); + drop(unsafe { Box::from_raw(self.queue) }); + } + } +} + +unsafe impl Send for Receiver {} + +pub fn unbounded(policy: WakePolicy) -> (Sender, Receiver) { + let queue = Box::into_raw(Box::new(Queue { + queue: SegQueue::new(), + waker: AtomicWaker::new(), + liveness: AtomicUsize::new(SENDER_COUNT_BASE | RECEIVER_COUNT_BASE), + policy, + })); + (Sender { queue }, Receiver { queue }) +} + +/// `BatchReceiver` is a `futures::Stream`, which returns a batched type. +pub struct BatchReceiver { + rx: Receiver, + max_batch_size: usize, + initializer: I, + collector: C, +} + +impl BatchReceiver { + /// Creates a new `BatchReceiver` with given `initializer` and `collector`. + /// `initializer` is used to generate a initial value, and `collector` + /// will collect every (at most `max_batch_size`) raw items into the + /// batched value. + pub fn new(rx: Receiver, max_batch_size: usize, initializer: I, collector: C) -> Self { + BatchReceiver { + rx, + max_batch_size, + initializer, + collector, + } + } +} + +impl Stream for BatchReceiver +where + T: Send + Unpin, + E: Unpin, + I: Fn() -> E + Unpin, + C: FnMut(&mut E, T) + Unpin, +{ + type Item = E; + + #[inline] + fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + let ctx = self.get_mut(); + let mut collector = match ctx.rx.poll_next_unpin(cx) { + Poll::Ready(Some(m)) => { + let mut c = (ctx.initializer)(); + (ctx.collector)(&mut c, m); + c + } + Poll::Ready(None) => return Poll::Ready(None), + Poll::Pending => return Poll::Pending, + }; + for _ in 1..ctx.max_batch_size { + if let Poll::Ready(Some(m)) = ctx.rx.poll_next_unpin(cx) { + (ctx.collector)(&mut collector, m); + } + } + Poll::Ready(Some(collector)) + } +} + +#[cfg(test)] +mod tests { + use std::{ + sync::{ + atomic::{AtomicBool, AtomicUsize}, + mpsc, Arc, Mutex, + }, + thread, time, + }; + + use futures::{ + future::{self, BoxFuture, FutureExt}, + stream::{self, StreamExt}, + task::{self, ArcWake, Poll}, + }; + use tokio::runtime::{Builder, Runtime}; + + use super::*; + + fn spawn_and_wait( + rx_builder: impl FnOnce() -> S, + ) -> (Runtime, Arc) { + let msg_counter = Arc::new(AtomicUsize::new(0)); + let msg_counter1 = msg_counter.clone(); + let pool = Builder::new_multi_thread() + .worker_threads(1) + .build() + .unwrap(); + let (nty, polled) = mpsc::sync_channel(1); + _ = pool.spawn( + stream::select( + rx_builder(), + stream::poll_fn(move |_| -> Poll> { + nty.send(()).unwrap(); + Poll::Ready(None) + }), + ) + .for_each(move |_| { + msg_counter1.fetch_add(1, Ordering::AcqRel); + future::ready(()) + }), + ); + + // Wait until the receiver has been polled in the spawned thread. + polled.recv().unwrap(); + (pool, msg_counter) + } + + #[test] + fn test_till_reach_wake() { + let (tx, rx) = unbounded::(WakePolicy::TillReach(4)); + + let (_pool, msg_counter) = spawn_and_wait(move || rx); + + // Receiver should not be woken up until its length reach specified value. + for _ in 0..3 { + tx.send(0).unwrap(); + } + thread::sleep(time::Duration::from_millis(10)); + assert_eq!(msg_counter.load(Ordering::Acquire), 0); + + tx.send(0).unwrap(); + thread::sleep(time::Duration::from_millis(10)); + assert_eq!(msg_counter.load(Ordering::Acquire), 4); + + // Should start new batch. + tx.send(0).unwrap(); + thread::sleep(time::Duration::from_millis(10)); + assert_eq!(msg_counter.load(Ordering::Acquire), 4); + + let tx1 = tx.clone(); + drop(tx); + thread::sleep(time::Duration::from_millis(10)); + assert_eq!(msg_counter.load(Ordering::Acquire), 4); + // If all senders are dropped, receiver should be woken up. + drop(tx1); + thread::sleep(time::Duration::from_millis(10)); + assert_eq!(msg_counter.load(Ordering::Acquire), 5); + } + + #[test] + fn test_immediately_wake() { + let (tx, rx) = unbounded::(WakePolicy::Immediately); + + let (_pool, msg_counter) = spawn_and_wait(move || rx); + + // Receiver should be woken up immediately. + for _ in 0..3 { + tx.send(0).unwrap(); + } + thread::sleep(time::Duration::from_millis(10)); + assert_eq!(msg_counter.load(Ordering::Acquire), 3); + + tx.send(0).unwrap(); + thread::sleep(time::Duration::from_millis(10)); + assert_eq!(msg_counter.load(Ordering::Acquire), 4); + } + + #[test] + fn test_batch_receiver() { + let (tx, rx) = unbounded::(WakePolicy::TillReach(4)); + + let len = Arc::new(AtomicUsize::new(0)); + let l = len.clone(); + let rx = BatchReceiver::new(rx, 8, || Vec::with_capacity(4), Vec::push); + let (_pool, msg_counter) = spawn_and_wait(move || { + stream::unfold((rx, l), |(mut rx, l)| async move { + rx.next().await.map(|i| { + l.fetch_add(i.len(), Ordering::SeqCst); + (i, (rx, l)) + }) + }) + }); + + tx.send(0).unwrap(); + thread::sleep(time::Duration::from_millis(10)); + assert_eq!(msg_counter.load(Ordering::SeqCst), 0); + + // Auto notify with more messages. + for _ in 0..16 { + tx.send(0).unwrap(); + } + thread::sleep(time::Duration::from_millis(10)); + let batch_count = msg_counter.load(Ordering::SeqCst); + assert!(batch_count < 17, "{}", batch_count); + assert_eq!(len.load(Ordering::SeqCst), 17); + } + + #[test] + fn test_switch_between_sender_and_receiver() { + let (tx, mut rx) = unbounded::(WakePolicy::TillReach(4)); + let future = async move { rx.next().await }; + let task = Task { + future: Arc::new(Mutex::new(Some(future.boxed()))), + }; + // Receiver has not received any messages, so the future is not be finished + // in this tick. + task.tick(); + assert!(task.future.lock().unwrap().is_some()); + // After sender is dropped, the task will be waked and then it tick self + // again to advance the progress. + drop(tx); + assert!(task.future.lock().unwrap().is_none()); + } + + #[derive(Clone)] + struct Task { + future: Arc>>>>, + } + + impl Task { + fn tick(&self) { + let task = Arc::new(self.clone()); + let mut future_slot = self.future.lock().unwrap(); + if let Some(mut future) = future_slot.take() { + let waker = task::waker_ref(&task); + let cx = &mut Context::from_waker(&waker); + match future.as_mut().poll(cx) { + Poll::Pending => { + *future_slot = Some(future); + } + Poll::Ready(None) => {} + _ => unimplemented!(), + } + } + } + } + + impl ArcWake for Task { + fn wake_by_ref(arc_self: &Arc) { + arc_self.tick(); + } + } + + #[derive(Default)] + struct SetOnDrop(Arc); + + impl Drop for SetOnDrop { + fn drop(&mut self) { + self.0.store(true, Ordering::Release); + } + } + + #[test] + fn test_drop() { + let dropped = Arc::new(AtomicBool::new(false)); + let (tx, rx) = super::unbounded(WakePolicy::Immediately); + tx.send(SetOnDrop(dropped.clone())).unwrap(); + drop(tx); + assert!(!dropped.load(Ordering::SeqCst)); + + drop(rx); + assert!(dropped.load(Ordering::SeqCst)); + + let dropped = Arc::new(AtomicBool::new(false)); + let (tx, rx) = super::unbounded(WakePolicy::Immediately); + tx.send(SetOnDrop(dropped.clone())).unwrap(); + drop(rx); + assert!(!dropped.load(Ordering::SeqCst)); + + tx.send(SetOnDrop::default()).unwrap_err(); + let tx1 = tx.clone(); + drop(tx); + assert!(!dropped.load(Ordering::SeqCst)); + + tx1.send(SetOnDrop::default()).unwrap_err(); + drop(tx1); + assert!(dropped.load(Ordering::SeqCst)); + } +} diff --git a/components/tikv_util/src/mpsc/mod.rs b/components/tikv_util/src/mpsc/mod.rs index ccec5448d0b..45249fed9bc 100644 --- a/components/tikv_util/src/mpsc/mod.rs +++ b/components/tikv_util/src/mpsc/mod.rs @@ -3,7 +3,7 @@ //! This module provides an implementation of mpsc channel based on //! crossbeam_channel. Comparing to the crossbeam_channel, this implementation //! supports closed detection and try operations. -pub mod batch; +pub mod future; use std::{ cell::Cell, diff --git a/src/server/service/batch.rs b/src/server/service/batch.rs index 15a755c3468..ba377bed4d2 100644 --- a/src/server/service/batch.rs +++ b/src/server/service/batch.rs @@ -3,7 +3,11 @@ // #[PerformanceCriticalPath] use api_version::KvFormat; use kvproto::kvrpcpb::*; -use tikv_util::{future::poll_future_notify, mpsc::batch::Sender, time::Instant}; +use tikv_util::{ + future::poll_future_notify, + mpsc::future::{Sender, WakePolicy}, + time::Instant, +}; use tracker::{with_tls_tracker, RequestInfo, RequestType, Tracker, TrackerToken, GLOBAL_TRACKERS}; use crate::{ @@ -184,7 +188,7 @@ impl ResponseBatchConsumer<(Option>, Statistics)> for GetCommandResponse let mesure = GrpcRequestDuration::new(begin, GrpcTypeKind::kv_batch_get_command, request_source); let task = MeasuredSingleResponse::new(id, res, mesure); - if self.tx.send_and_notify(task).is_err() { + if self.tx.send_with(task, WakePolicy::Immediately).is_err() { error!("KvService response batch commands fail"); } } @@ -215,7 +219,7 @@ impl ResponseBatchConsumer>> for GetCommandResponseConsumer { let mesure = GrpcRequestDuration::new(begin, GrpcTypeKind::raw_batch_get_command, request_source); let task = MeasuredSingleResponse::new(id, res, mesure); - if self.tx.send_and_notify(task).is_err() { + if self.tx.send_with(task, WakePolicy::Immediately).is_err() { error!("KvService response batch commands fail"); } } @@ -264,7 +268,7 @@ fn future_batch_get_command( source, ); let task = MeasuredSingleResponse::new(id, res, measure); - if tx.send_and_notify(task).is_err() { + if tx.send_with(task, WakePolicy::Immediately).is_err() { error!("KvService response batch commands fail"); } } @@ -310,7 +314,7 @@ fn future_batch_raw_get_command( source, ); let task = MeasuredSingleResponse::new(id, res, measure); - if tx.send_and_notify(task).is_err() { + if tx.send_with(task, WakePolicy::Immediately).is_err() { error!("KvService response batch commands fail"); } } diff --git a/src/server/service/kv.rs b/src/server/service/kv.rs index ab2fc41c47c..35deb7e4107 100644 --- a/src/server/service/kv.rs +++ b/src/server/service/kv.rs @@ -39,7 +39,7 @@ use raftstore::{ use tikv_alloc::trace::MemoryTraceGuard; use tikv_util::{ future::{paired_future_callback, poll_future_notify}, - mpsc::batch::{unbounded, BatchCollector, BatchReceiver, Sender}, + mpsc::future::{unbounded, BatchReceiver, Sender, WakePolicy}, sys::memory_usage_reaches_high_water, time::{duration_to_ms, duration_to_sec, Instant}, worker::Scheduler, @@ -1049,7 +1049,7 @@ impl + 'static, E: Engine, L: LockManager, F: KvFor mut sink: DuplexSink, ) { forward_duplex!(self.proxy, batch_commands, ctx, stream, sink); - let (tx, rx) = unbounded(GRPC_MSG_NOTIFY_SIZE); + let (tx, rx) = unbounded(WakePolicy::TillReach(GRPC_MSG_NOTIFY_SIZE)); let ctx = Arc::new(ctx); let peer = ctx.peer(); @@ -1093,7 +1093,7 @@ impl + 'static, E: Engine, L: LockManager, F: KvFor rx, GRPC_MSG_MAX_BATCH_SIZE, MeasuredBatchResponse::default, - BatchRespCollector, + collect_batch_resp, ); let mut response_retriever = response_retriever.map(move |mut item| { @@ -1268,7 +1268,7 @@ fn response_batch_commands_request( source, }; let task = MeasuredSingleResponse::new(id, resp, measure); - if let Err(e) = tx.send_and_notify(task) { + if let Err(e) = tx.send_with(task, WakePolicy::Immediately) { error!("KvService response batch commands fail"; "err" => ?e); } } @@ -2354,18 +2354,10 @@ impl Default for MeasuredBatchResponse { } } -struct BatchRespCollector; -impl BatchCollector for BatchRespCollector { - fn collect( - &mut self, - v: &mut MeasuredBatchResponse, - mut e: MeasuredSingleResponse, - ) -> Option { - v.batch_resp.mut_request_ids().push(e.id); - v.batch_resp.mut_responses().push(e.resp.consume()); - v.measures.push(e.measure); - None - } +fn collect_batch_resp(v: &mut MeasuredBatchResponse, mut e: MeasuredSingleResponse) { + v.batch_resp.mut_request_ids().push(e.id); + v.batch_resp.mut_responses().push(e.resp.consume()); + v.measures.push(e.measure); } fn raftstore_error_to_region_error(e: RaftStoreError, region_id: u64) -> RegionError { diff --git a/tests/Cargo.toml b/tests/Cargo.toml index b155ae4ab87..5c573b6e809 100644 --- a/tests/Cargo.toml +++ b/tests/Cargo.toml @@ -39,11 +39,6 @@ name = "deadlock_detector" harness = false path = "benches/deadlock_detector/mod.rs" -[[bench]] -name = "channel" -path = "benches/channel/mod.rs" -test = true - [features] default = ["failpoints", "testexport", "test-engine-kv-rocksdb", "test-engine-raft-raft-engine", "cloud-aws", "cloud-gcp", "cloud-azure"] failpoints = ["fail/failpoints", "tikv/failpoints"]