Skip to content

Commit

Permalink
net: Retries based on expect values are pointless
Browse files Browse the repository at this point in the history
  • Loading branch information
ohsayan committed Apr 7, 2024
1 parent 337ea9e commit b961e84
Show file tree
Hide file tree
Showing 6 changed files with 119 additions and 130 deletions.
24 changes: 4 additions & 20 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 1 addition & 1 deletion server/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ libc = "0.2.153"
# external deps
rand = "0.8.5"
tokio = { version = "1.37.0", features = ["test-util"] }
skytable = { git = "https://github.com/skytable/client-rust.git", branch = "feature/pipeline-batch" }
skytable = { git = "https://github.com/skytable/client-rust.git" }

[features]
nightly = []
Expand Down
29 changes: 9 additions & 20 deletions server/src/engine/net/protocol/handshake.rs
Original file line number Diff line number Diff line change
Expand Up @@ -210,12 +210,7 @@ pub enum HandshakeResult<'a> {
/// Finished handshake
Completed(CHandshake<'a>),
/// Update handshake state
///
/// **NOTE:** expect does not take into account the current amount of buffered data (hence the unbuffered part must be computed!)
ChangeState {
new_state: HandshakeState,
expect: usize,
},
ChangeState { new_state: HandshakeState },
/// An error occurred
Error(ProtocolError),
}
Expand Down Expand Up @@ -271,17 +266,15 @@ impl<'a> CHandshake<'a> {
/// Resume from the initial state (nothing buffered yet)
fn resume_initial(scanner: &mut BufferedScanner<'a>) -> HandshakeResult<'a> {
// get our block
if cfg!(debug_assertions) {
if scanner.remaining() < Self::INITIAL_READ {
return HandshakeResult::ChangeState {
new_state: HandshakeState::Initial,
expect: Self::INITIAL_READ,
};
}
} else {
assert!(scanner.remaining() >= Self::INITIAL_READ);
if scanner.remaining() < Self::INITIAL_READ {
return HandshakeResult::ChangeState {
new_state: HandshakeState::Initial,
};
}
let buf: [u8; CHandshake::INITIAL_READ] = unsafe { scanner.next_chunk() };
let buf: [u8; CHandshake::INITIAL_READ] = unsafe {
// UNSAFE(@ohsayan): validated in earlier branch
scanner.next_chunk()
};
let invalid_first_byte = buf[0] != Self::CLIENT_HELLO;
let invalid_hs_version = buf[1] > HandshakeVersion::MAX_DSCR;
let invalid_proto_version = buf[2] > ProtocolVersion::MAX_DSCR;
Expand Down Expand Up @@ -350,7 +343,6 @@ impl<'a> CHandshake<'a> {
uname_l,
pwd_l,
},
expect: (uname_l + pwd_l),
}
}
}
Expand All @@ -366,7 +358,6 @@ impl<'a> CHandshake<'a> {
// we need more data
return HandshakeResult::ChangeState {
new_state: HandshakeState::StaticBlock(static_header),
expect: static_header.auth_mode.min_payload_bytes(),
};
}
// we seem to have enough data for this auth mode
Expand All @@ -379,7 +370,6 @@ impl<'a> CHandshake<'a> {
ScannerDecodeResult::NeedMore => {
return HandshakeResult::ChangeState {
new_state: HandshakeState::StaticBlock(static_header),
expect: AuthMode::Password.min_payload_bytes(), // 2 for uname_l and 2 for pwd_l
};
}
ScannerDecodeResult::Value(v) => v as usize,
Expand All @@ -402,7 +392,6 @@ impl<'a> CHandshake<'a> {
// newline missing (or maybe there's more?)
return HandshakeResult::ChangeState {
new_state: HandshakeState::ExpectingMetaForVariableBlock { static_hs, uname_l },
expect: uname_l + 2, // space for username + password len
};
}
ScannerDecodeResult::Error => {
Expand Down
177 changes: 98 additions & 79 deletions server/src/engine/net/protocol/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -48,8 +48,8 @@ use {
self::{
exchange::{Exchange, ExchangeResult, ExchangeState, Pipeline},
handshake::{
AuthMode, CHandshake, DataExchangeMode, HandshakeResult, HandshakeState,
HandshakeVersion, ProtocolError, ProtocolVersion, QueryMode,
AuthMode, DataExchangeMode, HandshakeResult, HandshakeState, HandshakeVersion,
ProtocolError, ProtocolVersion, QueryMode,
},
},
super::{IoResult, QueryLoopResult, Socket},
Expand Down Expand Up @@ -107,14 +107,43 @@ impl ClientLocalState {
}
}

/*
read loop
*/

macro_rules! read_loop {
($con:expr, $buf:expr, $conn_closed:expr, $conn_reset:expr, $body:block) => {
loop {
let read_many = $con.read_buf($buf).await?;
if read_many == 0 {
if $buf.is_empty() {
return $conn_closed;
} else {
return $conn_reset;
}
}
$body
}
};
}

/*
handshake
*/

#[inline(always)]
async fn write_handshake_error<S: Socket>(
con: &mut BufWriter<S>,
e: ProtocolError,
) -> IoResult<()> {
let hs_err_packet = [b'H', 0, 1, e.value_u8()];
con.write_all(&hs_err_packet).await
}

#[derive(Debug, PartialEq)]
enum PostHandshake {
enum HandshakeCompleteResult {
Okay(ClientLocalState),
Error(ProtocolError),
Error,
ConnectionClosedFin,
ConnectionClosedRst,
}
Expand All @@ -123,40 +152,34 @@ async fn do_handshake<S: Socket>(
con: &mut BufWriter<S>,
buf: &mut BytesMut,
global: &Global,
) -> IoResult<PostHandshake> {
let mut expected = CHandshake::INITIAL_READ;
) -> IoResult<HandshakeCompleteResult> {
let mut state = HandshakeState::default();
let mut cursor = 0;
let handshake;
loop {
let read_many = con.read_buf(buf).await?;
if read_many == 0 {
if buf.is_empty() {
return Ok(PostHandshake::ConnectionClosedFin);
} else {
return Ok(PostHandshake::ConnectionClosedRst);
}
}
if buf.len() < expected {
continue;
}
let mut scanner = unsafe { BufferedScanner::new_with_cursor(buf, cursor) };
match handshake::CHandshake::resume_with(&mut scanner, state) {
HandshakeResult::Completed(hs) => {
handshake = hs;
cursor = scanner.cursor();
break;
}
HandshakeResult::ChangeState { new_state, expect } => {
expected = expect;
state = new_state;
cursor = scanner.cursor();
}
HandshakeResult::Error(e) => {
return Ok(PostHandshake::Error(e));
read_loop!(
con,
buf,
Ok(HandshakeCompleteResult::ConnectionClosedFin),
Ok(HandshakeCompleteResult::ConnectionClosedRst),
{
let mut scanner = unsafe { BufferedScanner::new_with_cursor(buf, cursor) };
match handshake::CHandshake::resume_with(&mut scanner, state) {
HandshakeResult::Completed(hs) => {
handshake = hs;
cursor = scanner.cursor();
break;
}
HandshakeResult::ChangeState { new_state } => {
state = new_state;
cursor = scanner.cursor();
}
HandshakeResult::Error(e) => {
write_handshake_error(con, e).await?;
return Ok(HandshakeCompleteResult::Error);
}
}
}
}
);
// check handshake
if cfg!(debug_assertions) {
assert_eq!(
Expand All @@ -181,7 +204,7 @@ async fn do_handshake<S: Socket>(
{
okay @ (VerifyUser::Okay | VerifyUser::OkayRoot) => {
let hs = handshake.hs_static();
let ret = Ok(PostHandshake::Okay(ClientLocalState::new(
let ret = Ok(HandshakeCompleteResult::Okay(ClientLocalState::new(
uname.into(),
okay.is_root(),
hs,
Expand All @@ -194,7 +217,8 @@ async fn do_handshake<S: Socket>(
}
Err(_) => {}
};
Ok(PostHandshake::Error(ProtocolError::RejectAuth))
write_handshake_error(con, ProtocolError::RejectAuth).await?;
Ok(HandshakeCompleteResult::Error)
}

/*
Expand All @@ -217,60 +241,55 @@ pub(super) async fn query_loop<S: Socket>(
) -> IoResult<QueryLoopResult> {
// handshake
let mut client_state = match do_handshake(con, buf, global).await? {
PostHandshake::Okay(hs) => hs,
PostHandshake::ConnectionClosedFin => return Ok(QueryLoopResult::Fin),
PostHandshake::ConnectionClosedRst => return Ok(QueryLoopResult::Rst),
PostHandshake::Error(e) => {
// failed to handshake; we'll close the connection
let hs_err_packet = [b'H', 0, 1, e.value_u8()];
con.write_all(&hs_err_packet).await?;
return Ok(QueryLoopResult::HSFailed);
}
HandshakeCompleteResult::Okay(hs) => hs,
HandshakeCompleteResult::ConnectionClosedFin => return Ok(QueryLoopResult::Fin),
HandshakeCompleteResult::ConnectionClosedRst => return Ok(QueryLoopResult::Rst),
HandshakeCompleteResult::Error => return Ok(QueryLoopResult::HSFailed),
};
// done handshaking
con.write_all(b"H\x00\x00\x00").await?;
con.flush().await?;
let mut state = ExchangeState::default();
let mut cursor = 0;
loop {
if con.read_buf(buf).await? == 0 {
if buf.is_empty() {
return Ok(QueryLoopResult::Fin);
} else {
return Ok(QueryLoopResult::Rst);
}
}
match Exchange::try_complete(
unsafe {
// UNSAFE(@ohsayan): the cursor is either 0 or returned by the exchange impl
BufferedScanner::new_with_cursor(&buf, cursor)
},
state,
) {
Ok((result, new_cursor)) => match result {
ExchangeResult::NewState(new_state) => {
state = new_state;
cursor = new_cursor;
}
ExchangeResult::Simple(query) => {
exec_simple(con, &mut client_state, global, query).await?;
read_loop!(
con,
buf,
Ok(QueryLoopResult::Fin),
Ok(QueryLoopResult::Rst),
{
match Exchange::try_complete(
unsafe {
// UNSAFE(@ohsayan): the cursor is either 0 or returned by the exchange impl
BufferedScanner::new_with_cursor(&buf, cursor)
},
state,
) {
Ok((result, new_cursor)) => match result {
ExchangeResult::NewState(new_state) => {
state = new_state;
cursor = new_cursor;
}
ExchangeResult::Simple(query) => {
exec_simple(con, &mut client_state, global, query).await?;
(state, cursor) = cleanup_for_next_query(con, buf).await?;
}
ExchangeResult::Pipeline(pipe) => {
exec_pipe(con, &mut client_state, global, pipe).await?;
(state, cursor) = cleanup_for_next_query(con, buf).await?;
}
},
Err(_) => {
// respond with error
let [a, b] = (QueryError::SysNetworkSystemIllegalClientPacket.value_u8()
as u16)
.to_le_bytes();
con.write_all(&[ResponseType::Error.value_u8(), a, b])
.await?;
(state, cursor) = cleanup_for_next_query(con, buf).await?;
}
ExchangeResult::Pipeline(pipe) => {
exec_pipe(con, &mut client_state, global, pipe).await?;
(state, cursor) = cleanup_for_next_query(con, buf).await?;
}
},
Err(_) => {
// respond with error
let [a, b] = (QueryError::SysNetworkSystemIllegalClientPacket.value_u8() as u16)
.to_le_bytes();
con.write_all(&[ResponseType::Error.value_u8(), a, b])
.await?;
(state, cursor) = cleanup_for_next_query(con, buf).await?;
}
}
}
);
}

/*
Expand Down
Loading

0 comments on commit b961e84

Please sign in to comment.