Skip to content

Commit

Permalink
Merge pull request #41 from tmccombs/remove-until
Browse files Browse the repository at this point in the history
feat!: Remove until & remove option from accept
  • Loading branch information
tmccombs committed Dec 1, 2023
2 parents 0f3a7ea + 2ed0ebe commit 9c5d44f
Show file tree
Hide file tree
Showing 9 changed files with 76 additions and 120 deletions.
5 changes: 3 additions & 2 deletions examples/echo-threads.rs
@@ -1,6 +1,6 @@
use futures_util::StreamExt;
use std::net::SocketAddr;
use tls_listener::{AsyncAccept, SpawningHandshakes, TlsListener};
use tls_listener::{SpawningHandshakes, TlsListener};
use tokio::io::{copy, split};
use tokio::net::{TcpListener, TcpStream};
use tokio::signal::ctrl_c;
Expand All @@ -27,9 +27,10 @@ async fn handle_stream(stream: TlsStream<TcpStream>, _remote_addr: SocketAddr) {
async fn main() -> Result<(), Box<dyn std::error::Error>> {
let addr: SocketAddr = ([127, 0, 0, 1], 3000).into();

let listener = TcpListener::bind(&addr).await?.until(ctrl_c());
let listener = TcpListener::bind(&addr).await?;

TlsListener::new(SpawningHandshakes(tls_acceptor()), listener)
.take_until(ctrl_c())
.for_each_concurrent(None, |s| async {
match s {
Ok((stream, remote_addr)) => {
Expand Down
5 changes: 3 additions & 2 deletions examples/echo.rs
@@ -1,6 +1,6 @@
use futures_util::StreamExt;
use std::net::SocketAddr;
use tls_listener::{AsyncAccept, TlsListener};
use tls_listener::TlsListener;
use tokio::io::{copy, split};
use tokio::net::{TcpListener, TcpStream};
use tokio::signal::ctrl_c;
Expand Down Expand Up @@ -36,9 +36,10 @@ async fn handle_stream(stream: TlsStream<TcpStream>, _remote_addr: SocketAddr) {
async fn main() -> Result<(), Box<dyn std::error::Error>> {
let addr: SocketAddr = ([127, 0, 0, 1], 3000).into();

let listener = TcpListener::bind(&addr).await?.until(ctrl_c());
let listener = TcpListener::bind(&addr).await?;

TlsListener::new(tls_acceptor(), listener)
.take_until(ctrl_c())
.for_each_concurrent(None, |s| async {
match s {
Ok((stream, remote_addr)) => {
Expand Down
2 changes: 1 addition & 1 deletion examples/http-change-certificate.rs
Expand Up @@ -31,7 +31,7 @@ async fn main() {
loop {
tokio::select! {
conn = listener.accept() => {
match conn.expect("Tls listener stream should be infinite") {
match conn {
Ok((conn, remote_addr)) => {
let http = http.clone();
let tx = tx.clone();
Expand Down
2 changes: 1 addition & 1 deletion examples/http.rs
Expand Up @@ -21,7 +21,7 @@ async fn main() -> Result<(), Box<dyn std::error::Error>> {

// We start a loop to continuously accept incoming connections
loop {
match listener.accept().await.unwrap() {
match listener.accept().await {
Ok((stream, _)) => {
let io = TokioIo::new(stream);

Expand Down
116 changes: 41 additions & 75 deletions src/lib.rs
Expand Up @@ -18,7 +18,7 @@ use pin_project_lite::pin_project;
#[cfg(feature = "rt")]
pub use spawning_handshake::SpawningHandshakes;
use std::fmt::Debug;
use std::future::Future;
use std::future::{poll_fn, Future};
use std::pin::Pin;
use std::task::{ready, Context, Poll};
use std::time::Duration;
Expand Down Expand Up @@ -77,24 +77,7 @@ pub trait AsyncAccept {
fn poll_accept(
self: Pin<&mut Self>,
cx: &mut Context<'_>,
) -> Poll<Option<Result<(Self::Connection, Self::Address), Self::Error>>>;

/// Return a new `AsyncAccept` that stops accepting connections after
/// `ender` completes.
///
/// Useful for graceful shutdown.
///
/// See [examples/echo.rs](https://github.com/tmccombs/tls-listener/blob/main/examples/echo.rs)
/// for example of how to use.
fn until<F: Future>(self, ender: F) -> Until<Self, F>
where
Self: Sized,
{
Until {
acceptor: self,
ender,
}
}
) -> Poll<Result<(Self::Connection, Self::Address), Self::Error>>;
}

pin_project! {
Expand Down Expand Up @@ -192,14 +175,49 @@ where
A: AsyncAccept,
T: AsyncTls<A::Connection>,
{
/// Poll accepting a connection.
///
/// This will return ready once the TLS handshake has completed on an incoming
/// connection and return the connection and the source address.
pub fn poll_accept(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<<Self as Stream>::Item> {
let mut this = self.project();

while this.waiting.len() < *this.max_handshakes {
match this.listener.as_mut().poll_accept(cx) {
Poll::Pending => break,
Poll::Ready(Ok((conn, addr))) => {
this.waiting.push(Waiting {
inner: timeout(*this.timeout, this.tls.accept(conn)),
peer_addr: Some(addr),
});
}
Poll::Ready(Err(e)) => {
return Poll::Ready(Err(Error::ListenerError(e)));
}
}
}

match this.waiting.poll_next_unpin(cx) {
Poll::Ready(Some(result)) => Poll::Ready(result),
// If we don't have anything waiting yet,
// then we are still pending,
Poll::Ready(None) | Poll::Pending => Poll::Pending,
}
}

/// Accept the next connection
///
/// This is essentially an alias to `self.next()` with a more domain-appropriate name.
pub fn accept(&mut self) -> impl Future<Output = Option<<Self as Stream>::Item>> + '_
/// This is similar to `self.next()`, but doesn't return an `Option` because
/// there isn't an end condition on accepting connections,
/// and has a more domain-appropriate name.
///
/// The future returned is "cancellation safe".
pub fn accept(&mut self) -> impl Future<Output = <Self as Stream>::Item> + '_
where
Self: Unpin,
{
self.next()
let mut pinned = Pin::new(self);
poll_fn(move |cx| pinned.as_mut().poll_accept(cx))
}

/// Replaces the Tls Acceptor configuration, which will be used for new connections.
Expand Down Expand Up @@ -237,31 +255,7 @@ where
type Item = Result<(T::Stream, A::Address), TlsListenerError<A, T>>;

fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
let mut this = self.project();

while this.waiting.len() < *this.max_handshakes {
match this.listener.as_mut().poll_accept(cx) {
Poll::Pending => break,
Poll::Ready(Some(Ok((conn, addr)))) => {
this.waiting.push(Waiting {
inner: timeout(*this.timeout, this.tls.accept(conn)),
peer_addr: Some(addr),
});
}
Poll::Ready(Some(Err(e))) => {
return Poll::Ready(Some(Err(Error::ListenerError(e))));
}
Poll::Ready(None) => return Poll::Ready(None),
}
}

match this.waiting.poll_next_unpin(cx) {
// If we don't have anything waiting yet,
// then we are still pending,
Poll::Ready(None) => Poll::Pending,
// Otherwise the result is already what we want
result => result,
}
self.poll_accept(cx).map(Some)
}
}

Expand Down Expand Up @@ -432,31 +426,3 @@ where
}
}
}

pin_project! {
/// See [`AsyncAccept::until`]
pub struct Until<A, E> {
#[pin]
acceptor: A,
#[pin]
ender: E,
}
}

impl<A: AsyncAccept, E: Future> AsyncAccept for Until<A, E> {
type Connection = A::Connection;
type Error = A::Error;
type Address = A::Address;

fn poll_accept(
self: Pin<&mut Self>,
cx: &mut Context<'_>,
) -> Poll<Option<Result<(Self::Connection, Self::Address), Self::Error>>> {
let this = self.project();

match this.ender.poll(cx) {
Poll::Pending => this.acceptor.poll_accept(cx),
Poll::Ready(_) => Poll::Ready(None),
}
}
}
12 changes: 6 additions & 6 deletions src/net.rs
Expand Up @@ -15,10 +15,10 @@ impl AsyncAccept for TcpListener {
fn poll_accept(
self: Pin<&mut Self>,
cx: &mut Context<'_>,
) -> Poll<Option<Result<(Self::Connection, Self::Address), Self::Error>>> {
) -> Poll<Result<(Self::Connection, Self::Address), Self::Error>> {
match (*self).poll_accept(cx) {
Poll::Ready(Ok(conn)) => Poll::Ready(Some(Ok(conn))),
Poll::Ready(Err(e)) => Poll::Ready(Some(Err(e))),
Poll::Ready(Ok(conn)) => Poll::Ready(Ok(conn)),
Poll::Ready(Err(e)) => Poll::Ready(Err(e)),
Poll::Pending => Poll::Pending,
}
}
Expand All @@ -34,10 +34,10 @@ impl AsyncAccept for UnixListener {
fn poll_accept(
self: Pin<&mut Self>,
cx: &mut Context<'_>,
) -> Poll<Option<Result<(Self::Connection, Self::Address), Self::Error>>> {
) -> Poll<Result<(Self::Connection, Self::Address), Self::Error>> {
match (*self).poll_accept(cx) {
Poll::Ready(Ok(conn)) => Poll::Ready(Some(Ok(conn))),
Poll::Ready(Err(e)) => Poll::Ready(Some(Err(e))),
Poll::Ready(Ok(conn)) => Poll::Ready(Ok(conn)),
Poll::Ready(Err(e)) => Poll::Ready(Err(e)),
Poll::Pending => Poll::Pending,
}
}
Expand Down
31 changes: 7 additions & 24 deletions tests/basic.rs
Expand Up @@ -7,6 +7,7 @@ use helper::*;
use helper::{assert_ascii_eq, assert_err};
use tokio::io::{AsyncWriteExt, DuplexStream};
use tokio::spawn;
use tokio::sync::oneshot;

use futures_util::StreamExt;

Expand Down Expand Up @@ -46,7 +47,7 @@ async fn stream_error() {
connecter
.send_error(Error::new(ErrorKind::ConnectionReset, "test"))
.await;
assert_err!(listener.accept().await.unwrap(), ListenerError(_));
assert_err!(listener.accept().await, ListenerError(_));
}

#[tokio::test]
Expand All @@ -67,38 +68,20 @@ async fn tls_error() {
let mut listener = TlsListener::new(ErrTls, accept);

assert_err!(
listener.accept().await.unwrap(),
listener.accept().await,
TlsAcceptError {
peer_addr: MockAddress(42),
..
}
);
}

#[tokio::test]
async fn accept_ended() {
let (connector, mut listener) = setup();

spawn(async move {
assert_ascii_eq!(connector.send_data(b"hello").await.unwrap(), b"abc");
});

let res = listener.accept().await;
if let Some(Ok((mut stream, MockAddress(stream_id)))) = res {
assert_eq!(stream_id, 42);
stream.write_all(b"ABC").await.unwrap();
} else {
panic!("Failed to accept stream. Got {:?}", res);
}

assert!(listener.accept().await.is_none());
}

static LONG_TEXT: &'static [u8] = include_bytes!("long_text.txt");

#[tokio::test]
async fn echo() {
let (connector, listener) = setup_echo();
let (ender, ended) = oneshot::channel();
let (connector, listener) = setup_echo(ended);

async fn check_message(c: &MockConnect, msg: &[u8]) -> () {
let resp = c.send_data(msg).await;
Expand All @@ -117,7 +100,7 @@ async fn echo() {
check_message(c, LONG_TEXT),
check_message(c, LONG_TEXT),
);
drop(connector);
ender.send(()).unwrap();

if let Err(e) = listener.await {
std::panic::resume_unwind(e.into_panic());
Expand All @@ -135,6 +118,6 @@ async fn addr() {
});

for i in 42..44 {
assert_eq!(listener.accept().await.unwrap().unwrap().1, MockAddress(i));
assert_eq!(listener.accept().await.unwrap().1, MockAddress(i));
}
}
4 changes: 2 additions & 2 deletions tests/helper/mocks.rs
Expand Up @@ -70,8 +70,8 @@ impl AsyncAccept for MockAccept {
type Error = io::Error;
type Address = MockAddress;

fn poll_accept(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<ConnResult>> {
Pin::into_inner(self).chan.poll_recv(cx)
fn poll_accept(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<ConnResult> {
Pin::into_inner(self).chan.poll_recv(cx).map(|c| c.unwrap())
}
}

Expand Down
19 changes: 12 additions & 7 deletions tests/helper/mod.rs
@@ -1,6 +1,7 @@
use futures_util::StreamExt;
use tls_listener::TlsListener;
use tokio::io::{copy, split};
use tokio::sync::oneshot;
use tokio::task::JoinHandle;

mod asserts;
Expand All @@ -14,14 +15,18 @@ pub fn setup() -> (MockConnect, TlsListener<MockAccept, MockTls>) {
(connect, TlsListener::new(MockTls, accept))
}

pub fn setup_echo() -> (MockConnect, JoinHandle<()>) {
pub fn setup_echo(end: oneshot::Receiver<()>) -> (MockConnect, JoinHandle<()>) {
let (connector, listener) = setup();

let handle = tokio::spawn(listener.for_each_concurrent(None, |s| async {
let (mut reader, mut writer) = split(s.expect("Unexpected error").0);
copy(&mut reader, &mut writer)
.await
.expect("Failed to copy");
}));
let handle = tokio::spawn(
listener
.take_until(end)
.for_each_concurrent(None, |s| async {
let (mut reader, mut writer) = split(s.expect("Unexpected error").0);
copy(&mut reader, &mut writer)
.await
.expect("Failed to copy");
}),
);
(connector, handle)
}

0 comments on commit 9c5d44f

Please sign in to comment.