diff --git a/Cargo.toml b/Cargo.toml index 8b73a048..8e87379d 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -1,7 +1,7 @@ [package] name = "quick-protobuf" description = "A pure Rust protobuf (de)serializer. Quick." -version = "0.2.0" +version = "0.3.0" authors = ["Johann Tuffe "] keywords = ["protobuf", "parser"] license = "MIT" diff --git a/Changelog.md b/Changelog.md index 825f3be3..f4758c48 100644 --- a/Changelog.md +++ b/Changelog.md @@ -15,6 +15,8 @@ - feat: break codegen when reserved fields conflict - feat: support imports in proto files - feat: support packages by encapsulating into rust modules +- feat: support map +- refactor: major refatorings of codegen ## 0.2.0 - feat: do not allocate for bytes and string field types diff --git a/benches/perftest.rs b/benches/perftest.rs index f8b0023a..38275c6f 100644 --- a/benches/perftest.rs +++ b/benches/perftest.rs @@ -131,7 +131,7 @@ fn generate_optional_messages() -> Vec { perfbench!(generate_optional_messages, TestOptionalMessages, write_optional_messages, read_optional_messages); fn generate_strings() -> Vec> { - let mut s = "hello world from quick-protobuf!!!".split('_').cycle().map(|s| Cow::Borrowed(s)); + let mut s = "hello world from quick-protobuf!!!".split(' ').cycle().map(|s| Cow::Borrowed(s)); (1..100).map(|_| TestStrings { s1: s.by_ref().next(), s2: s.by_ref().next(), @@ -142,7 +142,7 @@ fn generate_strings() -> Vec> { perfbench!(generate_strings, TestStrings, write_strings, read_strings); fn generate_small_bytes() -> Vec> { - let mut s = "hello world from quick-protobuf!!!".split('_').cycle() + let mut s = "hello world from quick-protobuf!!!".split(' ').cycle() .map(|s| Cow::Borrowed(s.as_bytes())); (1..800).map(|_| TestBytes { b1: s.by_ref().next() }) .collect() @@ -151,7 +151,7 @@ fn generate_small_bytes() -> Vec> { perfbench!(generate_small_bytes, TestBytes, write_small_bytes, read_small_bytes); fn generate_large_bytes() -> Vec> { - let mut s = "hello world from quick-protobuf!!!".split('_').cycle().map(|s| s.as_bytes()); + let mut s = "hello world from quick-protobuf!!!".split(' ').cycle().map(|s| s.as_bytes()); (1..30).map(|_| TestBytes { b1: Some(Cow::Owned(s.by_ref().take(500).fold(Vec::new(), |mut cur, nxt| { cur.extend_from_slice(nxt); @@ -162,6 +162,15 @@ fn generate_large_bytes() -> Vec> { perfbench!(generate_large_bytes, TestBytes, write_large_bytes, read_large_bytes); +fn generate_map() -> Vec> { + let mut s = "hello world from quick-protobuf!!!".split(' ').cycle(); + (1..30).map(|_| TestMap { + value: s.by_ref().take(500).map(|s| (Cow::Owned(s.to_string()), s.len() as u32)).collect() + }).collect() +} + +perfbench!(generate_map, TestMap, write_map, read_map); + fn generate_all() -> Vec> { vec![PerftestData { test1: generate_test1(), @@ -172,6 +181,7 @@ fn generate_all() -> Vec> { test_repeated_packed_int32: generate_repeated_packed_int32(), test_small_bytearrays: generate_small_bytes(), test_large_bytearrays: generate_large_bytes(), + test_map: generate_map(), }] } diff --git a/benches/perftest_data/mod.rs b/benches/perftest_data/mod.rs index 5858d629..c360fefd 100644 --- a/benches/perftest_data/mod.rs +++ b/benches/perftest_data/mod.rs @@ -6,6 +6,7 @@ use std::io::{Write}; use std::borrow::Cow; +use std::collections::HashMap; use quick_protobuf::{MessageWrite, BytesReader, Writer, Result}; use quick_protobuf::sizeofs::*; @@ -30,11 +31,11 @@ impl Test1 { impl MessageWrite for Test1 { fn get_size(&self) -> usize { - self.value.as_ref().map_or(0, |m| 1 + sizeof_int32(*m)) + self.value.as_ref().map_or(0, |m| 1 + sizeof_varint(*m as u64)) } - fn write_message(&self, r: &mut Writer) -> Result<()> { - if let Some(ref s) = self.value { r.write_int32_with_tag(8, *s)?; } + fn write_message(&self, w: &mut Writer) -> Result<()> { + if let Some(ref s) = self.value { w.write_with_tag(8, |w| w.write_int32(*s))?; } Ok(()) } } @@ -60,11 +61,11 @@ impl TestRepeatedBool { impl MessageWrite for TestRepeatedBool { fn get_size(&self) -> usize { - self.values.iter().map(|s| 1 + sizeof_bool(*s)).sum::() + self.values.iter().map(|s| 1 + sizeof_varint(*s as u64)).sum::() } - fn write_message(&self, r: &mut Writer) -> Result<()> { - for s in &self.values { r.write_bool_with_tag(8, *s)? } + fn write_message(&self, w: &mut Writer) -> Result<()> { + for s in &self.values { w.write_with_tag(8, |w| w.write_bool(*s))?; } Ok(()) } } @@ -90,11 +91,11 @@ impl TestRepeatedPackedInt32 { impl MessageWrite for TestRepeatedPackedInt32 { fn get_size(&self) -> usize { - if self.values.is_empty() { 0 } else { 1 + sizeof_var_length(self.values.iter().map(|s| sizeof_int32(*s)).sum::()) } + if self.values.is_empty() { 0 } else { 1 + sizeof_len(self.values.iter().map(|s| sizeof_varint(*s as u64)).sum::()) } } - fn write_message(&self, r: &mut Writer) -> Result<()> { - r.write_packed_repeated_field_with_tag(10, &self.values, |r, m| r.write_int32(*m), &|m| sizeof_int32(*m))?; + fn write_message(&self, w: &mut Writer) -> Result<()> { + w.write_packed_with_tag(10, &self.values, |w, m| w.write_int32(*m), &|m| sizeof_varint(*m as u64))?; Ok(()) } } @@ -124,15 +125,15 @@ impl TestRepeatedMessages { impl MessageWrite for TestRepeatedMessages { fn get_size(&self) -> usize { - self.messages1.iter().map(|s| 1 + sizeof_var_length(s.get_size())).sum::() - + self.messages2.iter().map(|s| 1 + sizeof_var_length(s.get_size())).sum::() - + self.messages3.iter().map(|s| 1 + sizeof_var_length(s.get_size())).sum::() + self.messages1.iter().map(|s| 1 + sizeof_len(s.get_size())).sum::() + + self.messages2.iter().map(|s| 1 + sizeof_len(s.get_size())).sum::() + + self.messages3.iter().map(|s| 1 + sizeof_len(s.get_size())).sum::() } - fn write_message(&self, r: &mut Writer) -> Result<()> { - for s in &self.messages1 { r.write_message_with_tag(10, s)? } - for s in &self.messages2 { r.write_message_with_tag(18, s)? } - for s in &self.messages3 { r.write_message_with_tag(26, s)? } + fn write_message(&self, w: &mut Writer) -> Result<()> { + for s in &self.messages1 { w.write_with_tag(10, |w| w.write_message(s))?; } + for s in &self.messages2 { w.write_with_tag(18, |w| w.write_message(s))?; } + for s in &self.messages3 { w.write_with_tag(26, |w| w.write_message(s))?; } Ok(()) } } @@ -162,15 +163,15 @@ impl TestOptionalMessages { impl MessageWrite for TestOptionalMessages { fn get_size(&self) -> usize { - self.message1.as_ref().map_or(0, |m| 1 + sizeof_var_length(m.get_size())) - + self.message2.as_ref().map_or(0, |m| 1 + sizeof_var_length(m.get_size())) - + self.message3.as_ref().map_or(0, |m| 1 + sizeof_var_length(m.get_size())) + self.message1.as_ref().map_or(0, |m| 1 + sizeof_len(m.get_size())) + + self.message2.as_ref().map_or(0, |m| 1 + sizeof_len(m.get_size())) + + self.message3.as_ref().map_or(0, |m| 1 + sizeof_len(m.get_size())) } - fn write_message(&self, r: &mut Writer) -> Result<()> { - if let Some(ref s) = self.message1 { r.write_message_with_tag(10, &**s)?; } - if let Some(ref s) = self.message2 { r.write_message_with_tag(18, &**s)?; } - if let Some(ref s) = self.message3 { r.write_message_with_tag(26, &**s)?; } + fn write_message(&self, w: &mut Writer) -> Result<()> { + if let Some(ref s) = self.message1 { w.write_with_tag(10, |w| w.write_message(&**s))?; } + if let Some(ref s) = self.message2 { w.write_with_tag(18, |w| w.write_message(&**s))?; } + if let Some(ref s) = self.message3 { w.write_with_tag(26, |w| w.write_message(&**s))?; } Ok(()) } } @@ -187,9 +188,9 @@ impl<'a> TestStrings<'a> { let mut msg = Self::default(); while !r.is_eof() { match r.next_tag(bytes) { - Ok(10) => msg.s1 = Some(Cow::Borrowed(r.read_string(bytes)?)), - Ok(18) => msg.s2 = Some(Cow::Borrowed(r.read_string(bytes)?)), - Ok(26) => msg.s3 = Some(Cow::Borrowed(r.read_string(bytes)?)), + Ok(10) => msg.s1 = Some(r.read_string(bytes).map(Cow::Borrowed)?), + Ok(18) => msg.s2 = Some(r.read_string(bytes).map(Cow::Borrowed)?), + Ok(26) => msg.s3 = Some(r.read_string(bytes).map(Cow::Borrowed)?), Ok(t) => { r.read_unknown(bytes, t)?; } Err(e) => return Err(e), } @@ -200,15 +201,15 @@ impl<'a> TestStrings<'a> { impl<'a> MessageWrite for TestStrings<'a> { fn get_size(&self) -> usize { - self.s1.as_ref().map_or(0, |m| 1 + sizeof_var_length(m.len())) - + self.s2.as_ref().map_or(0, |m| 1 + sizeof_var_length(m.len())) - + self.s3.as_ref().map_or(0, |m| 1 + sizeof_var_length(m.len())) + self.s1.as_ref().map_or(0, |m| 1 + sizeof_len(m.len())) + + self.s2.as_ref().map_or(0, |m| 1 + sizeof_len(m.len())) + + self.s3.as_ref().map_or(0, |m| 1 + sizeof_len(m.len())) } - fn write_message(&self, r: &mut Writer) -> Result<()> { - if let Some(ref s) = self.s1 { r.write_string_with_tag(10, s)?; } - if let Some(ref s) = self.s2 { r.write_string_with_tag(18, s)?; } - if let Some(ref s) = self.s3 { r.write_string_with_tag(26, s)?; } + fn write_message(&self, w: &mut Writer) -> Result<()> { + if let Some(ref s) = self.s1 { w.write_with_tag(10, |w| w.write_string(&**s))?; } + if let Some(ref s) = self.s2 { w.write_with_tag(18, |w| w.write_string(&**s))?; } + if let Some(ref s) = self.s3 { w.write_with_tag(26, |w| w.write_string(&**s))?; } Ok(()) } } @@ -223,7 +224,7 @@ impl<'a> TestBytes<'a> { let mut msg = Self::default(); while !r.is_eof() { match r.next_tag(bytes) { - Ok(10) => msg.b1 = Some(Cow::Borrowed(r.read_bytes(bytes)?)), + Ok(10) => msg.b1 = Some(r.read_bytes(bytes).map(Cow::Borrowed)?), Ok(t) => { r.read_unknown(bytes, t)?; } Err(e) => return Err(e), } @@ -234,11 +235,44 @@ impl<'a> TestBytes<'a> { impl<'a> MessageWrite for TestBytes<'a> { fn get_size(&self) -> usize { - self.b1.as_ref().map_or(0, |m| 1 + sizeof_var_length(m.len())) + self.b1.as_ref().map_or(0, |m| 1 + sizeof_len(m.len())) } - fn write_message(&self, r: &mut Writer) -> Result<()> { - if let Some(ref s) = self.b1 { r.write_bytes_with_tag(10, s)?; } + fn write_message(&self, w: &mut Writer) -> Result<()> { + if let Some(ref s) = self.b1 { w.write_with_tag(10, |w| w.write_bytes(&**s))?; } + Ok(()) + } +} + +#[derive(Debug, Default, PartialEq, Clone)] +pub struct TestMap<'a> { + pub value: HashMap, u32>, +} + +impl<'a> TestMap<'a> { + pub fn from_reader(r: &mut BytesReader, bytes: &'a [u8]) -> Result { + let mut msg = Self::default(); + while !r.is_eof() { + match r.next_tag(bytes) { + Ok(10) => { + let (key, value) = r.read_map(bytes, |r, bytes| r.read_string(bytes).map(Cow::Borrowed), |r, bytes| r.read_uint32(bytes))?; + msg.value.insert(key, value); + } + Ok(t) => { r.read_unknown(bytes, t)?; } + Err(e) => return Err(e), + } + } + Ok(msg) + } +} + +impl<'a> MessageWrite for TestMap<'a> { + fn get_size(&self) -> usize { + self.value.iter().map(|(k, v)| 1 + sizeof_len(2 + sizeof_len(k.len()) + sizeof_varint(*v as u64))).sum::() + } + + fn write_message(&self, w: &mut Writer) -> Result<()> { + for (k, v) in self.value.iter() { w.write_with_tag(10, |w| w.write_map(2 + sizeof_len(k.len()) + sizeof_varint(*v as u64), 10, |w| w.write_string(&**k), 16, |w| w.write_uint32(*v)))?; } Ok(()) } } @@ -253,6 +287,7 @@ pub struct PerftestData<'a> { pub test_repeated_packed_int32: Vec, pub test_small_bytearrays: Vec>, pub test_large_bytearrays: Vec>, + pub test_map: Vec>, } impl<'a> PerftestData<'a> { @@ -268,6 +303,7 @@ impl<'a> PerftestData<'a> { Ok(50) => msg.test_repeated_packed_int32.push(r.read_message(bytes, TestRepeatedPackedInt32::from_reader)?), Ok(58) => msg.test_small_bytearrays.push(r.read_message(bytes, TestBytes::from_reader)?), Ok(66) => msg.test_large_bytearrays.push(r.read_message(bytes, TestBytes::from_reader)?), + Ok(74) => msg.test_map.push(r.read_message(bytes, TestMap::from_reader)?), Ok(t) => { r.read_unknown(bytes, t)?; } Err(e) => return Err(e), } @@ -278,25 +314,27 @@ impl<'a> PerftestData<'a> { impl<'a> MessageWrite for PerftestData<'a> { fn get_size(&self) -> usize { - self.test1.iter().map(|s| 1 + sizeof_var_length(s.get_size())).sum::() - + self.test_repeated_bool.iter().map(|s| 1 + sizeof_var_length(s.get_size())).sum::() - + self.test_repeated_messages.iter().map(|s| 1 + sizeof_var_length(s.get_size())).sum::() - + self.test_optional_messages.iter().map(|s| 1 + sizeof_var_length(s.get_size())).sum::() - + self.test_strings.iter().map(|s| 1 + sizeof_var_length(s.get_size())).sum::() - + self.test_repeated_packed_int32.iter().map(|s| 1 + sizeof_var_length(s.get_size())).sum::() - + self.test_small_bytearrays.iter().map(|s| 1 + sizeof_var_length(s.get_size())).sum::() - + self.test_large_bytearrays.iter().map(|s| 1 + sizeof_var_length(s.get_size())).sum::() + self.test1.iter().map(|s| 1 + sizeof_len(s.get_size())).sum::() + + self.test_repeated_bool.iter().map(|s| 1 + sizeof_len(s.get_size())).sum::() + + self.test_repeated_messages.iter().map(|s| 1 + sizeof_len(s.get_size())).sum::() + + self.test_optional_messages.iter().map(|s| 1 + sizeof_len(s.get_size())).sum::() + + self.test_strings.iter().map(|s| 1 + sizeof_len(s.get_size())).sum::() + + self.test_repeated_packed_int32.iter().map(|s| 1 + sizeof_len(s.get_size())).sum::() + + self.test_small_bytearrays.iter().map(|s| 1 + sizeof_len(s.get_size())).sum::() + + self.test_large_bytearrays.iter().map(|s| 1 + sizeof_len(s.get_size())).sum::() + + self.test_map.iter().map(|s| 1 + sizeof_len(s.get_size())).sum::() } - fn write_message(&self, r: &mut Writer) -> Result<()> { - for s in &self.test1 { r.write_message_with_tag(10, s)? } - for s in &self.test_repeated_bool { r.write_message_with_tag(18, s)? } - for s in &self.test_repeated_messages { r.write_message_with_tag(26, s)? } - for s in &self.test_optional_messages { r.write_message_with_tag(34, s)? } - for s in &self.test_strings { r.write_message_with_tag(42, s)? } - for s in &self.test_repeated_packed_int32 { r.write_message_with_tag(50, s)? } - for s in &self.test_small_bytearrays { r.write_message_with_tag(58, s)? } - for s in &self.test_large_bytearrays { r.write_message_with_tag(66, s)? } + fn write_message(&self, w: &mut Writer) -> Result<()> { + for s in &self.test1 { w.write_with_tag(10, |w| w.write_message(s))?; } + for s in &self.test_repeated_bool { w.write_with_tag(18, |w| w.write_message(s))?; } + for s in &self.test_repeated_messages { w.write_with_tag(26, |w| w.write_message(s))?; } + for s in &self.test_optional_messages { w.write_with_tag(34, |w| w.write_message(s))?; } + for s in &self.test_strings { w.write_with_tag(42, |w| w.write_message(s))?; } + for s in &self.test_repeated_packed_int32 { w.write_with_tag(50, |w| w.write_message(s))?; } + for s in &self.test_small_bytearrays { w.write_with_tag(58, |w| w.write_message(s))?; } + for s in &self.test_large_bytearrays { w.write_with_tag(66, |w| w.write_message(s))?; } + for s in &self.test_map { w.write_with_tag(74, |w| w.write_message(s))?; } Ok(()) } } diff --git a/benches/perftest_data/perftest_data.proto b/benches/perftest_data/perftest_data.proto index fa860c83..3150d468 100644 --- a/benches/perftest_data/perftest_data.proto +++ b/benches/perftest_data/perftest_data.proto @@ -32,6 +32,10 @@ message TestBytes { optional bytes b1 = 1; } +message TestMap { + map value = 1; +} + message PerftestData { repeated Test1 test1 = 1; repeated TestRepeatedBool test_repeated_bool = 2; @@ -41,4 +45,5 @@ message PerftestData { repeated TestRepeatedPackedInt32 test_repeated_packed_int32 = 6; repeated TestBytes test_small_bytearrays = 7; repeated TestBytes test_large_bytearrays = 8; + repeated TestMap test_map = 9; } diff --git a/benches/rust-protobuf/perftest_data_quick.rs b/benches/rust-protobuf/perftest_data_quick.rs index 137d84a9..042d1454 100644 --- a/benches/rust-protobuf/perftest_data_quick.rs +++ b/benches/rust-protobuf/perftest_data_quick.rs @@ -2,6 +2,7 @@ #![allow(non_snake_case)] #![allow(non_upper_case_globals)] +#![allow(non_camel_case_types)] use std::io::{Write}; use std::borrow::Cow; @@ -29,11 +30,11 @@ impl Test1 { impl MessageWrite for Test1 { fn get_size(&self) -> usize { - self.value.as_ref().map_or(0, |m| 1 + sizeof_int32(*m)) + self.value.as_ref().map_or(0, |m| 1 + sizeof_varint(*m as u64)) } - fn write_message(&self, r: &mut Writer) -> Result<()> { - if let Some(ref s) = self.value { r.write_int32_with_tag(8, *s)?; } + fn write_message(&self, w: &mut Writer) -> Result<()> { + if let Some(ref s) = self.value { w.write_with_tag(8, |w| w.write_int32(*s))?; } Ok(()) } } @@ -59,11 +60,11 @@ impl TestRepeatedBool { impl MessageWrite for TestRepeatedBool { fn get_size(&self) -> usize { - self.values.iter().map(|s| 1 + sizeof_bool(*s)).sum::() + self.values.iter().map(|s| 1 + sizeof_varint(*s as u64)).sum::() } - fn write_message(&self, r: &mut Writer) -> Result<()> { - for s in &self.values { r.write_bool_with_tag(8, *s)? } + fn write_message(&self, w: &mut Writer) -> Result<()> { + for s in &self.values { w.write_with_tag(8, |w| w.write_bool(*s))?; } Ok(()) } } @@ -89,11 +90,11 @@ impl TestRepeatedPackedInt32 { impl MessageWrite for TestRepeatedPackedInt32 { fn get_size(&self) -> usize { - if self.values.is_empty() { 0 } else { 1 + sizeof_var_length(self.values.iter().map(|s| sizeof_int32(*s)).sum::()) } + if self.values.is_empty() { 0 } else { 1 + sizeof_len(self.values.iter().map(|s| sizeof_varint(*s as u64)).sum::()) } } - fn write_message(&self, r: &mut Writer) -> Result<()> { - r.write_packed_repeated_field_with_tag(10, &self.values, |r, m| r.write_int32(*m), &|m| sizeof_int32(*m))?; + fn write_message(&self, w: &mut Writer) -> Result<()> { + w.write_packed_with_tag(10, &self.values, |w, m| w.write_int32(*m), &|m| sizeof_varint(*m as u64))?; Ok(()) } } @@ -123,15 +124,15 @@ impl TestRepeatedMessages { impl MessageWrite for TestRepeatedMessages { fn get_size(&self) -> usize { - self.messages1.iter().map(|s| 1 + sizeof_var_length(s.get_size())).sum::() - + self.messages2.iter().map(|s| 1 + sizeof_var_length(s.get_size())).sum::() - + self.messages3.iter().map(|s| 1 + sizeof_var_length(s.get_size())).sum::() + self.messages1.iter().map(|s| 1 + sizeof_len(s.get_size())).sum::() + + self.messages2.iter().map(|s| 1 + sizeof_len(s.get_size())).sum::() + + self.messages3.iter().map(|s| 1 + sizeof_len(s.get_size())).sum::() } - fn write_message(&self, r: &mut Writer) -> Result<()> { - for s in &self.messages1 { r.write_message_with_tag(10, s)? } - for s in &self.messages2 { r.write_message_with_tag(18, s)? } - for s in &self.messages3 { r.write_message_with_tag(26, s)? } + fn write_message(&self, w: &mut Writer) -> Result<()> { + for s in &self.messages1 { w.write_with_tag(10, |w| w.write_message(s))?; } + for s in &self.messages2 { w.write_with_tag(18, |w| w.write_message(s))?; } + for s in &self.messages3 { w.write_with_tag(26, |w| w.write_message(s))?; } Ok(()) } } @@ -161,15 +162,15 @@ impl TestOptionalMessages { impl MessageWrite for TestOptionalMessages { fn get_size(&self) -> usize { - self.message1.as_ref().map_or(0, |m| 1 + sizeof_var_length(m.get_size())) - + self.message2.as_ref().map_or(0, |m| 1 + sizeof_var_length(m.get_size())) - + self.message3.as_ref().map_or(0, |m| 1 + sizeof_var_length(m.get_size())) + self.message1.as_ref().map_or(0, |m| 1 + sizeof_len(m.get_size())) + + self.message2.as_ref().map_or(0, |m| 1 + sizeof_len(m.get_size())) + + self.message3.as_ref().map_or(0, |m| 1 + sizeof_len(m.get_size())) } - fn write_message(&self, r: &mut Writer) -> Result<()> { - if let Some(ref s) = self.message1 { r.write_message_with_tag(10, &**s)?; } - if let Some(ref s) = self.message2 { r.write_message_with_tag(18, &**s)?; } - if let Some(ref s) = self.message3 { r.write_message_with_tag(26, &**s)?; } + fn write_message(&self, w: &mut Writer) -> Result<()> { + if let Some(ref s) = self.message1 { w.write_with_tag(10, |w| w.write_message(&**s))?; } + if let Some(ref s) = self.message2 { w.write_with_tag(18, |w| w.write_message(&**s))?; } + if let Some(ref s) = self.message3 { w.write_with_tag(26, |w| w.write_message(&**s))?; } Ok(()) } } @@ -186,9 +187,9 @@ impl<'a> TestStrings<'a> { let mut msg = Self::default(); while !r.is_eof() { match r.next_tag(bytes) { - Ok(10) => msg.s1 = Some(Cow::Borrowed(r.read_string(bytes)?)), - Ok(18) => msg.s2 = Some(Cow::Borrowed(r.read_string(bytes)?)), - Ok(26) => msg.s3 = Some(Cow::Borrowed(r.read_string(bytes)?)), + Ok(10) => msg.s1 = Some(r.read_string(bytes).map(Cow::Borrowed)?), + Ok(18) => msg.s2 = Some(r.read_string(bytes).map(Cow::Borrowed)?), + Ok(26) => msg.s3 = Some(r.read_string(bytes).map(Cow::Borrowed)?), Ok(t) => { r.read_unknown(bytes, t)?; } Err(e) => return Err(e), } @@ -199,15 +200,15 @@ impl<'a> TestStrings<'a> { impl<'a> MessageWrite for TestStrings<'a> { fn get_size(&self) -> usize { - self.s1.as_ref().map_or(0, |m| 1 + sizeof_var_length(m.len())) - + self.s2.as_ref().map_or(0, |m| 1 + sizeof_var_length(m.len())) - + self.s3.as_ref().map_or(0, |m| 1 + sizeof_var_length(m.len())) + self.s1.as_ref().map_or(0, |m| 1 + sizeof_len(m.len())) + + self.s2.as_ref().map_or(0, |m| 1 + sizeof_len(m.len())) + + self.s3.as_ref().map_or(0, |m| 1 + sizeof_len(m.len())) } - fn write_message(&self, r: &mut Writer) -> Result<()> { - if let Some(ref s) = self.s1 { r.write_string_with_tag(10, s)?; } - if let Some(ref s) = self.s2 { r.write_string_with_tag(18, s)?; } - if let Some(ref s) = self.s3 { r.write_string_with_tag(26, s)?; } + fn write_message(&self, w: &mut Writer) -> Result<()> { + if let Some(ref s) = self.s1 { w.write_with_tag(10, |w| w.write_string(&**s))?; } + if let Some(ref s) = self.s2 { w.write_with_tag(18, |w| w.write_string(&**s))?; } + if let Some(ref s) = self.s3 { w.write_with_tag(26, |w| w.write_string(&**s))?; } Ok(()) } } @@ -222,7 +223,7 @@ impl<'a> TestBytes<'a> { let mut msg = Self::default(); while !r.is_eof() { match r.next_tag(bytes) { - Ok(10) => msg.b1 = Some(Cow::Borrowed(r.read_bytes(bytes)?)), + Ok(10) => msg.b1 = Some(r.read_bytes(bytes).map(Cow::Borrowed)?), Ok(t) => { r.read_unknown(bytes, t)?; } Err(e) => return Err(e), } @@ -233,11 +234,11 @@ impl<'a> TestBytes<'a> { impl<'a> MessageWrite for TestBytes<'a> { fn get_size(&self) -> usize { - self.b1.as_ref().map_or(0, |m| 1 + sizeof_var_length(m.len())) + self.b1.as_ref().map_or(0, |m| 1 + sizeof_len(m.len())) } - fn write_message(&self, r: &mut Writer) -> Result<()> { - if let Some(ref s) = self.b1 { r.write_bytes_with_tag(10, s)?; } + fn write_message(&self, w: &mut Writer) -> Result<()> { + if let Some(ref s) = self.b1 { w.write_with_tag(10, |w| w.write_bytes(&**s))?; } Ok(()) } } @@ -277,25 +278,25 @@ impl<'a> PerftestData<'a> { impl<'a> MessageWrite for PerftestData<'a> { fn get_size(&self) -> usize { - self.test1.iter().map(|s| 1 + sizeof_var_length(s.get_size())).sum::() - + self.test_repeated_bool.iter().map(|s| 1 + sizeof_var_length(s.get_size())).sum::() - + self.test_repeated_messages.iter().map(|s| 1 + sizeof_var_length(s.get_size())).sum::() - + self.test_optional_messages.iter().map(|s| 1 + sizeof_var_length(s.get_size())).sum::() - + self.test_strings.iter().map(|s| 1 + sizeof_var_length(s.get_size())).sum::() - + self.test_repeated_packed_int32.iter().map(|s| 1 + sizeof_var_length(s.get_size())).sum::() - + self.test_small_bytearrays.iter().map(|s| 1 + sizeof_var_length(s.get_size())).sum::() - + self.test_large_bytearrays.iter().map(|s| 1 + sizeof_var_length(s.get_size())).sum::() + self.test1.iter().map(|s| 1 + sizeof_len(s.get_size())).sum::() + + self.test_repeated_bool.iter().map(|s| 1 + sizeof_len(s.get_size())).sum::() + + self.test_repeated_messages.iter().map(|s| 1 + sizeof_len(s.get_size())).sum::() + + self.test_optional_messages.iter().map(|s| 1 + sizeof_len(s.get_size())).sum::() + + self.test_strings.iter().map(|s| 1 + sizeof_len(s.get_size())).sum::() + + self.test_repeated_packed_int32.iter().map(|s| 1 + sizeof_len(s.get_size())).sum::() + + self.test_small_bytearrays.iter().map(|s| 1 + sizeof_len(s.get_size())).sum::() + + self.test_large_bytearrays.iter().map(|s| 1 + sizeof_len(s.get_size())).sum::() } - fn write_message(&self, r: &mut Writer) -> Result<()> { - for s in &self.test1 { r.write_message_with_tag(10, s)? } - for s in &self.test_repeated_bool { r.write_message_with_tag(18, s)? } - for s in &self.test_repeated_messages { r.write_message_with_tag(26, s)? } - for s in &self.test_optional_messages { r.write_message_with_tag(34, s)? } - for s in &self.test_strings { r.write_message_with_tag(42, s)? } - for s in &self.test_repeated_packed_int32 { r.write_message_with_tag(50, s)? } - for s in &self.test_small_bytearrays { r.write_message_with_tag(58, s)? } - for s in &self.test_large_bytearrays { r.write_message_with_tag(66, s)? } + fn write_message(&self, w: &mut Writer) -> Result<()> { + for s in &self.test1 { w.write_with_tag(10, |w| w.write_message(s))?; } + for s in &self.test_repeated_bool { w.write_with_tag(18, |w| w.write_message(s))?; } + for s in &self.test_repeated_messages { w.write_with_tag(26, |w| w.write_message(s))?; } + for s in &self.test_optional_messages { w.write_with_tag(34, |w| w.write_message(s))?; } + for s in &self.test_strings { w.write_with_tag(42, |w| w.write_message(s))?; } + for s in &self.test_repeated_packed_int32 { w.write_with_tag(50, |w| w.write_message(s))?; } + for s in &self.test_small_bytearrays { w.write_with_tag(58, |w| w.write_message(s))?; } + for s in &self.test_large_bytearrays { w.write_with_tag(66, |w| w.write_message(s))?; } Ok(()) } } diff --git a/codegen/src/parser.rs b/codegen/src/parser.rs index 15ce608d..d2afc840 100644 --- a/codegen/src/parser.rs +++ b/codegen/src/parser.rs @@ -1,7 +1,7 @@ use std::str; use std::path::{Path, PathBuf}; -use types::{Frequency, Field, Message, Enumerator, FileDescriptor, Syntax}; +use types::{Frequency, Field, Message, Enumerator, FileDescriptor, Syntax, FieldType}; use nom::{multispace, digit}; fn is_word(b: u8) -> bool { @@ -61,30 +61,59 @@ named!(frequency, tag!("repeated") => { |_| Frequency::Repeated } | tag!("required") => { |_| Frequency::Required } )); +named!(field_type, + alt!(tag!("int32") => { |_| FieldType::Int32 } | + tag!("int64") => { |_| FieldType::Int64 } | + tag!("uint32") => { |_| FieldType::Uint32 } | + tag!("uint64") => { |_| FieldType::Uint64 } | + tag!("sint32") => { |_| FieldType::Sint32 } | + tag!("sint64") => { |_| FieldType::Sint64 } | + tag!("fixed32") => { |_| FieldType::Fixed32 } | + tag!("sfixed32") => { |_| FieldType::Sfixed32 } | + tag!("fixed64") => { |_| FieldType::Fixed64 } | + tag!("sfixed64") => { |_| FieldType::Sfixed64 } | + tag!("bool") => { |_| FieldType::Bool } | + tag!("string") => { |_| FieldType::String_ } | + tag!("bytes") => { |_| FieldType::Bytes } | + tag!("float") => { |_| FieldType::Float } | + tag!("double") => { |_| FieldType::Double } | + map_field => { |(k, v)| FieldType::Map(Box::new((k, v))) } | + word => { |w| FieldType::Message(w) })); + +named!(map_field<(FieldType, FieldType)>, + do_parse!(tag!("map") >> many0!(br) >> tag!("<") >> many0!(br) >> + key: field_type >> many0!(br) >> tag!(",") >> many0!(br) >> + value: field_type >> tag!(">") >> + ((key, value)) )); + named!(message_field, - do_parse!(frequency: opt!(frequency) >> many1!(br) >> - typ: word >> many1!(br) >> + do_parse!(frequency: opt!(frequency) >> many0!(br) >> + typ: field_type >> many1!(br) >> name: word >> many0!(br) >> tag!("=") >> many0!(br) >> number: map_res!(map_res!(digit, str::from_utf8), str::FromStr::from_str) >> many0!(br) >> key_vals: many0!(key_val) >> tag!(";") >> (Field { - name: name, - frequency: frequency.unwrap_or(Frequency::Optional), - typ: typ, - number: number, - default: key_vals.iter().find(|&&(k, _)| k == "default") - .map(|&(_, v)| v.to_string()), - packed: key_vals.iter().find(|&&(k, _)| k == "packed") - .map(|&(_, v)| str::FromStr::from_str(v) - .expect("Cannot parse Packed value")), - boxed: false, - deprecated: key_vals.iter().find(|&&(k, _)| k == "deprecated") - .map_or(false, |&(_, v)| str::FromStr::from_str(v) - .expect("Cannot parse Deprecated value")), + name: name, + frequency: frequency.unwrap_or(Frequency::Optional), + typ: typ, + number: number, + default: key_vals.iter() + .find(|&&(k, _)| k == "default") + .map(|&(_, v)| v.to_string()), + packed: key_vals.iter() + .find(|&&(k, _)| k == "packed") + .map(|&(_, v)| str::FromStr::from_str(v) + .expect("Cannot parse Packed value")), + boxed: false, + deprecated: key_vals.iter() + .find(|&&(k, _)| k == "deprecated") + .map_or(false, |&(_, v)| str::FromStr::from_str(v) + .expect("Cannot parse Deprecated value")), }) )); enum MessageEvent { Message(Message), + Enumerator(Enumerator), Field(Field), ReservedNums(Vec), ReservedNames(Vec), @@ -95,6 +124,7 @@ named!(message_event, alt!(reserved_nums => { |r| MessageEvent::Re reserved_names => { |r| MessageEvent::ReservedNames(r) } | message_field => { |f| MessageEvent::Field(f) } | message => { |m| MessageEvent::Message(m) } | + enumerator => { |e| MessageEvent::Enumerator(e) } | br => { |_| MessageEvent::Ignore })); named!(message_events<(String, Vec)>, @@ -114,6 +144,7 @@ named!(message, MessageEvent::ReservedNums(r) => msg.reserved_nums = Some(r), MessageEvent::ReservedNames(r) => msg.reserved_names = Some(r), MessageEvent::Message(m) => msg.messages.push(m), + MessageEvent::Enumerator(e) => msg.enums.push(e), MessageEvent::Ignore => (), } } @@ -275,4 +306,26 @@ mod test { assert!(mess.messages.len() == 1); } } + + #[test] + fn test_map() { + let msg = r#"message A + { + optional map b = 1; + }"#; + + let mess = message(msg.as_bytes()); + if let ::nom::IResult::Done(_, mess) = mess { + assert_eq!(1, mess.fields.len()); + match mess.fields[0].typ { + FieldType::Map(ref f) => match &**f { + &(FieldType::String_, FieldType::Int32) => (), + ref f => panic!("Expecting Map found {:?}", f), + }, + ref f => panic!("Expecting map, got {:?}", f), + } + } else { + panic!("Could not parse map message"); + } + } } diff --git a/codegen/src/types.rs b/codegen/src/types.rs index 7581f8d7..e009d708 100644 --- a/codegen/src/types.rs +++ b/codegen/src/types.rs @@ -34,330 +34,377 @@ pub enum Frequency { Required, } -#[derive(Debug, Clone)] -pub struct Field { - pub name: String, - pub frequency: Frequency, - pub typ: String, - pub number: i32, - pub default: Option, - pub packed: Option, - pub boxed: bool, - pub deprecated: bool, +#[derive(Debug, Clone, PartialEq, Eq, Hash)] +pub enum FieldType { + Int32, + Int64, + Uint32, + Uint64, + Sint32, + Sint64, + Bool, + Enum(String), + Fixed64, + Sfixed64, + Double, + String_, + Bytes, + Message(String), + Fixed32, + Sfixed32, + Float, + Map(Box<(FieldType, FieldType)>), } -impl Field { - fn packed(&self) -> bool { - self.packed.unwrap_or(false) - } +impl FieldType { fn is_numeric(&self) -> bool { - match &*self.typ { - "int32" | "sint32" | "sfixed32" | - "int64" | "sint64" | "sfixed64" | - "uint32" | "fixed32" | - "uint64" | "fixed64" | - "float" | "double" => true, + match *self { + FieldType::Int32 | + FieldType::Int64 | + FieldType::Uint32 | + FieldType::Uint64 | + FieldType::Sint32 | + FieldType::Sint64 | + FieldType::Fixed64 | + FieldType::Sfixed64 | + FieldType::Double | + FieldType::Fixed32 | + FieldType::Sfixed32 | + FieldType::Float => true, _ => false, } } - /// searches if the message must be boxed - fn is_leaf(&self, leaf_messages: &[String], msgs: &[Message]) -> bool { - match self.frequency { - Frequency::Repeated | Frequency::Required => return true, - Frequency::Optional if !self.is_message(msgs) => true, - _ => { - let typ = match self.typ.rfind('.') { - Some(p) => &self.typ[p + 1..], - None => &self.typ[..], - }; - leaf_messages.iter().any(|m| &*m == &typ) - }, + fn is_cow(&self) -> bool { + match *self { + FieldType::Bytes | FieldType::String_ => true, + _ => false, } } - fn is_message(&self, msgs: &[Message]) -> bool { - self.find_message(msgs).is_some() - } - - fn is_enum(&self, msgs: &[Message]) -> bool { - self.get_type(msgs) == "enum" - } - - fn is_fixed_size(&self, msgs: &[Message]) -> bool { - match self.wire_type_num_non_packed(msgs) { - 1 | 5 => true, + fn is_map(&self) -> bool { + match *self { + FieldType::Map(_) => true, _ => false, } } - fn is_cow(&self) -> bool { - match &*self.typ { - "bytes" | "string" => true, + fn wire_type_num(&self, packed: bool) -> u32 { + if packed { + 2 + } else { + self.wire_type_num_non_packed() + } + } + + fn wire_type_num_non_packed(&self) -> u32 { + match *self { + FieldType::Int32 | FieldType::Sint32 | FieldType::Int64 | + FieldType::Sint64 | FieldType::Uint32 | FieldType::Uint64 | + FieldType::Bool | FieldType::Enum(_) => 0, + FieldType::Fixed64 | FieldType::Sfixed64 | FieldType::Double => 1, + FieldType::String_ | FieldType::Bytes | + FieldType::Message(_) | FieldType::Map(_) => 2, + FieldType::Fixed32 | FieldType::Sfixed32 | FieldType::Float => 5, + } + } + + fn proto_type(&self) -> &str { + match *self { + FieldType::Int32 => "int32", + FieldType::Sint32 => "sint32", + FieldType::Int64 => "int64", + FieldType::Sint64 => "sint64", + FieldType::Uint32 => "uint32", + FieldType::Uint64 => "uint64", + FieldType::Bool => "bool", + FieldType::Enum(_) => "enum", + FieldType::Fixed32 => "fixed32", + FieldType::Sfixed32 => "sfixed32", + FieldType::Float => "float", + FieldType::Fixed64 => "fixed64", + FieldType::Sfixed64 => "sfixed64", + FieldType::Double => "double", + FieldType::String_ => "string", + FieldType::Bytes => "bytes", + FieldType::Message(_) => "message", + FieldType::Map(_) => "map", + } + } + + fn is_fixed_size(&self) -> bool { + match self.wire_type_num_non_packed() { + 1 | 5 => true, _ => false, } } + /// Searches for message corresponding to the current type + /// + /// Searches first basic name then within nested messages fn find_message<'a, 'b>(&'a self, msgs: &'b [Message]) -> Option<&'b Message> { + match *self { + FieldType::Message(ref m) => { + let mut found = match m.rfind('.') { + Some(p) => { + let package = &m[..p]; + let name = &m[(p + 1)..]; + msgs.iter().find(|m| m.package == package && m.name == name) + }, + None => msgs.iter().find(|m2| m2.name == &m[..]), + }; - let mut found = match self.typ.rfind('.') { - Some(p) => { - let package = &self.typ[..p]; - let name = &self.typ[(p + 1)..]; - msgs.iter().find(|m| m.package == package && m.name == name) - }, - None => msgs.iter().find(|m| m.name == self.typ), - }; - - if found.is_none() { - // recursively search into nested messages - for m in msgs { - found = self.find_message(&m.messages); - if found.is_some() { break; } - } - } - - found - } - - fn find_enum<'a, 'b>(&'a self, enums: &'b [Enumerator]) -> Option<&'b Enumerator> { - enums.iter().find(|m| m.name == self.typ) - } + if found.is_none() { + // recursively search into nested messages + for m in msgs { + found = self.find_message(&m.messages); + if found.is_some() { break; } + } + } - fn has_unregular_default(&self, enums: &[Enumerator], msgs: &[Message]) -> bool { - match self.default { - None => false, - Some(ref d) => match &*self.rust_type(msgs) { - "i32" | "i64" | "u32" | "u64" | "f32" | "f64" => d.parse::().unwrap() != 0., - "bool" => *d != "false", - "Cow<'a, str>" => *d != "\"\"", - "Cow<'a, [u8]>" => *d != "[]", - t => self.find_enum(enums).map_or(false, |e| t != e.fields[0].0), - } + found + }, + _ => None, } } fn has_lifetime(&self, msgs: &[Message]) -> bool { - if self.is_cow() { return true; } - self.find_message(msgs).map_or(false, |m| m.has_lifetime(msgs)) + match *self { + FieldType::String_ | FieldType::Bytes => true, // Cow + FieldType::Message(_) => self.find_message(msgs).map_or(false, |m| m.has_lifetime(msgs)), + FieldType::Map(ref m) => { + let &(ref key, ref value) = &**m; + key.has_lifetime(msgs) || value.has_lifetime(msgs) + } + _ => false, + } } fn rust_type(&self, msgs: &[Message]) -> String { - match &*self.typ { - "int32" | "sint32" | "sfixed32" => "i32".to_string(), - "int64" | "sint64" | "sfixed64" => "i64".to_string(), - "uint32" | "fixed32" => "u32".to_string(), - "uint64" | "fixed64" => "u64".to_string(), - "float" => "f32".to_string(), - "double" => "f64".to_string(), - "string" => "Cow<'a, str>".to_string(), - "bytes" => "Cow<'a, [u8]>".to_string(), - t => match self.find_message(msgs) { + match *self { + FieldType::Int32 | FieldType::Sint32 | FieldType::Sfixed32 => "i32".to_string(), + FieldType::Int64 | FieldType::Sint64 | FieldType::Sfixed64 => "i64".to_string(), + FieldType::Uint32 | FieldType::Fixed32 => "u32".to_string(), + FieldType::Uint64 | FieldType::Fixed64 => "u64".to_string(), + FieldType::Double => "f64".to_string(), + FieldType::Float => "f32".to_string(), + FieldType::String_ => "Cow<'a, str>".to_string(), + FieldType::Bytes => "Cow<'a, [u8]>".to_string(), + FieldType::Bool => "bool".to_string(), + FieldType::Enum(ref e) => e.replace(".", "::"), + FieldType::Message(ref msg) => match self.find_message(msgs) { Some(m) => { let lifetime = if m.has_lifetime(msgs) { "<'a>" } else { "" }; - let package = m.package.split('.').filter(|p| !p.is_empty()) - .map(|p| format!("mod_{}::", p)).collect::(); - format!("{}{}{}", package, m.name, lifetime) + format!("{}{}{}", m.get_modules(), m.name, lifetime) }, - None => t.replace(".", "::"), // enum + None => unreachable!(format!("Could not find message {}", msg)), + }, + FieldType::Map(ref t) => { + let &(ref key, ref value) = &**t; + format!("HashMap<{}, {}>", key.rust_type(msgs), value.rust_type(msgs)) } } } - fn wire_type_num(&self, msgs: &[Message]) -> u32 { - if self.packed() { - 2 - } else { - self.wire_type_num_non_packed(msgs) + /// Returns the relevant function to read the data, both for regular and Cow wrapped + fn read_fn(&self, msgs: &[Message]) -> (String, String) { + match *self { + FieldType::Message(ref msg) => match self.find_message(msgs) { + Some(m) => { + let m = format!("r.read_message(bytes, {}{}::from_reader)", m.get_modules(), m.name); + (m.clone(), m) + } + None => unreachable!(format!("Could not find message {}", msg)), + }, + FieldType::Map(_) => unreachable!("There should be a special case for maps"), + FieldType::String_ | FieldType::Bytes => { + let m = format!("r.read_{}(bytes)", self.proto_type()); + let cow = format!("{}.map(Cow::Borrowed)", m); + (m, cow) + }, + _ => { + let m = format!("r.read_{}(bytes)", self.proto_type()); + (m.clone(), m) + } } } - fn wire_type_num_non_packed(&self, msgs: &[Message]) -> u32 { - match &*self.typ { - "int32" | "sint32" | "int64" | "sint64" | - "uint32" | "uint64" | "bool" | "enum" => 0, - "fixed64" | "sfixed64" | "double" => 1, - "fixed32" | "sfixed32" | "float" => 5, - "string" | "bytes" => 2, - _ => if self.is_message(msgs) { 2 } else { 0 /* enum */ } - } - } + fn get_size(&self, s: &str) -> String { + match *self { + FieldType::Int32 | FieldType::Sint32 | FieldType::Int64 | + FieldType::Sint64 | FieldType::Uint32 | FieldType::Uint64 | + FieldType::Bool | FieldType::Enum(_) => format!("sizeof_varint(*{} as u64)", s), + + FieldType::Fixed64 | FieldType::Sfixed64 | FieldType::Double => "8".to_string(), + FieldType::Fixed32 | FieldType::Sfixed32 | FieldType::Float => "4".to_string(), + + FieldType::String_ | FieldType::Bytes => format!("sizeof_len({}.len())", s), - fn get_type(&self, msgs: &[Message]) -> &str { - match &*self.typ { - "int32" | "sint32" | "int64" | "sint64" | - "uint32" | "uint64" | "bool" | "fixed64" | - "sfixed64" | "double" | "fixed32" | "sfixed32" | - "float" | "bytes" | "string" => &self.typ, - _ => if self.is_message(msgs) { "message" } else { "enum" }, + FieldType::Message(_) => format!("sizeof_len({}.get_size())", s), + + FieldType::Map(ref m) => { + let &(ref k, ref v) = &**m; + format!("2 + {} + {}", k.get_size("k"), v.get_size("v")) + } } } - fn read_fn(&self, msgs: &[Message]) -> String { - match self.find_message(msgs) { - Some(m) if m.package.is_empty()=> { - format!("read_message(bytes, {}::from_reader)", m.name) + fn get_write(&self, s: &str, boxed: bool) -> String { + match *self { + FieldType::Enum(_) => format!("write_enum(*{} as i32)", s), + + FieldType::Int32 | FieldType::Sint32 | FieldType::Int64 | + FieldType::Sint64 | FieldType::Uint32 | FieldType::Uint64 | + FieldType::Bool | + FieldType::Fixed64 | FieldType::Sfixed64 | FieldType::Double | + FieldType::Fixed32 | FieldType::Sfixed32 | FieldType::Float => { + format!("write_{}(*{})", self.proto_type(), s) }, - Some(m) => { - format!("read_message(bytes, {}{}::from_reader)", - m.package.split('.').map(|p| format!("mod_{}::", p)).collect::(), m.name) - } - None => { - format!("read_{}(bytes)", self.get_type(msgs)) + + FieldType::String_ => format!("write_string(&**{})", s), + FieldType::Bytes => format!("write_bytes(&**{})", s), + + FieldType::Message(_) if boxed => format!("write_message(&**{})", s), + FieldType::Message(_) => format!("write_message({})", s), + + FieldType::Map(ref m) => { + let &(ref k, ref v) = &**m; + format!("write_map({}, {}, |w| w.{}, {}, |w| w.{})", + self.get_size(""), + tag(1, k, false), k.get_write("k", false), + tag(2, v, false), v.get_write("v", false)) } } } +} + +#[derive(Debug, Clone)] +pub struct Field { + pub name: String, + pub frequency: Frequency, + pub typ: FieldType, + pub number: i32, + pub default: Option, + pub packed: Option, + pub boxed: bool, + pub deprecated: bool, +} - fn tag(&self, msgs: &[Message]) -> u32 { - (self.number as u32) << 3 | self.wire_type_num(msgs) +impl Field { + fn packed(&self) -> bool { + self.packed.unwrap_or(false) } - fn write_definition(&self, w: &mut W, msgs: &[Message]) -> Result<()> { - match self.frequency { - Frequency::Optional => { - if self.boxed { - writeln!(w, " pub {}: Option>,", self.name, self.rust_type(msgs))? - } else { - if self.default.is_none() { - writeln!(w, " pub {}: Option<{}>,", self.name, self.rust_type(msgs))? - } else { - writeln!(w, " pub {}: {},", self.name, self.rust_type(msgs))? + /// searches if the message must be boxed + fn is_leaf(&self, leaf_messages: &[String]) -> bool { + match self.typ { + FieldType::Message(ref s) => { + match self.frequency { + Frequency::Repeated | Frequency::Required => true, + _ => { + let typ = match s.rfind('.') { + Some(p) => &s[p + 1..], + None => &s[..], + }; + leaf_messages.iter().any(|m| &*m == &typ) } } } - Frequency::Repeated => writeln!(w, " pub {}: Vec<{}>,", self.name, self.rust_type(msgs))?, - Frequency::Required => writeln!(w, " pub {}: {},", self.name, self.rust_type(msgs))?, + _ => true, } - Ok(()) } - fn write_match_tag_owned(&self, w: &mut W, msgs: &[Message]) -> Result<()> { + fn tag(&self) -> u32 { + tag(self.number as u32, &self.typ, self.packed()) + } + + fn write_definition(&self, w: &mut W, msgs: &[Message]) -> Result<()> { + write!(w, " pub {}: ", self.name)?; match self.frequency { - Frequency::Optional => { - if self.boxed { - writeln!(w, "Ok({}) => msg.{} = Some(Box::new(r.{}?)),", - self.tag(msgs), self.name, self.read_fn(msgs))? - } else { - if self.default.is_none() { - writeln!(w, "Ok({}) => msg.{} = Some(r.{}?),", - self.tag(msgs), self.name, self.read_fn(msgs))? - } else { - writeln!(w, "Ok({}) => msg.{} = r.{}?,", - self.tag(msgs), self.name, self.read_fn(msgs))? - } - } - } - Frequency::Repeated => { - if self.packed() { - writeln!(w, "Ok({}) => msg.{} = r.read_packed(bytes, |r, bytes| r.{})?,", - self.tag(msgs), self.name, self.read_fn(msgs))? - } else { - writeln!(w, "Ok({}) => msg.{}.push(r.{}?),", - self.tag(msgs), self.name, self.read_fn(msgs))? - } - } - Frequency::Required => { - writeln!(w, "Ok({}) => msg.{} = r.{}?,", - self.tag(msgs), self.name, self.read_fn(msgs))? - } + Frequency::Optional if self.boxed => writeln!(w, "Option>,", self.typ.rust_type(msgs))?, + Frequency::Optional if self.default.is_some() => writeln!(w, "{},", self.typ.rust_type(msgs))?, + Frequency::Optional => writeln!(w, "Option<{}>,", self.typ.rust_type(msgs))?, + Frequency::Repeated => writeln!(w, "Vec<{}>,", self.typ.rust_type(msgs))?, + Frequency::Required => writeln!(w, "{},", self.typ.rust_type(msgs))?, } Ok(()) } - fn write_match_tag_borrowed(&self, w: &mut W, msgs: &[Message]) -> Result<()> { + fn write_match_tag(&self, w: &mut W, msgs: &[Message]) -> Result<()> { + + // special case for FieldType::Map: destructure tuple before inserting in HashMap + if let FieldType::Map(ref m) = self.typ { + let &(ref key, ref value) = &**m; + + writeln!(w, " Ok({}) => {{", self.tag())?; + writeln!(w, " let (key, value) = r.read_map(bytes, |r, bytes| {}, |r, bytes| {})?;", + key.read_fn(msgs).1, value.read_fn(msgs).1)?; + writeln!(w, " msg.{}.insert(key, value);", self.name)?; + writeln!(w, " }}")?; + return Ok(()); + } + + let (val, val_cow) = self.typ.read_fn(msgs); + let name = &self.name; + write!(w, " Ok({}) => ", self.tag())?; match self.frequency { - Frequency::Optional => { - if self.boxed { - writeln!(w, "Ok({}) => msg.{} = Some(Box::new(Cow::Borrowed(r.{}?))),", - self.tag(msgs), self.name, self.read_fn(msgs))? - } else { - if self.default.is_none() { - writeln!(w, "Ok({}) => msg.{} = Some(Cow::Borrowed(r.{}?)),", - self.tag(msgs), self.name, self.read_fn(msgs))? - } else { - writeln!(w, "Ok({}) => msg.{} = Cow::Borrowed(r.{}?),", - self.tag(msgs), self.name, self.read_fn(msgs))? - } - } - } - Frequency::Repeated => { - if self.packed() { - writeln!(w, "Ok({}) => msg.{} = r.read_packed(bytes, |r, bytes| r.{})?,", - self.tag(msgs), self.name, self.read_fn(msgs))? - } else { - writeln!(w, "Ok({}) => msg.{}.push(Cow::Borrowed(r.{}?)),", - self.tag(msgs), self.name, self.read_fn(msgs))? - } - } - Frequency::Required => { - writeln!(w, "Ok({}) => msg.{} = Cow::Borrowed(r.{}?),", - self.tag(msgs), self.name, self.read_fn(msgs))? - } + Frequency::Required => writeln!(w, "msg.{} = {}?,", name, val_cow)?, + Frequency::Optional if self.boxed => writeln!(w, "msg.{} = Some(Box::new({}?)),", name, val)?, + Frequency::Optional if self.default.is_some() => writeln!(w, "msg.{} = {}?,", name, val_cow)?, + Frequency::Optional => writeln!(w, "msg.{} = Some({}?),", name, val_cow)?, + Frequency::Repeated if self.packed() => { + writeln!(w, "msg.{} = r.read_packed(bytes, |r, bytes| {})?,", name, val)?; + }, + Frequency::Repeated => writeln!(w, "msg.{}.push({}?),", name, val)?, } Ok(()) } - fn write_get_size(&self, w: &mut W, msgs: &[Message], is_first: bool) -> Result<()> { + fn write_get_size(&self, w: &mut W, is_first: bool) -> Result<()> { if is_first { write!(w, " ")?; } else { write!(w, " + ")?; } + let tag_size = sizeof_varint(self.tag()); match self.frequency { - Frequency::Required => { - self.write_inner_get_size(w, msgs, &format!("self.{}", self.name), "")?; - writeln!(w, "")?; + Frequency::Required if self.typ.is_map() => { + writeln!(w, "self.{}.iter().map(|(k, v)| {} + sizeof_len({})).sum::()", + self.name, tag_size, self.typ.get_size(""))?; } + Frequency::Required => writeln!(w, "{} + {}", tag_size, self.typ.get_size(&format!("&self.{}", self.name)))?, Frequency::Optional => { match self.default.as_ref() { None => { - if self.is_fixed_size(msgs) { - write!(w, "self.{}.as_ref().map_or(0, |_| ", self.name)?; + write!(w, "self.{}.as_ref().map_or(0, ", self.name)?; + if self.typ.is_fixed_size() { + writeln!(w, "|_| {} + {})", tag_size, self.typ.get_size(""))?; } else { - write!(w, "self.{}.as_ref().map_or(0, |m| ", self.name)?; + writeln!(w, "|m| {} + {})", tag_size, self.typ.get_size("m"))?; } - self.write_inner_get_size(w, msgs, "m", "*")?; - writeln!(w, ")")?; } Some(d) => { - write!(w, "if self.{} == {} {{ 0 }} else {{", self.name, d)?; - self.write_inner_get_size(w, msgs, &format!("self.{}", self.name), "")?; - writeln!(w, "}}")?; + write!(w, "if self.{} == {} {{ 0 }} else {{ {} + {} }}", + self.name, d, tag_size, self.typ.get_size(&format!("&self.{}", self.name)))?; } } } Frequency::Repeated => { - let tag_size = sizeof_varint(self.tag(msgs)); - let get_type = self.get_type(msgs); - let as_enum = if self.is_enum(msgs) { " as i32" } else { "" }; if self.packed() { - write!(w, "if self.{}.is_empty() {{ 0 }} else {{ ", self.name)?; - match self.wire_type_num_non_packed(msgs) { - 0 => write!(w, "{} + sizeof_var_length(self.{}.iter().map(|s| sizeof_{}(*s{})).sum::())", - tag_size, self.name, get_type, as_enum)?, - 1 => write!(w, "{} + sizeof_var_length(self.{}.len() * 8)", tag_size, self.name)?, - 5 => write!(w, "{} + sizeof_var_length(self.{}.len() * 4)", tag_size, self.name)?, - 2 => { - let len = if self.is_message(msgs) { "get_size" } else { "len" }; - write!(w, "{} + sizeof_var_length(self.{}.iter().map(|s| sizeof_var_length(s.{}())).sum::())", - tag_size, self.name, len)?; - } - e => panic!("expecting wire type number, got: {}", e), + write!(w, "if self.{}.is_empty() {{ 0 }} else {{ {} + ", self.name, tag_size)?; + match self.typ.wire_type_num_non_packed() { + 1 => writeln!(w, "sizeof_len(self.{}.len() * 8) }}", self.name)?, + 5 => writeln!(w, "sizeof_len(self.{}.len() * 4) }}", self.name)?, + _ => writeln!(w, "sizeof_len(self.{}.iter().map(|s| {}).sum::()) }}", + self.name, self.typ.get_size("s"))?, } - writeln!(w, " }}")?; } else { - match self.wire_type_num_non_packed(msgs) { - 0 => writeln!(w, "self.{}.iter().map(|s| {} + sizeof_{}(*s{})).sum::()", - self.name, tag_size, get_type, as_enum)?, + match self.typ.wire_type_num_non_packed() { 1 => writeln!(w, "({} + 8) * self.{}.len()", tag_size, self.name)?, 5 => writeln!(w, "({} + 4) * self.{}.len()", tag_size, self.name)?, - 2 => { - let len = if self.is_message(msgs) { "get_size" } else { "len" }; - writeln!(w, "self.{}.iter().map(|s| {} + sizeof_var_length(s.{}())).sum::()", - self.name, tag_size, len)?; - } - e => panic!("expecting wire type number, got: {}", e), + _ => writeln!(w, "self.{}.iter().map(|s| {} + {}).sum::()", + self.name, tag_size, self.typ.get_size("s"))?, } } } @@ -365,79 +412,33 @@ impl Field { Ok(()) } - fn write_inner_get_size(&self, w: &mut W, msgs: &[Message], s: &str, as_ref: &str) -> Result<()> { - let tag_size = sizeof_varint(self.tag(msgs)); - match self.wire_type_num_non_packed(msgs) { - 0 => { - let get_type = self.get_type(msgs); - let as_enum = if self.is_enum(msgs) { " as i32" } else { "" }; - write!(w, "{} + sizeof_{}({}{}{})", tag_size, get_type, as_ref, s, as_enum)? - }, - 1 => write!(w, "{} + 8", tag_size)?, - 5 => write!(w, "{} + 4", tag_size)?, - 2 => { - let len = if self.is_message(msgs) { "get_size" } else { "len" }; - if self.packed() { - write!(w, "if s.is_empty() {{ 0 }} else {{ {} + sizeof_var_length({}.{}()) }}", tag_size, s, len)?; - } else { - write!(w, "{} + sizeof_var_length({}.{}())", tag_size, s, len)?; - } - } - e => panic!("expecting wire type number, got: {}", e), - } - Ok(()) - } - - fn write_write(&self, w: &mut W, msgs: &[Message]) -> Result<()> { - let tag = self.tag(msgs); - let use_ref = self.wire_type_num_non_packed(msgs) == 2; - let get_type = self.get_type(msgs); - let as_enum = if self.is_enum(msgs) { " as i32" } else { "" }; + fn write_write(&self, w: &mut W) -> Result<()> { match self.frequency { + Frequency::Required if self.typ.is_map() => { + writeln!(w, " for (k, v) in self.{}.iter() {{ w.write_with_tag({}, |w| w.{})?; }}", + self.name, self.tag(), self.typ.get_write("", false))?; + } Frequency::Required => { - let r = if use_ref { "&" } else { "" }; - writeln!(w, " r.write_{}_with_tag({}, {}self.{}{})?;", get_type, tag, r, self.name, as_enum)?; - }, - Frequency::Optional => { - let r = if use_ref { - if self.boxed { "&**" } else { "" } - } else { - "*" - }; - match self.default.as_ref() { - None => { - writeln!(w, " if let Some(ref s) = self.{} {{ r.write_{}_with_tag({}, {}s{})?; }}", - self.name, get_type, tag, r, as_enum)?; - }, - Some(d) => { - writeln!(w, " if self.{} != {} {{ r.write_{}_with_tag({}, self.{0}{})?; }}", - self.name, d, get_type, tag, as_enum)?; - } + writeln!(w, " w.write_with_tag({}, |w| w.{})?;", + self.tag(), self.typ.get_write(&format!("&self.{}", self.name), self.boxed))?; + } + Frequency::Optional => match self.default.as_ref() { + None => { + writeln!(w, " if let Some(ref s) = self.{} {{ w.write_with_tag({}, |w| w.{})?; }}", + self.name, self.tag(), self.typ.get_write("s", self.boxed))?; + }, + Some(d) => { + writeln!(w, " if self.{} != {} {{ w.write_with_tag({}, |w| w.{})?; }}", + self.name, d, self.tag(), self.typ.get_write(&format!("&self.{}", self.name), self.boxed))?; } + }, + Frequency::Repeated if self.packed() => { + writeln!(w, " w.write_packed_with_tag({}, &self.{}, |w, m| w.{}, &|m| {})?;", + self.tag(), self.name, self.typ.get_write("m", self.boxed), self.typ.get_size("m"))? } Frequency::Repeated => { - if self.packed() { - match get_type { - "message" => { - writeln!(w, " r.write_packed_repeated_field_with_tag({}, &self.{}, |r, m| r.write_{}({}m{}), \ - &|m| sizeof_var_length(m.get_size()))?;", - tag, self.name, get_type, if use_ref { "" } else { "*" }, as_enum)? - }, - "bytes" | "string" => { - writeln!(w, " r.write_packed_repeated_field_with_tag({}, &self.{}, |r, m| r.write_{}({}m{}), \ - &|m| sizeof_var_length(m.len()))?;", - tag, self.name, get_type, if use_ref { "" } else { "*" }, as_enum)? - }, - t => { - writeln!(w, " r.write_packed_repeated_field_with_tag({}, &self.{}, |r, m| r.write_{}({}m{}), \ - &|m| sizeof_{}(*m))?;", - tag, self.name, get_type, if use_ref { "" } else { "*" }, as_enum, t)? - }, - } - } else { - writeln!(w, " for s in &self.{} {{ r.write_{}_with_tag({}, {}s{})? }}", - self.name, get_type, tag, if use_ref { "" } else { "*" }, as_enum)?; - } + writeln!(w, " for s in &self.{} {{ w.write_with_tag({}, |w| w.{})?; }}", + self.name, self.tag(), self.typ.get_write("s", self.boxed))?; } } Ok(()) @@ -446,31 +447,66 @@ impl Field { #[derive(Debug, Clone, Default)] pub struct Message { - pub messages: Vec, pub name: String, pub fields: Vec, pub reserved_nums: Option>, pub reserved_names: Option>, pub imported: bool, - pub package: String, // package from imports + nested classes + pub package: String, // package from imports + nested items + pub messages: Vec, // nested messages + pub enums: Vec, // nested enums } impl Message { - - fn is_leaf(&self, leaf_messages: &[String], msgs: &[Message]) -> bool { - self.imported || self.fields.iter().all(|f| f.is_leaf(leaf_messages, msgs) || f.deprecated) + fn is_leaf(&self, leaf_messages: &[String]) -> bool { + self.imported || self.fields.iter().all(|f| f.is_leaf(leaf_messages) || f.deprecated) } fn has_lifetime(&self, msgs: &[Message]) -> bool { - self.fields.iter().any(|f| f.typ != self.name && f.has_lifetime(msgs)) + self.fields.iter().any(|f| match f.typ { + FieldType::Message(ref m) if &m[..] == self.name => false, + ref t => t.has_lifetime(msgs), + }) } - fn write_definition(&self, w: &mut W, enums: &[Enumerator], msgs: &[Message]) -> Result<()> { - if self.can_derive_default(enums, msgs) { - writeln!(w, "#[derive(Debug, Default, PartialEq, Clone)]")?; - } else { - writeln!(w, "#[derive(Debug, PartialEq, Clone)]")?; + fn get_modules(&self) -> String { + self.package + .split('.').filter(|p| !p.is_empty()) + .map(|p| format!("mod_{}::", p)) + .collect() + } + + fn write(&self, w: &mut W, msgs: &[Message]) -> Result<()> { + println!("Writing message {}{}", self.get_modules(), self.name); + writeln!(w, "")?; + self.write_definition(w, msgs)?; + writeln!(w, "")?; + self.write_impl_message_read(w, msgs)?; + writeln!(w, "")?; + self.write_impl_message_write(w, msgs)?; + + if !self.messages.is_empty() { + writeln!(w, "")?; + writeln!(w, "pub mod mod_{} {{", self.name)?; + writeln!(w, "")?; + if self.messages.iter().any(|m| m.fields.iter().any(|f| f.typ.is_cow())) { + writeln!(w, "use std::borrow::Cow;")?; + } + if self.messages.iter().any(|m| m.fields.iter().any(|f| f.typ.is_map())) { + writeln!(w, "use std::collections::HashMap;")?; + } + writeln!(w, "use super::*;")?; + for m in &self.messages { + m.write(w, msgs)?; + } + writeln!(w, "")?; + writeln!(w, "}}")?; } + Ok(()) + } + + fn write_definition(&self, w: &mut W, msgs: &[Message]) -> Result<()> { + writeln!(w, "#[derive(Debug, Default, PartialEq, Clone)]")?; if self.has_lifetime(msgs) { writeln!(w, "pub struct {}<'a> {{", self.name)?; } else { @@ -483,11 +519,7 @@ impl Message { Ok(()) } - fn can_derive_default(&self, enums: &[Enumerator], msgs: &[Message]) -> bool { - self.fields.iter().all(|f| f.deprecated || !f.has_unregular_default(enums, msgs)) - } - - fn write_impl_message_read(&self, w: &mut W, enums: &[Enumerator], msgs: &[Message]) -> Result<()> { + fn write_impl_message_read(&self, w: &mut W, msgs: &[Message]) -> Result<()> { if self.has_lifetime(msgs) { writeln!(w, "impl<'a> {}<'a> {{", self.name)?; writeln!(w, " pub fn from_reader(r: &mut BytesReader, bytes: &'a [u8]) -> Result {{")?; @@ -499,12 +531,7 @@ impl Message { writeln!(w, " while !r.is_eof() {{")?; writeln!(w, " match r.next_tag(bytes) {{")?; for f in self.fields.iter().filter(|f| !f.deprecated) { - write!(w, " ")?; - if f.is_cow() { - f.write_match_tag_borrowed(w, msgs)?; - } else { - f.write_match_tag_owned(w, msgs)?; - } + f.write_match_tag(w, msgs)?; } writeln!(w, " Ok(t) => {{ r.read_unknown(bytes, t)?; }}")?; writeln!(w, " Err(e) => return Err(e),")?; @@ -514,10 +541,9 @@ impl Message { writeln!(w, " }}")?; writeln!(w, "}}")?; - if !self.can_derive_default(enums, msgs) { -// writeln!(w, "")?; -// self.write_impl_default(w, msgs)?; - } + // TODO: write impl default when special default? + // alternatively set the default value directly when reading + Ok(()) } @@ -527,51 +553,32 @@ impl Message { } else { writeln!(w, "impl MessageWrite for {} {{", self.name)?; } - self.write_get_size(w, msgs)?; + self.write_get_size(w)?; writeln!(w, "")?; - self.write_write_message(w, msgs)?; + self.write_write_message(w)?; writeln!(w, "}}")?; Ok(()) } - fn write_get_size(&self, w: &mut W, msgs: &[Message]) -> Result<()> { + fn write_get_size(&self, w: &mut W) -> Result<()> { writeln!(w, " fn get_size(&self) -> usize {{")?; for (i, f) in self.fields.iter().filter(|f| !f.deprecated).enumerate() { - f.write_get_size(w, msgs, i == 0)?; + f.write_get_size(w, i == 0)?; } writeln!(w, " }}")?; Ok(()) } - fn write_write_message(&self, w: &mut W, msgs: &[Message]) -> Result<()> { - writeln!(w, " fn write_message(&self, r: &mut Writer) -> Result<()> {{")?; + fn write_write_message(&self, w: &mut W) -> Result<()> { + writeln!(w, " fn write_message(&self, w: &mut Writer) -> Result<()> {{")?; for f in self.fields.iter().filter(|f| !f.deprecated) { - f.write_write(w, msgs)?; + f.write_write(w)?; } writeln!(w, " Ok(())")?; writeln!(w, " }}")?; Ok(()) } -// fn write_impl_default(&self, w: &mut W, msgs: &[Message]) -> IoResult<()> { -// writeln!(w, "impl Default for {} {{", self.name)?; -// writeln!(w, " fn default() -> Self {{")?; -// writeln!(w, " {} {{", self.name)?; -// for f in self.fields.iter().filter(|f| !f.deprecated) { -// match f.default { -// None => writeln!(w, " {}::default(),", f.rust_type())?, -// Some(ref d) => if msgs.iter().any(|m| m.name == f.typ) { -// writeln!(w, " {}: {},", f.name, d)? -// } else { -// writeln!(w, " {}: {}::{},", f.name, f.typ, d)? -// } -// } -// } -// writeln!(w, " }}")?; -// writeln!(w, " }}")?; -// writeln!(w, "}}") -// } - fn sanity_checks(&self) -> Result<()> { // checks for reserved fields for f in &self.fields { @@ -596,6 +603,22 @@ impl Message { } } + /// Searches for a matching message in all message + /// + /// If none is found, + fn set_enums(&mut self, msgs: &[Message]) { + for f in &mut self.fields { + if f.typ.find_message(&msgs).is_none() { + if let FieldType::Message(m) = f.typ.clone() { + f.typ = FieldType::Enum(m); + f.boxed = false; + } + } + } + for m in &mut self.messages { + m.set_enums(msgs); + } + } } #[derive(Debug, Clone)] @@ -661,6 +684,7 @@ impl FileDescriptor { break_cycles(&mut desc.messages, &mut leaf_messages); desc.sanity_checks(in_file.as_ref())?; + desc.set_enums(); desc.set_defaults(); let name = in_file.as_ref().file_name().and_then(|e| e.to_str()).unwrap(); @@ -724,6 +748,14 @@ impl FileDescriptor { } fn set_defaults(&mut self) { + // set map fields as required (they are equivalent to repeated message) + for m in &mut self.messages { + for f in &mut m.fields { + if let FieldType::Map(_) = f.typ { + f.frequency = Frequency::Required; + } + } + } // if proto3, then changes several defaults if let Syntax::Proto3 = self.syntax { for m in &mut self.messages { @@ -733,7 +765,7 @@ impl FileDescriptor { f.packed = Some(true); } } - if f.default.is_none() && f.is_numeric() { + if f.default.is_none() && f.typ.is_numeric() { f.default = Some("0".to_string()); } } @@ -741,6 +773,15 @@ impl FileDescriptor { } } + fn set_enums(&mut self) { + // this is very inefficient but we don't care ... + let msgs = self.messages.clone(); + + for m in &mut self.messages { + m.set_enums(&msgs); + } + } + fn write(&self, w: &mut W, filename: &str) -> Result<()> { println!("Found {} messages, and {} enums", self.messages.len(), self.enums.len()); self.write_headers(w, filename)?; @@ -766,9 +807,12 @@ impl FileDescriptor { fn write_uses(&self, w: &mut W) -> Result<()> { writeln!(w, "use std::io::{{Write}};")?; - if self.messages.iter().any(|m| m.has_lifetime(&self.messages)) { + if self.messages.iter().any(|m| m.fields.iter().any(|f| f.typ.is_cow())) { writeln!(w, "use std::borrow::Cow;")?; } + if self.messages.iter().any(|m| m.fields.iter().any(|f| f.typ.is_map())) { + writeln!(w, "use std::collections::HashMap;")?; + } writeln!(w, "use quick_protobuf::{{MessageWrite, BytesReader, Writer, Result}};")?; writeln!(w, "use quick_protobuf::sizeofs::*;")?; Ok(()) @@ -829,31 +873,7 @@ impl FileDescriptor { fn write_messages(&self, w: &mut W) -> Result<()> { for m in self.messages.iter().filter(|m| !m.imported) { - println!("Writing message {}", m.name); - writeln!(w, "")?; - m.write_definition(w, &self.enums, &self.messages)?; - writeln!(w, "")?; - m.write_impl_message_read(w, &self.enums, &self.messages)?; - writeln!(w, "")?; - m.write_impl_message_write(w, &self.messages)?; - - if !m.messages.is_empty() { - writeln!(w, "")?; - writeln!(w, "pub mod mod_{} {{", m.name)?; - writeln!(w, "")?; - writeln!(w, "use super::*;")?; - for m_sub in &m.messages { - println!("Writing message mod_{}::{}", m.name, m_sub.name); - writeln!(w, "")?; - m_sub.write_definition(w, &self.enums, &self.messages)?; - writeln!(w, "")?; - m_sub.write_impl_message_read(w, &self.enums, &self.messages)?; - writeln!(w, "")?; - m_sub.write_impl_message_write(w, &self.messages)?; - } - writeln!(w, "")?; - writeln!(w, "}}")?; - } + m.write(w, &self.messages)?; } Ok(()) } @@ -863,6 +883,9 @@ fn get_imported_path, Q: AsRef>(in_file: P, import: Q) -> P in_file.as_ref().parent().map_or_else(|| import.as_ref().into(), |p| p.join(import.as_ref())) } +/// Breaks cycles by adding boxes when necessary +/// +/// Cycles means one Message calls itself at some point fn break_cycles(messages: &mut [Message], leaf_messages: &mut Vec) { for m in messages.iter_mut() { @@ -876,7 +899,7 @@ fn break_cycles(messages: &mut [Message], leaf_messages: &mut Vec) { let len = undef_messages.len(); let mut new_undefs = Vec::new(); for i in undef_messages { - if messages[i].is_leaf(&leaf_messages, &messages) { + if messages[i].is_leaf(&leaf_messages) { leaf_messages.push(message_names[i].clone()) } else { new_undefs.push(i); @@ -889,7 +912,7 @@ fn break_cycles(messages: &mut [Message], leaf_messages: &mut Vec) { { let mut m = messages[k].clone(); for f in m.fields.iter_mut() { - if !f.is_leaf(&leaf_messages, &messages) { + if !f.is_leaf(&leaf_messages) { f.boxed = true; } } @@ -899,3 +922,7 @@ fn break_cycles(messages: &mut [Message], leaf_messages: &mut Vec) { } } +/// Calculates the tag value +fn tag(number: u32, typ: &FieldType, packed: bool) -> u32 { + number << 3 | typ.wire_type_num(packed) +} diff --git a/examples/codegen/data_types.proto b/examples/codegen/data_types.proto index bd4fb31d..0412799a 100644 --- a/examples/codegen/data_types.proto +++ b/examples/codegen/data_types.proto @@ -33,6 +33,7 @@ message FooMessage { optional a.b.ImportedMessage f_imported = 21; optional BazMessage f_baz = 22; optional BazMessage.Nested f_nested = 23; + map f_map = 24; } message BazMessage { diff --git a/examples/codegen/data_types.rs b/examples/codegen/data_types.rs index 9a3c6474..d6b77fa4 100644 --- a/examples/codegen/data_types.rs +++ b/examples/codegen/data_types.rs @@ -6,6 +6,7 @@ use std::io::{Write}; use std::borrow::Cow; +use std::collections::HashMap; use quick_protobuf::{MessageWrite, BytesReader, Writer, Result}; use quick_protobuf::sizeofs::*; @@ -54,11 +55,11 @@ impl BarMessage { impl MessageWrite for BarMessage { fn get_size(&self) -> usize { - 1 + sizeof_int32(self.b_required_int32) + 1 + sizeof_varint(*&self.b_required_int32 as u64) } - fn write_message(&self, r: &mut Writer) -> Result<()> { - r.write_int32_with_tag(8, self.b_required_int32)?; + fn write_message(&self, w: &mut Writer) -> Result<()> { + w.write_with_tag(8, |w| w.write_int32(*&self.b_required_int32))?; Ok(()) } } @@ -88,6 +89,7 @@ pub struct FooMessage<'a> { pub f_imported: Option, pub f_baz: Option, pub f_nested: Option, + pub f_map: HashMap, i32>, } impl<'a> FooMessage<'a> { @@ -109,8 +111,8 @@ impl<'a> FooMessage<'a> { Ok(101) => msg.f_sfixed32 = Some(r.read_sfixed32(bytes)?), Ok(105) => msg.f_double = Some(r.read_double(bytes)?), Ok(117) => msg.f_float = Some(r.read_float(bytes)?), - Ok(122) => msg.f_bytes = Some(Cow::Borrowed(r.read_bytes(bytes)?)), - Ok(130) => msg.f_string = Some(Cow::Borrowed(r.read_string(bytes)?)), + Ok(122) => msg.f_bytes = Some(r.read_bytes(bytes).map(Cow::Borrowed)?), + Ok(130) => msg.f_string = Some(r.read_string(bytes).map(Cow::Borrowed)?), Ok(138) => msg.f_self_message = Some(Box::new(r.read_message(bytes, FooMessage::from_reader)?)), Ok(146) => msg.f_bar_message = Some(r.read_message(bytes, BarMessage::from_reader)?), Ok(152) => msg.f_repeated_int32.push(r.read_int32(bytes)?), @@ -118,6 +120,10 @@ impl<'a> FooMessage<'a> { Ok(170) => msg.f_imported = Some(r.read_message(bytes, mod_a::mod_b::ImportedMessage::from_reader)?), Ok(178) => msg.f_baz = Some(r.read_message(bytes, BazMessage::from_reader)?), Ok(186) => msg.f_nested = Some(r.read_message(bytes, mod_BazMessage::Nested::from_reader)?), + Ok(194) => { + let (key, value) = r.read_map(bytes, |r, bytes| r.read_string(bytes).map(Cow::Borrowed), |r, bytes| r.read_int32(bytes))?; + msg.f_map.insert(key, value); + } Ok(t) => { r.read_unknown(bytes, t)?; } Err(e) => return Err(e), } @@ -128,55 +134,57 @@ impl<'a> FooMessage<'a> { impl<'a> MessageWrite for FooMessage<'a> { fn get_size(&self) -> usize { - self.f_int32.as_ref().map_or(0, |m| 1 + sizeof_int32(*m)) - + self.f_int64.as_ref().map_or(0, |m| 1 + sizeof_int64(*m)) - + self.f_uint32.as_ref().map_or(0, |m| 1 + sizeof_uint32(*m)) - + self.f_uint64.as_ref().map_or(0, |m| 1 + sizeof_uint64(*m)) - + self.f_sint32.as_ref().map_or(0, |m| 1 + sizeof_sint32(*m)) - + self.f_sint64.as_ref().map_or(0, |m| 1 + sizeof_sint64(*m)) - + self.f_bool.as_ref().map_or(0, |m| 1 + sizeof_bool(*m)) - + self.f_FooEnum.as_ref().map_or(0, |m| 1 + sizeof_enum(*m as i32)) + self.f_int32.as_ref().map_or(0, |m| 1 + sizeof_varint(*m as u64)) + + self.f_int64.as_ref().map_or(0, |m| 1 + sizeof_varint(*m as u64)) + + self.f_uint32.as_ref().map_or(0, |m| 1 + sizeof_varint(*m as u64)) + + self.f_uint64.as_ref().map_or(0, |m| 1 + sizeof_varint(*m as u64)) + + self.f_sint32.as_ref().map_or(0, |m| 1 + sizeof_varint(*m as u64)) + + self.f_sint64.as_ref().map_or(0, |m| 1 + sizeof_varint(*m as u64)) + + self.f_bool.as_ref().map_or(0, |m| 1 + sizeof_varint(*m as u64)) + + self.f_FooEnum.as_ref().map_or(0, |m| 1 + sizeof_varint(*m as u64)) + self.f_fixed64.as_ref().map_or(0, |_| 1 + 8) + self.f_sfixed64.as_ref().map_or(0, |_| 1 + 8) + self.f_fixed32.as_ref().map_or(0, |_| 1 + 4) + self.f_sfixed32.as_ref().map_or(0, |_| 1 + 4) + self.f_double.as_ref().map_or(0, |_| 1 + 8) + self.f_float.as_ref().map_or(0, |_| 1 + 4) - + self.f_bytes.as_ref().map_or(0, |m| 1 + sizeof_var_length(m.len())) - + self.f_string.as_ref().map_or(0, |m| 2 + sizeof_var_length(m.len())) - + self.f_self_message.as_ref().map_or(0, |m| 2 + sizeof_var_length(m.get_size())) - + self.f_bar_message.as_ref().map_or(0, |m| 2 + sizeof_var_length(m.get_size())) - + self.f_repeated_int32.iter().map(|s| 2 + sizeof_int32(*s)).sum::() - + if self.f_repeated_packed_int32.is_empty() { 0 } else { 2 + sizeof_var_length(self.f_repeated_packed_int32.iter().map(|s| sizeof_int32(*s)).sum::()) } - + self.f_imported.as_ref().map_or(0, |m| 2 + sizeof_var_length(m.get_size())) - + self.f_baz.as_ref().map_or(0, |m| 2 + sizeof_var_length(m.get_size())) - + self.f_nested.as_ref().map_or(0, |m| 2 + sizeof_var_length(m.get_size())) + + self.f_bytes.as_ref().map_or(0, |m| 1 + sizeof_len(m.len())) + + self.f_string.as_ref().map_or(0, |m| 2 + sizeof_len(m.len())) + + self.f_self_message.as_ref().map_or(0, |m| 2 + sizeof_len(m.get_size())) + + self.f_bar_message.as_ref().map_or(0, |m| 2 + sizeof_len(m.get_size())) + + self.f_repeated_int32.iter().map(|s| 2 + sizeof_varint(*s as u64)).sum::() + + if self.f_repeated_packed_int32.is_empty() { 0 } else { 2 + sizeof_len(self.f_repeated_packed_int32.iter().map(|s| sizeof_varint(*s as u64)).sum::()) } + + self.f_imported.as_ref().map_or(0, |m| 2 + sizeof_len(m.get_size())) + + self.f_baz.as_ref().map_or(0, |m| 2 + sizeof_len(m.get_size())) + + self.f_nested.as_ref().map_or(0, |m| 2 + sizeof_len(m.get_size())) + + self.f_map.iter().map(|(k, v)| 2 + sizeof_len(2 + sizeof_len(k.len()) + sizeof_varint(*v as u64))).sum::() } - fn write_message(&self, r: &mut Writer) -> Result<()> { - if let Some(ref s) = self.f_int32 { r.write_int32_with_tag(8, *s)?; } - if let Some(ref s) = self.f_int64 { r.write_int64_with_tag(16, *s)?; } - if let Some(ref s) = self.f_uint32 { r.write_uint32_with_tag(24, *s)?; } - if let Some(ref s) = self.f_uint64 { r.write_uint64_with_tag(32, *s)?; } - if let Some(ref s) = self.f_sint32 { r.write_sint32_with_tag(40, *s)?; } - if let Some(ref s) = self.f_sint64 { r.write_sint64_with_tag(48, *s)?; } - if let Some(ref s) = self.f_bool { r.write_bool_with_tag(56, *s)?; } - if let Some(ref s) = self.f_FooEnum { r.write_enum_with_tag(64, *s as i32)?; } - if let Some(ref s) = self.f_fixed64 { r.write_fixed64_with_tag(73, *s)?; } - if let Some(ref s) = self.f_sfixed64 { r.write_sfixed64_with_tag(81, *s)?; } - if let Some(ref s) = self.f_fixed32 { r.write_fixed32_with_tag(93, *s)?; } - if let Some(ref s) = self.f_sfixed32 { r.write_sfixed32_with_tag(101, *s)?; } - if let Some(ref s) = self.f_double { r.write_double_with_tag(105, *s)?; } - if let Some(ref s) = self.f_float { r.write_float_with_tag(117, *s)?; } - if let Some(ref s) = self.f_bytes { r.write_bytes_with_tag(122, s)?; } - if let Some(ref s) = self.f_string { r.write_string_with_tag(130, s)?; } - if let Some(ref s) = self.f_self_message { r.write_message_with_tag(138, &**s)?; } - if let Some(ref s) = self.f_bar_message { r.write_message_with_tag(146, s)?; } - for s in &self.f_repeated_int32 { r.write_int32_with_tag(152, *s)? } - r.write_packed_repeated_field_with_tag(162, &self.f_repeated_packed_int32, |r, m| r.write_int32(*m), &|m| sizeof_int32(*m))?; - if let Some(ref s) = self.f_imported { r.write_message_with_tag(170, s)?; } - if let Some(ref s) = self.f_baz { r.write_message_with_tag(178, s)?; } - if let Some(ref s) = self.f_nested { r.write_message_with_tag(186, s)?; } + fn write_message(&self, w: &mut Writer) -> Result<()> { + if let Some(ref s) = self.f_int32 { w.write_with_tag(8, |w| w.write_int32(*s))?; } + if let Some(ref s) = self.f_int64 { w.write_with_tag(16, |w| w.write_int64(*s))?; } + if let Some(ref s) = self.f_uint32 { w.write_with_tag(24, |w| w.write_uint32(*s))?; } + if let Some(ref s) = self.f_uint64 { w.write_with_tag(32, |w| w.write_uint64(*s))?; } + if let Some(ref s) = self.f_sint32 { w.write_with_tag(40, |w| w.write_sint32(*s))?; } + if let Some(ref s) = self.f_sint64 { w.write_with_tag(48, |w| w.write_sint64(*s))?; } + if let Some(ref s) = self.f_bool { w.write_with_tag(56, |w| w.write_bool(*s))?; } + if let Some(ref s) = self.f_FooEnum { w.write_with_tag(64, |w| w.write_enum(*s as i32))?; } + if let Some(ref s) = self.f_fixed64 { w.write_with_tag(73, |w| w.write_fixed64(*s))?; } + if let Some(ref s) = self.f_sfixed64 { w.write_with_tag(81, |w| w.write_sfixed64(*s))?; } + if let Some(ref s) = self.f_fixed32 { w.write_with_tag(93, |w| w.write_fixed32(*s))?; } + if let Some(ref s) = self.f_sfixed32 { w.write_with_tag(101, |w| w.write_sfixed32(*s))?; } + if let Some(ref s) = self.f_double { w.write_with_tag(105, |w| w.write_double(*s))?; } + if let Some(ref s) = self.f_float { w.write_with_tag(117, |w| w.write_float(*s))?; } + if let Some(ref s) = self.f_bytes { w.write_with_tag(122, |w| w.write_bytes(&**s))?; } + if let Some(ref s) = self.f_string { w.write_with_tag(130, |w| w.write_string(&**s))?; } + if let Some(ref s) = self.f_self_message { w.write_with_tag(138, |w| w.write_message(&**s))?; } + if let Some(ref s) = self.f_bar_message { w.write_with_tag(146, |w| w.write_message(s))?; } + for s in &self.f_repeated_int32 { w.write_with_tag(152, |w| w.write_int32(*s))?; } + w.write_packed_with_tag(162, &self.f_repeated_packed_int32, |w, m| w.write_int32(*m), &|m| sizeof_varint(*m as u64))?; + if let Some(ref s) = self.f_imported { w.write_with_tag(170, |w| w.write_message(s))?; } + if let Some(ref s) = self.f_baz { w.write_with_tag(178, |w| w.write_message(s))?; } + if let Some(ref s) = self.f_nested { w.write_with_tag(186, |w| w.write_message(s))?; } + for (k, v) in self.f_map.iter() { w.write_with_tag(194, |w| w.write_map(2 + sizeof_len(k.len()) + sizeof_varint(*v as u64), 10, |w| w.write_string(&**k), 16, |w| w.write_int32(*v)))?; } Ok(()) } } @@ -202,11 +210,11 @@ impl BazMessage { impl MessageWrite for BazMessage { fn get_size(&self) -> usize { - self.nested.as_ref().map_or(0, |m| 1 + sizeof_var_length(m.get_size())) + self.nested.as_ref().map_or(0, |m| 1 + sizeof_len(m.get_size())) } - fn write_message(&self, r: &mut Writer) -> Result<()> { - if let Some(ref s) = self.nested { r.write_message_with_tag(10, s)?; } + fn write_message(&self, w: &mut Writer) -> Result<()> { + if let Some(ref s) = self.nested { w.write_with_tag(10, |w| w.write_message(s))?; } Ok(()) } } @@ -236,11 +244,11 @@ impl Nested { impl MessageWrite for Nested { fn get_size(&self) -> usize { - 1 + sizeof_int32(self.f_nested) + 1 + sizeof_varint(*&self.f_nested as u64) } - fn write_message(&self, r: &mut Writer) -> Result<()> { - r.write_int32_with_tag(8, self.f_nested)?; + fn write_message(&self, w: &mut Writer) -> Result<()> { + w.write_with_tag(8, |w| w.write_int32(*&self.f_nested))?; Ok(()) } } diff --git a/examples/codegen/data_types_import.rs b/examples/codegen/data_types_import.rs index db45eaef..e35dbb9d 100644 --- a/examples/codegen/data_types_import.rs +++ b/examples/codegen/data_types_import.rs @@ -32,11 +32,11 @@ impl ImportedMessage { impl MessageWrite for ImportedMessage { fn get_size(&self) -> usize { - self.i.as_ref().map_or(0, |m| 1 + sizeof_bool(*m)) + self.i.as_ref().map_or(0, |m| 1 + sizeof_varint(*m as u64)) } - fn write_message(&self, r: &mut Writer) -> Result<()> { - if let Some(ref s) = self.i { r.write_bool_with_tag(8, *s)?; } + fn write_message(&self, w: &mut Writer) -> Result<()> { + if let Some(ref s) = self.i { w.write_with_tag(8, |w| w.write_bool(*s))?; } Ok(()) } } diff --git a/examples/codegen_example.rs b/examples/codegen_example.rs index f5fbb539..2c33f734 100644 --- a/examples/codegen_example.rs +++ b/examples/codegen_example.rs @@ -34,6 +34,9 @@ fn main() { // nested messages are encapsulated into a rust module mod_Message f_nested: Some(data_types::mod_BazMessage::Nested { f_nested: 2 }), + // a map! + f_map: vec![(Cow::Borrowed("foo"), 1), (Cow::Borrowed("bar"), 2)].into_iter().collect(), + // Each message implements Default ... which makes it much easier ..FooMessage::default() }; diff --git a/src/errors.rs b/src/errors.rs index 4517581a..c77f9d28 100644 --- a/src/errors.rs +++ b/src/errors.rs @@ -28,5 +28,9 @@ error_chain! { display("error while parsing message: {}", s) } + Map(tag: u8) { + description("unexpected map tag") + display("expecting a tag number 1 or 2, got {}", tag) + } } } diff --git a/src/reader.rs b/src/reader.rs index fdfc83cf..ec2367cf 100644 --- a/src/reader.rs +++ b/src/reader.rs @@ -319,9 +319,35 @@ impl BytesReader { self.read_len(bytes, read) } + /// Reads a map item: (key, value) + pub fn read_map<'a, K, V, F, G>(&mut self, + bytes: &'a[u8], + mut read_key: F, + mut read_val: G) -> Result<(K, V)> + where F: FnMut(&mut BytesReader, &'a[u8]) -> Result, + G: FnMut(&mut BytesReader, &'a[u8]) -> Result, + K: ::std::fmt::Debug + Default, + V: ::std::fmt::Debug + Default, + { + self.read_len(bytes, |r, bytes| { + let mut k = K::default(); + let mut v = V::default(); + while !r.is_eof() { + let t = r.read_u8(bytes); + match t >> 3 { + 1 => k = read_key(r, bytes)?, + 2 => v = read_val(r, bytes)?, + t => return Err(ErrorKind::Map(t).into()), + } + } + Ok((k, v)) + }) + } + /// Reads unknown data, based on its tag value (which itself gives us the wire_type value) #[inline] pub fn read_unknown(&mut self, bytes: &[u8], tag_value: u32) -> Result<()> { + println!("reading unknown {}", tag_value); match (tag_value & 0x7) as u8 { WIRE_TYPE_VARINT => { self.read_varint64(bytes)?; }, WIRE_TYPE_FIXED64 => self.start += 8, diff --git a/src/sizeofs.rs b/src/sizeofs.rs index b2137994..98ef7df3 100644 --- a/src/sizeofs.rs +++ b/src/sizeofs.rs @@ -25,7 +25,7 @@ pub fn sizeof_varint(v: u64) -> usize { /// /// The total size is the varint encoded length size plus the length itself /// https://developers.google.com/protocol-buffers/docs/encoding -pub fn sizeof_var_length(len: usize) -> usize { +pub fn sizeof_len(len: usize) -> usize { sizeof_varint(len as u64) + len } diff --git a/src/writer.rs b/src/writer.rs index fb6e8b2e..a9eb7980 100644 --- a/src/writer.rs +++ b/src/writer.rs @@ -159,7 +159,7 @@ impl Writer { } /// Writes packed repeated field: length first then the chunk of data - pub fn write_packed_repeated_field(&mut self, v: &[M], mut write: F, size: &S) -> Result<()> + pub fn write_packed(&mut self, v: &[M], mut write: F, size: &S) -> Result<()> where F: FnMut(&mut Self, &M) -> Result<()>, S: Fn(&M) -> usize, { @@ -179,8 +179,8 @@ impl Writer { /// `item_size` is internally used to compute the total length /// As the length is fixed (and the same as rust internal representation, we can directly dump /// all data at once - pub fn write_packed_fixed_size(&mut self, v: &[M], item_size: usize) -> Result<()> { - let len = v.len() * item_size; + pub fn write_packed_fixed_size(&mut self, v: &[M]) -> Result<()> { + let len = v.len() * ::std::mem::size_of::(); let bytes = unsafe { ::std::slice::from_raw_parts(v as *const [M] as *const M as *const u8, len) }; self.write_bytes(bytes) } @@ -192,105 +192,22 @@ impl Writer { m.write_message(self) } - /// Writes tag then `int32` - pub fn write_int32_with_tag(&mut self, tag: u32, v: i32) -> Result<()> { - self.write_tag(tag)?; - self.write_varint(v as u64) - } - - /// Writes tag then `int64` - pub fn write_int64_with_tag(&mut self, tag: u32, v: i64) -> Result<()> { - self.write_tag(tag)?; - self.write_varint(v as u64) - } - - /// Writes tag then `uint32` - pub fn write_uint32_with_tag(&mut self, tag: u32, v: u32) -> Result<()> { - self.write_tag(tag)?; - self.write_varint(v as u64) - } - - /// Writes tag then `uint64` - pub fn write_uint64_with_tag(&mut self, tag: u32, v: u64) -> Result<()> { - self.write_tag(tag)?; - self.write_varint(v) - } - - /// Writes tag then `sint32` - pub fn write_sint32_with_tag(&mut self, tag: u32, v: i32) -> Result<()> { - self.write_tag(tag)?; - self.write_sint32(v) - } - - /// Writes tag then `sint64` - pub fn write_sint64_with_tag(&mut self, tag: u32, v: i64) -> Result<()> { - self.write_tag(tag)?; - self.write_sint64(v) - } - - /// Writes tag then `fixed64` - pub fn write_fixed64_with_tag(&mut self, tag: u32, v: u64) -> Result<()> { - self.write_tag(tag)?; - self.inner.write_u64::(v).map_err(|e| e.into()) - } - - /// Writes tag then `fixed32` - pub fn write_fixed32_with_tag(&mut self, tag: u32, v: u32) -> Result<()> { - self.write_tag(tag)?; - self.inner.write_u32::(v).map_err(|e| e.into()) - } - - /// Writes tag then `sfixed64` - pub fn write_sfixed64_with_tag(&mut self, tag: u32, v: i64) -> Result<()> { - self.write_tag(tag)?; - self.inner.write_i64::(v).map_err(|e| e.into()) - } - - /// Writes tag then `sfixed32` - pub fn write_sfixed32_with_tag(&mut self, tag: u32, v: i32) -> Result<()> { - self.write_tag(tag)?; - self.inner.write_i32::(v).map_err(|e| e.into()) - } - - /// Writes tag then `float` - pub fn write_float_with_tag(&mut self, tag: u32, v: f32) -> Result<()> { - self.write_tag(tag)?; - self.inner.write_f32::(v).map_err(|e| e.into()) - } - - /// Writes tag then `double` - pub fn write_double_with_tag(&mut self, tag: u32, v: f64) -> Result<()> { - self.write_tag(tag)?; - self.inner.write_f64::(v).map_err(|e| e.into()) - } - - /// Writes tag then `bool` - pub fn write_bool_with_tag(&mut self, tag: u32, v: bool) -> Result<()> { - self.write_tag(tag)?; - self.write_varint(if v { 1 } else { 0 }) - } - - /// Writes tag then `bytes` - pub fn write_bytes_with_tag(&mut self, tag: u32, bytes: &[u8]) -> Result<()> { - self.write_tag(tag)?; - self.write_varint(bytes.len() as u64)?; - self.inner.write_all(bytes).map_err(|e| e.into()) - } - - /// Writes tag then `string` - pub fn write_string_with_tag(&mut self, tag: u32, s: &str) -> Result<()> { + /// Writes another item prefixed with tag + pub fn write_with_tag(&mut self, tag: u32, mut write: F) -> Result<()> + where F: FnMut(&mut Self) -> Result<()> + { self.write_tag(tag)?; - self.write_bytes(s.as_bytes()) + write(self) } /// Writes tag then repeated field /// /// If array is empty, then do nothing (do not even write the tag) - pub fn write_packed_repeated_field_with_tag(&mut self, - tag: u32, - v: &[M], - mut write: F, - size: &S) -> Result<()> + pub fn write_packed_with_tag(&mut self, + tag: u32, + v: &[M], + mut write: F, + size: &S) -> Result<()> where F: FnMut(&mut Self, &M) -> Result<()>, S: Fn(&M) -> usize, { @@ -323,15 +240,17 @@ impl Writer { self.write_bytes(bytes) } - /// Writes tag then message - pub fn write_message_with_tag(&mut self, tag: u32, m: &M) -> Result<()> { - self.write_tag(tag)?; - self.write_message(m) - } - - /// Writes tag then enum - pub fn write_enum_with_tag(&mut self, tag: u32, v: i32) -> Result<()> { - self.write_tag(tag)?; - self.write_int32(v) + /// Write entire map + pub fn write_map(&mut self, size: usize, + tag_key: u32, mut write_key: FK, + tag_val: u32, mut write_val: FV) -> Result<()> + where FK: FnMut(&mut Self) -> Result<()>, + FV: FnMut(&mut Self) -> Result<()>, + { + self.write_varint(size as u64)?; + self.write_tag(tag_key)?; + write_key(self)?; + self.write_tag(tag_val)?; + write_val(self) } } diff --git a/tests/write_read.rs b/tests/write_read.rs index 352c830a..40302643 100644 --- a/tests/write_read.rs +++ b/tests/write_read.rs @@ -1,5 +1,7 @@ extern crate quick_protobuf; +use std::collections::HashMap; +use std::borrow::Cow; use std::io::{Write}; use quick_protobuf::{BytesReader, Writer, MessageWrite, Result}; use quick_protobuf::sizeofs::*; @@ -118,8 +120,8 @@ impl MessageWrite for TestMessage { } fn write_message(&self, r: &mut Writer) -> Result<()> { - if let Some(ref s) = self.id { r.write_uint32_with_tag(10, *s)?; } - for s in &self.val { r.write_sint64_with_tag(18, *s)?; } + if let Some(ref s) = self.id { r.write_with_tag(10, |r| r.write_uint32(*s))?; } + for s in &self.val { r.write_with_tag(18, |r| r.write_sint64(*s))?; } Ok(()) } } @@ -168,12 +170,12 @@ impl<'a> TestMessageBorrow<'a> { impl<'a> MessageWrite for TestMessageBorrow<'a> { fn get_size(&self) -> usize { self.id.as_ref().map_or(0, |m| 1 + sizeof_uint32(*m)) - + self.val.iter().map(|m| 1 + sizeof_var_length(m.len())).sum::() + + self.val.iter().map(|m| 1 + sizeof_len(m.len())).sum::() } fn write_message(&self, r: &mut Writer) -> Result<()> { - if let Some(ref s) = self.id { r.write_uint32_with_tag(10, *s)?; } - for s in &self.val { r.write_string_with_tag(18, *s)?; } + if let Some(ref s) = self.id { r.write_with_tag(10, |r| r.write_uint32(*s))?; } + for s in &self.val { r.write_with_tag(18, |r| r.write_string(*s))?; } Ok(()) } } @@ -205,8 +207,32 @@ fn wr_packed_uint32(){ let mut buf = Vec::new(); { let mut w = Writer::new(&mut buf); - w.write_packed_repeated_field(&v, |r, m| r.write_uint32(*m), &|m| sizeof_uint32(*m)).unwrap(); + w.write_packed(&v, |r, m| r.write_uint32(*m), &|m| sizeof_uint32(*m)).unwrap(); } let mut r = BytesReader::from_bytes(&buf); assert_eq!(v, r.read_packed(&buf, |r, b| r.read_uint32(b)).unwrap()); } + +#[test] +fn wr_map(){ + let v = { + let mut v = HashMap::new(); + v.insert(Cow::Borrowed("foo"), 1i32); + v.insert(Cow::Borrowed("bar"), 2); + v + }; + let mut buf = Vec::new(); + { + let mut w = Writer::new(&mut buf); + for (k, v) in v.iter() { + w.write_map(2 + sizeof_len(k.len()) + sizeof_varint(*v as u64), 10, |w| w.write_string(&**k), 16, |w| w.write_int32(*v)).unwrap(); + } + } + let mut r = BytesReader::from_bytes(&buf); + let mut read_back = HashMap::new(); + while !r.is_eof() { + let (key, value) = r.read_map(&buf, |r, bytes| r.read_string(bytes).map(Cow::Borrowed), |r, bytes| r.read_int32(bytes)).unwrap(); + read_back.insert(key, value); + } + assert_eq!(v, read_back); +}