Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

log_backup: make a more rusty CallbackWaitGroup #16740

Merged
merged 4 commits into from
Apr 3, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
4 changes: 2 additions & 2 deletions components/backup-stream/src/endpoint.rs
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@ use crate::{
subscription_manager::{RegionSubscriptionManager, ResolvedRegions},
subscription_track::{Ref, RefMut, ResolveResult, SubscriptionTracer},
try_send,
utils::{self, CallbackWaitGroup, StopWatch, Work},
utils::{self, FutureWaitGroup, StopWatch, Work},
};

const SLOW_EVENT_THRESHOLD: f64 = 120.0;
Expand Down Expand Up @@ -1118,7 +1118,7 @@ where
}

pub fn do_backup(&self, events: Vec<CmdBatch>) {
let wg = CallbackWaitGroup::new();
let wg = FutureWaitGroup::new();
for batch in events {
self.backup_batch(batch, wg.clone().work());
}
Expand Down
20 changes: 11 additions & 9 deletions components/backup-stream/src/subscription_manager.rs
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ use crate::{
router::{Router, TaskSelector},
subscription_track::{CheckpointType, Ref, RefMut, ResolveResult, SubscriptionTracer},
try_send,
utils::{self, CallbackWaitGroup, Work},
utils::{self, FutureWaitGroup, Work},
Task,
};

Expand Down Expand Up @@ -322,7 +322,7 @@ pub struct RegionSubscriptionManager<S, R> {

messenger: WeakSender<ObserveOp>,
scan_pool_handle: ScanPoolHandle,
scans: Arc<CallbackWaitGroup>,
scans: Arc<FutureWaitGroup>,
}

/// Create a pool for doing initial scanning.
Expand Down Expand Up @@ -374,7 +374,7 @@ where
subs: initial_loader.tracing,
messenger: tx.downgrade(),
scan_pool_handle,
scans: CallbackWaitGroup::new(),
scans: FutureWaitGroup::new(),
failure_count: HashMap::new(),
memory_manager: Arc::clone(&initial_loader.quota),
};
Expand All @@ -383,8 +383,10 @@ where
}

/// wait initial scanning get finished.
pub fn wait(&self, timeout: Duration) -> future![bool] {
tokio::time::timeout(timeout, self.scans.wait()).map(|result| result.is_err())
pub async fn wait(&self, timeout: Duration) -> bool {
tokio::time::timeout(timeout, self.scans.wait())
.map(move |result| result.is_err())
.await
}

fn issue_fatal_of(&self, region: &Region, err: Error) {
Expand Down Expand Up @@ -859,7 +861,7 @@ mod test {
router::{Router, RouterInner},
subscription_manager::{OOM_BACKOFF_BASE, OOM_BACKOFF_JITTER_SECS},
subscription_track::{CheckpointType, SubscriptionTracer},
utils::CallbackWaitGroup,
utils::FutureWaitGroup,
BackupStreamResolver, ObserveOp, Task,
};

Expand Down Expand Up @@ -903,7 +905,7 @@ mod test {
use futures::executor::block_on;

use super::ScanCmd;
use crate::{subscription_manager::spawn_executors, utils::CallbackWaitGroup};
use crate::{subscription_manager::spawn_executors, utils::FutureWaitGroup};

fn should_finish_in(f: impl FnOnce() + Send + 'static, d: std::time::Duration) {
let (tx, rx) = futures::channel::oneshot::channel();
Expand All @@ -920,7 +922,7 @@ mod test {
}

let pool = spawn_executors(FuncInitialScan(|_, _, _| Ok(Statistics::default())), 1);
let wg = CallbackWaitGroup::new();
let wg = FutureWaitGroup::new();
let (tx, _) = tokio::sync::mpsc::channel(1);
fail::cfg("execute_scan_command_sleep_100", "return").unwrap();
for _ in 0..100 {
Expand Down Expand Up @@ -1073,7 +1075,7 @@ mod test {
memory_manager,
messenger: tx.downgrade(),
scan_pool_handle: spawn_executors_to(init, pool.handle()),
scans: CallbackWaitGroup::new(),
scans: FutureWaitGroup::new(),
};
let events = Arc::new(Mutex::new(vec![]));
let ob_events = Arc::clone(&events);
Expand Down
94 changes: 54 additions & 40 deletions components/backup-stream/src/utils.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,20 +5,21 @@ use std::{
borrow::Borrow,
cell::RefCell,
collections::{hash_map::RandomState, BTreeMap, HashMap},
future::Future,
ops::{Bound, RangeBounds},
path::Path,
sync::{
atomic::{AtomicUsize, Ordering},
Arc,
},
task::Context,
task::{Context, Waker},
time::Duration,
};

use async_compression::{tokio::write::ZstdEncoder, Level};
use engine_rocks::ReadPerfInstant;
use engine_traits::{CfName, CF_DEFAULT, CF_LOCK, CF_RAFT, CF_WRITE};
use futures::{ready, task::Poll, FutureExt};
use futures::{ready, task::Poll};
use kvproto::{
brpb::CompressionType,
metapb::Region,
Expand All @@ -37,13 +38,12 @@ use tikv_util::{
use tokio::{
fs::File,
io::{AsyncRead, AsyncWrite, AsyncWriteExt, BufWriter},
sync::{oneshot, Mutex, RwLock},
sync::{Mutex, RwLock},
};
use txn_types::{Key, Lock, LockType};

use crate::{
errors::{Error, Result},
metadata::store::BoxFuture,
router::TaskSelector,
Task,
};
Expand Down Expand Up @@ -379,47 +379,65 @@ pub fn should_track_lock(l: &Lock) -> bool {
}
}

pub struct CallbackWaitGroup {
pub struct FutureWaitGroup {
running: AtomicUsize,
on_finish_all: std::sync::Mutex<Vec<Box<dyn FnOnce() + Send + 'static>>>,
wakers: std::sync::Mutex<Vec<Waker>>,
}

impl CallbackWaitGroup {
pub struct Work(Arc<FutureWaitGroup>);

impl Drop for Work {
fn drop(&mut self) {
self.0.work_done();
}
}

pub struct WaitAll<'a>(&'a FutureWaitGroup);

impl<'a> Future for WaitAll<'a> {
type Output = ();

fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
// Fast path: nothing to wait.
let running = self.0.running.load(Ordering::SeqCst);
if running == 0 {
YuJuncen marked this conversation as resolved.
Show resolved Hide resolved
return Poll::Ready(());
}

// <1>
let mut callbacks = self.0.wakers.lock().unwrap();
callbacks.push(cx.waker().clone());
let running = self.0.running.load(Ordering::SeqCst);
// Unlikely path: if all background tasks finish at <1>, there will be a long
// period that nobody will wake the `wakers` even the condition is ready.
// We need to help ourselves here.
if running == 0 {
callbacks.drain(..).for_each(|w| w.wake());
}
Poll::Pending
}
}

impl FutureWaitGroup {
pub fn new() -> Arc<Self> {
Arc::new(Self {
running: AtomicUsize::new(0),
on_finish_all: std::sync::Mutex::default(),
wakers: Default::default(),
})
}

fn work_done(&self) {
let last = self.running.fetch_sub(1, Ordering::SeqCst);
if last == 1 {
self.on_finish_all
.lock()
.unwrap()
.drain(..)
.for_each(|x| x())
self.wakers.lock().unwrap().drain(..).for_each(|x| {
x.wake();
})
}
}

/// wait until all running tasks done.
pub fn wait(&self) -> BoxFuture<()> {
// Fast path: no uploading.
if self.running.load(Ordering::SeqCst) == 0 {
return Box::pin(futures::future::ready(()));
}

let (tx, rx) = oneshot::channel();
self.on_finish_all.lock().unwrap().push(Box::new(move || {
// The waiter may timed out.
let _ = tx.send(());
}));
// try to acquire the lock again.
if self.running.load(Ordering::SeqCst) == 0 {
return Box::pin(futures::future::ready(()));
}
Box::pin(rx.map(|_| ()))
pub fn wait(&self) -> WaitAll<'_> {
WaitAll(self)
}

/// make a work, as long as the return value held, mark a work in the group
Expand All @@ -430,14 +448,6 @@ impl CallbackWaitGroup {
}
}

pub struct Work(Arc<CallbackWaitGroup>);

impl Drop for Work {
fn drop(&mut self) {
self.0.work_done();
}
}

struct ReadThroughputRecorder {
// The system tool set.
ins: Option<OsInspector>,
Expand Down Expand Up @@ -813,7 +823,7 @@ mod test {
use kvproto::metapb::{Region, RegionEpoch};
use tokio::io::{AsyncWriteExt, BufReader};

use crate::utils::{is_in_range, CallbackWaitGroup, SegmentMap};
use crate::utils::{is_in_range, FutureWaitGroup, SegmentMap};

#[test]
fn test_redact() {
Expand Down Expand Up @@ -922,8 +932,8 @@ mod test {
}

fn run_case(c: Case) {
let wg = FutureWaitGroup::new();
for i in 0..c.repeat {
let wg = CallbackWaitGroup::new();
let cnt = Arc::new(AtomicUsize::new(c.bg_task));
for _ in 0..c.bg_task {
let cnt = cnt.clone();
Expand All @@ -934,7 +944,7 @@ mod test {
});
}
block_on(tokio::time::timeout(Duration::from_secs(20), wg.wait())).unwrap();
assert_eq!(cnt.load(Ordering::SeqCst), 0, "{:?}@{}", c, i);
assert_eq!(cnt.load(Ordering::SeqCst), 0, "{:?}@{}", c, i,);
}
}

Expand All @@ -951,6 +961,10 @@ mod test {
bg_task: 512,
repeat: 1,
},
Case {
bg_task: 16,
repeat: 10000,
},
Case {
bg_task: 2,
repeat: 100000,
Expand Down