Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

WIP: Implement PAIR sockets #176

Open
wants to merge 3 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
110 changes: 110 additions & 0 deletions src/async_rt/block_on_read_till_set.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,110 @@
/// Implements a generic data structure that holds a value of type T and
/// can be instantiated without a value of type T.
/// If it is read from before it is set, it will block on read until it
/// is set.
use std::sync::Arc;
use tokio::sync::mpsc::{channel, Receiver, Sender};
use tokio::sync::{Mutex, MutexGuard};

pub struct BlockOnReadTillSet<T> {
value: Arc<Mutex<Option<T>>>,
sender: Arc<Mutex<Sender<()>>>,
receiver: Arc<Mutex<Receiver<()>>>,
}

impl<T> Clone for BlockOnReadTillSet<T> {
fn clone(&self) -> Self {
Self {
value: self.value.clone(),
sender: self.sender.clone(),
receiver: self.receiver.clone(),
}
}
}

impl<T> BlockOnReadTillSet<T> {
pub fn new() -> Self {
let (sender, receiver) = channel(100);
BlockOnReadTillSet {
value: Arc::new(Mutex::new(None)),
sender: Arc::new(Mutex::new(sender)),
receiver: Arc::new(Mutex::new(receiver)),
}
}

pub async fn set(&self, value: T) {
*self.value.lock().await = Some(value);
assert!(self.value.lock().await.is_some());
self.sender.lock().await.send(()).await.unwrap();
}

pub async fn is_set(&self) -> bool {
self.value.lock().await.is_some()
}

pub async fn unset(&self) {
*self.value.lock().await = None;
}

pub async fn get(&self) -> MutexGuard<'_, Option<T>> {
let mut i = 0;
loop {
{
let value = self.value.lock().await;
if value.is_some() {
return value;
}
}
i = i + 1;
if i > 100 {
panic!("BlockOnReadTillSet::get() called more than 100 times without a value being set");
}
let mut receiver = self.receiver.lock().await;
receiver.recv().await.unwrap();
// We send a message to the receiver because we've
// consumed the message that was sent when the value
// was set. We need to send a new message so that
// other get calls won't stay locked up if they happened
// to be waiting for the value to be set at the same time
// as we were.
let sender = self.sender.lock().await;
sender.send(()).await.unwrap();
}
}
}

#[cfg(test)]
mod tests {
use super::*;
use tokio::runtime::Runtime;

#[test]
fn test_block_on_read_till_set() {
let rt = Runtime::new().unwrap();
rt.block_on(async {
let block_on_read_till_set = BlockOnReadTillSet::new();
let mut handles = Vec::new();
for _ in 0..10 {
let block_on_read_till_set = block_on_read_till_set.clone();
handles.push(tokio::spawn(async move {
assert_eq!(block_on_read_till_set.get().await.unwrap(), 1);
}));
}
let mut handles2 = Vec::new();
// Sleep a bit before setting the value to experience the blocking behavior
tokio::time::sleep(std::time::Duration::from_millis(100)).await;
{
let block_on_read_till_set = block_on_read_till_set.clone();
handles2.push(tokio::spawn(async move {
block_on_read_till_set.set(1).await;
}));
}
for handle in handles2 {
handle.await.unwrap();
}
for handle in handles {
handle.await.unwrap();
}
});
}
}
1 change: 1 addition & 0 deletions src/async_rt/mod.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
//! General purpose helpers for async runtime cross-compatibility

pub mod block_on_read_till_set;
pub mod task;

#[cfg(feature = "tokio-runtime")]
Expand Down
7 changes: 7 additions & 0 deletions src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ mod endpoint;
mod error;
mod fair_queue;
mod message;
mod pair;
mod r#pub;
mod pull;
mod push;
Expand All @@ -29,6 +30,7 @@ pub use crate::dealer::*;
pub use crate::endpoint::{Endpoint, Host, Transport, TryIntoEndpoint};
pub use crate::error::{ZmqError, ZmqResult};
pub use crate::message::*;
pub use crate::pair::*;
pub use crate::pull::*;
pub use crate::push::*;
pub use crate::r#pub::*;
Expand Down Expand Up @@ -227,6 +229,10 @@ pub trait Socket: Sized + Send {
/// Returns the endpoint resolved to the exact bound location if applicable
/// (port # resolved, for example).
async fn bind(&mut self, endpoint: &str) -> ZmqResult<Endpoint> {
self.bind_default(endpoint).await
}

async fn bind_default(&mut self, endpoint: &str) -> ZmqResult<Endpoint> {
let endpoint = endpoint.try_into()?;

let cloned_backend = self.backend();
Expand Down Expand Up @@ -309,6 +315,7 @@ pub trait Socket: Sized + Send {
},
Err(e) => Err(e),
};

match result {
Ok((endpoint, peer_id)) => {
if let Some(monitor) = self.backend().monitor().lock().as_mut() {
Expand Down
174 changes: 174 additions & 0 deletions src/pair.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,174 @@
/// Implements the ZMQ EXPAIR communication strategy.
///
/// See RFC 31 https://rfc.zeromq.org/spec/31/
///
/// The PAIR Socket Type
///
/// General behavior:
///
/// - MAY be connected to at most one PAIR peers, and MAY both send and receive messages.
/// - SHALL not filter or modify outgoing or incoming messages in any way.
/// - SHALL maintain a double queue for its peer, allowing outgoing and incoming messages to be queued independently.
/// - SHALL create a double queue when initiating an outgoing connection to a peer, and SHALL maintain the double queue whether or not the connection is established.
/// - SHALL create a double queue when a peer connects to it. If this peer disconnects, the PAIR socket SHALL destroy its double queue and SHALL discard any messages it contains.
/// - SHOULD constrain incoming and outgoing queue sizes to a runtime-configurable limit.
///
/// For processing outgoing messages:
///
/// - SHALL consider its peer as available only when it has a outgoing queue that is not full.
/// - SHALL block on sending, or return a suitable error, when it has no available peer.
/// - SHALL not accept further messages when it has no available peer.
/// - SHALL NOT discard messages that it cannot queue.
///
/// For processing incoming messages:
///
/// - SHALL receive incoming messages from its single peer if it has one.
/// - SHALL deliver these to its calling application.
///
/// Example usage can be found in ../tests/pair.rs
///
use crate::async_rt::block_on_read_till_set::BlockOnReadTillSet;
use crate::codec::*;
use crate::endpoint::Endpoint;
use crate::error::*;
use crate::transport::AcceptStopHandle;
use crate::util::{Peer, PeerIdentity};
use crate::*;
use crate::{SocketType, ZmqResult};

use async_trait::async_trait;
use futures_util::{SinkExt, StreamExt};

use std::collections::HashMap;
use std::sync::Arc;
use tokio::sync::broadcast;

struct PairSocketBackend {
pub(crate) peer: BlockOnReadTillSet<Peer>,
socket_monitor: Mutex<Option<mpsc::Sender<SocketEvent>>>,
socket_options: SocketOptions,
drop_tx: broadcast::Sender<()>,
}

pub struct PairSocket {
backend: Arc<PairSocketBackend>,
binds: HashMap<Endpoint, AcceptStopHandle>,
}


impl SocketBackend for PairSocketBackend {
fn socket_type(&self) -> SocketType {
SocketType::PAIR
}

fn socket_options(&self) -> &SocketOptions {
&self.socket_options
}

fn shutdown(&self) {
self.drop_tx.send(()).unwrap();
}

fn monitor(&self) -> &Mutex<Option<mpsc::Sender<SocketEvent>>> {
&self.socket_monitor
}
}

#[async_trait]
impl SocketSend for PairSocket {
async fn send(&mut self, msg: ZmqMessage) -> ZmqResult<()> {
return match *self.backend.peer.get().await {
Some(ref mut peer) => peer
.send_queue
.send(Message::Message(msg))
.await
.map_err(|codec_error| ZmqError::Codec(codec_error)),
None => panic!("Unreachable"),
};
}
}

#[async_trait]
impl SocketRecv for PairSocket {
async fn recv(&mut self) -> ZmqResult<ZmqMessage> {
let message = match *self.backend.peer.get().await {
Some(ref mut peer) => peer.recv_queue.next().await,
None => panic!("Unreachable"),
};
match message {
Some(Ok(Message::Greeting(_))) => todo!(),
Some(Ok(Message::Command(_))) => todo!(),
Some(Ok(Message::Message(m))) => Ok(m),
Some(Err(e)) => Err(ZmqError::Codec(e)),
None => Err(ZmqError::NoMessage),
}
}
}

#[async_trait]
impl Socket for PairSocket {
fn with_options(options: SocketOptions) -> Self {
let (drop_tx, _) = broadcast::channel(100);
Self {
backend: Arc::new(PairSocketBackend {
peer: BlockOnReadTillSet::new(),
socket_monitor: Mutex::new(None),
socket_options: options,
drop_tx,
}),
binds: HashMap::new(),
}
}

/// Bind to an endpoint & launch dropper task
/// Calls the default bind internally
async fn bind(&mut self, endpoint: &str) -> ZmqResult<Endpoint> {
// Spawn dropper task
let mut drop_rx = self.backend.drop_tx.subscribe();
let peer = self.backend.peer.clone();
tokio::spawn(async move {
let _ = drop_rx.recv().await;
peer.unset().await;
panic!("Peer disconnected");
});
let ret = self.bind_default(endpoint).await;
ret
}




fn backend(&self) -> Arc<dyn MultiPeerBackend> {
self.backend.clone()
}

fn binds(&mut self) -> &mut HashMap<Endpoint, AcceptStopHandle> {
&mut self.binds
}

fn monitor(&mut self) -> mpsc::Receiver<SocketEvent> {
let (sender, receiver) = mpsc::channel(1024);
self.backend.socket_monitor.lock().replace(sender);
receiver
}
}

#[async_trait]
impl MultiPeerBackend for PairSocketBackend {
async fn peer_connected(self: Arc<Self>, peer_id: &PeerIdentity, io: FramedIo) {
if self.peer.is_set().await {
todo!("Refuse connection if already connected.");
}
let (recv_queue, send_queue) = io.into_parts();
let peer = Peer {
_identity: peer_id.clone(),
send_queue: send_queue,
recv_queue: recv_queue,
};
self.peer.set(peer).await;
}

fn peer_disconnected(&self, _peer_id: &PeerIdentity) {
self.shutdown();
}
}
1 change: 1 addition & 0 deletions src/util.rs
Original file line number Diff line number Diff line change
Expand Up @@ -203,6 +203,7 @@ pub(crate) async fn peer_connected(

pub(crate) async fn connect_forever(endpoint: Endpoint) -> ZmqResult<(FramedIo, Endpoint)> {
let mut try_num: u64 = 0;

loop {
match transport::connect(&endpoint).await {
Ok(res) => return Ok(res),
Expand Down
Loading