diff --git a/CHANGELOG.md b/CHANGELOG.md index f135f7b9..ef60745d 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -9,6 +9,11 @@ All changes in this project will be noted in this file. - Skyhash/2: Restored support for pipelines - Enable online (runtime) recovery of transactional failures due to disk errors +### Fixes + +- Fixed an issue where an incorrect handshake with multiple errors cause the client connection + to be terminated without yielding an error + ## Version 0.8.1 ### Additions diff --git a/server/src/engine/net/protocol/handshake.rs b/server/src/engine/net/protocol/handshake.rs index 7dd6ec22..8c4eddf5 100644 --- a/server/src/engine/net/protocol/handshake.rs +++ b/server/src/engine/net/protocol/handshake.rs @@ -50,6 +50,34 @@ pub enum ProtocolError { RejectAuth = 5, } +impl ProtocolError { + #[cold] + unsafe fn get_error( + invalid_first_byte: bool, + invalid_hs_version: bool, + invalid_proto_version: bool, + invalid_exchange_mode: bool, + invalid_query_mode: bool, + invalid_auth_mode: bool, + ) -> ProtocolError { + if invalid_first_byte { + ProtocolError::CorruptedHSPacket + } else if invalid_hs_version { + ProtocolError::RejectHSVersion + } else if invalid_proto_version { + ProtocolError::RejectProtocol + } else if invalid_exchange_mode { + ProtocolError::RejectExchangeMode + } else if invalid_query_mode { + ProtocolError::RejectQueryMode + } else if invalid_auth_mode { + ProtocolError::RejectAuth + } else { + impossible!() + } + } +} + /* handshake meta */ @@ -290,23 +318,17 @@ impl<'a> CHandshake<'a> { | invalid_query_mode | invalid_auth_mode, ) { - static ERROR: [ProtocolError; 6] = [ - ProtocolError::CorruptedHSPacket, - ProtocolError::RejectHSVersion, - ProtocolError::RejectProtocol, - ProtocolError::RejectExchangeMode, - ProtocolError::RejectQueryMode, - ProtocolError::RejectAuth, - ]; - return HandshakeResult::Error( - ERROR[((invalid_first_byte as u8 * 1) - | (invalid_hs_version as u8 * 2) - | (invalid_proto_version as u8 * 3) - | (invalid_exchange_mode as u8 * 4) - | (invalid_query_mode as u8 * 5) - | (invalid_auth_mode as u8) * 6) as usize - - 1usize], - ); + return HandshakeResult::Error(unsafe { + // UNSAFE(@ohsayan): it is guaranteed by the branch that one or more of these booleans are true + ProtocolError::get_error( + invalid_first_byte, + invalid_hs_version, + invalid_proto_version, + invalid_exchange_mode, + invalid_query_mode, + invalid_auth_mode, + ) + }); } // init header let static_header = CHandshakeStatic::new( diff --git a/server/src/engine/net/protocol/tests.rs b/server/src/engine/net/protocol/tests.rs index ed3b5448..66b7bf4a 100644 --- a/server/src/engine/net/protocol/tests.rs +++ b/server/src/engine/net/protocol/tests.rs @@ -61,6 +61,42 @@ const STATIC_HANDSHAKE_WITH_AUTH: CHandshakeStatic = CHandshakeStatic::new( handshake with no state changes */ +#[test] +fn handshake_with_multiple_errors() { + for (bad_hs, error) in [ + // all incorrect + ( + b"H\xFF\xFF\xFF\xFF\xFF5\n8\nsayanpass1234", + ProtocolError::RejectHSVersion, + ), + // protocol and continuing bytes + ( + b"H\x00\xFF\xFF\xFF\xFF5\n8\nsayanpass1234", + ProtocolError::RejectProtocol, + ), + // xchg and continuing bytes + ( + b"H\x00\x00\xFF\xFF\xFF5\n8\nsayanpass1234", + ProtocolError::RejectExchangeMode, + ), + // qmode and continuing bytes + ( + b"H\x00\x00\x00\xFF\xFF5\n8\nsayanpass1234", + ProtocolError::RejectQueryMode, + ), + // auth + ( + b"H\x00\x00\x00\x00\xFF5\n8\nsayanpass1234", + ProtocolError::RejectAuth, + ), + ] { + assert_eq!( + CHandshake::resume_with(&mut BufferedScanner::new(bad_hs), HandshakeState::Initial), + HandshakeResult::Error(error) + ); + } +} + #[test] fn parse_staged_with_auth() { for i in 0..FULL_HANDSHAKE_WITH_AUTH.len() {