From 8a0dfcdad1280984ed040fa8023fd33eee2c2de7 Mon Sep 17 00:00:00 2001 From: Dom Williams Date: Tue, 10 Oct 2023 10:02:56 +0300 Subject: [PATCH] Add smallstring support with compact_str --- prost-derive/src/field/map.rs | 1 + prost-derive/src/field/scalar.rs | 34 +++++++++++++++- prost/Cargo.toml | 2 + prost/src/encoding.rs | 70 ++++++++++++++++++++++++++++++++ 4 files changed, 105 insertions(+), 2 deletions(-) diff --git a/prost-derive/src/field/map.rs b/prost-derive/src/field/map.rs index 4855cc5c6..da01d41ea 100644 --- a/prost-derive/src/field/map.rs +++ b/prost-derive/src/field/map.rs @@ -367,6 +367,7 @@ fn key_ty_from_str(s: &str) -> Result { | scalar::Ty::Sfixed32 | scalar::Ty::Sfixed64 | scalar::Ty::Bool + | scalar::Ty::SmallString | scalar::Ty::String => Ok(ty), _ => bail!("invalid map key type: {}", s), } diff --git a/prost-derive/src/field/scalar.rs b/prost-derive/src/field/scalar.rs index 6be16cd70..e807f4769 100644 --- a/prost-derive/src/field/scalar.rs +++ b/prost-derive/src/field/scalar.rs @@ -115,6 +115,13 @@ impl Field { let tag = self.tag; match self.kind { + Kind::Plain(DefaultValue::SmallString) => { + quote! { + if !#ident.is_empty(){ + #encode_fn(#tag, &#ident, buf); + } + } + } Kind::Plain(ref default) => { let default = default.typed(); quote! { @@ -169,6 +176,15 @@ impl Field { let tag = self.tag; match self.kind { + Kind::Plain(DefaultValue::SmallString) => { + quote! { + if !#ident.is_empty() { + #encoded_len_fn(#tag, &#ident) + } else { + 0 + } + } + } Kind::Plain(ref default) => { let default = default.typed(); quote! { @@ -193,7 +209,7 @@ impl Field { Kind::Plain(ref default) | Kind::Required(ref default) => { let default = default.typed(); match self.ty { - Ty::String | Ty::Bytes(..) => quote!(#ident.clear()), + Ty::String | Ty::SmallString | Ty::Bytes(..) => quote!(#ident.clear()), _ => quote!(#ident = #default), } } @@ -397,6 +413,7 @@ pub enum Ty { Sfixed64, Bool, String, + SmallString, Bytes(BytesTy), Enumeration(Path), } @@ -441,6 +458,7 @@ impl Ty { Meta::Path(ref name) if name.is_ident("sfixed64") => Ty::Sfixed64, Meta::Path(ref name) if name.is_ident("bool") => Ty::Bool, Meta::Path(ref name) if name.is_ident("string") => Ty::String, + Meta::Path(ref name) if name.is_ident("smallstring") => Ty::SmallString, Meta::Path(ref name) if name.is_ident("bytes") => Ty::Bytes(BytesTy::Vec), Meta::NameValue(MetaNameValue { ref path, @@ -486,6 +504,7 @@ impl Ty { "sfixed64" => Ty::Sfixed64, "bool" => Ty::Bool, "string" => Ty::String, + "smallstring" => Ty::SmallString, "bytes" => Ty::Bytes(BytesTy::Vec), s if s.len() > enumeration_len && &s[..enumeration_len] == "enumeration" => { let s = &s[enumeration_len..].trim(); @@ -522,6 +541,7 @@ impl Ty { Ty::Sfixed64 => "sfixed64", Ty::Bool => "bool", Ty::String => "string", + Ty::SmallString => "smallstring", Ty::Bytes(..) => "bytes", Ty::Enumeration(..) => "enum", } @@ -531,6 +551,7 @@ impl Ty { pub fn rust_type(&self) -> TokenStream { match self { Ty::String => quote!(::prost::alloc::string::String), + Ty::SmallString => quote!(::compact_str::CompactString), Ty::Bytes(ty) => ty.rust_type(), _ => self.rust_ref_type(), } @@ -553,6 +574,7 @@ impl Ty { Ty::Sfixed64 => quote!(i64), Ty::Bool => quote!(bool), Ty::String => quote!(&str), + Ty::SmallString => quote!(&str), Ty::Bytes(..) => quote!(&[u8]), Ty::Enumeration(..) => quote!(i32), } @@ -567,7 +589,7 @@ impl Ty { /// Returns false if the scalar type is length delimited (i.e., `string` or `bytes`). pub fn is_numeric(&self) -> bool { - !matches!(self, Ty::String | Ty::Bytes(..)) + !matches!(self, Ty::String | Ty::SmallString | Ty::Bytes(..)) } } @@ -609,6 +631,7 @@ pub enum DefaultValue { U64(u64), Bool(bool), String(String), + SmallString, Bytes(Vec), Enumeration(TokenStream), Path(Path), @@ -773,6 +796,7 @@ impl DefaultValue { Ty::Bool => DefaultValue::Bool(false), Ty::String => DefaultValue::String(String::new()), + Ty::SmallString => DefaultValue::SmallString, Ty::Bytes(..) => DefaultValue::Bytes(Vec::new()), Ty::Enumeration(ref path) => DefaultValue::Enumeration(quote!(#path::default())), } @@ -784,6 +808,9 @@ impl DefaultValue { quote!(::prost::alloc::string::String::new()) } DefaultValue::String(ref value) => quote!(#value.into()), + DefaultValue::SmallString => { + quote!(::compact_str::CompactString::default()) + } DefaultValue::Bytes(ref value) if value.is_empty() => { quote!(::core::default::Default::default()) } @@ -799,6 +826,8 @@ impl DefaultValue { pub fn typed(&self) -> TokenStream { if let DefaultValue::Enumeration(_) = *self { quote!(#self as i32) + } else if let DefaultValue::SmallString = *self { + quote!(Default::default()) } else { quote!(#self) } @@ -816,6 +845,7 @@ impl ToTokens for DefaultValue { DefaultValue::U64(value) => value.to_tokens(tokens), DefaultValue::Bool(value) => value.to_tokens(tokens), DefaultValue::String(ref value) => value.to_tokens(tokens), + DefaultValue::SmallString => "".to_tokens(tokens), DefaultValue::Bytes(ref value) => { let byte_str = LitByteStr::new(value, Span::call_site()); tokens.append_all(quote!(#byte_str as &[u8])); diff --git a/prost/Cargo.toml b/prost/Cargo.toml index 8dacdd3b3..e8d816c21 100644 --- a/prost/Cargo.toml +++ b/prost/Cargo.toml @@ -27,10 +27,12 @@ derive = ["dep:prost-derive"] prost-derive = ["derive"] # deprecated, please use derive feature instead no-recursion-limit = [] std = [] +smallstring = ["compact_str"] [dependencies] bytes = { version = "1", default-features = false } prost-derive = { version = "0.12.4", path = "../prost-derive", optional = true } +compact_str = { version = "0.7.1", optional = true } [dev-dependencies] criterion = { version = "0.4", default-features = false } diff --git a/prost/src/encoding.rs b/prost/src/encoding.rs index 88d4fe891..78daa2518 100644 --- a/prost/src/encoding.rs +++ b/prost/src/encoding.rs @@ -873,6 +873,76 @@ pub mod string { } } +#[cfg(feature = "smallstring")] +pub mod smallstring { + use super::*; + use core::mem::ManuallyDrop; + + pub fn encode(tag: u32, value: &str, buf: &mut B) + where + B: BufMut, + { + encode_key(tag, WireType::LengthDelimited, buf); + encode_varint(value.len() as u64, buf); + buf.put_slice(value.as_bytes()); + } + pub fn merge( + wire_type: WireType, + value: &mut compact_str::CompactString, + buf: &mut B, + ctx: DecodeContext, + ) -> Result<(), DecodeError> + where + B: Buf, + { + unsafe { + let mut fake_vec = ManuallyDrop::new(Vec::from_raw_parts( + value.as_mut_ptr(), + value.len(), + value.capacity(), + )); + + bytes::merge_one_copy(wire_type, &mut *fake_vec, buf, ctx)?; + match compact_str::CompactString::from_utf8(&*fake_vec) { + Ok(s) => { + *value = s; + Ok(()) + } + Err(_) => Err(DecodeError::new( + "invalid string value: data is not UTF-8 encoded", + )), + } + } + } + + length_delimited!(compact_str::CompactString); + + #[cfg(test)] + mod test { + use compact_str::{CompactString, ToCompactString}; + use proptest::prelude::*; + + use super::super::test::{check_collection_type, check_type}; + use super::*; + + proptest! { + #[test] + fn check(value: String, tag in MIN_TAG..=MAX_TAG) { + let value = value.to_compact_string(); + super::test::check_type(value, tag, WireType::LengthDelimited, + encode, merge, |tag, s| encoded_len(tag, &s.to_compact_string()))?; + } + #[test] + fn check_repeated(value: Vec, tag in MIN_TAG..=MAX_TAG) { + let value = value.into_iter().map(CompactString::from).collect(); + super::test::check_collection_type(value, tag, WireType::LengthDelimited, + encode_repeated, merge_repeated, + encoded_len_repeated)?; + } + } + } +} + pub trait BytesAdapter: sealed::BytesAdapter {} mod sealed {