diff --git a/relay_client/src/client/stream.rs b/relay_client/src/client/stream.rs index 48d6615..185c59f 100644 --- a/relay_client/src/client/stream.rs +++ b/relay_client/src/client/stream.rs @@ -175,7 +175,7 @@ impl ClientStream { Payload::Response(response) => { let id = response.id(); - if !id.is_valid() { + if id.is_zero() { return match response { Response::Error(response) => { Some(StreamEvent::InboundError(Error::Rpc { diff --git a/relay_rpc/src/domain.rs b/relay_rpc/src/domain.rs index 62035bf..4694aa9 100644 --- a/relay_rpc/src/domain.rs +++ b/relay_rpc/src/domain.rs @@ -143,11 +143,18 @@ new_type!( ); impl MessageId { - pub fn is_valid(&self) -> bool { + /// Minimum allowed value of a [`MessageId`]. + const MIN: Self = Self(1000000000); + + pub(crate) fn validate(&self) -> bool { + self.0 >= Self::MIN.0 + } + + pub fn is_zero(&self) -> bool { // Message ID `0` is used when the client request failed to parse for whatever // reason, and the server doesn't know the message ID of that request, but still // wants to communicate the error. - self.0 != 0 + self.0 == 0 } } diff --git a/relay_rpc/src/rpc.rs b/relay_rpc/src/rpc.rs index 5ccc4d8..f327ba4 100644 --- a/relay_rpc/src/rpc.rs +++ b/relay_rpc/src/rpc.rs @@ -37,6 +37,9 @@ pub enum ValidationError { #[error("Subscription ID decoding failed: {0}")] SubscriptionIdDecoding(DecodingError), + #[error("Invalid request ID")] + RequestId, + #[error("Invalid JSON RPC version")] JsonRpcVersion, @@ -701,6 +704,10 @@ impl Request { /// Validates the request payload. pub fn validate(&self) -> Result<(), ValidationError> { + if !self.id.validate() { + return Err(ValidationError::RequestId); + } + if self.jsonrpc.as_ref() != JSON_RPC_VERSION_STR { return Err(ValidationError::JsonRpcVersion); } diff --git a/relay_rpc/src/rpc/tests.rs b/relay_rpc/src/rpc/tests.rs index f7be09d..861ce25 100644 --- a/relay_rpc/src/rpc/tests.rs +++ b/relay_rpc/src/rpc/tests.rs @@ -194,13 +194,27 @@ fn deserialize_batch_methods() { #[test] fn validation() { // Valid data. - let id = MessageId::from(1); + let id = MessageId::from(1234567890); let jsonrpc: Arc = "2.0".into(); let message: Arc = "0".repeat(512).into(); let topic = Topic::from("c4163cf65859106b3f5435fc296e7765411178ed452d1c30337a6230138c9840"); let subscription_id = SubscriptionId::from("c4163cf65859106b3f5435fc296e7765411178ed452d1c30337a6230138c9841"); + // Invalid request ID. + let request = Request { + id: MessageId::new(1), + jsonrpc: jsonrpc.clone(), + params: Params::Publish(Publish { + topic: topic.clone(), + message: message.clone(), + ttl_secs: 0, + tag: 0, + prompt: false, + }), + }; + assert_eq!(request.validate(), Err(ValidationError::RequestId)); + // Invalid JSONRPC version. let request = Request { id,