Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add small string optimisation with compact_str #1051

Open
wants to merge 1 commit into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions prost-derive/src/field/map.rs
Original file line number Diff line number Diff line change
Expand Up @@ -367,6 +367,7 @@ fn key_ty_from_str(s: &str) -> Result<scalar::Ty, Error> {
| scalar::Ty::Sfixed32
| scalar::Ty::Sfixed64
| scalar::Ty::Bool
| scalar::Ty::SmallString
| scalar::Ty::String => Ok(ty),
_ => bail!("invalid map key type: {}", s),
}
Expand Down
34 changes: 32 additions & 2 deletions prost-derive/src/field/scalar.rs
Original file line number Diff line number Diff line change
Expand Up @@ -116,6 +116,13 @@ impl Field {
let tag = self.tag;

match self.kind {
Kind::Plain(DefaultValue::SmallString) => {
quote! {
if !#ident.is_empty(){
#encode_fn(#tag, &#ident, buf);
}
}
}
Kind::Plain(ref default) => {
let default = default.typed();
quote! {
Expand Down Expand Up @@ -170,6 +177,15 @@ impl Field {
let tag = self.tag;

match self.kind {
Kind::Plain(DefaultValue::SmallString) => {
quote! {
if !#ident.is_empty() {
#encoded_len_fn(#tag, &#ident)
} else {
0
}
}
}
Kind::Plain(ref default) => {
let default = default.typed();
quote! {
Expand All @@ -194,7 +210,7 @@ impl Field {
Kind::Plain(ref default) | Kind::Required(ref default) => {
let default = default.typed();
match self.ty {
Ty::String | Ty::Bytes(..) => quote!(#ident.clear()),
Ty::String | Ty::SmallString | Ty::Bytes(..) => quote!(#ident.clear()),
_ => quote!(#ident = #default),
}
}
Expand Down Expand Up @@ -398,6 +414,7 @@ pub enum Ty {
Sfixed64,
Bool,
String,
SmallString,
Bytes(BytesTy),
Enumeration(Path),
}
Expand Down Expand Up @@ -442,6 +459,7 @@ impl Ty {
Meta::Path(ref name) if name.is_ident("sfixed64") => Ty::Sfixed64,
Meta::Path(ref name) if name.is_ident("bool") => Ty::Bool,
Meta::Path(ref name) if name.is_ident("string") => Ty::String,
Meta::Path(ref name) if name.is_ident("smallstring") => Ty::SmallString,
Meta::Path(ref name) if name.is_ident("bytes") => Ty::Bytes(BytesTy::Vec),
Meta::NameValue(MetaNameValue {
ref path,
Expand Down Expand Up @@ -487,6 +505,7 @@ impl Ty {
"sfixed64" => Ty::Sfixed64,
"bool" => Ty::Bool,
"string" => Ty::String,
"smallstring" => Ty::SmallString,
"bytes" => Ty::Bytes(BytesTy::Vec),
s if s.len() > enumeration_len && &s[..enumeration_len] == "enumeration" => {
let s = &s[enumeration_len..].trim();
Expand Down Expand Up @@ -523,6 +542,7 @@ impl Ty {
Ty::Sfixed64 => "sfixed64",
Ty::Bool => "bool",
Ty::String => "string",
Ty::SmallString => "smallstring",
Ty::Bytes(..) => "bytes",
Ty::Enumeration(..) => "enum",
}
Expand All @@ -532,6 +552,7 @@ impl Ty {
pub fn rust_type(&self) -> TokenStream {
match self {
Ty::String => quote!(::prost::alloc::string::String),
Ty::SmallString => quote!(::compact_str::CompactString),
Ty::Bytes(ty) => ty.rust_type(),
_ => self.rust_ref_type(),
}
Expand All @@ -554,6 +575,7 @@ impl Ty {
Ty::Sfixed64 => quote!(i64),
Ty::Bool => quote!(bool),
Ty::String => quote!(&str),
Ty::SmallString => quote!(&str),
Ty::Bytes(..) => quote!(&[u8]),
Ty::Enumeration(..) => quote!(i32),
}
Expand All @@ -568,7 +590,7 @@ impl Ty {

/// Returns false if the scalar type is length delimited (i.e., `string` or `bytes`).
pub fn is_numeric(&self) -> bool {
!matches!(self, Ty::String | Ty::Bytes(..))
!matches!(self, Ty::String | Ty::SmallString | Ty::Bytes(..))
}
}

Expand Down Expand Up @@ -610,6 +632,7 @@ pub enum DefaultValue {
U64(u64),
Bool(bool),
String(String),
SmallString,
Bytes(Vec<u8>),
Enumeration(TokenStream),
Path(Path),
Expand Down Expand Up @@ -774,6 +797,7 @@ impl DefaultValue {

Ty::Bool => DefaultValue::Bool(false),
Ty::String => DefaultValue::String(String::new()),
Ty::SmallString => DefaultValue::SmallString,
Ty::Bytes(..) => DefaultValue::Bytes(Vec::new()),
Ty::Enumeration(ref path) => DefaultValue::Enumeration(quote!(#path::default())),
}
Expand All @@ -785,6 +809,9 @@ impl DefaultValue {
quote!(::prost::alloc::string::String::new())
}
DefaultValue::String(ref value) => quote!(#value.into()),
DefaultValue::SmallString => {
quote!(::compact_str::CompactString::default())
}
DefaultValue::Bytes(ref value) if value.is_empty() => {
quote!(::core::default::Default::default())
}
Expand All @@ -800,6 +827,8 @@ impl DefaultValue {
pub fn typed(&self) -> TokenStream {
if let DefaultValue::Enumeration(_) = *self {
quote!(#self as i32)
} else if let DefaultValue::SmallString = *self {
quote!(Default::default())
} else {
quote!(#self)
}
Expand All @@ -817,6 +846,7 @@ impl ToTokens for DefaultValue {
DefaultValue::U64(value) => value.to_tokens(tokens),
DefaultValue::Bool(value) => value.to_tokens(tokens),
DefaultValue::String(ref value) => value.to_tokens(tokens),
DefaultValue::SmallString => "".to_tokens(tokens),
DefaultValue::Bytes(ref value) => {
let byte_str = LitByteStr::new(value, Span::call_site());
tokens.append_all(quote!(#byte_str as &[u8]));
Expand Down
2 changes: 2 additions & 0 deletions prost/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -21,10 +21,12 @@ derive = ["dep:prost-derive"]
prost-derive = ["derive"] # deprecated, please use derive feature instead
no-recursion-limit = []
std = []
smallstring = ["compact_str"]

[dependencies]
bytes = { version = "1", default-features = false }
prost-derive = { version = "0.12.6", path = "../prost-derive", optional = true }
compact_str = { version = "0.8.0-beta", optional = true }

[dev-dependencies]
criterion = { version = "0.5", default-features = false }
Expand Down
70 changes: 70 additions & 0 deletions prost/src/encoding.rs
Original file line number Diff line number Diff line change
Expand Up @@ -832,6 +832,76 @@ pub mod string {
}
}

#[cfg(feature = "smallstring")]
pub mod smallstring {
use super::*;
use core::mem::ManuallyDrop;

pub fn encode<B>(tag: u32, value: &str, buf: &mut B)
where
B: BufMut,
{
encode_key(tag, WireType::LengthDelimited, buf);
encode_varint(value.len() as u64, buf);
buf.put_slice(value.as_bytes());
}
pub fn merge<B>(
wire_type: WireType,
value: &mut compact_str::CompactString,
buf: &mut B,
ctx: DecodeContext,
) -> Result<(), DecodeError>
where
B: Buf,
{
unsafe {
let mut fake_vec = ManuallyDrop::new(Vec::from_raw_parts(
value.as_mut_ptr(),
value.len(),
value.capacity(),
));

bytes::merge_one_copy(wire_type, &mut *fake_vec, buf, ctx)?;
match compact_str::CompactString::from_utf8(&*fake_vec) {
Ok(s) => {
*value = s;
Ok(())
}
Err(_) => Err(DecodeError::new(
"invalid string value: data is not UTF-8 encoded",
)),
}
}
}

length_delimited!(compact_str::CompactString);

#[cfg(test)]
mod test {
use compact_str::{CompactString, ToCompactString};
use proptest::prelude::*;

use super::super::test::{check_collection_type, check_type};
use super::*;

proptest! {
#[test]
fn check(value: String, tag in MIN_TAG..=MAX_TAG) {
let value = value.to_compact_string();
super::test::check_type(value, tag, WireType::LengthDelimited,
encode, merge, |tag, s| encoded_len(tag, &s.to_compact_string()))?;
}
#[test]
fn check_repeated(value: Vec<String>, tag in MIN_TAG..=MAX_TAG) {
let value = value.into_iter().map(CompactString::from).collect();
super::test::check_collection_type(value, tag, WireType::LengthDelimited,
encode_repeated, merge_repeated,
encoded_len_repeated)?;
}
}
}
}

pub trait BytesAdapter: sealed::BytesAdapter {}

mod sealed {
Expand Down
Loading