Skip to content

Commit

Permalink
Optimize the codes
Browse files Browse the repository at this point in the history
  • Loading branch information
rmqtt committed Mar 10, 2024
1 parent 59e2bb9 commit 2a9bb6a
Showing 1 changed file with 194 additions and 23 deletions.
217 changes: 194 additions & 23 deletions rmqtt-plugins/rmqtt-retainer/src/storage.rs
Original file line number Diff line number Diff line change
@@ -1,10 +1,11 @@
use std::borrow::Cow;
use std::convert::From as _;
use std::future::Future;
use std::ops::Deref;
use std::str::FromStr;
use std::sync::atomic::{AtomicBool, AtomicI64, AtomicIsize, AtomicUsize, Ordering};
use std::sync::atomic::{AtomicBool, AtomicIsize, Ordering};
use std::sync::Arc;
use std::time::Duration;
use std::time::{Duration, Instant};

use futures_time::future::FutureExt;

Expand Down Expand Up @@ -71,16 +72,16 @@ impl Retainer {
retain_enable: Arc<AtomicBool>,
) -> Result<Retainer> {
let (msg_tx, msg_queue_count) = Self::serve(cfg.clone())?;
let retain_count = Arc::new(AtomicUsize::new(storage_db.len().await?));
let retain_count_utime = Arc::new(AtomicI64::new(timestamp_millis()));
let storage_messages_count = ValueCached::new(Duration::from_millis(3000));
let storage_messages_max = ValueCached::new(Duration::from_millis(3000));
let inner = Arc::new(RetainerInner {
cfg,
storage_db,
msg_tx,
msg_queue_count,
retain_enable,
retain_count,
retain_count_utime,
storage_messages_count,
storage_messages_max,
});
Ok(Self { inner })
}
Expand Down Expand Up @@ -151,8 +152,10 @@ pub struct RetainerInner {
msg_tx: mpsc::Sender<Msg>,
pub(crate) msg_queue_count: Arc<AtomicIsize>,
retain_enable: Arc<AtomicBool>,
retain_count: Arc<AtomicUsize>,
retain_count_utime: Arc<AtomicI64>,
// retain_count: Arc<AtomicUsize>,
// retain_count_utime: Arc<AtomicI64>,
storage_messages_count: ValueCached<usize>,
storage_messages_max: ValueCached<isize>,
}

impl RetainerInner {
Expand Down Expand Up @@ -260,8 +263,7 @@ impl RetainerInner {
return Ok(false);
}

//@TODO ...
if max_retained_messages > 0 && self.get_retain_count().await? >= max_retained_messages {
if max_retained_messages > 0 && self.get_retain_count().await >= max_retained_messages {
log::warn!(
"The retained message has exceeded the maximum limit of: {}, topic: {:?}, retain: {:?}",
max_retained_messages,
Expand All @@ -275,17 +277,20 @@ impl RetainerInner {
}

#[inline]
async fn get_retain_count(&self) -> Result<usize> {
let retain_count = if (timestamp_millis() - self.retain_count_utime.load(Ordering::SeqCst)) < 3000 {
self.retain_count.load(Ordering::SeqCst)
async fn get_retain_count(&self) -> usize {
let db = self.storage_db.clone();
let count = self
.storage_messages_count
.call_timeout(async move { db.len().await.map_err(MqttError::from) }, Duration::from_millis(3000))
.await
.get()
.map(|v| *v)
.unwrap_or_default();
if count > 0 {
count - 1
} else {
let retain_count = self.storage_db.len().await?;
self.retain_count.store(retain_count, Ordering::SeqCst);
self.retain_count_utime.store(timestamp_millis(), Ordering::SeqCst);
retain_count
};
let retain_count = if retain_count > 0 { retain_count - 1 } else { retain_count };
Ok(retain_count)
count
}
}

#[inline]
Expand Down Expand Up @@ -320,7 +325,14 @@ impl RetainerInner {
let topic_filter_pattern = Self::topic_filter_to_pattern(topic_filter);
let mut matched_topics = Vec::new();
let mut db = self.storage_db.clone();
let mut iter = db.scan([RETAIN_MESSAGES_PREFIX, topic_filter_pattern.as_bytes()].concat()).await?;
let mut iter = match db.scan([RETAIN_MESSAGES_PREFIX, topic_filter_pattern.as_bytes()].concat()).await
{
Err(e) => {
log::error!("{:?}", e);
return Ok(Vec::new());
}
Ok(iter) => iter,
};
while let Some(key) = iter.next().await {
match key {
Ok(key) => {
Expand Down Expand Up @@ -406,11 +418,170 @@ impl RetainStorage for &'static Retainer {

#[inline]
async fn count(&self) -> isize {
self.get_retain_count().await.unwrap_or_default() as isize
self.get_retain_count().await as isize
}

#[inline]
async fn max(&self) -> isize {
self.storage_messages_max_get().await.unwrap_or_default()
self.storage_messages_max
.call_timeout(self.storage_messages_max_get(), Duration::from_millis(3000))
.await
.get()
.map(|v| *v)
.unwrap_or(-1)
}
}

#[derive(Clone)]
pub struct ValueCached<T> {
inner: Arc<RwLock<ValueCachedInner<T>>>,
guard: Arc<RwLock<()>>,
}

pub struct ValueCachedInner<T> {
cached_val: Option<Result<T>>,
expire_interval: Duration,
instant: Instant,
}

impl<T> ValueCached<T> {
#[inline]
pub fn new(expire_interval: Duration) -> Self {
Self {
inner: Arc::new(RwLock::new(ValueCachedInner {
cached_val: None,
expire_interval,
instant: Instant::now(),
})),
guard: Arc::new(RwLock::new(())),
}
}

#[inline]
#[allow(unused)]
pub async fn call<F>(&self, f: F) -> ValueRef<'_, T>
where
F: Future<Output = Result<T>> + Send + 'static,
{
self._call_timeout(f, None).await
}

#[inline]
pub async fn call_timeout<F>(&self, f: F, timeout: Duration) -> ValueRef<'_, T>
where
F: Future<Output = Result<T>> + Send + 'static,
{
self._call_timeout(f, Some(timeout)).await
}

#[inline]
async fn _call_timeout<F>(&self, f: F, timeout: Option<Duration>) -> ValueRef<'_, T>
where
F: Future<Output = Result<T>> + Send + 'static,
{
let inst = std::time::Instant::now();
let (call_enable, updating) = {
let mut enable = false;
#[allow(unused_assignments)]
let mut updating = false;
loop {
if let Ok(_guard) = self.guard.try_read() {
let inner_rl = self.inner.read().await;
enable = inner_rl.cached_val.is_none() || inner_rl.is_expired();
updating = false;
break;
}
updating = true;

if self.inner.read().await.cached_val.is_some() {
break;
}

tokio::time::sleep(Duration::from_millis(10)).await;
if let Some(t) = timeout.as_ref() {
if inst.elapsed() > *t {
break;
}
}
}
(enable, updating)
};

let (cached, updating) = if call_enable {
if let Ok(_guard) = self.guard.try_write() {
let val = if let Some(t) = timeout {
match tokio::time::timeout(t, f).await {
Ok(Ok(v)) => Ok(v),
Ok(Err(e)) => Err(e),
Err(e) => Err(MqttError::from(anyhow!(e))),
}
} else {
f.await
};
let mut inner_wl = self.inner.write().await;
inner_wl.cached_val = Some(val);
inner_wl.instant = Instant::now();
(false, false)
} else {
#[allow(unused_assignments)]
let mut updating = false;
loop {
if let Ok(_guard) = self.guard.try_read() {
updating = false;
break;
}
updating = true;
if self.inner.read().await.cached_val.is_some() {
break;
}
tokio::time::sleep(Duration::from_millis(10)).await;
if let Some(t) = timeout.as_ref() {
if inst.elapsed() > *t {
break;
}
}
}
(true, updating)
}
} else {
(true, updating)
};
ValueRef { val_guard: self.inner.read().await, cached, updating }
}
}

impl<T> ValueCachedInner<T> {
#[inline]
fn is_expired(&self) -> bool {
self.instant.elapsed() > self.expire_interval
}
}

pub struct ValueRef<'a, T> {
val_guard: tokio::sync::RwLockReadGuard<'a, ValueCachedInner<T>>,
cached: bool,
updating: bool,
}

impl<'a, T> ValueRef<'a, T> {
#[inline]
pub fn get(&self) -> Result<&T> {
if let Some(val) = self.val_guard.cached_val.as_ref() {
Ok(val.as_ref().map_err(|e| anyhow!(e.to_string()))?)
} else {
Err(MqttError::from("Timeout"))
}
}

#[inline]
#[allow(unused)]
pub fn is_cached(&self) -> bool {
self.cached
}

#[inline]
#[allow(unused)]
pub fn is_updating(&self) -> bool {
self.updating
}
}

0 comments on commit 2a9bb6a

Please sign in to comment.