diff --git a/src/server.rs b/src/server.rs index 1d141c5..1d93c60 100644 --- a/src/server.rs +++ b/src/server.rs @@ -111,7 +111,7 @@ where /// let listener = tokio::net::TcpListener::bind("127.0.0.1:4443").await.unwrap(); /// let (stream, _) = listener.accept().await.unwrap(); /// - /// let acceptor = tokio_rustls::LazyConfigAcceptor::new(rustls::server::Acceptor::default(), stream); + /// let acceptor = tokio_rustls::LazyConfigAcceptor::new(rustls::server::Acceptor::default(), stream).send_alert(false); /// tokio::pin!(acceptor); /// /// match acceptor.as_mut().await { @@ -146,6 +146,57 @@ where None => None, } } + + /// Writes a stored alert, consuming the alert (if any) and IO. + pub async fn write_alert(&mut self) -> io::Result<()> { + let Some(alert) = self.take_alert() else { + return Ok(()); + }; + let Some(io) = self.take_io() else { + return Ok(()); + }; + WritingAlert { + io, + alert: Some(alert), + } + .await + } +} + +struct WritingAlert { + io: IO, + alert: Option, +} + +impl Future for WritingAlert +where + IO: AsyncRead + AsyncWrite + Unpin, +{ + type Output = Result<(), io::Error>; + fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { + let this = self.get_mut(); + let io = &mut this.io; + loop { + if let Some(mut alert) = this.alert.take() { + match alert.write(&mut SyncWriteAdapter { io, cx }) { + Err(e) if e.kind() == io::ErrorKind::WouldBlock => { + this.alert = Some(alert); + return Poll::Pending; + } + Err(e) => { + return Poll::Ready(Err(io::Error::new(io::ErrorKind::InvalidData, e))); + } + Ok(0) => { + return Poll::Ready(Ok(())); + } + Ok(_) => { + this.alert = Some(alert); + continue; + } + }; + } + } + } } impl Future for LazyConfigAcceptor @@ -199,7 +250,10 @@ where Ok(None) => {} Err((err, alert)) => match this.send_alert { true => this.alert = Some(AlertState::Sending(err, alert)), - false => this.alert = Some(AlertState::Saved(alert)), + false => { + this.alert = Some(AlertState::Saved(alert)); + return Poll::Ready(Err(io::Error::new(io::ErrorKind::InvalidData, err))); + } }, } } diff --git a/tests/test.rs b/tests/test.rs index 762f300..b7b2f03 100644 --- a/tests/test.rs +++ b/tests/test.rs @@ -318,6 +318,77 @@ async fn lazy_config_acceptor_alert() { assert_eq!(received, fatal_alert_decode_error) } +#[tokio::test] +async fn lazy_config_acceptor_return_http() { + let (mut cstream, sstream) = tokio::io::duplex(1024); + + let (tx, rx) = oneshot::channel(); + + tokio::spawn(async move { + // This is write instead of write_all because of the short duplex size, which is necessarily + // symmetrical. We never finish writing because the LazyConfigAcceptor returns an error + let _ = cstream.write(b"not tls").await; + let mut buf = Vec::new(); + cstream.read_to_end(&mut buf).await.unwrap(); + tx.send(buf).unwrap(); + }); + + let acceptor = + LazyConfigAcceptor::new(rustls::server::Acceptor::default(), sstream).send_alert(false); + tokio::pin!(acceptor); + + let Ok(accept_result) = time::timeout(Duration::from_secs(3), acceptor.as_mut()).await else { + panic!("timeout"); + }; + + assert!(accept_result.is_err()); + let mut io = acceptor.take_io().unwrap(); + io.write_all(b"HTTP/1.1 400 Invalid Input\r\n\r\n\r\nNot TLS\n") + .await + .unwrap(); + io.shutdown().await.unwrap(); + + let Ok(Ok(received)) = time::timeout(Duration::from_secs(3), rx).await else { + panic!("failed to receive"); + }; + + let recv = b"HTTP/1.1 400 Invalid Input\r\n\r\n\r\nNot TLS\n"; + assert_eq!(received, recv) +} + +#[tokio::test] +async fn lazy_config_acceptor_manual_alert() { + let (mut cstream, sstream) = tokio::io::duplex(2); + + let (tx, rx) = oneshot::channel(); + + tokio::spawn(async move { + // This is write instead of write_all because of the short duplex size, which is necessarily + // symmetrical. We never finish writing because the LazyConfigAcceptor returns an error + let _ = cstream.write(b"not tls").await; + let mut buf = Vec::new(); + cstream.read_to_end(&mut buf).await.unwrap(); + tx.send(buf).unwrap(); + }); + + let acceptor = + LazyConfigAcceptor::new(rustls::server::Acceptor::default(), sstream).send_alert(false); + tokio::pin!(acceptor); + + let Ok(accept_result) = time::timeout(Duration::from_secs(3), acceptor.as_mut()).await else { + panic!("timeout"); + }; + + assert!(accept_result.is_err()); + acceptor.write_alert().await.unwrap(); + let Ok(Ok(received)) = time::timeout(Duration::from_secs(3), rx).await else { + panic!("failed to receive"); + }; + + let fatal_alert_decode_error = b"\x15\x03\x03\x00\x02\x02\x32"; + assert_eq!(received, fatal_alert_decode_error) +} + #[tokio::test] async fn handshake_flush_pending() -> io::Result<()> { pass_impl(utils::FlushWrapper::new, false).await