Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

server: allow acceptor to send alerts after error #1811

Merged
merged 2 commits into from
Feb 27, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 15 additions & 5 deletions examples/src/bin/server_acceptor.rs
Original file line number Diff line number Diff line change
Expand Up @@ -117,17 +117,27 @@ fn main() {
// connection.
let accepted = loop {
acceptor.read_tls(&mut stream).unwrap();
if let Some(accepted) = acceptor.accept().unwrap() {
break accepted;

match acceptor.accept() {
Ok(Some(accepted)) => break accepted,
Ok(None) => continue,
Err((e, mut alert)) => {
alert.write(&mut stream).unwrap();
panic!("error accepting connection: {e}");
}
cpu marked this conversation as resolved.
Show resolved Hide resolved
}
};

// Generate a server config for the accepted connection, optionally customizing the
// configuration based on the client hello.
let config = test_pki.server_config(&crl_path, accepted.client_hello());
let mut conn = accepted
.into_connection(config)
.unwrap();
let mut conn = match accepted.into_connection(config) {
Ok(conn) => conn,
Err((e, mut alert)) => {
alert.write(&mut stream).unwrap();
panic!("error completing accepting connection: {e}");
}
};

// Proceed with handling the ServerConnection
// Important: We do no error handling here, but you should!
Expand Down
4 changes: 2 additions & 2 deletions provider-example/examples/server.rs
Original file line number Diff line number Diff line change
Expand Up @@ -43,8 +43,8 @@ fn main() {
conn.write_tls(&mut stream).unwrap();
conn.complete_io(&mut stream).unwrap();
}
Err(e) => {
eprintln!("{}", e);
Err((err, _)) => {
eprintln!("{err}");
}
}
}
Expand Down
2 changes: 1 addition & 1 deletion rustls/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -593,7 +593,7 @@ pub mod server {
Accepted, ServerConfig, ServerConnectionData, UnbufferedServerConnection,
};
#[cfg(feature = "std")]
pub use server_conn::{Acceptor, ReadEarlyData, ServerConnection};
pub use server_conn::{AcceptedAlert, Acceptor, ReadEarlyData, ServerConnection};
pub use server_conn::{ClientHello, ProducesTickets, ResolvesServerCert};

/// Dangerous configuration that should be audited and used with extreme care.
Expand Down
92 changes: 74 additions & 18 deletions rustls/src/server/server_conn.rs
Original file line number Diff line number Diff line change
Expand Up @@ -503,6 +503,7 @@
use crate::error::Error;
use crate::server::hs;
use crate::suites::ExtractedSecrets;
use crate::vecbuf::ChunkVecBuffer;

use alloc::boxed::Box;
use alloc::sync::Arc;
Expand Down Expand Up @@ -759,25 +760,36 @@
/// Returns `Ok(Some(accepted))` if the connection has been accepted. Call
/// `accepted.into_connection()` to continue. Do not call this function again.
///
/// Returns `Err(err)` if an error occurred. Do not call this function again.
pub fn accept(&mut self) -> Result<Option<Accepted>, Error> {
/// Returns `Err((err, alert))` if an error occurred. If an alert is returned, the
/// application should call `alert.write()` to send the alert to the client. It should
/// not call `accept()` again.
pub fn accept(&mut self) -> Result<Option<Accepted>, (Error, AcceptedAlert)> {
let mut connection = match self.inner.take() {
Some(conn) => conn,
None => {
return Err(Error::General("Acceptor polled after completion".into()));
return Err((
Error::General("Acceptor polled after completion".into()),
AcceptedAlert::empty(),
));
}
};

let message = match connection.first_handshake_message()? {
Some(msg) => msg,
None => {
let message = match connection.first_handshake_message() {
Ok(Some(msg)) => msg,
Ok(None) => {
self.inner = Some(connection);
return Ok(None);
}
Err(err) => return Err((err, AcceptedAlert::from(connection))),
};

let (_, sig_schemes) =
hs::process_client_hello(&message, false, &mut Context::from(&mut connection))?;
let mut cx = Context::from(&mut connection);
let sig_schemes = match hs::process_client_hello(&message, false, &mut cx) {
Ok((_, sig_schemes)) => sig_schemes,
Err(err) => {
return Err((err, AcceptedAlert::from(connection)));
}
};

Ok(Some(Accepted {
connection,
Expand All @@ -786,9 +798,39 @@
}))
}
}

/// Represents a TLS alert resulting from handling the client's `ClientHello` message.
///
/// When [`Acceptor::accept()`] returns an error, it yields an `AcceptedAlert` such that the
/// application can communicate failure to the client via [`AcceptedAlert::write()`].
pub struct AcceptedAlert(ChunkVecBuffer);

impl AcceptedAlert {
pub(super) fn empty() -> Self {
Self(ChunkVecBuffer::new(None))
}

/// Send the alert to the client.
pub fn write(&mut self, wr: &mut dyn io::Write) -> Result<usize, io::Error> {
self.0.write_to(wr)
}
}

impl From<ConnectionCommon<ServerConnectionData>> for AcceptedAlert {
fn from(conn: ConnectionCommon<ServerConnectionData>) -> Self {
Self(conn.core.common_state.sendable_tls)
}
}

impl Debug for AcceptedAlert {
fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
f.debug_struct("AcceptedAlert").finish()
}

Check warning on line 828 in rustls/src/server/server_conn.rs

View check run for this annotation

Codecov / codecov/patch

rustls/src/server/server_conn.rs#L826-L828

Added lines #L826 - L828 were not covered by tests
}
}

#[cfg(feature = "std")]
pub use connection::{Acceptor, ReadEarlyData, ServerConnection};
pub use connection::{AcceptedAlert, Acceptor, ReadEarlyData, ServerConnection};

/// Unbuffered version of `ServerConnection`
///
Expand Down Expand Up @@ -859,21 +901,29 @@
/// [`sign::CertifiedKey`] that should be used for the session. Returns an error if
/// configuration-dependent validation of the received `ClientHello` message fails.
#[cfg(feature = "std")]
pub fn into_connection(mut self, config: Arc<ServerConfig>) -> Result<ServerConnection, Error> {
self.connection
.set_max_fragment_size(config.max_fragment_size)?;
pub fn into_connection(
mut self,
config: Arc<ServerConfig>,
) -> Result<ServerConnection, (Error, AcceptedAlert)> {
if let Err(err) = self
.connection
.set_max_fragment_size(config.max_fragment_size)
{
// We have a connection here, but it won't contain an alert since the error
// is with the fragment size configured in the `ServerConfig`.
return Err((err, AcceptedAlert::empty()));

Check warning on line 914 in rustls/src/server/server_conn.rs

View check run for this annotation

Codecov / codecov/patch

rustls/src/server/server_conn.rs#L914

Added line #L914 was not covered by tests
}

self.connection.enable_secret_extraction = config.enable_secret_extraction;

let state = hs::ExpectClientHello::new(config, Vec::new());
let mut cx = hs::ServerContext::from(&mut self.connection);

let new = state.with_certified_key(
self.sig_schemes,
Self::client_hello_payload(&self.message),
&self.message,
&mut cx,
)?;
let ch = Self::client_hello_payload(&self.message);
let new = match state.with_certified_key(self.sig_schemes, ch, &self.message, &mut cx) {
Ok(new) => new,
Err(err) => return Err((err, AcceptedAlert::from(self.connection))),
};

self.connection.replace_state(new);
Ok(ServerConnection {
Expand All @@ -893,6 +943,12 @@
}
}

impl Debug for Accepted {
fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
f.debug_struct("Accepted").finish()
}

Check warning on line 949 in rustls/src/server/server_conn.rs

View check run for this annotation

Codecov / codecov/patch

rustls/src/server/server_conn.rs#L947-L949

Added lines #L947 - L949 were not covered by tests
}

struct Accepting;

impl State<ServerConnectionData> for Accepting {
Expand Down
60 changes: 53 additions & 7 deletions rustls/tests/api.rs
Original file line number Diff line number Diff line change
Expand Up @@ -33,8 +33,8 @@ use rustls::internal::msgs::message::{
use rustls::server::{ClientHello, ParsedCertificate, ResolvesServerCert};
use rustls::SupportedCipherSuite;
use rustls::{
sign, AlertDescription, CertificateError, ConnectionCommon, ContentType, Error, KeyLog,
PeerIncompatible, PeerMisbehaved, SideData,
sign, AlertDescription, CertificateError, ConnectionCommon, ContentType, Error, InvalidMessage,
KeyLog, PeerIncompatible, PeerMisbehaved, SideData,
};
use rustls::{CipherSuite, ProtocolVersion, SignatureScheme};
use rustls::{ClientConfig, ClientConnection};
Expand Down Expand Up @@ -5412,8 +5412,8 @@ fn test_acceptor() {
io::ErrorKind::Other,
);
assert_eq!(
acceptor.accept().err(),
Some(Error::General("Acceptor polled after completion".into()))
acceptor.accept().err().unwrap().0,
Error::General("Acceptor polled after completion".into())
);

let mut acceptor = Acceptor::default();
Expand All @@ -5422,24 +5422,70 @@ fn test_acceptor() {
.read_tls(&mut &buf[..3])
.unwrap(); // incomplete message
assert!(acceptor.accept().unwrap().is_none());

acceptor
.read_tls(&mut [0x80, 0x00].as_ref())
.unwrap(); // invalid message (len = 32k bytes)
assert!(acceptor.accept().is_err());
let (err, mut alert) = acceptor.accept().unwrap_err();
assert_eq!(err, Error::InvalidMessage(InvalidMessage::MessageTooLarge));
let mut alert_content = Vec::new();
let _ = alert.write(&mut alert_content);
let expected = build_alert(AlertLevel::Fatal, AlertDescription::DecodeError, &[]);
assert_eq!(alert_content, expected);

let mut acceptor = Acceptor::default();
// Minimal valid 1-byte application data message is not a handshake message
acceptor
.read_tls(&mut [0x17, 0x03, 0x03, 0x00, 0x01, 0x00].as_ref())
.unwrap();
assert!(acceptor.accept().is_err());
let (err, mut alert) = acceptor.accept().unwrap_err();
assert!(matches!(err, Error::InappropriateMessage { .. }));
let mut alert_content = Vec::new();
let _ = alert.write(&mut alert_content);
assert!(alert_content.is_empty()); // We do not expect an alert for this condition.

let mut acceptor = Acceptor::default();
// Minimal 1-byte ClientHello message is not a legal handshake message
acceptor
.read_tls(&mut [0x16, 0x03, 0x03, 0x00, 0x05, 0x01, 0x00, 0x00, 0x01, 0x00].as_ref())
.unwrap();
assert!(acceptor.accept().is_err());
let (err, mut alert) = acceptor.accept().unwrap_err();
assert!(matches!(err, Error::InvalidMessage(InvalidMessage::MissingData(_))));
let mut alert_content = Vec::new();
let _ = alert.write(&mut alert_content);
let expected = build_alert(AlertLevel::Fatal, AlertDescription::DecodeError, &[]);
assert_eq!(alert_content, expected);
}

#[test]
fn test_acceptor_rejected_handshake() {
use rustls::server::Acceptor;

let client_config = finish_client_config(KeyType::Ed25519, ClientConfig::builder_with_provider(provider::default_provider().into())
.with_protocol_versions(&[&rustls::version::TLS13])
.unwrap());
let mut client = ClientConnection::new(client_config.into(), server_name("localhost")).unwrap();
let mut buf = Vec::new();
client.write_tls(&mut buf).unwrap();

let server_config = finish_server_config(KeyType::Ed25519, ServerConfig::builder_with_provider(provider::default_provider().into())
.with_protocol_versions(&[&rustls::version::TLS12])
.unwrap());
let mut acceptor = Acceptor::default();
acceptor
.read_tls(&mut buf.as_slice())
.unwrap();
let accepted = acceptor.accept().unwrap().unwrap();
let ch = accepted.client_hello();
assert_eq!(ch.server_name(), Some("localhost"));

let (err, mut alert) = accepted.into_connection(server_config.into()).unwrap_err();
assert_eq!(err, Error::PeerIncompatible(PeerIncompatible::Tls12NotOfferedOrEnabled));

let mut alert_content = Vec::new();
let _ = alert.write(&mut alert_content);
let expected = build_alert(AlertLevel::Fatal, AlertDescription::ProtocolVersion, &[]);
assert_eq!(alert_content, expected);
}

#[test]
Expand Down
Loading