Skip to content

Commit

Permalink
feat: rpc response message chunking (#3336)
Browse files Browse the repository at this point in the history
Description
---
Adds server-side "chunking" protocol for large RPC responses.

Given a payload exceeding some threshold, comms RPC will split message payloads into a number (maximum 16) 
RPC messages and stream them to the client. The client will recombine the message payloads before emitting it to
the caller.  

BREAKING CHANGE: Backward compatible (client-side), an upgraded node may request from an older node and process
responses. Since an older node cannot process chunks, if a response contains any, the response will be invalid to the older node. Otherwise, if less than the payload threshold, the response is exactly as before. 

Motivation and Context
---
Large streamed message payloads (those seen in block sync) are often ~ 1.2 MiB over the wire. 
The [Bandwidth Delay product](https://www.wikiwand.com/en/Bandwidth-delay_product) is often low/very low for connections over tor resulting in large message frames will take many seconds to be completely received. 

This PR breaks message payloads into smaller chunks with less time between receiving a full frame on the client side. 

The algorithm for chunking is as follows:
```
THRESHOLD = 256kb 
SIZE_LIMIT = 384Kb
MAX_CHUNKS = 16
```
a. `payload <= THRESHOLD`: emit a single frame
b. `payload > THRESHOLD && payload <= SIZE_LIMIT `: emit a single frame containing the full payload
c. `payload > SIZE_LIMIT `: emit a one or more frames containing the next `SIZE_THRESHOLD` bytes, with the final frame either as per (a) or (b).

This is done to prevent an unnecessary/very small trailing chunk when a payload is only a handful of bytes over the threshold
e.g 256Kb (`THRESHOLD`) + 1 byte results in _one_ chunk not _two_.  384Kb (`SIZE_LIMIT`) + 1 kb results 
in 2 frames of 256Kb and 129Kb respectively. meaning each chunk has a maximum size of `SIZE_LIMIT` and the
trailing/last chunk has a minimum size of `SIZE_LIMIT - THRESHOLD  (128kb)`

How Has This Been Tested?
---

Some new unit tests, existing large message tests, manual archival sync using force sync
  • Loading branch information
sdbondi committed Sep 13, 2021
1 parent 8512939 commit 496ff14
Show file tree
Hide file tree
Showing 17 changed files with 577 additions and 233 deletions.
Expand Up @@ -1003,11 +1003,7 @@ impl tari_rpc::base_node_server::BaseNode for BaseNodeGrpcServer {
.state_info
.get_block_sync_info()
.map(|info| {
let node_ids = info
.sync_peers
.iter()
.map(|x| x.to_string().as_bytes().to_vec())
.collect();
let node_ids = info.sync_peers.iter().map(|x| x.to_string().into_bytes()).collect();
tari_rpc::SyncInfoResponse {
tip_height: info.tip_height,
local_height: info.local_height,
Expand Down
1 change: 0 additions & 1 deletion applications/tari_base_node/src/main.rs
@@ -1,4 +1,3 @@
#![recursion_limit = "1024"]
// Copyright 2019. The Tari Project
//
// Redistribution and use in source and binary forms, with or without modification, are permitted provided that the
Expand Down
Expand Up @@ -98,9 +98,9 @@ impl wallet_server::Wallet for WalletGrpcServer {
async fn identify(&self, _: Request<GetIdentityRequest>) -> Result<Response<GetIdentityResponse>, Status> {
let identity = self.wallet.comms.node_identity();
Ok(Response::new(GetIdentityResponse {
public_key: identity.public_key().to_string().as_bytes().to_vec(),
public_key: identity.public_key().to_string().into_bytes(),
public_address: identity.public_address().to_string(),
node_id: identity.node_id().to_string().as_bytes().to_vec(),
node_id: identity.node_id().to_string().into_bytes(),
}))
}

Expand Down
Expand Up @@ -1546,8 +1546,7 @@ struct KeyManagerStateUpdateSql {
impl Encryptable<Aes256Gcm> for KeyManagerStateSql {
fn encrypt(&mut self, cipher: &Aes256Gcm) -> Result<(), Error> {
let encrypted_master_key = encrypt_bytes_integral_nonce(&cipher, self.master_key.clone())?;
let encrypted_branch_seed =
encrypt_bytes_integral_nonce(&cipher, self.branch_seed.clone().as_bytes().to_vec())?;
let encrypted_branch_seed = encrypt_bytes_integral_nonce(&cipher, self.branch_seed.clone().into_bytes())?;
self.master_key = encrypted_master_key;
self.branch_seed = encrypted_branch_seed.to_hex();
Ok(())
Expand Down
2 changes: 1 addition & 1 deletion base_layer/wallet/src/storage/sqlite_db.rs
Expand Up @@ -588,7 +588,7 @@ impl ClientKeyValueSql {
impl Encryptable<Aes256Gcm> for ClientKeyValueSql {
#[allow(unused_assignments)]
fn encrypt(&mut self, cipher: &Aes256Gcm) -> Result<(), AeadError> {
let encrypted_value = encrypt_bytes_integral_nonce(&cipher, self.clone().value.as_bytes().to_vec())?;
let encrypted_value = encrypt_bytes_integral_nonce(&cipher, self.value.as_bytes().to_vec())?;
self.value = encrypted_value.to_hex();
Ok(())
}
Expand Down
Expand Up @@ -1028,8 +1028,7 @@ impl InboundTransactionSql {

impl Encryptable<Aes256Gcm> for InboundTransactionSql {
fn encrypt(&mut self, cipher: &Aes256Gcm) -> Result<(), AeadError> {
let encrypted_protocol =
encrypt_bytes_integral_nonce(&cipher, self.receiver_protocol.clone().as_bytes().to_vec())?;
let encrypted_protocol = encrypt_bytes_integral_nonce(&cipher, self.receiver_protocol.as_bytes().to_vec())?;
self.receiver_protocol = encrypted_protocol.to_hex();
Ok(())
}
Expand Down Expand Up @@ -1211,8 +1210,7 @@ impl OutboundTransactionSql {

impl Encryptable<Aes256Gcm> for OutboundTransactionSql {
fn encrypt(&mut self, cipher: &Aes256Gcm) -> Result<(), AeadError> {
let encrypted_protocol =
encrypt_bytes_integral_nonce(&cipher, self.sender_protocol.clone().as_bytes().to_vec())?;
let encrypted_protocol = encrypt_bytes_integral_nonce(&cipher, self.sender_protocol.as_bytes().to_vec())?;
self.sender_protocol = encrypted_protocol.to_hex();
Ok(())
}
Expand Down Expand Up @@ -1534,8 +1532,7 @@ impl CompletedTransactionSql {

impl Encryptable<Aes256Gcm> for CompletedTransactionSql {
fn encrypt(&mut self, cipher: &Aes256Gcm) -> Result<(), AeadError> {
let encrypted_protocol =
encrypt_bytes_integral_nonce(&cipher, self.transaction_protocol.clone().as_bytes().to_vec())?;
let encrypted_protocol = encrypt_bytes_integral_nonce(&cipher, self.transaction_protocol.as_bytes().to_vec())?;
self.transaction_protocol = encrypted_protocol.to_hex();
Ok(())
}
Expand Down
2 changes: 1 addition & 1 deletion comms/dht/src/store_forward/saf_handler/task.rs
Expand Up @@ -599,7 +599,7 @@ mod test {
dht_header: DhtMessageHeader,
stored_at: NaiveDateTime,
) -> StoredMessage {
let body = message.as_bytes().to_vec();
let body = message.into_bytes();
let body_hash = hex::to_hex(&Challenge::new().chain(body.clone()).finalize());
StoredMessage {
id: 1,
Expand Down
4 changes: 2 additions & 2 deletions comms/src/proto/rpc.proto
Expand Up @@ -16,7 +16,7 @@ message RpcRequest {
uint64 deadline = 4;

// The message payload
bytes message = 10;
bytes payload = 10;
}

// Message type for all RPC responses
Expand All @@ -29,7 +29,7 @@ message RpcResponse {
uint32 flags = 3;

// The message payload. If the status is non-zero, this contains additional error details.
bytes message = 10;
bytes payload = 10;
}

// Message sent by the client when negotiating an RPC session. A server may close the substream if it does
Expand Down
11 changes: 7 additions & 4 deletions comms/src/protocol/rpc/body.rs
Expand Up @@ -177,6 +177,10 @@ impl BodyBytes {
pub fn into_vec(self) -> Vec<u8> {
self.0.map(|bytes| bytes.to_vec()).unwrap_or_else(Vec::new)
}

pub fn into_bytes(self) -> Option<Bytes> {
self.0
}
}

#[allow(clippy::from_over_into)]
Expand All @@ -186,10 +190,9 @@ impl Into<Bytes> for BodyBytes {
}
}

#[allow(clippy::from_over_into)]
impl Into<Vec<u8>> for BodyBytes {
fn into(self) -> Vec<u8> {
self.into_vec()
impl From<BodyBytes> for Vec<u8> {
fn from(body: BodyBytes) -> Self {
body.into_vec()
}
}

Expand Down
160 changes: 111 additions & 49 deletions comms/src/protocol/rpc/client.rs
Expand Up @@ -34,6 +34,7 @@ use crate::{
Response,
RpcError,
RpcStatus,
RPC_CHUNKING_MAX_CHUNKS,
},
ProtocolId,
},
Expand Down Expand Up @@ -239,7 +240,7 @@ where TClient: From<RpcClient> + NamedProtocolService
}
}

#[derive(Debug, Clone)]
#[derive(Debug, Clone, Copy)]
pub struct RpcClientConfig {
pub deadline: Option<Duration>,
pub deadline_grace_period: Duration,
Expand Down Expand Up @@ -489,7 +490,8 @@ impl RpcClientWorker {
self.protocol_name(),
start.elapsed()
);
let resp = match self.read_reply().await {
let mut reader = RpcResponseReader::new(&mut self.framed, self.config, 0);
let resp = match reader.read_ack().await {
Ok(resp) => resp,
Err(RpcError::ReplyTimeout) => {
debug!(
Expand Down Expand Up @@ -529,7 +531,7 @@ impl RpcClientWorker {
Ok(())
}

#[tracing::instrument(name = "rpc_do_request_response", skip(self, reply))]
#[tracing::instrument(name = "rpc_do_request_response", skip(self, reply, request), fields(request_method = ?request.method, request_size = request.message.len()))]
async fn do_request_response(
&mut self,
request: BaseRequest<Bytes>,
Expand All @@ -542,7 +544,7 @@ impl RpcClientWorker {
method,
deadline: self.config.deadline.map(|t| t.as_secs()).unwrap_or(0),
flags: 0,
message: request.message.to_vec(),
payload: request.message.to_vec(),
};

debug!(target: LOG_TARGET, "Sending request: {}", req);
Expand Down Expand Up @@ -575,14 +577,14 @@ impl RpcClientWorker {
}

loop {
let resp = match self.read_reply().await {
let resp = match self.read_response(request_id).await {
Ok(resp) => {
let latency = start.elapsed();
event!(Level::TRACE, "Message received");
trace!(
target: LOG_TARGET,
"Received response ({} byte(s)) from request #{} (protocol = {}, method={}) in {:.0?}",
resp.message.len(),
resp.payload.len(),
request_id,
self.protocol_name(),
method,
Expand Down Expand Up @@ -617,12 +619,19 @@ impl RpcClientWorker {
break;
},
Err(err) => {
event!(Level::ERROR, "Errored:{}", err);
event!(
Level::WARN,
"Request {} (method={}) returned an error after {:.0?}: {}",
request_id,
method,
start.elapsed(),
err
);
return Err(err);
},
};

match Self::convert_to_result(resp, request_id) {
match Self::convert_to_result(resp) {
Ok(Ok(resp)) => {
// The consumer may drop the receiver before all responses are received.
// We just ignore that as we still want obey the protocol and receive messages until the FIN flag or
Expand Down Expand Up @@ -665,27 +674,10 @@ impl RpcClientWorker {
Ok(())
}

async fn read_reply(&mut self) -> Result<proto::rpc::RpcResponse, RpcError> {
// Wait until the timeout, allowing an extra grace period to account for latency
let next_msg_fut = match self.config.timeout_with_grace_period() {
Some(timeout) => Either::Left(time::timeout(timeout, self.framed.next())),
None => Either::Right(self.framed.next().map(Ok)),
};

let result = tokio::select! {
biased;
_ = &mut self.shutdown_signal => {
return Err(RpcError::ClientClosed);
}
result = next_msg_fut => result,
};

match result {
Ok(Some(Ok(resp))) => Ok(proto::rpc::RpcResponse::decode(resp)?),
Ok(Some(Err(err))) => Err(err.into()),
Ok(None) => Err(RpcError::ServerClosedRequest),
Err(_) => Err(RpcError::ReplyTimeout),
}
async fn read_response(&mut self, request_id: u16) -> Result<proto::rpc::RpcResponse, RpcError> {
let mut reader = RpcResponseReader::new(&mut self.framed, self.config, request_id);
let resp = reader.read_response().await?;
Ok(resp)
}

fn next_request_id(&mut self) -> u16 {
Expand All @@ -695,33 +687,15 @@ impl RpcClientWorker {
next_id
}

fn convert_to_result(
resp: proto::rpc::RpcResponse,
request_id: u16,
) -> Result<Result<Response<Bytes>, RpcStatus>, RpcError> {
let resp_id = u16::try_from(resp.request_id)
.map_err(|_| RpcStatus::protocol_error(format!("invalid request_id: must be less than {}", u16::MAX)))?;

let flags = RpcMessageFlags::from_bits_truncate(resp.flags as u8);
if flags.contains(RpcMessageFlags::ACK) {
return Err(RpcError::UnexpectedAckResponse);
}

if resp_id != request_id {
return Err(RpcError::ResponseIdDidNotMatchRequest {
expected: request_id,
actual: resp.request_id as u16,
});
}

fn convert_to_result(resp: proto::rpc::RpcResponse) -> Result<Result<Response<Bytes>, RpcStatus>, RpcError> {
let status = RpcStatus::from(&resp);
if !status.is_ok() {
return Ok(Err(status));
}

let resp = Response {
flags: resp.flags(),
message: resp.message.into(),
payload: resp.payload.into(),
};

Ok(Ok(resp))
Expand All @@ -736,3 +710,91 @@ pub enum ClientRequest {
GetLastRequestLatency(oneshot::Sender<Option<Duration>>),
SendPing(oneshot::Sender<Result<Duration, RpcStatus>>),
}

struct RpcResponseReader<'a> {
framed: &'a mut CanonicalFraming<Substream>,
config: RpcClientConfig,
request_id: u16,
}
impl<'a> RpcResponseReader<'a> {
pub fn new(framed: &'a mut CanonicalFraming<Substream>, config: RpcClientConfig, request_id: u16) -> Self {
Self {
framed,
config,
request_id,
}
}

pub async fn read_response(&mut self) -> Result<proto::rpc::RpcResponse, RpcError> {
let mut resp = self.next().await?;
self.check_response(&resp)?;
let mut chunk_count = 1;
let mut last_chunk_flags = RpcMessageFlags::from_bits_truncate(resp.flags as u8);
let mut last_chunk_size = resp.payload.len();
loop {
trace!(
target: LOG_TARGET,
"Chunk {} received (flags={:?}, {} bytes, {} total)",
chunk_count,
last_chunk_flags,
last_chunk_size,
resp.payload.len()
);
if !last_chunk_flags.is_more() {
return Ok(resp);
}

if chunk_count >= RPC_CHUNKING_MAX_CHUNKS {
return Err(RpcError::ExceededMaxChunkCount {
expected: RPC_CHUNKING_MAX_CHUNKS,
});
}

let msg = self.next().await?;
last_chunk_flags = RpcMessageFlags::from_bits_truncate(msg.flags as u8);
last_chunk_size = msg.payload.len();
self.check_response(&resp)?;
resp.payload.extend(msg.payload);
chunk_count += 1;
}
}

pub async fn read_ack(&mut self) -> Result<proto::rpc::RpcResponse, RpcError> {
let resp = self.next().await?;
Ok(resp)
}

fn check_response(&self, resp: &proto::rpc::RpcResponse) -> Result<(), RpcError> {
let resp_id = u16::try_from(resp.request_id)
.map_err(|_| RpcStatus::protocol_error(format!("invalid request_id: must be less than {}", u16::MAX)))?;

let flags = RpcMessageFlags::from_bits_truncate(resp.flags as u8);
if flags.contains(RpcMessageFlags::ACK) {
return Err(RpcError::UnexpectedAckResponse);
}

if resp_id != self.request_id {
return Err(RpcError::ResponseIdDidNotMatchRequest {
expected: self.request_id,
actual: resp.request_id as u16,
});
}

Ok(())
}

async fn next(&mut self) -> Result<proto::rpc::RpcResponse, RpcError> {
// Wait until the timeout, allowing an extra grace period to account for latency
let next_msg_fut = match self.config.timeout_with_grace_period() {
Some(timeout) => Either::Left(time::timeout(timeout, self.framed.next())),
None => Either::Right(self.framed.next().map(Ok)),
};

match next_msg_fut.await {
Ok(Some(Ok(resp))) => Ok(proto::rpc::RpcResponse::decode(resp)?),
Ok(Some(Err(err))) => Err(err.into()),
Ok(None) => Err(RpcError::ServerClosedRequest),
Err(_) => Err(RpcError::ReplyTimeout),
}
}
}
2 changes: 2 additions & 0 deletions comms/src/protocol/rpc/error.rs
Expand Up @@ -65,6 +65,8 @@ pub enum RpcError {
InvalidPingResponse,
#[error("Unexpected ACK response. This is likely because of a previous ACK timeout")]
UnexpectedAckResponse,
#[error("Attempted to send more than {expected} payload chunks")]
ExceededMaxChunkCount { expected: usize },
#[error(transparent)]
UnknownError(#[from] anyhow::Error),
}
Expand Down

0 comments on commit 496ff14

Please sign in to comment.