diff --git a/src/acceptor.rs b/src/acceptor.rs index 4cf816d..037b64e 100644 --- a/src/acceptor.rs +++ b/src/acceptor.rs @@ -17,9 +17,9 @@ pub use builder::AcceptorBuilder; use builder::WantsTlsConfig; /// A TLS acceptor that can be used with hyper servers. -pub struct TlsAcceptor { +pub struct TlsAcceptor { config: Arc, - incoming: AddrIncoming, + acceptor: A, } /// An Acceptor for the `https` scheme. @@ -31,12 +31,19 @@ impl TlsAcceptor { /// Creates a new `TlsAcceptor` from a `ServerConfig` and an `AddrIncoming`. pub fn new(config: Arc, incoming: AddrIncoming) -> Self { - Self { config, incoming } + Self { + config, + acceptor: incoming, + } } } -impl Accept for TlsAcceptor { - type Conn = TlsStream; +impl Accept for TlsAcceptor +where + A: Accept + Unpin, + A::Conn: AsyncRead + AsyncWrite + Unpin, +{ + type Conn = TlsStream; type Error = io::Error; fn poll_accept( @@ -44,7 +51,7 @@ impl Accept for TlsAcceptor { cx: &mut Context<'_>, ) -> Poll>> { let pin = self.get_mut(); - Poll::Ready(match ready!(Pin::new(&mut pin.incoming).poll_accept(cx)) { + Poll::Ready(match ready!(Pin::new(&mut pin.acceptor).poll_accept(cx)) { Some(Ok(sock)) => Some(Ok(TlsStream::new(sock, pin.config.clone()))), Some(Err(e)) => Some(Err(e)), None => None, @@ -66,22 +73,21 @@ where // tokio_rustls::server::TlsStream doesn't expose constructor methods, // so we have to TlsAcceptor::accept and handshake to have access to it // TlsStream implements AsyncRead/AsyncWrite by handshaking with tokio_rustls::Accept first -pub struct TlsStream { - state: State, +pub struct TlsStream { + state: State, } -impl TlsStream { - fn new(stream: AddrStream, config: Arc) -> Self { +impl TlsStream { + fn new(stream: C, config: Arc) -> Self { let accept = tokio_rustls::TlsAcceptor::from(config).accept(stream); Self { state: State::Handshaking(accept), } } - /// Returns a reference to the underlying IO stream. /// /// This should always return `Some`, except if an error has already been yielded. - pub fn io(&self) -> Option<&AddrStream> { + pub fn io(&self) -> Option<&C> { match &self.state { State::Handshaking(accept) => accept.get_ref(), State::Streaming(stream) => Some(stream.get_ref().0), @@ -99,7 +105,7 @@ impl TlsStream { } } -impl AsyncRead for TlsStream { +impl AsyncRead for TlsStream { fn poll_read( self: Pin<&mut Self>, cx: &mut Context, @@ -122,7 +128,7 @@ impl AsyncRead for TlsStream { } } -impl AsyncWrite for TlsStream { +impl AsyncWrite for TlsStream { fn poll_write( self: Pin<&mut Self>, cx: &mut Context<'_>, @@ -159,7 +165,7 @@ impl AsyncWrite for TlsStream { } } -enum State { - Handshaking(tokio_rustls::Accept), - Streaming(tokio_rustls::server::TlsStream), +enum State { + Handshaking(tokio_rustls::Accept), + Streaming(tokio_rustls::server::TlsStream), } diff --git a/src/acceptor/builder.rs b/src/acceptor/builder.rs index 70c0ca7..9b95a82 100644 --- a/src/acceptor/builder.rs +++ b/src/acceptor/builder.rs @@ -90,9 +90,17 @@ impl AcceptorBuilder { /// Passes a [`AddrIncoming`] to configure the TLS connection and /// creates the [`TlsAcceptor`] pub fn with_incoming(self, incoming: impl Into) -> TlsAcceptor { + self.with_acceptor(incoming.into()) + } + + /// Passes an acceptor implementing [`Accept`] to configure the TLS connection and + /// creates the [`TlsAcceptor`] + /// + /// [`Accept`]: hyper::server::accept::Accept + pub fn with_acceptor(self, acceptor: A) -> TlsAcceptor { TlsAcceptor { config: Arc::new(self.0 .0), - incoming: incoming.into(), + acceptor, } } }