Skip to content

Commit

Permalink
Add smallstring support with compact_str
Browse files Browse the repository at this point in the history
  • Loading branch information
DomWilliamsEE committed May 7, 2024
1 parent baddf98 commit 8a0dfcd
Show file tree
Hide file tree
Showing 4 changed files with 105 additions and 2 deletions.
1 change: 1 addition & 0 deletions prost-derive/src/field/map.rs
Original file line number Diff line number Diff line change
Expand Up @@ -367,6 +367,7 @@ fn key_ty_from_str(s: &str) -> Result<scalar::Ty, Error> {
| scalar::Ty::Sfixed32
| scalar::Ty::Sfixed64
| scalar::Ty::Bool
| scalar::Ty::SmallString
| scalar::Ty::String => Ok(ty),
_ => bail!("invalid map key type: {}", s),
}
Expand Down
34 changes: 32 additions & 2 deletions prost-derive/src/field/scalar.rs
Original file line number Diff line number Diff line change
Expand Up @@ -115,6 +115,13 @@ impl Field {
let tag = self.tag;

match self.kind {
Kind::Plain(DefaultValue::SmallString) => {
quote! {
if !#ident.is_empty(){
#encode_fn(#tag, &#ident, buf);
}
}
}
Kind::Plain(ref default) => {
let default = default.typed();
quote! {
Expand Down Expand Up @@ -169,6 +176,15 @@ impl Field {
let tag = self.tag;

match self.kind {
Kind::Plain(DefaultValue::SmallString) => {
quote! {
if !#ident.is_empty() {
#encoded_len_fn(#tag, &#ident)
} else {
0
}
}
}
Kind::Plain(ref default) => {
let default = default.typed();
quote! {
Expand All @@ -193,7 +209,7 @@ impl Field {
Kind::Plain(ref default) | Kind::Required(ref default) => {
let default = default.typed();
match self.ty {
Ty::String | Ty::Bytes(..) => quote!(#ident.clear()),
Ty::String | Ty::SmallString | Ty::Bytes(..) => quote!(#ident.clear()),
_ => quote!(#ident = #default),
}
}
Expand Down Expand Up @@ -397,6 +413,7 @@ pub enum Ty {
Sfixed64,
Bool,
String,
SmallString,
Bytes(BytesTy),
Enumeration(Path),
}
Expand Down Expand Up @@ -441,6 +458,7 @@ impl Ty {
Meta::Path(ref name) if name.is_ident("sfixed64") => Ty::Sfixed64,
Meta::Path(ref name) if name.is_ident("bool") => Ty::Bool,
Meta::Path(ref name) if name.is_ident("string") => Ty::String,
Meta::Path(ref name) if name.is_ident("smallstring") => Ty::SmallString,
Meta::Path(ref name) if name.is_ident("bytes") => Ty::Bytes(BytesTy::Vec),
Meta::NameValue(MetaNameValue {
ref path,
Expand Down Expand Up @@ -486,6 +504,7 @@ impl Ty {
"sfixed64" => Ty::Sfixed64,
"bool" => Ty::Bool,
"string" => Ty::String,
"smallstring" => Ty::SmallString,
"bytes" => Ty::Bytes(BytesTy::Vec),
s if s.len() > enumeration_len && &s[..enumeration_len] == "enumeration" => {
let s = &s[enumeration_len..].trim();
Expand Down Expand Up @@ -522,6 +541,7 @@ impl Ty {
Ty::Sfixed64 => "sfixed64",
Ty::Bool => "bool",
Ty::String => "string",
Ty::SmallString => "smallstring",
Ty::Bytes(..) => "bytes",
Ty::Enumeration(..) => "enum",
}
Expand All @@ -531,6 +551,7 @@ impl Ty {
pub fn rust_type(&self) -> TokenStream {
match self {
Ty::String => quote!(::prost::alloc::string::String),
Ty::SmallString => quote!(::compact_str::CompactString),
Ty::Bytes(ty) => ty.rust_type(),
_ => self.rust_ref_type(),
}
Expand All @@ -553,6 +574,7 @@ impl Ty {
Ty::Sfixed64 => quote!(i64),
Ty::Bool => quote!(bool),
Ty::String => quote!(&str),
Ty::SmallString => quote!(&str),
Ty::Bytes(..) => quote!(&[u8]),
Ty::Enumeration(..) => quote!(i32),
}
Expand All @@ -567,7 +589,7 @@ impl Ty {

/// Returns false if the scalar type is length delimited (i.e., `string` or `bytes`).
pub fn is_numeric(&self) -> bool {
!matches!(self, Ty::String | Ty::Bytes(..))
!matches!(self, Ty::String | Ty::SmallString | Ty::Bytes(..))
}
}

Expand Down Expand Up @@ -609,6 +631,7 @@ pub enum DefaultValue {
U64(u64),
Bool(bool),
String(String),
SmallString,
Bytes(Vec<u8>),
Enumeration(TokenStream),
Path(Path),
Expand Down Expand Up @@ -773,6 +796,7 @@ impl DefaultValue {

Ty::Bool => DefaultValue::Bool(false),
Ty::String => DefaultValue::String(String::new()),
Ty::SmallString => DefaultValue::SmallString,
Ty::Bytes(..) => DefaultValue::Bytes(Vec::new()),
Ty::Enumeration(ref path) => DefaultValue::Enumeration(quote!(#path::default())),
}
Expand All @@ -784,6 +808,9 @@ impl DefaultValue {
quote!(::prost::alloc::string::String::new())
}
DefaultValue::String(ref value) => quote!(#value.into()),
DefaultValue::SmallString => {
quote!(::compact_str::CompactString::default())
}
DefaultValue::Bytes(ref value) if value.is_empty() => {
quote!(::core::default::Default::default())
}
Expand All @@ -799,6 +826,8 @@ impl DefaultValue {
pub fn typed(&self) -> TokenStream {
if let DefaultValue::Enumeration(_) = *self {
quote!(#self as i32)
} else if let DefaultValue::SmallString = *self {
quote!(Default::default())
} else {
quote!(#self)
}
Expand All @@ -816,6 +845,7 @@ impl ToTokens for DefaultValue {
DefaultValue::U64(value) => value.to_tokens(tokens),
DefaultValue::Bool(value) => value.to_tokens(tokens),
DefaultValue::String(ref value) => value.to_tokens(tokens),
DefaultValue::SmallString => "".to_tokens(tokens),
DefaultValue::Bytes(ref value) => {
let byte_str = LitByteStr::new(value, Span::call_site());
tokens.append_all(quote!(#byte_str as &[u8]));
Expand Down
2 changes: 2 additions & 0 deletions prost/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -27,10 +27,12 @@ derive = ["dep:prost-derive"]
prost-derive = ["derive"] # deprecated, please use derive feature instead
no-recursion-limit = []
std = []
smallstring = ["compact_str"]

[dependencies]
bytes = { version = "1", default-features = false }
prost-derive = { version = "0.12.4", path = "../prost-derive", optional = true }
compact_str = { version = "0.7.1", optional = true }

[dev-dependencies]
criterion = { version = "0.4", default-features = false }
Expand Down
70 changes: 70 additions & 0 deletions prost/src/encoding.rs
Original file line number Diff line number Diff line change
Expand Up @@ -873,6 +873,76 @@ pub mod string {
}
}

#[cfg(feature = "smallstring")]
pub mod smallstring {
use super::*;
use core::mem::ManuallyDrop;

pub fn encode<B>(tag: u32, value: &str, buf: &mut B)
where
B: BufMut,
{
encode_key(tag, WireType::LengthDelimited, buf);
encode_varint(value.len() as u64, buf);
buf.put_slice(value.as_bytes());
}
pub fn merge<B>(
wire_type: WireType,
value: &mut compact_str::CompactString,
buf: &mut B,
ctx: DecodeContext,
) -> Result<(), DecodeError>
where
B: Buf,
{
unsafe {
let mut fake_vec = ManuallyDrop::new(Vec::from_raw_parts(
value.as_mut_ptr(),
value.len(),
value.capacity(),
));

bytes::merge_one_copy(wire_type, &mut *fake_vec, buf, ctx)?;
match compact_str::CompactString::from_utf8(&*fake_vec) {
Ok(s) => {
*value = s;
Ok(())
}
Err(_) => Err(DecodeError::new(
"invalid string value: data is not UTF-8 encoded",
)),
}
}
}

length_delimited!(compact_str::CompactString);

#[cfg(test)]
mod test {
use compact_str::{CompactString, ToCompactString};
use proptest::prelude::*;

use super::super::test::{check_collection_type, check_type};
use super::*;

proptest! {
#[test]
fn check(value: String, tag in MIN_TAG..=MAX_TAG) {
let value = value.to_compact_string();
super::test::check_type(value, tag, WireType::LengthDelimited,
encode, merge, |tag, s| encoded_len(tag, &s.to_compact_string()))?;
}
#[test]
fn check_repeated(value: Vec<String>, tag in MIN_TAG..=MAX_TAG) {
let value = value.into_iter().map(CompactString::from).collect();
super::test::check_collection_type(value, tag, WireType::LengthDelimited,
encode_repeated, merge_repeated,
encoded_len_repeated)?;
}
}
}
}

pub trait BytesAdapter: sealed::BytesAdapter {}

mod sealed {
Expand Down

0 comments on commit 8a0dfcd

Please sign in to comment.