Skip to content

Commit

Permalink
Optimize MQTT connection handshake
Browse files Browse the repository at this point in the history
  • Loading branch information
rmqtt committed Oct 12, 2023
1 parent f3defc4 commit 8a580bc
Show file tree
Hide file tree
Showing 3 changed files with 43 additions and 151 deletions.
168 changes: 27 additions & 141 deletions rmqtt/src/broker/executor.rs
Original file line number Diff line number Diff line change
@@ -1,73 +1,58 @@
use std::future::Future;
use std::pin::Pin;
use std::rc::Rc;
use std::sync::atomic::{AtomicIsize, Ordering};
use std::task::{Context, Poll};
use std::thread::ThreadId;
use std::time::Duration;

use crossbeam::queue::SegQueue;
use futures::task::AtomicWaker;
use once_cell::sync::OnceCell;
//use parking_lot::RwLock;
use tokio::sync::RwLock;
use update_rate::{DiscreteRateCounter, RateCounter};
use rust_box::task_exec_queue::{LocalBuilder, LocalTaskExecQueue};
use tokio::task::spawn_local;

use crate::broker::types::*;
use crate::settings::listener::Listener;
use crate::{MqttError, Result, Runtime};

pub type Port = u16;

std::thread_local! {
pub static LOCAL_HANDSHAKE_EXECUTORS: DashMap<Port, Executor> = DashMap::default();
pub static HANDSHAKE_EXECUTORS: DashMap<Port, LocalTaskExecQueue> = DashMap::default();
}

#[inline]
pub(crate) async fn get_handshake_exec(name: Port, listen_cfg: Listener) -> Executor {
let exec = LOCAL_HANDSHAKE_EXECUTORS.with(|m| {
pub(crate) fn get_handshake_exec(name: Port, listen_cfg: Listener) -> LocalTaskExecQueue {
HANDSHAKE_EXECUTORS.with(|m| {
m.entry(name)
.or_insert_with(|| {
Executor::new(
(listen_cfg.max_handshaking_limit / listen_cfg.workers) as isize,
listen_cfg.handshake_timeout,
)
let (exec, task_runner) = LocalBuilder::default()
.workers(listen_cfg.max_handshaking_limit / listen_cfg.workers)
.queue_max(listen_cfg.max_connections / listen_cfg.workers)
.build();

let exec1 = exec.clone();
spawn_local(async move {
futures::future::join(task_runner, async move {
loop {
set_active_count(name, exec1.active_count());
set_rate(name, exec1.rate());
tokio::time::sleep(Duration::from_secs(5)).await;
}
})
.await;
});

exec
})
.value()
.clone()
});

set_active_count(name, exec.active_count());
set_rate(name, exec.rate().await);
exec
})
}

static ACTIVE_COUNTS: OnceCell<DashMap<(Port, ThreadId), (isize, Timestamp)>> = OnceCell::new();
static ACTIVE_COUNTS: OnceCell<DashMap<(Port, ThreadId), isize>> = OnceCell::new();

fn set_active_count(name: Port, c: isize) {
let active_counts = ACTIVE_COUNTS.get_or_init(DashMap::default);
let mut entry = active_counts.entry((name, std::thread::current().id())).or_default();
let (count, t) = entry.value_mut();
*count = c;
*t = chrono::Local::now().timestamp();
*entry.value_mut() = c;
}

pub fn get_active_count() -> isize {
ACTIVE_COUNTS
.get()
.map(|m| {
m.iter()
.filter_map(|entry| {
let (c, t) = entry.value();
if *t + 5 > chrono::Local::now().timestamp() {
Some(*c)
} else {
None
}
})
.sum()
})
.unwrap_or_default()
ACTIVE_COUNTS.get().map(|m| m.iter().map(|item| *item.value()).sum()).unwrap_or_default()
}

static RATES: OnceCell<DashMap<(Port, ThreadId), f64>> = OnceCell::new();
Expand All @@ -81,102 +66,3 @@ fn set_rate(name: Port, rate: f64) {
pub fn get_rate() -> f64 {
RATES.get().map(|m| m.iter().map(|entry| *entry.value()).sum::<f64>()).unwrap_or_default()
}

#[derive(Clone)]
pub struct Executor {
inner: Rc<ExecutorInner>,
}

struct ExecutorInner {
max_handshake_limit: isize,
handshake_timeout: Duration,
pending_wakers: SegQueue<Rc<AtomicWaker>>,
active_count: AtomicIsize,
rate_counter: RwLock<DiscreteRateCounter>,
}

impl Executor {
pub fn new(max_handshake_limit: isize, handshake_timeout: Duration) -> Self {
Self {
inner: Rc::new(ExecutorInner {
max_handshake_limit,
handshake_timeout,
pending_wakers: SegQueue::new(),
active_count: AtomicIsize::new(0),
rate_counter: RwLock::new(DiscreteRateCounter::new(100)),
}),
}
}

#[inline]
pub async fn spawn<T>(self, future: T) -> Result<T::Output>
where
T: Future + 'static,
T::Output: 'static,
{
if self.inner.active_count.load(Ordering::SeqCst) >= self.inner.max_handshake_limit {
let now = std::time::Instant::now();
let w = Rc::new(AtomicWaker::new());
self.inner.pending_wakers.push(w.clone());
let delay = tokio::time::sleep(self.inner.handshake_timeout);
tokio::pin!(delay);
tokio::select! {
_ = &mut delay => {
log::debug!("is timeout ... {:?} {:?}", now.elapsed(), self.inner.handshake_timeout);
Runtime::instance().metrics.client_handshaking_timeout_inc();
return Err(MqttError::from(format!(
"handshake timeout, acquire cost time: {:?}",
now.elapsed()
)));
},
_ = PendingOnce::new(w) => {
log::debug!("is waked ... {:?}", now.elapsed());
}
}
}

self.inner.active_count.fetch_add(1, Ordering::SeqCst);
let output = future.await;
self.inner.active_count.fetch_sub(1, Ordering::SeqCst);
if let Some(w) = self.inner.pending_wakers.pop() {
w.wake();
}
self.inner.rate_counter.write().await.update();
Ok(output)
}

#[inline]
pub fn active_count(&self) -> isize {
self.inner.active_count.load(Ordering::SeqCst)
}

#[inline]
pub async fn rate(&self) -> f64 {
self.inner.rate_counter.read().await.rate()
}
}

struct PendingOnce {
w: Rc<AtomicWaker>,
is_ready: bool,
}

impl PendingOnce {
#[inline]
fn new(w: Rc<AtomicWaker>) -> Self {
Self { w, is_ready: false }
}
}

impl Future for PendingOnce {
type Output = ();
fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
if self.is_ready {
Poll::Ready(())
} else {
self.w.register(cx.waker());
self.is_ready = true;
Poll::Pending
}
}
}
13 changes: 8 additions & 5 deletions rmqtt/src/broker/v3.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ use std::convert::From as _f;
use std::net::SocketAddr;

use ntex_mqtt::v3::{self};
use rust_box::task_exec_queue::LocalSpawnExt;
use uuid::Uuid;

use crate::broker::executor::get_handshake_exec;
Expand Down Expand Up @@ -78,16 +79,18 @@ pub async fn handshake<Io: 'static>(

Runtime::instance().stats.handshakings.max_max(handshake.handshakings());

let exec = get_handshake_exec(local_addr.port(), listen_cfg.clone()).await;
match exec.spawn(_handshake(id.clone(), listen_cfg, handshake)).await {
let exec = get_handshake_exec(local_addr.port(), listen_cfg.clone());
match _handshake(id.clone(), listen_cfg, handshake).spawn(&exec).result().await {
Ok(Ok(res)) => Ok(res),
Ok(Err(e)) => {
log::warn!("{:?} Connection Refused, handshake error, reason: {}", id, e);
log::warn!("{:?} Connection Refused, handshake error, reason: {:?}", id, e.to_string());
Err(e)
}
Err(e) => {
log::warn!("{:?} Connection Refused, handshake timeout, reason: {}", id, e);
Err(MqttError::from("Connection Refused, execute handshake timeout"))
Runtime::instance().metrics.client_handshaking_timeout_inc();
let err = MqttError::from("Connection Refused, execute handshake timeout");
log::warn!("{:?} {:?}, reason: {:?}", id, err, e.to_string());
Err(err)
}
}
}
Expand Down
13 changes: 8 additions & 5 deletions rmqtt/src/broker/v5.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ use std::net::SocketAddr;

use ntex_mqtt::v5;
use ntex_mqtt::v5::codec::{Auth, DisconnectReasonCode};
use rust_box::task_exec_queue::LocalSpawnExt;
use uuid::Uuid;

use crate::broker::executor::get_handshake_exec;
Expand Down Expand Up @@ -67,16 +68,18 @@ pub async fn handshake<Io: 'static>(

Runtime::instance().stats.handshakings.max_max(handshake.handshakings());

let exec = get_handshake_exec(local_addr.port(), listen_cfg.clone()).await;
match exec.spawn(_handshake(id.clone(), listen_cfg, handshake, assigned_client_id)).await {
let exec = get_handshake_exec(local_addr.port(), listen_cfg.clone());
match _handshake(id.clone(), listen_cfg, handshake, assigned_client_id).spawn(&exec).result().await {
Ok(Ok(res)) => Ok(res),
Ok(Err(e)) => {
log::warn!("{:?} Connection Refused, handshake error, reason: {}", id, e);
log::warn!("{:?} Connection Refused, handshake error, reason: {:?}", id, e.to_string());
Err(e)
}
Err(e) => {
log::warn!("{:?} Connection Refused, handshake timeout, reason: {}", id, e);
Err(MqttError::from("Connection Refused, execute handshake timeout"))
Runtime::instance().metrics.client_handshaking_timeout_inc();
let err = MqttError::from("Connection Refused, execute handshake timeout");
log::warn!("{:?} {:?}, reason: {:?}", id, err, e.to_string());
Err(err)
}
}
}
Expand Down

0 comments on commit 8a580bc

Please sign in to comment.