Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: change truncate_from_bits to from_bits #5773

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions comms/core/src/peer_manager/error.rs
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,8 @@ pub enum PeerManagerError {
InvalidPeerFeatures { bits: u32 },
#[error("Address {address} not found for peer {node_id}")]
AddressNotFoundError { address: Multiaddr, node_id: NodeId },
#[error("Protocol error: {0}")]
ProtocolError(String),
}

impl PeerManagerError {
Expand Down
5 changes: 4 additions & 1 deletion comms/core/src/peer_manager/peer_identity_claim.rs
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,10 @@ impl TryFrom<PeerIdentityMsg> for PeerIdentityClaim {
if addresses.is_empty() {
return Err(PeerManagerError::PeerIdentityNoValidAddresses);
}
let features = PeerFeatures::from_bits_truncate(value.features);
let features = PeerFeatures::from_bits(value.features).ok_or(PeerManagerError::ProtocolError(format!(
"Invalid message flag, does not match any flags ({})",
value.features
)))?;

if let Some(signature) = value.identity_signature {
Ok(Self {
Expand Down
4 changes: 3 additions & 1 deletion comms/core/src/protocol/error.rs
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,8 @@ use thiserror::Error;
pub enum ProtocolError {
#[error("IO error: {0}")]
IoError(#[from] io::Error),
#[error("Invalid flag: {0}")]
InvalidFlag(String),
#[error("The ProtocolId was longer than {}", u8::max_value())]
ProtocolIdTooLong,
#[error("Protocol negotiation failed because the peer did not accept any of the given protocols: {protocols}")]
Expand Down Expand Up @@ -56,7 +58,7 @@ impl ProtocolError {
ProtocolError::ProtocolOptimisticNegotiationFailed |
ProtocolError::NotificationSenderDisconnected => false,

ProtocolError::ProtocolIdTooLong => true,
ProtocolError::ProtocolIdTooLong | ProtocolError::InvalidFlag(_) => true,
}
}
}
5 changes: 4 additions & 1 deletion comms/core/src/protocol/negotiation.rs
Original file line number Diff line number Diff line change
Expand Up @@ -189,7 +189,10 @@ where TSocket: AsyncRead + AsyncWrite + Unpin
// Len can never overflow the buffer because the buffer len is u8::MAX and the length delimiter
// is a u8. If that changes, then len should be checked for overflow
let len = u8::from_be_bytes([self.buf[0]]) as usize;
let flags = Flags::from_bits_truncate(u8::from_be_bytes([self.buf[1]]));
let flags = Flags::from_bits(u8::from_be_bytes([self.buf[1]])).ok_or(ProtocolError::InvalidFlag(format!(
"Does not match any flags ({})",
self.buf[1]
)))?;
self.socket.read_exact(&mut self.buf[0..len]).await?;
trace!(
target: LOG_TARGET,
Expand Down
40 changes: 30 additions & 10 deletions comms/core/src/protocol/rpc/client/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,7 @@ use crate::{
NamedProtocolService,
Response,
RpcError,
RpcServerError,
RpcStatus,
RPC_CHUNKING_MAX_CHUNKS,
},
Expand Down Expand Up @@ -574,9 +575,13 @@ where TSubstream: AsyncRead + AsyncWrite + Unpin + Send + StreamId
}

let resp_flags =
RpcMessageFlags::from_bits_truncate(u8::try_from(resp.flags).map_err(|_| {
RpcMessageFlags::from_bits(u8::try_from(resp.flags).map_err(|_| {
RpcStatus::protocol_error(&format!("invalid message flag: must be less than {}", u8::MAX))
})?);
})?)
.ok_or(RpcStatus::protocol_error(&format!(
"invalid message flag, does not match any flags ({})",
resp.flags
)))?;
if !resp_flags.contains(RpcMessageFlags::ACK) {
warn!(
target: LOG_TARGET,
Expand Down Expand Up @@ -871,9 +876,12 @@ where TSubstream: AsyncRead + AsyncWrite + Unpin + Send + StreamId
if !status.is_ok() {
return Ok(Err(status));
}

let flags = match resp.flags() {
Ok(flags) => flags,
Err(e) => return Ok(Err(RpcError::ServerError(RpcServerError::ProtocolError(e)).into())),
};
let resp = Response {
flags: resp.flags(),
flags,
payload: resp.payload.into(),
};

Expand Down Expand Up @@ -925,9 +933,13 @@ where TSubstream: AsyncRead + AsyncWrite + Unpin
self.check_response(&resp)?;
let mut chunk_count = 1;
let mut last_chunk_flags =
RpcMessageFlags::from_bits_truncate(u8::try_from(resp.flags).map_err(|_| {
RpcMessageFlags::from_bits(u8::try_from(resp.flags).map_err(|_| {
RpcStatus::protocol_error(&format!("invalid message flag: must be less than {}", u8::MAX))
})?);
})?)
.ok_or(RpcStatus::protocol_error(&format!(
"invalid message flag, does not match any flags ({})",
resp.flags
)))?;
let mut last_chunk_size = resp.payload.len();
self.bytes_read += last_chunk_size;
loop {
Expand All @@ -950,9 +962,13 @@ where TSubstream: AsyncRead + AsyncWrite + Unpin
}

let msg = self.next().await?;
last_chunk_flags = RpcMessageFlags::from_bits_truncate(u8::try_from(msg.flags).map_err(|_| {
last_chunk_flags = RpcMessageFlags::from_bits(u8::try_from(msg.flags).map_err(|_| {
RpcStatus::protocol_error(&format!("invalid message flag: must be less than {}", u8::MAX))
})?);
})?)
.ok_or(RpcStatus::protocol_error(&format!(
"invalid message flag, does not match any flags ({})",
resp.flags
)))?;
last_chunk_size = msg.payload.len();
self.bytes_read += last_chunk_size;
self.check_response(&resp)?;
Expand All @@ -971,9 +987,13 @@ where TSubstream: AsyncRead + AsyncWrite + Unpin
.map_err(|_| RpcStatus::protocol_error(&format!("invalid request_id: must be less than {}", u16::MAX)))?;

let flags =
RpcMessageFlags::from_bits_truncate(u8::try_from(resp.flags).map_err(|_| {
RpcMessageFlags::from_bits(u8::try_from(resp.flags).map_err(|_| {
RpcStatus::protocol_error(&format!("invalid message flag: must be less than {}", u8::MAX))
})?);
})?)
.ok_or(RpcStatus::protocol_error(&format!(
"invalid message flag, does not match any flags ({})",
resp.flags
)))?;
if flags.contains(RpcMessageFlags::ACK) {
return Err(RpcError::UnexpectedAckResponse);
}
Expand Down
20 changes: 16 additions & 4 deletions comms/core/src/protocol/rpc/message.rs
Original file line number Diff line number Diff line change
Expand Up @@ -232,8 +232,14 @@ impl proto::rpc::RpcRequest {
Duration::from_secs(self.deadline)
}

pub fn flags(&self) -> RpcMessageFlags {
RpcMessageFlags::from_bits_truncate(u8::try_from(self.flags).unwrap())
pub fn flags(&self) -> Result<RpcMessageFlags, String> {
RpcMessageFlags::from_bits(
u8::try_from(self.flags).map_err(|_| format!("invalid message flag: must be less than {}", u8::MAX))?,
)
.ok_or(format!(
"invalid message flag, does not match any flags ({})",
self.flags
))
}
}

Expand Down Expand Up @@ -282,8 +288,14 @@ impl Default for RpcResponse {
}

impl proto::rpc::RpcResponse {
pub fn flags(&self) -> RpcMessageFlags {
RpcMessageFlags::from_bits_truncate(u8::try_from(self.flags).unwrap())
pub fn flags(&self) -> Result<RpcMessageFlags, String> {
RpcMessageFlags::from_bits(
u8::try_from(self.flags).map_err(|_| format!("invalid message flag: must be less than {}", u8::MAX))?,
)
.ok_or(format!(
"invalid message flag, does not match any flags ({})",
self.flags
))
}

pub fn is_fin(&self) -> bool {
Expand Down
20 changes: 15 additions & 5 deletions comms/core/src/protocol/rpc/server/chunking.rs
Original file line number Diff line number Diff line change
Expand Up @@ -216,7 +216,9 @@ mod test {
assert_eq!(iter.total_chunks, 1);
let msgs = iter.collect::<Vec<_>>();
assert_eq!(msgs.len(), 1);
assert!(!RpcMessageFlags::from_bits_truncate(u8::try_from(msgs[0].flags).unwrap()).is_more());
assert!(!RpcMessageFlags::from_bits(u8::try_from(msgs[0].flags).unwrap())
.unwrap()
.is_more());
}

#[test]
Expand All @@ -225,7 +227,9 @@ mod test {
assert_eq!(iter.total_chunks, 1);
let msgs = iter.collect::<Vec<_>>();
assert_eq!(msgs.len(), 1);
assert!(!RpcMessageFlags::from_bits_truncate(u8::try_from(msgs[0].flags).unwrap()).is_more());
assert!(!RpcMessageFlags::from_bits(u8::try_from(msgs[0].flags).unwrap())
.unwrap()
.is_more());
}

#[test]
Expand Down Expand Up @@ -255,8 +259,14 @@ mod test {
use std::convert::TryFrom;
let iter = create(RPC_CHUNKING_THRESHOLD * 3);
let msgs = iter.collect::<Vec<_>>();
assert!(RpcMessageFlags::from_bits_truncate(u8::try_from(msgs[0].flags).unwrap()).is_more());
assert!(RpcMessageFlags::from_bits_truncate(u8::try_from(msgs[1].flags).unwrap()).is_more());
assert!(!RpcMessageFlags::from_bits_truncate(u8::try_from(msgs[2].flags).unwrap()).is_more());
assert!(RpcMessageFlags::from_bits(u8::try_from(msgs[0].flags).unwrap())
.unwrap()
.is_more());
assert!(RpcMessageFlags::from_bits(u8::try_from(msgs[1].flags).unwrap())
.unwrap()
.is_more());
assert!(!RpcMessageFlags::from_bits(u8::try_from(msgs[2].flags).unwrap())
.unwrap()
.is_more());
}
}
2 changes: 2 additions & 0 deletions comms/core/src/protocol/rpc/server/error.rs
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,8 @@ pub enum RpcServerError {
ReadStreamExceededDeadline,
#[error("Early close: {0}")]
EarlyClose(#[from] EarlyCloseError<BytesMut>),
#[error("Protocol error: {0}")]
ProtocolError(String),
}

impl RpcServerError {
Expand Down
30 changes: 28 additions & 2 deletions comms/core/src/protocol/rpc/server/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -629,7 +629,13 @@ where
return Ok(());
}

let msg_flags = RpcMessageFlags::from_bits_truncate(u8::try_from(decoded_msg.flags).unwrap());
let msg_flags = RpcMessageFlags::from_bits(u8::try_from(decoded_msg.flags).map_err(|_| {
RpcServerError::ProtocolError(format!("invalid message flag: must be less than {}", u8::MAX))
})?)
.ok_or(RpcServerError::ProtocolError(format!(
"invalid message flag, does not match any flags ({})",
decoded_msg.flags
)))?;

if msg_flags.contains(RpcMessageFlags::FIN) {
debug!(target: LOG_TARGET, "({}) Client sent FIN.", self.logging_context_string);
Expand Down Expand Up @@ -815,7 +821,27 @@ where
return Poll::Ready(Some(RpcServerError::UnexpectedIncomingMessageMalformed));
},
};
let msg_flags = RpcMessageFlags::from_bits_truncate(u8::try_from(decoded_msg.flags).unwrap());
let u8_bits = match u8::try_from(decoded_msg.flags) {
Ok(bits) => bits,
Err(err) => {
error!(target: LOG_TARGET, "Client send MALFORMED flags: {}", err);
return Poll::Ready(Some(RpcServerError::ProtocolError(format!(
"invalid message flag: must be less than {}",
u8::MAX
))));
},
};

let msg_flags = match RpcMessageFlags::from_bits(u8_bits) {
Some(flags) => flags,
None => {
error!(target: LOG_TARGET, "Client send MALFORMED flags: {}", u8_bits);
return Poll::Ready(Some(RpcServerError::ProtocolError(format!(
"invalid message flag, does not match any flags ({})",
u8_bits
))));
},
};
if msg_flags.is_fin() {
Poll::Ready(Some(RpcServerError::ClientInterruptedStream))
} else {
Expand Down
2 changes: 1 addition & 1 deletion comms/dht/src/proto/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,7 @@ impl fmt::Display for dht::JoinMessage {
"JoinMessage(PK = {}, {} Addresses, Features = {:?})",
self.public_key.to_hex(),
self.addresses.len(),
PeerFeatures::from_bits_truncate(self.peer_features),
PeerFeatures::from_bits(self.peer_features),
)
}
}
Expand Down
Loading