Skip to content

Commit

Permalink
buffer: wake tasks waiting for channel capacity when terminating (#480)
Browse files Browse the repository at this point in the history
  • Loading branch information
hawkw committed Oct 28, 2020
1 parent 069c908 commit 43c4492
Show file tree
Hide file tree
Showing 4 changed files with 183 additions and 4 deletions.
4 changes: 2 additions & 2 deletions tower/src/buffer/service.rs
Original file line number Diff line number Diff line change
Expand Up @@ -80,8 +80,8 @@ where
Request: Send + 'static,
{
let (tx, rx) = mpsc::unbounded_channel();
let (handle, worker) = Worker::new(service, rx);
let semaphore = Semaphore::new(bound);
let (semaphore, wake_waiters) = Semaphore::new_with_close(bound);
let (handle, worker) = Worker::new(service, rx, wake_waiters);
(
Buffer {
tx,
Expand Down
23 changes: 22 additions & 1 deletion tower/src/buffer/worker.rs
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ use tower_service::Service;
/// as part of the public API. This is the "sealed" pattern to include "private"
/// types in public traits that are not meant for consumers of the library to
/// implement (only call).
#[pin_project]
#[pin_project(PinnedDrop)]
#[derive(Debug)]
pub struct Worker<T, Request>
where
Expand All @@ -33,6 +33,7 @@ where
finish: bool,
failed: Option<ServiceError>,
handle: Handle,
close: Option<crate::semaphore::Close>,
}

/// Get the error out
Expand All @@ -49,6 +50,7 @@ where
pub(crate) fn new(
service: T,
rx: mpsc::UnboundedReceiver<Message<Request, T::Future>>,
close: crate::semaphore::Close,
) -> (Handle, Worker<T, Request>) {
let handle = Handle {
inner: Arc::new(Mutex::new(None)),
Expand All @@ -61,6 +63,7 @@ where
rx,
service,
handle: handle.clone(),
close: Some(close),
};

(handle, worker)
Expand Down Expand Up @@ -195,6 +198,11 @@ where
.as_ref()
.expect("Worker::failed did not set self.failed?")
.clone()));
// Wake any tasks waiting on channel capacity.
if let Some(close) = self.close.take() {
tracing::debug!("waking pending tasks");
close.close();
}
}
}
}
Expand All @@ -208,6 +216,19 @@ where
}
}

#[pin_project::pinned_drop]
impl<T, Request> PinnedDrop for Worker<T, Request>
where
T: Service<Request>,
T::Error: Into<crate::BoxError>,
{
fn drop(mut self: Pin<&mut Self>) {
if let Some(close) = self.as_mut().close.take() {
close.close();
}
}
}

impl Handle {
pub(crate) fn get_error_on_closed(&self) -> crate::BoxError {
self.inner
Expand Down
41 changes: 40 additions & 1 deletion tower/src/semaphore.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ use std::{
future::Future,
mem,
pin::Pin,
sync::Arc,
sync::{Arc, Weak},
task::{Context, Poll},
};
use tokio::sync;
Expand All @@ -16,13 +16,32 @@ pub(crate) struct Semaphore {
state: State,
}

#[derive(Debug)]
pub(crate) struct Close {
semaphore: Weak<sync::Semaphore>,
permits: usize,
}

enum State {
Waiting(Pin<Box<dyn Future<Output = Permit> + Send + 'static>>),
Ready(Permit),
Empty,
}

impl Semaphore {
pub(crate) fn new_with_close(permits: usize) -> (Self, Close) {
let semaphore = Arc::new(sync::Semaphore::new(permits));
let close = Close {
semaphore: Arc::downgrade(&semaphore),
permits,
};
let semaphore = Self {
semaphore,
state: State::Empty,
};
(semaphore, close)
}

pub(crate) fn new(permits: usize) -> Self {
Self {
semaphore: Arc::new(sync::Semaphore::new(permits)),
Expand Down Expand Up @@ -72,3 +91,23 @@ impl fmt::Debug for State {
}
}
}

impl Close {
/// Close the semaphore, waking any remaining tasks currently awaiting a permit.
pub(crate) fn close(self) {
// The maximum number of permits that a `tokio::sync::Semaphore`
// can hold is usize::MAX >> 3. If we attempt to add more than that
// number of permits, the semaphore will panic.
// XXX(eliza): another shift is kinda janky but if we add (usize::MAX
// > 3 - initial permits) the semaphore impl panics (I think due to a
// bug in tokio?).
// TODO(eliza): Tokio should _really_ just expose `Semaphore::close`
// publicly so we don't have to do this nonsense...
const MAX: usize = std::usize::MAX >> 4;
if let Some(semaphore) = self.semaphore.upgrade() {
// If we added `MAX - available_permits`, any tasks that are
// currently holding permits could drop them, overflowing the max.
semaphore.add_permits(MAX - self.permits);
}
}
}
119 changes: 119 additions & 0 deletions tower/tests/buffer/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ mod support;
use std::thread;
use tokio_test::{assert_pending, assert_ready, assert_ready_err, assert_ready_ok, task};
use tower::buffer::{error, Buffer};
use tower::{util::ServiceExt, Service};
use tower_test::{assert_request_eq, mock};

fn let_worker_work() {
Expand Down Expand Up @@ -227,6 +228,124 @@ async fn waits_for_channel_capacity() {
assert_ready_ok!(response4.poll());
}

#[tokio::test(flavor = "current_thread")]
async fn wakes_pending_waiters_on_close() {
let _t = support::trace_init();

let (service, mut handle) = mock::pair::<_, ()>();

let (mut service, worker) = Buffer::pair(service, 1);
let mut worker = task::spawn(worker);

// keep the request in the worker
handle.allow(0);
let service1 = service.ready_and().await.unwrap();
assert_pending!(worker.poll());
let mut response = task::spawn(service1.call("hello"));

let mut service1 = service.clone();
let mut ready_and1 = task::spawn(service1.ready_and());
assert_pending!(worker.poll());
assert_pending!(ready_and1.poll(), "no capacity");

let mut service1 = service.clone();
let mut ready_and2 = task::spawn(service1.ready_and());
assert_pending!(worker.poll());
assert_pending!(ready_and2.poll(), "no capacity");

// kill the worker task
drop(worker);

let err = assert_ready_err!(response.poll());
assert!(
err.is::<error::Closed>(),
"response should fail with a Closed, got: {:?}",
err
);

assert!(
ready_and1.is_woken(),
"dropping worker should wake ready_and task 1"
);
let err = assert_ready_err!(ready_and1.poll());
assert!(
err.is::<error::Closed>(),
"ready_and 1 should fail with a Closed, got: {:?}",
err
);

assert!(
ready_and2.is_woken(),
"dropping worker should wake ready_and task 2"
);
let err = assert_ready_err!(ready_and1.poll());
assert!(
err.is::<error::Closed>(),
"ready_and 2 should fail with a Closed, got: {:?}",
err
);
}

#[tokio::test(flavor = "current_thread")]
async fn wakes_pending_waiters_on_failure() {
let _t = support::trace_init();

let (service, mut handle) = mock::pair::<_, ()>();

let (mut service, worker) = Buffer::pair(service, 1);
let mut worker = task::spawn(worker);

// keep the request in the worker
handle.allow(0);
let service1 = service.ready_and().await.unwrap();
assert_pending!(worker.poll());
let mut response = task::spawn(service1.call("hello"));

let mut service1 = service.clone();
let mut ready_and1 = task::spawn(service1.ready_and());
assert_pending!(worker.poll());
assert_pending!(ready_and1.poll(), "no capacity");

let mut service1 = service.clone();
let mut ready_and2 = task::spawn(service1.ready_and());
assert_pending!(worker.poll());
assert_pending!(ready_and2.poll(), "no capacity");

// fail the inner service
handle.send_error("foobar");
// worker task terminates
assert_ready!(worker.poll());

let err = assert_ready_err!(response.poll());
assert!(
err.is::<error::ServiceError>(),
"response should fail with a ServiceError, got: {:?}",
err
);

assert!(
ready_and1.is_woken(),
"dropping worker should wake ready_and task 1"
);
let err = assert_ready_err!(ready_and1.poll());
assert!(
err.is::<error::ServiceError>(),
"ready_and 1 should fail with a ServiceError, got: {:?}",
err
);

assert!(
ready_and2.is_woken(),
"dropping worker should wake ready_and task 2"
);
let err = assert_ready_err!(ready_and1.poll());
assert!(
err.is::<error::ServiceError>(),
"ready_and 2 should fail with a ServiceError, got: {:?}",
err
);
}

type Mock = mock::Mock<&'static str, &'static str>;
type Handle = mock::Handle<&'static str, &'static str>;

Expand Down

0 comments on commit 43c4492

Please sign in to comment.