Skip to content

Commit

Permalink
Optimize the MQTT client connection handshake process
Browse files Browse the repository at this point in the history
  • Loading branch information
rmqtt committed Jul 15, 2023
1 parent 07eeaae commit a629b9a
Show file tree
Hide file tree
Showing 3 changed files with 129 additions and 51 deletions.
132 changes: 117 additions & 15 deletions rmqtt/src/broker/executor.rs
Original file line number Diff line number Diff line change
@@ -1,33 +1,36 @@
use std::future::Future;
use std::pin::Pin;
use std::rc::Rc;
use std::sync::atomic::{AtomicIsize, Ordering};
use std::task::{Context, Poll};
use std::thread::ThreadId;
use std::time::Duration;

use crossbeam::queue::SegQueue;
use futures::task::AtomicWaker;
use once_cell::sync::OnceCell;
use rust_box::task_exec_queue::{LocalBuilder, LocalTaskExecQueue};
use tokio::task::spawn_local;
use parking_lot::RwLock;
use update_rate::{DiscreteRateCounter, RateCounter};

use crate::broker::types::*;
use crate::settings::listener::Listener;
use crate::{MqttError, Result, Runtime};

pub type Port = u16;

std::thread_local! {
pub static HANDSHAKE_EXECUTORS: DashMap<Port, LocalTaskExecQueue> = DashMap::default();
pub static LOCAL_HANDSHAKE_EXECUTORS: DashMap<Port, Executor> = DashMap::default();
}

#[inline]
pub(crate) fn get_handshake_exec_queue(name: Port, listen_cfg: Listener) -> LocalTaskExecQueue {
let exec = HANDSHAKE_EXECUTORS.with(|m| {
pub(crate) fn get_handshake_exec(name: Port, listen_cfg: Listener) -> Executor {
let exec = LOCAL_HANDSHAKE_EXECUTORS.with(|m| {
m.entry(name)
.or_insert_with(|| {
let (exec, task_runner) = LocalBuilder::default()
.workers(listen_cfg.max_handshaking_limit / listen_cfg.workers)
.queue_max(listen_cfg.max_connections / listen_cfg.workers)
.build();

spawn_local(async move {
task_runner.await;
});

exec
Executor::new(
(listen_cfg.max_handshaking_limit / listen_cfg.workers) as isize,
listen_cfg.handshake_timeout,
)
})
.value()
.clone()
Expand Down Expand Up @@ -77,3 +80,102 @@ fn set_rate(name: Port, rate: f64) {
pub fn get_rate() -> f64 {
RATES.get().map(|m| m.iter().map(|entry| *entry.value()).sum::<f64>()).unwrap_or_default()
}

#[derive(Clone)]
pub struct Executor {
inner: Rc<ExecutorInner>,
}

struct ExecutorInner {
max_handshake_limit: isize,
handshake_timeout: Duration,
pending_wakers: SegQueue<Rc<AtomicWaker>>,
active_count: AtomicIsize,
rate_counter: RwLock<DiscreteRateCounter>,
}

impl Executor {
pub fn new(max_handshake_limit: isize, handshake_timeout: Duration) -> Self {
Self {
inner: Rc::new(ExecutorInner {
max_handshake_limit,
handshake_timeout,
pending_wakers: SegQueue::new(),
active_count: AtomicIsize::new(0),
rate_counter: RwLock::new(DiscreteRateCounter::new(100)),
}),
}
}

#[inline]
pub async fn spawn<T>(self, future: T) -> Result<T::Output>
where
T: Future + 'static,
T::Output: 'static,
{
if self.inner.active_count.load(Ordering::SeqCst) >= self.inner.max_handshake_limit {
let now = std::time::Instant::now();
let w = Rc::new(AtomicWaker::new());
self.inner.pending_wakers.push(w.clone());
let delay = tokio::time::sleep(self.inner.handshake_timeout);
tokio::pin!(delay);
tokio::select! {
_ = &mut delay => {
log::debug!("is timeout ... {:?} {:?}", now.elapsed(), self.inner.handshake_timeout);
Runtime::instance().metrics.client_handshaking_timeout_inc();
return Err(MqttError::from(format!(
"handshake timeout, acquire cost time: {:?}",
now.elapsed()
)));
},
_ = PendingOnce::new(w) => {
log::debug!("is waked ... {:?}", now.elapsed());
}
}
}

self.inner.active_count.fetch_add(1, Ordering::SeqCst);
let output = future.await;
self.inner.active_count.fetch_sub(1, Ordering::SeqCst);
if let Some(w) = self.inner.pending_wakers.pop() {
w.wake();
}
self.inner.rate_counter.write().update();
Ok(output)
}

#[inline]
pub fn active_count(&self) -> isize {
self.inner.active_count.load(Ordering::SeqCst)
}

#[inline]
pub fn rate(&self) -> f64 {
self.inner.rate_counter.read().rate()
}
}

struct PendingOnce {
w: Rc<AtomicWaker>,
is_ready: bool,
}

impl PendingOnce {
#[inline]
fn new(w: Rc<AtomicWaker>) -> Self {
Self { w, is_ready: false }
}
}

impl Future for PendingOnce {
type Output = ();
fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
if self.is_ready {
Poll::Ready(())
} else {
self.w.register(cx.waker());
self.is_ready = true;
Poll::Pending
}
}
}
24 changes: 6 additions & 18 deletions rmqtt/src/broker/v3.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,8 @@ use std::convert::From as _f;
use std::net::SocketAddr;

use ntex_mqtt::v3::{self};
use rust_box::task_exec_queue::LocalSpawnExt;

use crate::broker::executor::get_handshake_exec_queue;
use crate::broker::executor::get_handshake_exec;
use crate::broker::{inflight::MomentStatus, types::*};
use crate::runtime::Runtime;
use crate::settings::listener::Listener;
Expand Down Expand Up @@ -55,29 +54,18 @@ pub async fn handshake<Io: 'static>(
handshake.packet().client_id.clone(),
handshake.packet().username.clone(),
);
let id1 = id.clone();
Runtime::instance().stats.handshakings.max_max(handshake.handshakings());

let exec = get_handshake_exec_queue(local_addr.port(), listen_cfg.clone());

let start = chrono::Local::now().timestamp_millis();
let handshake_fut = async move {
if (chrono::Local::now().timestamp_millis() - start) > listen_cfg.handshake_timeout() as i64 {
Runtime::instance().metrics.client_handshaking_timeout_inc();
return Err(MqttError::from("execute handshake timeout"));
}
_handshake(id, listen_cfg, handshake).await
};
Runtime::instance().stats.handshakings.max_max(handshake.handshakings());

match handshake_fut.spawn(&exec).result().await {
let exec = get_handshake_exec(local_addr.port(), listen_cfg.clone());
match exec.spawn(_handshake(id.clone(), listen_cfg, handshake)).await {
Ok(Ok(res)) => Ok(res),
Ok(Err(e)) => {
log::warn!("{:?} Connection Refused, handshake error, reason: {:?}", id1, e.to_string());
log::warn!("{:?} Connection Refused, handshake error, reason: {:?}", id, e);
Err(e)
}
Err(e) => {
Runtime::instance().metrics.client_handshaking_timeout_inc();
log::warn!("{:?} Connection Refused, handshake timeout, reason: {:?}", id1, e.to_string());
log::warn!("{:?} Connection Refused, handshake timeout, reason: {:?}", id, e);
Err(MqttError::from("Connection Refused, execute handshake timeout"))
}
}
Expand Down
24 changes: 6 additions & 18 deletions rmqtt/src/broker/v5.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,8 @@ use std::net::SocketAddr;

use ntex_mqtt::v5;
use ntex_mqtt::v5::codec::{Auth, DisconnectReasonCode};
use rust_box::task_exec_queue::LocalSpawnExt;

use crate::broker::executor::get_handshake_exec_queue;
use crate::broker::executor::get_handshake_exec;
use crate::broker::{inflight::MomentStatus, types::*};
use crate::settings::listener::Listener;
use crate::{ClientInfo, MqttError, Result, Runtime, Session, SessionState};
Expand Down Expand Up @@ -55,29 +54,18 @@ pub async fn handshake<Io: 'static>(
handshake.packet().client_id.clone(),
handshake.packet().username.clone(),
);
let id1 = id.clone();
Runtime::instance().stats.handshakings.max_max(handshake.handshakings());

let exec = get_handshake_exec_queue(local_addr.port(), listen_cfg.clone());

let start = chrono::Local::now().timestamp_millis();
let handshake_fut = async move {
if (chrono::Local::now().timestamp_millis() - start) > listen_cfg.handshake_timeout() as i64 {
Runtime::instance().metrics.client_handshaking_timeout_inc();
return Err(MqttError::from("execute handshake timeout"));
}
_handshake(id, listen_cfg, handshake).await
};
Runtime::instance().stats.handshakings.max_max(handshake.handshakings());

match handshake_fut.spawn(&exec).result().await {
let exec = get_handshake_exec(local_addr.port(), listen_cfg.clone());
match exec.spawn(_handshake(id.clone(), listen_cfg, handshake)).await {
Ok(Ok(res)) => Ok(res),
Ok(Err(e)) => {
log::warn!("{:?} Connection Refused, handshake error, reason: {:?}", id1, e.to_string());
log::warn!("{:?} Connection Refused, handshake error, reason: {:?}", id, e);
Err(e)
}
Err(e) => {
Runtime::instance().metrics.client_handshaking_timeout_inc();
log::warn!("{:?} Connection Refused, handshake timeout, reason: {:?}", id1, e.to_string());
log::warn!("{:?} Connection Refused, handshake timeout, reason: {:?}", id, e);
Err(MqttError::from("Connection Refused, execute handshake timeout"))
}
}
Expand Down

0 comments on commit a629b9a

Please sign in to comment.