From a59f6f813d2eec743b5fb3d38cfd8885418bb0ca Mon Sep 17 00:00:00 2001 From: dharjeezy Date: Tue, 7 Oct 2025 23:21:42 +0100 Subject: [PATCH] introduce max message size in webrtc config --- src/transport/webrtc/config.rs | 2 ++ src/transport/webrtc/connection.rs | 10 +++++++--- src/transport/webrtc/mod.rs | 7 +++++++ src/transport/webrtc/opening.rs | 6 +++++- src/transport/webrtc/util.rs | 12 +++++++----- 5 files changed, 28 insertions(+), 9 deletions(-) diff --git a/src/transport/webrtc/config.rs b/src/transport/webrtc/config.rs index b9314010..5428ca77 100644 --- a/src/transport/webrtc/config.rs +++ b/src/transport/webrtc/config.rs @@ -32,6 +32,7 @@ pub struct Config { /// /// How many datagrams can the buffer between `WebRtcTransport` and a connection handler hold. pub datagram_buffer_size: usize, + pub max_message_size: usize, } impl Default for Config { @@ -41,6 +42,7 @@ impl Default for Config { .parse() .expect("valid multiaddress")], datagram_buffer_size: 2048, + max_message_size: 512 * 1024, } } } diff --git a/src/transport/webrtc/connection.rs b/src/transport/webrtc/connection.rs index 6c1e5746..12ba021c 100644 --- a/src/transport/webrtc/connection.rs +++ b/src/transport/webrtc/connection.rs @@ -199,6 +199,8 @@ pub struct WebRtcConnection { /// Substream handles. handles: SubstreamHandleSet, + /// Max message size + max_message_size: usize } impl WebRtcConnection { @@ -212,6 +214,7 @@ impl WebRtcConnection { protocol_set: ProtocolSet, endpoint: Endpoint, dgram_rx: Receiver>, + max_message_size: usize ) -> Self { Self { rtc, @@ -225,6 +228,7 @@ impl WebRtcConnection { pending_outbound: HashMap::new(), channels: HashMap::new(), handles: SubstreamHandleSet::new(), + max_message_size } } @@ -318,7 +322,7 @@ impl WebRtcConnection { "handle opening inbound substream", ); - let payload = WebRtcMessage::decode(&data)?.payload.ok_or(Error::InvalidData)?; + let payload = WebRtcMessage::decode(&data, self.max_message_size)?.payload.ok_or(Error::InvalidData)?; let (response, negotiated) = match webrtc_listener_negotiate( &mut self.protocol_set.protocols().iter(), payload.into(), @@ -385,7 +389,7 @@ impl WebRtcConnection { "handle opening outbound substream", ); - let rtc_message = WebRtcMessage::decode(&data) + let rtc_message = WebRtcMessage::decode(&data, self.max_message_size) .map_err(|err| SubstreamError::NegotiationError(err.into()))?; let message = rtc_message.payload.ok_or(SubstreamError::NegotiationError( ParseError::InvalidData.into(), @@ -445,7 +449,7 @@ impl WebRtcConnection { channel_id: ChannelId, data: Vec, ) -> crate::Result<()> { - let message = WebRtcMessage::decode(&data)?; + let message = WebRtcMessage::decode(&data, self.max_message_size)?; tracing::trace!( target: LOG_TARGET, diff --git a/src/transport/webrtc/mod.rs b/src/transport/webrtc/mod.rs index bc060102..c1541102 100644 --- a/src/transport/webrtc/mod.rs +++ b/src/transport/webrtc/mod.rs @@ -135,6 +135,10 @@ pub(crate) struct WebRtcTransport { /// Datagram buffer size. datagram_buffer_size: usize, + /// Max Webrtc message size + + max_message_size: usize, + /// Connected peers. open: HashMap, @@ -412,6 +416,7 @@ impl WebRtcTransport { self.context.keypair.clone(), source, self.listen_address, + self.max_message_size, ); self.opening.insert(source, connection); @@ -486,6 +491,7 @@ impl TransportBuilder for WebRtcTransport { timeouts: HashMap::new(), pending_events: VecDeque::new(), datagram_buffer_size: config.datagram_buffer_size, + max_message_size: config.max_message_size }, listen_multi_addresses, )) @@ -573,6 +579,7 @@ impl Transport for WebRtcTransport { protocol_set, endpoint, rx, + self.max_message_size, ); self.open.insert( source, diff --git a/src/transport/webrtc/opening.rs b/src/transport/webrtc/opening.rs index d2fa19c9..66035588 100644 --- a/src/transport/webrtc/opening.rs +++ b/src/transport/webrtc/opening.rs @@ -110,6 +110,8 @@ pub struct OpeningWebRtcConnection { /// Local address. local_address: SocketAddr, + /// Max message size + max_message_size: usize, } /// Connection state. @@ -151,6 +153,7 @@ impl OpeningWebRtcConnection { id_keypair: Keypair, peer_address: SocketAddr, local_address: SocketAddr, + max_message_size: usize ) -> OpeningWebRtcConnection { tracing::trace!( target: LOG_TARGET, @@ -167,6 +170,7 @@ impl OpeningWebRtcConnection { id_keypair, peer_address, local_address, + max_message_size } } @@ -253,7 +257,7 @@ impl OpeningWebRtcConnection { return Err(Error::InvalidState); }; - let message = WebRtcMessage::decode(&data)?.payload.ok_or(Error::InvalidData)?; + let message = WebRtcMessage::decode(&data, self.max_message_size)?.payload.ok_or(Error::InvalidData)?; let public_key = context.get_remote_public_key(&message)?; let remote_peer_id = PeerId::from_public_key(&public_key); diff --git a/src/transport/webrtc/util.rs b/src/transport/webrtc/util.rs index 55917afc..c6180ffd 100644 --- a/src/transport/webrtc/util.rs +++ b/src/transport/webrtc/util.rs @@ -72,9 +72,9 @@ impl WebRtcMessage { } /// Decode payload into [`WebRtcMessage`]. - pub fn decode(payload: &[u8]) -> Result { + pub fn decode(payload: &[u8], max_message_size: usize) -> Result { // TODO: https://github.com/paritytech/litep2p/issues/352 set correct size - let mut codec = UnsignedVarint::new(None); + let mut codec = UnsignedVarint::new(Some(max_message_size)); let mut data = bytes::BytesMut::from(payload); let result = codec .decode(&mut data) @@ -95,10 +95,12 @@ impl WebRtcMessage { mod tests { use super::*; + const TEST_MAX_SIZE: usize = 512 * 1024; + #[test] fn with_payload_no_flags() { let message = WebRtcMessage::encode("Hello, world!".as_bytes().to_vec()); - let decoded = WebRtcMessage::decode(&message).unwrap(); + let decoded = WebRtcMessage::decode(&message, TEST_MAX_SIZE).unwrap(); assert_eq!(decoded.payload, Some("Hello, world!".as_bytes().to_vec())); assert_eq!(decoded.flags, None); @@ -107,7 +109,7 @@ mod tests { #[test] fn with_payload_and_flags() { let message = WebRtcMessage::encode_with_flags("Hello, world!".as_bytes().to_vec(), 1i32); - let decoded = WebRtcMessage::decode(&message).unwrap(); + let decoded = WebRtcMessage::decode(&message, TEST_MAX_SIZE).unwrap(); assert_eq!(decoded.payload, Some("Hello, world!".as_bytes().to_vec())); assert_eq!(decoded.flags, Some(1i32)); @@ -116,7 +118,7 @@ mod tests { #[test] fn no_payload_with_flags() { let message = WebRtcMessage::encode_with_flags(vec![], 2i32); - let decoded = WebRtcMessage::decode(&message).unwrap(); + let decoded = WebRtcMessage::decode(&message, TEST_MAX_SIZE).unwrap(); assert_eq!(decoded.payload, None); assert_eq!(decoded.flags, Some(2i32));