Skip to content

Commit

Permalink
add AcceptorBuilder::with_acceptor method
Browse files Browse the repository at this point in the history
  • Loading branch information
MarinPostma committed Sep 15, 2023
1 parent 6e6df04 commit 4a6220a
Show file tree
Hide file tree
Showing 2 changed files with 35 additions and 18 deletions.
43 changes: 26 additions & 17 deletions src/acceptor.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,9 +17,9 @@ pub use builder::AcceptorBuilder;
use builder::WantsTlsConfig;

/// A TLS acceptor that can be used with hyper servers.
pub struct TlsAcceptor {
pub struct TlsAcceptor<A = AddrIncoming> {
config: Arc<ServerConfig>,
incoming: AddrIncoming,
acceptor: A,
}

/// An Acceptor for the `https` scheme.
Expand All @@ -31,20 +31,23 @@ impl TlsAcceptor {

/// Creates a new `TlsAcceptor` from a `ServerConfig` and an `AddrIncoming`.
pub fn new(config: Arc<ServerConfig>, incoming: AddrIncoming) -> Self {
Self { config, incoming }
Self { config, acceptor: incoming }
}
}

impl Accept for TlsAcceptor {
type Conn = TlsStream;
impl<A> Accept for TlsAcceptor<A>
where A: Accept<Error = io::Error> + Unpin,
A::Conn: AsyncRead + AsyncWrite + Unpin,
{
type Conn = TlsStream<A::Conn>;
type Error = io::Error;

fn poll_accept(
self: Pin<&mut Self>,
cx: &mut Context<'_>,
) -> Poll<Option<Result<Self::Conn, Self::Error>>> {
let pin = self.get_mut();
Poll::Ready(match ready!(Pin::new(&mut pin.incoming).poll_accept(cx)) {
Poll::Ready(match ready!(Pin::new(&mut pin.acceptor).poll_accept(cx)) {
Some(Ok(sock)) => Some(Ok(TlsStream::new(sock, pin.config.clone()))),
Some(Err(e)) => Some(Err(e)),
None => None,
Expand All @@ -66,22 +69,24 @@ where
// tokio_rustls::server::TlsStream doesn't expose constructor methods,
// so we have to TlsAcceptor::accept and handshake to have access to it
// TlsStream implements AsyncRead/AsyncWrite by handshaking with tokio_rustls::Accept first
pub struct TlsStream {
state: State,
pub struct TlsStream<C = AddrStream> {
state: State<C>,
}

impl TlsStream {
fn new(stream: AddrStream, config: Arc<ServerConfig>) -> Self {

impl<C> TlsStream<C>
where C: AsyncRead + AsyncWrite + Unpin
{
fn new(stream: C, config: Arc<ServerConfig>) -> Self {
let accept = tokio_rustls::TlsAcceptor::from(config).accept(stream);
Self {
state: State::Handshaking(accept),
}
}

/// Returns a reference to the underlying IO stream.
///
/// This should always return `Some`, except if an error has already been yielded.
pub fn io(&self) -> Option<&AddrStream> {
pub fn io(&self) -> Option<&C> {
match &self.state {
State::Handshaking(accept) => accept.get_ref(),
State::Streaming(stream) => Some(stream.get_ref().0),
Expand All @@ -99,7 +104,9 @@ impl TlsStream {
}
}

impl AsyncRead for TlsStream {
impl<C> AsyncRead for TlsStream<C>
where C: AsyncRead + AsyncWrite + Unpin
{
fn poll_read(
self: Pin<&mut Self>,
cx: &mut Context,
Expand All @@ -122,7 +129,9 @@ impl AsyncRead for TlsStream {
}
}

impl AsyncWrite for TlsStream {
impl<C> AsyncWrite for TlsStream<C>
where C: AsyncRead + AsyncWrite + Unpin
{
fn poll_write(
self: Pin<&mut Self>,
cx: &mut Context<'_>,
Expand Down Expand Up @@ -159,7 +168,7 @@ impl AsyncWrite for TlsStream {
}
}

enum State {
Handshaking(tokio_rustls::Accept<AddrStream>),
Streaming(tokio_rustls::server::TlsStream<AddrStream>),
enum State<C> {
Handshaking(tokio_rustls::Accept<C>),
Streaming(tokio_rustls::server::TlsStream<C>),
}
10 changes: 9 additions & 1 deletion src/acceptor/builder.rs
Original file line number Diff line number Diff line change
Expand Up @@ -90,9 +90,17 @@ impl AcceptorBuilder<WantsIncoming> {
/// Passes a [`AddrIncoming`] to configure the TLS connection and
/// creates the [`TlsAcceptor`]
pub fn with_incoming(self, incoming: impl Into<AddrIncoming>) -> TlsAcceptor {
self.with_acceptor(incoming.into())
}

/// Passes an acceptor implementing [Accept] to configure the TLS connection and
/// creates the [`TlsAcceptor`]
///
/// [Accept]: hyper::server::accept::Accept
pub fn with_acceptor<A>(self, acceptor: A) -> TlsAcceptor<A> {
TlsAcceptor {
config: Arc::new(self.0 .0),
incoming: incoming.into(),
acceptor,
}
}
}

0 comments on commit 4a6220a

Please sign in to comment.