Skip to content

Commit

Permalink
fix: change truncate_from_bits to from_bits (#5773)
Browse files Browse the repository at this point in the history
Description
---
This removes all occurrences of `truncate_from_bits` to `from_bits`

Motivation and Context
---
`truncate_from_bits` will truncate all unknown bits and may cause bits
to be interpreted as the wrong flag. This changes to the much more
strict `from_bits` which forces peers to use the correct bits and only
the correct bits.
  • Loading branch information
SWvheerden committed Sep 15, 2023
1 parent a9b730a commit fb18078
Show file tree
Hide file tree
Showing 10 changed files with 105 additions and 25 deletions.
2 changes: 2 additions & 0 deletions comms/core/src/peer_manager/error.rs
Expand Up @@ -53,6 +53,8 @@ pub enum PeerManagerError {
InvalidPeerFeatures { bits: u32 },
#[error("Address {address} not found for peer {node_id}")]
AddressNotFoundError { address: Multiaddr, node_id: NodeId },
#[error("Protocol error: {0}")]
ProtocolError(String),
}

impl PeerManagerError {
Expand Down
5 changes: 4 additions & 1 deletion comms/core/src/peer_manager/peer_identity_claim.rs
Expand Up @@ -66,7 +66,10 @@ impl TryFrom<PeerIdentityMsg> for PeerIdentityClaim {
if addresses.is_empty() {
return Err(PeerManagerError::PeerIdentityNoValidAddresses);
}
let features = PeerFeatures::from_bits_truncate(value.features);
let features = PeerFeatures::from_bits(value.features).ok_or(PeerManagerError::ProtocolError(format!(
"Invalid message flag, does not match any flags ({})",
value.features
)))?;

if let Some(signature) = value.identity_signature {
Ok(Self {
Expand Down
4 changes: 3 additions & 1 deletion comms/core/src/protocol/error.rs
Expand Up @@ -29,6 +29,8 @@ use thiserror::Error;
pub enum ProtocolError {
#[error("IO error: {0}")]
IoError(#[from] io::Error),
#[error("Invalid flag: {0}")]
InvalidFlag(String),
#[error("The ProtocolId was longer than {}", u8::max_value())]
ProtocolIdTooLong,
#[error("Protocol negotiation failed because the peer did not accept any of the given protocols: {protocols}")]
Expand Down Expand Up @@ -56,7 +58,7 @@ impl ProtocolError {
ProtocolError::ProtocolOptimisticNegotiationFailed |
ProtocolError::NotificationSenderDisconnected => false,

ProtocolError::ProtocolIdTooLong => true,
ProtocolError::ProtocolIdTooLong | ProtocolError::InvalidFlag(_) => true,
}
}
}
5 changes: 4 additions & 1 deletion comms/core/src/protocol/negotiation.rs
Expand Up @@ -189,7 +189,10 @@ where TSocket: AsyncRead + AsyncWrite + Unpin
// Len can never overflow the buffer because the buffer len is u8::MAX and the length delimiter
// is a u8. If that changes, then len should be checked for overflow
let len = u8::from_be_bytes([self.buf[0]]) as usize;
let flags = Flags::from_bits_truncate(u8::from_be_bytes([self.buf[1]]));
let flags = Flags::from_bits(u8::from_be_bytes([self.buf[1]])).ok_or(ProtocolError::InvalidFlag(format!(
"Does not match any flags ({})",
self.buf[1]
)))?;
self.socket.read_exact(&mut self.buf[0..len]).await?;
trace!(
target: LOG_TARGET,
Expand Down
40 changes: 30 additions & 10 deletions comms/core/src/protocol/rpc/client/mod.rs
Expand Up @@ -72,6 +72,7 @@ use crate::{
NamedProtocolService,
Response,
RpcError,
RpcServerError,
RpcStatus,
RPC_CHUNKING_MAX_CHUNKS,
},
Expand Down Expand Up @@ -574,9 +575,13 @@ where TSubstream: AsyncRead + AsyncWrite + Unpin + Send + StreamId
}

let resp_flags =
RpcMessageFlags::from_bits_truncate(u8::try_from(resp.flags).map_err(|_| {
RpcMessageFlags::from_bits(u8::try_from(resp.flags).map_err(|_| {
RpcStatus::protocol_error(&format!("invalid message flag: must be less than {}", u8::MAX))
})?);
})?)
.ok_or(RpcStatus::protocol_error(&format!(
"invalid message flag, does not match any flags ({})",
resp.flags
)))?;
if !resp_flags.contains(RpcMessageFlags::ACK) {
warn!(
target: LOG_TARGET,
Expand Down Expand Up @@ -871,9 +876,12 @@ where TSubstream: AsyncRead + AsyncWrite + Unpin + Send + StreamId
if !status.is_ok() {
return Ok(Err(status));
}

let flags = match resp.flags() {
Ok(flags) => flags,
Err(e) => return Ok(Err(RpcError::ServerError(RpcServerError::ProtocolError(e)).into())),
};
let resp = Response {
flags: resp.flags(),
flags,
payload: resp.payload.into(),
};

Expand Down Expand Up @@ -925,9 +933,13 @@ where TSubstream: AsyncRead + AsyncWrite + Unpin
self.check_response(&resp)?;
let mut chunk_count = 1;
let mut last_chunk_flags =
RpcMessageFlags::from_bits_truncate(u8::try_from(resp.flags).map_err(|_| {
RpcMessageFlags::from_bits(u8::try_from(resp.flags).map_err(|_| {
RpcStatus::protocol_error(&format!("invalid message flag: must be less than {}", u8::MAX))
})?);
})?)
.ok_or(RpcStatus::protocol_error(&format!(
"invalid message flag, does not match any flags ({})",
resp.flags
)))?;
let mut last_chunk_size = resp.payload.len();
self.bytes_read += last_chunk_size;
loop {
Expand All @@ -950,9 +962,13 @@ where TSubstream: AsyncRead + AsyncWrite + Unpin
}

let msg = self.next().await?;
last_chunk_flags = RpcMessageFlags::from_bits_truncate(u8::try_from(msg.flags).map_err(|_| {
last_chunk_flags = RpcMessageFlags::from_bits(u8::try_from(msg.flags).map_err(|_| {
RpcStatus::protocol_error(&format!("invalid message flag: must be less than {}", u8::MAX))
})?);
})?)
.ok_or(RpcStatus::protocol_error(&format!(
"invalid message flag, does not match any flags ({})",
resp.flags
)))?;
last_chunk_size = msg.payload.len();
self.bytes_read += last_chunk_size;
self.check_response(&resp)?;
Expand All @@ -971,9 +987,13 @@ where TSubstream: AsyncRead + AsyncWrite + Unpin
.map_err(|_| RpcStatus::protocol_error(&format!("invalid request_id: must be less than {}", u16::MAX)))?;

let flags =
RpcMessageFlags::from_bits_truncate(u8::try_from(resp.flags).map_err(|_| {
RpcMessageFlags::from_bits(u8::try_from(resp.flags).map_err(|_| {
RpcStatus::protocol_error(&format!("invalid message flag: must be less than {}", u8::MAX))
})?);
})?)
.ok_or(RpcStatus::protocol_error(&format!(
"invalid message flag, does not match any flags ({})",
resp.flags
)))?;
if flags.contains(RpcMessageFlags::ACK) {
return Err(RpcError::UnexpectedAckResponse);
}
Expand Down
20 changes: 16 additions & 4 deletions comms/core/src/protocol/rpc/message.rs
Expand Up @@ -232,8 +232,14 @@ impl proto::rpc::RpcRequest {
Duration::from_secs(self.deadline)
}

pub fn flags(&self) -> RpcMessageFlags {
RpcMessageFlags::from_bits_truncate(u8::try_from(self.flags).unwrap())
pub fn flags(&self) -> Result<RpcMessageFlags, String> {
RpcMessageFlags::from_bits(
u8::try_from(self.flags).map_err(|_| format!("invalid message flag: must be less than {}", u8::MAX))?,
)
.ok_or(format!(
"invalid message flag, does not match any flags ({})",
self.flags
))
}
}

Expand Down Expand Up @@ -282,8 +288,14 @@ impl Default for RpcResponse {
}

impl proto::rpc::RpcResponse {
pub fn flags(&self) -> RpcMessageFlags {
RpcMessageFlags::from_bits_truncate(u8::try_from(self.flags).unwrap())
pub fn flags(&self) -> Result<RpcMessageFlags, String> {
RpcMessageFlags::from_bits(
u8::try_from(self.flags).map_err(|_| format!("invalid message flag: must be less than {}", u8::MAX))?,
)
.ok_or(format!(
"invalid message flag, does not match any flags ({})",
self.flags
))
}

pub fn is_fin(&self) -> bool {
Expand Down
20 changes: 15 additions & 5 deletions comms/core/src/protocol/rpc/server/chunking.rs
Expand Up @@ -216,7 +216,9 @@ mod test {
assert_eq!(iter.total_chunks, 1);
let msgs = iter.collect::<Vec<_>>();
assert_eq!(msgs.len(), 1);
assert!(!RpcMessageFlags::from_bits_truncate(u8::try_from(msgs[0].flags).unwrap()).is_more());
assert!(!RpcMessageFlags::from_bits(u8::try_from(msgs[0].flags).unwrap())
.unwrap()
.is_more());
}

#[test]
Expand All @@ -225,7 +227,9 @@ mod test {
assert_eq!(iter.total_chunks, 1);
let msgs = iter.collect::<Vec<_>>();
assert_eq!(msgs.len(), 1);
assert!(!RpcMessageFlags::from_bits_truncate(u8::try_from(msgs[0].flags).unwrap()).is_more());
assert!(!RpcMessageFlags::from_bits(u8::try_from(msgs[0].flags).unwrap())
.unwrap()
.is_more());
}

#[test]
Expand Down Expand Up @@ -255,8 +259,14 @@ mod test {
use std::convert::TryFrom;
let iter = create(RPC_CHUNKING_THRESHOLD * 3);
let msgs = iter.collect::<Vec<_>>();
assert!(RpcMessageFlags::from_bits_truncate(u8::try_from(msgs[0].flags).unwrap()).is_more());
assert!(RpcMessageFlags::from_bits_truncate(u8::try_from(msgs[1].flags).unwrap()).is_more());
assert!(!RpcMessageFlags::from_bits_truncate(u8::try_from(msgs[2].flags).unwrap()).is_more());
assert!(RpcMessageFlags::from_bits(u8::try_from(msgs[0].flags).unwrap())
.unwrap()
.is_more());
assert!(RpcMessageFlags::from_bits(u8::try_from(msgs[1].flags).unwrap())
.unwrap()
.is_more());
assert!(!RpcMessageFlags::from_bits(u8::try_from(msgs[2].flags).unwrap())
.unwrap()
.is_more());
}
}
2 changes: 2 additions & 0 deletions comms/core/src/protocol/rpc/server/error.rs
Expand Up @@ -62,6 +62,8 @@ pub enum RpcServerError {
ReadStreamExceededDeadline,
#[error("Early close: {0}")]
EarlyClose(#[from] EarlyCloseError<BytesMut>),
#[error("Protocol error: {0}")]
ProtocolError(String),
}

impl RpcServerError {
Expand Down
30 changes: 28 additions & 2 deletions comms/core/src/protocol/rpc/server/mod.rs
Expand Up @@ -629,7 +629,13 @@ where
return Ok(());
}

let msg_flags = RpcMessageFlags::from_bits_truncate(u8::try_from(decoded_msg.flags).unwrap());
let msg_flags = RpcMessageFlags::from_bits(u8::try_from(decoded_msg.flags).map_err(|_| {
RpcServerError::ProtocolError(format!("invalid message flag: must be less than {}", u8::MAX))
})?)
.ok_or(RpcServerError::ProtocolError(format!(
"invalid message flag, does not match any flags ({})",
decoded_msg.flags
)))?;

if msg_flags.contains(RpcMessageFlags::FIN) {
debug!(target: LOG_TARGET, "({}) Client sent FIN.", self.logging_context_string);
Expand Down Expand Up @@ -815,7 +821,27 @@ where
return Poll::Ready(Some(RpcServerError::UnexpectedIncomingMessageMalformed));
},
};
let msg_flags = RpcMessageFlags::from_bits_truncate(u8::try_from(decoded_msg.flags).unwrap());
let u8_bits = match u8::try_from(decoded_msg.flags) {
Ok(bits) => bits,
Err(err) => {
error!(target: LOG_TARGET, "Client send MALFORMED flags: {}", err);
return Poll::Ready(Some(RpcServerError::ProtocolError(format!(
"invalid message flag: must be less than {}",
u8::MAX
))));
},
};

let msg_flags = match RpcMessageFlags::from_bits(u8_bits) {
Some(flags) => flags,
None => {
error!(target: LOG_TARGET, "Client send MALFORMED flags: {}", u8_bits);
return Poll::Ready(Some(RpcServerError::ProtocolError(format!(
"invalid message flag, does not match any flags ({})",
u8_bits
))));
},
};
if msg_flags.is_fin() {
Poll::Ready(Some(RpcServerError::ClientInterruptedStream))
} else {
Expand Down
2 changes: 1 addition & 1 deletion comms/dht/src/proto/mod.rs
Expand Up @@ -85,7 +85,7 @@ impl fmt::Display for dht::JoinMessage {
"JoinMessage(PK = {}, {} Addresses, Features = {:?})",
self.public_key.to_hex(),
self.addresses.len(),
PeerFeatures::from_bits_truncate(self.peer_features),
PeerFeatures::from_bits(self.peer_features),
)
}
}
Expand Down

0 comments on commit fb18078

Please sign in to comment.