Skip to content

Commit

Permalink
futures: fully port to futures-locks, drop lock_transport macro
Browse files Browse the repository at this point in the history
Signed-off-by: Marc-Antoine Perennou <Marc-Antoine@Perennou.com>
  • Loading branch information
Keruspe committed Jun 21, 2018
1 parent b924110 commit d77ca8c
Show file tree
Hide file tree
Showing 5 changed files with 76 additions and 88 deletions.
103 changes: 47 additions & 56 deletions futures/src/channel.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ use message::BasicGetMessage;
use types::FieldTable;
use consumer::Consumer;
use queue::Queue;
use util::MutexExt;

/// `Channel` provides methods to act on a channel, such as managing queues
//#[derive(Clone)]
Expand Down Expand Up @@ -149,35 +150,32 @@ impl<T: AsyncRead+AsyncWrite+Send+Sync+'static> Channel<T> {
pub fn create(transport: Mutex<AMQPTransport<T>>) -> impl Future<Item = Self, Error = io::Error> + Send + 'static {
let channel_transport = transport.clone();

future::poll_fn(move || {
let mut transport = lock_transport!(channel_transport);
future::lazy(move || channel_transport.clone().run(move |mut transport| {
if let Some(id) = transport.conn.create_channel() {
return Ok(Async::Ready(Channel {
Ok(Channel {
id,
transport: channel_transport.clone(),
}))
transport: channel_transport,
})
} else {
return Err(io::Error::new(io::ErrorKind::ConnectionAborted, "The maximum number of channels for this connection has been reached"));
Err(io::Error::new(io::ErrorKind::ConnectionAborted, "The maximum number of channels for this connection has been reached"))
}
}).and_then(|channel| {
})).and_then(|channel| {
let channel_id = channel.id;
channel.run_on_locked_transport("create", "Could not create channel", move |transport| {
transport.conn.channel_open(channel_id, "".to_string()).map(Some)
}).and_then(move |_| {
future::poll_fn(move || {
let transport = lock_transport!(transport);

}).and_then(move |_| future::loop_fn((), move |()| {
transport.run(move |transport| {
match transport.conn.get_state(channel_id) {
Some(ChannelState::Connected) => return Ok(Async::Ready(())),
Some(ChannelState::Error) => return Err(io::Error::new(io::ErrorKind::Other, format!("Failed to open channel"))),
Some(ChannelState::Closed) => return Err(io::Error::new(io::ErrorKind::Other, format!("Failed to open channel"))),
Some(ChannelState::Connected) => Ok(future::Loop::Break(())),
Some(ChannelState::Error) => Err(io::Error::new(io::ErrorKind::Other, format!("Failed to open channel"))),
Some(ChannelState::Closed) => Err(io::Error::new(io::ErrorKind::Other, format!("Failed to open channel"))),
_ => {
task::current().notify();
return Ok(Async::NotReady);
Ok(future::Loop::Continue(()))
}
}
})
}).map(move |_| {
})).map(move |_| {
channel
})
})
Expand Down Expand Up @@ -267,22 +265,21 @@ impl<T: AsyncRead+AsyncWrite+Send+Sync+'static> Channel<T> {
self.run_on_locked_transport("queue_declare", "Could not declare queue", move |transport| {
transport.conn.queue_declare(channel_id, options.ticket, name,
options.passive, options.durable, options.exclusive, options.auto_delete, options.nowait, arguments).map(Some)
}).and_then(move |request_id| {
future::poll_fn(move || {
let mut transport = lock_transport!(transport);
}).and_then(move |request_id| future::loop_fn(request_id, move |request_id| {
transport.run(move |mut transport| {
if let Some(queue) = transport.conn.get_generated_name(request_id.expect("expected request_id")) {
let (consumer_count, message_count) = if let Some(async_queue) = transport.conn.channels.get(&channel_id).and_then(|channel| channel.queues.get(&queue)) {
(async_queue.consumer_count, async_queue.message_count)
} else {
(0, 0)
};
return Ok(Async::Ready(Queue::new(queue, consumer_count, message_count)))
return Ok(future::Loop::Break(Queue::new(queue, consumer_count, message_count)))
} else {
task::current().notify();
return Ok(Async::NotReady)
return Ok(future::Loop::Continue(request_id))
}
})
})
}))
}

/// binds a queue to an exchange
Expand Down Expand Up @@ -381,17 +378,16 @@ impl<T: AsyncRead+AsyncWrite+Send+Sync+'static> Channel<T> {
self.run_on_locked_transport("basic_consume", "Could not start consumer", move |transport| {
transport.conn.basic_consume(channel_id, options.ticket, queue_name, consumer_tag,
options.no_local, options.no_ack, options.exclusive, options.no_wait, arguments, Box::new(subscriber)).map(Some)
}).and_then(move |request_id| {
future::poll_fn(move || {
let mut transport = lock_transport!(transport);
}).and_then(move |request_id| future::loop_fn(request_id, move |request_id| {
transport.run(move |mut transport| {
if let Some(consumer_tag) = transport.conn.get_generated_name(request_id.expect("expected request_id")) {
return Ok(Async::Ready(consumer_tag))
Ok(future::Loop::Break(consumer_tag))
} else {
task::current().notify();
return Ok(Async::NotReady)
Ok(future::Loop::Continue(request_id))
}
})
}).map(|consumer_tag| {
})).map(|consumer_tag| {
trace!("basic_consume received response, returning consumer");
consumer.update_consumer_tag(consumer_tag);
consumer
Expand Down Expand Up @@ -428,16 +424,17 @@ impl<T: AsyncRead+AsyncWrite+Send+Sync+'static> Channel<T> {
/// gets a message
pub fn basic_get(&self, queue: &str, options: BasicGetOptions) -> impl Future<Item = BasicGetMessage, Error = io::Error> + Send + 'static {
let channel_id = self.id;
let _queue = queue.to_string();
let queue = queue.to_string();
let queue = queue.to_owned();
let receive_transport = self.transport.clone();
let receive_future = future::poll_fn(move || {
let mut transport = lock_transport!(receive_transport);
let receive_future = future::loop_fn((queue.clone(), channel_id), move |(queue, channel_id)| {
receive_transport.run(move |mut transport| {
transport.poll()?;
if let Some(message) = transport.conn.next_basic_get_message(channel_id, &_queue) {
return Ok(Async::Ready(message));
if let Some(message) = transport.conn.next_basic_get_message(channel_id, &queue) {
Ok(future::Loop::Break(message))
} else {
Ok(future::Loop::Continue((queue, channel_id)))
}
Ok(Async::NotReady)
})
});

self.run_on_locked_transport_full("basic_get", "Could not get message", move |transport| {
Expand Down Expand Up @@ -530,23 +527,14 @@ impl<T: AsyncRead+AsyncWrite+Send+Sync+'static> Channel<T> {
trace!("run on locked transport; method={:?}", method);
let channel_id = self.id;
let transport = self.transport.clone();
let _transport = self.transport.clone();
let _method = method.to_string();
let method = method.to_string();
let error = error.to_string();
// Tweak to make the borrow checker happy, see below for more explaination
let mut action = Some(action);
let mut payload = Some(payload);

future::poll_fn(move || {
let mut transport = lock_transport!(transport);
// The poll_fn here is only there for the lock_transport call above.
// Once the lock_transport yields a Async::Ready transport, the rest of the function is
// ran only once as we either return an error or an Async::Ready, it's thus safe to .take().unwrap()
// the action, which is always Some() the first time, and never called twice.
// This is needed because we're in an FnMut and thus cannot transfer ownership as an
// FnMut can be called several time and action which is an FnOnce can only be called
// once (which is implemented as a ownership transfer).
self.transport.run(move |mut transport| {
match action.take().unwrap()(&mut transport) {
Err(e) => Err(Error::new(ErrorKind::Other, format!("{}: {:?}", error, e))),
Ok(request_id) => {
Expand All @@ -556,27 +544,30 @@ impl<T: AsyncRead+AsyncWrite+Send+Sync+'static> Channel<T> {
transport.send_content_frames(channel_id, payload.as_slice(), properties);
}

Ok(Async::Ready(request_id))
Ok(request_id)
},
}
}).and_then(move |request_id| {
if request_id.is_some() {
trace!("{} returning closure", method);
}

future::poll_fn(move || {
let mut transport = lock_transport!(_transport);

if let Some(request_id) = request_id {
Self::wait_for_answer(&mut transport, request_id, &finished)
} else {
transport.poll().map(|r| r.map(|_| None))
}
future::loop_fn((transport, finished), move |(tr, finished)| {
tr.clone().run(move |mut transport| {
if let Some(request_id) = request_id {
Self::wait_for_answer(&mut transport, request_id, &finished)
} else {
transport.poll().map(|r| r.map(|_| None))
}.map(|res| match res {
Async::Ready(r) => future::Loop::Break(r),
Async::NotReady => future::Loop::Continue((tr, finished)),
})
})
})
})
}

fn run_on_lock_transport_basic_finished(conn: &mut Connection, request_id: RequestId) -> Poll<Option<RequestId>, io::Error> {
fn run_on_locked_transport_basic_finished(conn: &mut Connection, request_id: RequestId) -> Poll<Option<RequestId>, io::Error> {
match conn.is_finished(request_id) {
Some(answer) if answer => Ok(Async::Ready(Some(request_id))),
_ => {
Expand All @@ -588,7 +579,7 @@ impl<T: AsyncRead+AsyncWrite+Send+Sync+'static> Channel<T> {

fn run_on_locked_transport<Action>(&self, method: &str, error: &str, action: Action) -> impl Future<Item = Option<RequestId>, Error = io::Error> + Send + 'static
where Action: 'static + Send + FnOnce(&mut AMQPTransport<T>) -> Result<Option<RequestId>, lapin_async::error::Error> {
self.run_on_locked_transport_full(method, error, action, Self::run_on_lock_transport_basic_finished, None)
self.run_on_locked_transport_full(method, error, action, &Self::run_on_locked_transport_basic_finished, None)
}

/// internal method to wait until a request succeeds
Expand Down
15 changes: 7 additions & 8 deletions futures/src/client.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ use lapin_async::format::frame::Frame;
use std::default::Default;
use std::io;
use std::str::FromStr;
use futures::{future,Async,Future,Poll,Stream};
use futures::{future,Future,Poll,Stream};
use futures::sync::oneshot;
use futures_locks::Mutex;
use tokio_io::{AsyncRead,AsyncWrite};
Expand All @@ -13,6 +13,7 @@ use std::time::{Duration,Instant};

use transport::*;
use channel::{Channel, ConfirmSelectOptions};
use util::MutexExt;

/// the Client structures connects to a server and creates channels
//#[derive(Clone)]
Expand Down Expand Up @@ -91,15 +92,13 @@ fn heartbeat_pulse<T: AsyncRead+AsyncWrite+Send+'static>(transport: Mutex<AMQPTr
let send_transport = transport.clone();
let poll_transport = transport.clone();

future::poll_fn(move || {
let mut transport = lock_transport!(send_transport);
send_transport.run(|mut transport| {
debug!("Sending heartbeat");
transport.send_frame(Frame::Heartbeat(0));
Ok(Async::Ready(()))
}).and_then(move |_| future::poll_fn(move || {
let mut transport = lock_transport!(poll_transport);
transport.poll()
})).map(|_| ()).map_err(|err| {
Ok(())
}).and_then(move |_| {
poll_transport.run(|mut transport| transport.poll())
}).map(|_| ()).map_err(|err| {
error!("Error occured in heartbeat interval: {}", err);
err
})
Expand Down
25 changes: 14 additions & 11 deletions futures/src/consumer.rs
Original file line number Diff line number Diff line change
@@ -1,23 +1,22 @@
use futures::{Async,Future,Poll,Stream,task};
use futures::{Async,Poll,Stream,task};
use futures_locks::Mutex;
use lapin_async::consumer::ConsumerSubscriber;
use tokio_io::{AsyncRead,AsyncWrite};
use std::collections::VecDeque;
use std::io;
use std::sync::{Arc,Mutex as SMutex};

use message::Delivery;
use transport::*;

#[derive(Clone,Debug)]
pub struct ConsumerSub {
inner: Arc<SMutex<ConsumerInner>>,
inner: Mutex<ConsumerInner>,
}

impl ConsumerSubscriber for ConsumerSub {
fn new_delivery(&mut self, delivery: Delivery) {
trace!("new_delivery;");
if let Ok(mut inner) = self.inner.lock() {
if let Ok(mut inner) = self.inner.try_lock() {
inner.deliveries.push_back(delivery);
if let Some(task) = inner.task.as_ref() {
task.notify();
Expand All @@ -32,7 +31,7 @@ impl ConsumerSubscriber for ConsumerSub {
#[derive(Clone)]
pub struct Consumer<T> {
transport: Mutex<AMQPTransport<T>>,
inner: Arc<SMutex<ConsumerInner>>,
inner: Mutex<ConsumerInner>,
channel_id: u16,
queue: String,
consumer_tag: String,
Expand All @@ -57,7 +56,7 @@ impl<T: AsyncRead+AsyncWrite+Sync+Send+'static> Consumer<T> {
pub fn new(transport: Mutex<AMQPTransport<T>>, channel_id: u16, queue: String, consumer_tag: String) -> Consumer<T> {
Consumer {
transport,
inner: Arc::new(SMutex::new(ConsumerInner::default())),
inner: Mutex::new(ConsumerInner::default()),
channel_id,
queue,
consumer_tag,
Expand All @@ -81,13 +80,17 @@ impl<T: AsyncRead+AsyncWrite+Sync+Send+'static> Stream for Consumer<T> {

fn poll(&mut self) -> Poll<Option<Delivery>, io::Error> {
trace!("consumer poll; consumer_tag={:?} polling transport", self.consumer_tag);
let mut transport = lock_transport!(self.transport);
let mut transport = match self.transport.try_lock() {
Ok(transport) => transport,
Err(_) => {
task::current().notify();
return Ok(Async::NotReady);
}
};
transport.poll()?;
let mut inner = match self.inner.lock() {
let mut inner = match self.inner.try_lock() {
Ok(inner) => inner,
Err(_) => if self.inner.is_poisoned() {
return Err(io::Error::new(io::ErrorKind::Other, "Consumer mutex is poisoned"))
} else {
Err(_) => {
task::current().notify();
return Ok(Async::NotReady)
},
Expand Down
7 changes: 0 additions & 7 deletions futures/src/transport.rs
Original file line number Diff line number Diff line change
Expand Up @@ -279,13 +279,6 @@ impl<T> Future for AMQPTransportConnector<T>
}
}

#[macro_export]
macro_rules! lock_transport (
($t: expr) => ({
$t.lock().wait().unwrap()
});
);

#[cfg(test)]
mod tests {
extern crate env_logger;
Expand Down
14 changes: 8 additions & 6 deletions futures/src/util.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,17 +3,19 @@ use std::io;
use futures::future::{Future, IntoFuture};
use futures_locks::{Mutex, MutexGuard};

pub(crate) trait MutexExt<T> {
pub(crate) trait MutexExt<T: Send + 'static> {
fn run<F, B, R>(&self, f: F) -> Box<Future<Item = R, Error = io::Error> + Send + 'static>
where F: FnOnce(MutexGuard<T>) -> B + 'static,
B: IntoFuture<Item = R, Error = io::Error> + 'static,
where F: FnOnce(MutexGuard<T>) -> B + Send + 'static,
B: IntoFuture<Item = R, Error = io::Error> + Send + 'static,
<B as IntoFuture>::Future: Send + 'static,
R: Send + 'static;
}

impl<T: 'static> MutexExt<T> for Mutex<T> {
impl<T: Send + 'static> MutexExt<T> for Mutex<T> {
fn run<F, B, R>(&self, f: F) -> Box<Future<Item = R, Error = io::Error> + Send + 'static>
where F: FnOnce(MutexGuard<T>) -> B + 'static,
B: IntoFuture<Item = R, Error = io::Error> + 'static,
where F: FnOnce(MutexGuard<T>) -> B + Send + 'static,
B: IntoFuture<Item = R, Error = io::Error> + Send + 'static,
<B as IntoFuture>::Future: Send + 'static,
R: Send + 'static {
Box::new(self.with(f).then(|res| match res {
Ok(res) => res,
Expand Down

0 comments on commit d77ca8c

Please sign in to comment.