From 9a0bc392b3011cabc214f01b41314acab33658c2 Mon Sep 17 00:00:00 2001 From: Alyssa Haroldsen Date: Wed, 23 Aug 2023 11:05:31 -0700 Subject: [PATCH] Implement v0.6 Optional Bytes This makes a few changes: - It changes generated messages to reference message innards as a type in `__runtime` instead of branching on what fields should be there. That results in much less bifurcation in gencode and lets runtime-agnostic code reference raw message innards. - It adds a generic mechanism for creating vtable-based mutators. These vtables point to thunks generated for interacting with C++ or upb fields. Right now, the design results in 2-word (msg+vtable) mutators for C++ and 3-word mutators (msg+arena+vtable) for UPB. See upb.rs for an explanation of the design options. I chose the `RawMessage+&Arena` design for mutator data as opposed to a `&MessageInner` design because it did not result in extra-indirection layout changes for message mutators. We could revisit this in the future with performance data, since this results in all field mutators being 3 words large instead of the register-friendly 2 words. - And lastly, as a nearby change that touches on many of the same topics, it adds some extra SAFETY comments for Send/Sync in message gencode. PiperOrigin-RevId: 559483437 --- rust/BUILD | 2 + rust/cpp.rs | 43 +- rust/internal.rs | 3 + rust/shared.rs | 1 + rust/string.rs | 145 ++++--- rust/test/BUILD | 64 +++ rust/test/cpp/interop/main.rs | 6 +- rust/test/shared/BUILD | 34 +- rust/test/shared/accessors_proto3_test.rs | 67 +++- rust/test/shared/accessors_test.rs | 99 ++++- rust/test/shared/serialization_test.rs | 2 +- rust/upb.rs | 68 +++- rust/vtable.rs | 378 ++++++++++++++++++ .../compiler/rust/accessors/singular_bytes.cc | 123 ++++-- .../rust/accessors/singular_message.cc | 2 +- .../rust/accessors/singular_scalar.cc | 10 +- src/google/protobuf/compiler/rust/message.cc | 79 ++-- 17 files changed, 969 insertions(+), 157 deletions(-) create mode 100644 rust/vtable.rs diff --git a/rust/BUILD b/rust/BUILD index 889e00ee54d0..4feba32638d8 100644 --- a/rust/BUILD +++ b/rust/BUILD @@ -57,6 +57,7 @@ rust_library( "shared.rs", "string.rs", "upb.rs", + "vtable.rs", ], crate_root = "shared.rs", rustc_flags = ["--cfg=upb_kernel"], @@ -92,6 +93,7 @@ rust_library( "proxied.rs", "shared.rs", "string.rs", + "vtable.rs", ], crate_root = "shared.rs", rustc_flags = ["--cfg=cpp_kernel"], diff --git a/rust/cpp.rs b/rust/cpp.rs index 46cead7cf71c..58939a3970c5 100644 --- a/rust/cpp.rs +++ b/rust/cpp.rs @@ -30,7 +30,7 @@ // Rust Protobuf runtime using the C++ kernel. -use crate::__internal::RawArena; +use crate::__internal::{Private, RawArena, RawMessage}; use std::alloc::Layout; use std::cell::UnsafeCell; use std::fmt; @@ -156,6 +156,47 @@ impl fmt::Debug for SerializedData { } } +pub type BytesPresentMutData<'msg> = crate::vtable::RawVTableOptionalMutatorData<'msg, [u8]>; +pub type BytesAbsentMutData<'msg> = crate::vtable::RawVTableOptionalMutatorData<'msg, [u8]>; +pub type InnerBytesMut<'msg> = crate::vtable::RawVTableMutator<'msg, [u8]>; + +/// The raw contents of every generated message. +#[derive(Debug)] +pub struct MessageInner { + pub msg: RawMessage, +} + +/// Mutators that point to their original message use this to do so. +/// +/// Since C++ messages manage their own memory, this can just copy the +/// `RawMessage` instead of referencing an arena like UPB must. +/// +/// Note: even though this type is `Copy`, it should only be copied by +/// protobuf internals that can maintain mutation invariants. +#[derive(Clone, Copy, Debug)] +pub struct MutatorMessageRef<'msg> { + msg: RawMessage, + _phantom: PhantomData<&'msg mut ()>, +} +impl<'msg> MutatorMessageRef<'msg> { + #[allow(clippy::needless_pass_by_ref_mut)] // Sound construction requires mutable access. + pub fn new(_private: Private, msg: &'msg mut MessageInner) -> Self { + MutatorMessageRef { msg: msg.msg, _phantom: PhantomData } + } + + pub fn msg(&self) -> RawMessage { + self.msg + } +} + +pub fn copy_bytes_in_arena_if_needed_by_runtime<'a>( + _msg_ref: MutatorMessageRef<'a>, + val: &'a [u8], +) -> &'a [u8] { + // Nothing to do, the message manages its own string memory for C++. + val +} + #[cfg(test)] mod tests { use super::*; diff --git a/rust/internal.rs b/rust/internal.rs index 7423a3cb812a..c1d8dca3f6de 100644 --- a/rust/internal.rs +++ b/rust/internal.rs @@ -32,6 +32,9 @@ //! exposed to through the `protobuf` path but must be public for use by //! generated code. +pub use crate::vtable::{ + new_vtable_field_entry, BytesMutVTable, BytesOptionalMutVTable, RawVTableMutator, +}; use std::slice; /// Used to protect internal-only items from being used accidentally. diff --git a/rust/shared.rs b/rust/shared.rs index dfee341a5f83..5abd8be06265 100644 --- a/rust/shared.rs +++ b/rust/shared.rs @@ -67,6 +67,7 @@ pub mod __runtime; mod optional; mod proxied; mod string; +mod vtable; /// An error that happened during deserialization. #[derive(Debug, Clone)] diff --git a/rust/string.rs b/rust/string.rs index 143d2849b496..fe788710678b 100644 --- a/rust/string.rs +++ b/rust/string.rs @@ -32,7 +32,8 @@ #![allow(dead_code)] #![allow(unused)] -use crate::__internal::Private; +use crate::__internal::{Private, PtrAndLen, RawMessage}; +use crate::__runtime::{BytesAbsentMutData, BytesPresentMutData, InnerBytesMut}; use crate::{Mut, MutProxy, Proxied, ProxiedWithPresence, SettableValue, View, ViewProxy}; use std::borrow::Cow; use std::cmp::{Eq, Ord, Ordering, PartialEq, PartialOrd}; @@ -43,10 +44,6 @@ use std::iter; use std::ops::{Deref, DerefMut}; use utf8::Utf8Chunks; -/// This type will be replaced by something else in a future revision. -// TODO(b/285309330): remove this and any `impl`s using it. -pub type Todo<'msg> = (std::convert::Infallible, std::marker::PhantomData<&'msg mut ()>); - /// A mutator for `bytes` fields - this type is `protobuf::Mut<'msg, [u8]>`. /// /// This type implements `Deref`, so many operations are @@ -62,9 +59,29 @@ pub type Todo<'msg> = (std::convert::Infallible, std::marker::PhantomData<&'msg /// recommended to instead build a `Vec` or `String` and pass that directly /// to `set`, which will reuse the allocation if supported by the runtime. #[derive(Debug)] -pub struct BytesMut<'msg>(Todo<'msg>); +pub struct BytesMut<'msg> { + inner: InnerBytesMut<'msg>, +} + +// SAFETY: +// - Protobuf Rust messages don't allow shared mutation across threads. +// - Protobuf Rust messages don't share arenas. +// - All access that touches an arena occurs behind a `&mut`. +// - All mutators that store an arena are `!Send`. +unsafe impl Sync for BytesMut<'_> {} impl<'msg> BytesMut<'msg> { + /// Constructs a new `BytesMut` from its internal, runtime-dependent part. + #[doc(hidden)] + pub fn from_inner(_private: Private, inner: InnerBytesMut<'msg>) -> Self { + Self { inner } + } + + /// Gets the current value of the field. + pub fn get(&self) -> &[u8] { + self.as_view() + } + /// Sets the byte string to the given `val`, cloning any borrowed data. /// /// This method accepts both owned and borrowed byte strings; if the runtime @@ -78,7 +95,7 @@ impl<'msg> BytesMut<'msg> { /// /// Has no effect if `new_len` is larger than the current `len`. pub fn truncate(&mut self, new_len: usize) { - todo!("b/285309330") + self.inner.truncate(new_len) } /// Clears the byte string to the empty string. @@ -93,7 +110,7 @@ impl<'msg> BytesMut<'msg> { /// `BytesMut::clear` results in the accessor returning an empty string /// while `FieldEntry::clear` results in the non-empty default. /// - /// However, for a proto3 `bytes` that has implicit presence, there is no + /// However, for a proto3 `bytes` that have implicit presence, there is no /// distinction between these states: unset `bytes` is the same as empty /// `bytes` and the default is always the empty string. /// @@ -117,7 +134,7 @@ impl Deref for BytesMut<'_> { impl AsRef<[u8]> for BytesMut<'_> { fn as_ref(&self) -> &[u8] { - todo!("b/285309330") + unsafe { self.inner.get() } } } @@ -126,45 +143,20 @@ impl Proxied for [u8] { type Mut<'msg> = BytesMut<'msg>; } -impl<'msg> ViewProxy<'msg> for Todo<'msg> { - type Proxied = [u8]; - fn as_view(&self) -> &[u8] { - unreachable!() - } - fn into_view<'shorter>(self) -> &'shorter [u8] - where - 'msg: 'shorter, - { - unreachable!() - } -} - -impl<'msg> MutProxy<'msg> for Todo<'msg> { - fn as_mut(&mut self) -> BytesMut<'msg> { - unreachable!() - } - fn into_mut<'shorter>(self) -> BytesMut<'shorter> - where - 'msg: 'shorter, - { - unreachable!() - } -} - impl ProxiedWithPresence for [u8] { - type PresentMutData<'msg> = Todo<'msg>; - type AbsentMutData<'msg> = Todo<'msg>; + type PresentMutData<'msg> = BytesPresentMutData<'msg>; + type AbsentMutData<'msg> = BytesAbsentMutData<'msg>; fn clear_present_field<'a>( present_mutator: Self::PresentMutData<'a>, ) -> Self::AbsentMutData<'a> { - todo!("b/285309330") + present_mutator.clear() } fn set_absent_to_default<'a>( absent_mutator: Self::AbsentMutData<'a>, ) -> Self::PresentMutData<'a> { - todo!("b/285309330") + absent_mutator.set_absent_to_default() } } @@ -194,48 +186,89 @@ impl<'msg> ViewProxy<'msg> for BytesMut<'msg> { where 'msg: 'shorter, { - todo!("b/285309330") + self.inner.get() } } impl<'msg> MutProxy<'msg> for BytesMut<'msg> { fn as_mut(&mut self) -> BytesMut<'_> { - todo!("b/285309330") + BytesMut { inner: self.inner } } fn into_mut<'shorter>(self) -> BytesMut<'shorter> where 'msg: 'shorter, { - todo!("b/285309330") + BytesMut { inner: self.inner } } } -impl SettableValue<[u8]> for &'_ [u8] { +impl<'bytes> SettableValue<[u8]> for &'bytes [u8] { fn set_on(self, _private: Private, mutator: BytesMut<'_>) { - todo!("b/285309330") + // SAFETY: this is a `bytes` field with no restriction on UTF-8. + unsafe { mutator.inner.set(self) } + } + + fn set_on_absent( + self, + _private: Private, + absent_mutator: <[u8] as ProxiedWithPresence>::AbsentMutData<'_>, + ) -> <[u8] as ProxiedWithPresence>::PresentMutData<'_> { + // SAFETY: this is a `bytes` field with no restriction on UTF-8. + unsafe { absent_mutator.set(self) } + } + + fn set_on_present( + self, + _private: Private, + present_mutator: <[u8] as ProxiedWithPresence>::PresentMutData<'_>, + ) { + // SAFETY: this is a `bytes` field with no restriction on UTF-8. + unsafe { + present_mutator.set(self); + } } } -impl SettableValue<[u8]> for &'_ [u8; N] { - fn set_on(self, _private: Private, mutator: BytesMut<'_>) { - self[..].set_on(Private, mutator) - } +macro_rules! impl_forwarding_settable_value { + ($proxied:ty, $self:ident => $self_forwarding_expr:expr) => { + fn set_on($self, _private: Private, mutator: BytesMut<'_>) { + ($self_forwarding_expr).set_on(Private, mutator) + } + + fn set_on_absent( + $self, + _private: Private, + absent_mutator: <$proxied as ProxiedWithPresence>::AbsentMutData<'_>, + ) -> <$proxied as ProxiedWithPresence>::PresentMutData<'_> { + ($self_forwarding_expr).set_on_absent(Private, absent_mutator) + } + + fn set_on_present( + $self, + _private: Private, + present_mutator: <$proxied as ProxiedWithPresence>::PresentMutData<'_>, + ) { + ($self_forwarding_expr).set_on_present(Private, present_mutator) + } + }; +} + +impl<'a, const N: usize> SettableValue<[u8]> for &'a [u8; N] { + // forward to `self[..]` + impl_forwarding_settable_value!([u8], self => &self[..]); } impl SettableValue<[u8]> for Vec { - fn set_on(self, _private: Private, mutator: BytesMut<'_>) { - todo!("b/285309330") - } + // TODO(b/293956360): Investigate taking ownership of this when allowed by the + // runtime. + impl_forwarding_settable_value!([u8], self => &self[..]); } impl SettableValue<[u8]> for Cow<'_, [u8]> { - fn set_on(self, _private: Private, mutator: BytesMut<'_>) { - match self { - Cow::Borrowed(s) => s.set_on(Private, mutator), - Cow::Owned(v) => v.set_on(Private, mutator), - } - } + // TODO(b/293956360): Investigate taking ownership of this when allowed by the + // runtime. + impl_forwarding_settable_value!([u8], self => &self[..]); } impl Hash for BytesMut<'_> { diff --git a/rust/test/BUILD b/rust/test/BUILD index 2ff1bc85437e..84566ffc7252 100644 --- a/rust/test/BUILD +++ b/rust/test/BUILD @@ -10,6 +10,8 @@ UNITTEST_PROTO_TARGET = "//src/google/protobuf:test_protos" UNITTEST_CC_PROTO_TARGET = "//src/google/protobuf:cc_test_protos" UNITTEST_PROTO3_TARGET = "//src/google/protobuf:test_protos" UNITTEST_PROTO3_CC_TARGET = "//src/google/protobuf:cc_test_protos" +UNITTEST_PROTO3_OPTIONAL_TARGET = "//src/google/protobuf:test_protos" +UNITTEST_PROTO3_OPTIONAL_CC_TARGET = "//src/google/protobuf:cc_test_protos" alias( name = "unittest_cc_proto", @@ -46,6 +48,68 @@ rust_cc_proto_library( deps = [UNITTEST_CC_PROTO_TARGET], ) +rust_proto_library( + name = "unittest_proto3_rust_proto", + testonly = True, + visibility = [ + "//rust/test/shared:__subpackages__", + ], + deps = [ + UNITTEST_PROTO3_TARGET, + ], +) + +rust_cc_proto_library( + name = "unittest_proto3_cc_rust_proto", + testonly = True, + visibility = [ + "//rust/test/cpp:__subpackages__", + "//rust/test/shared:__subpackages__", + ], + deps = [UNITTEST_PROTO3_CC_TARGET], +) + +rust_upb_proto_library( + name = "unittest_proto3_upb_rust_proto", + testonly = True, + visibility = [ + "//rust/test/cpp:__subpackages__", + "//rust/test/shared:__subpackages__", + ], + deps = [UNITTEST_PROTO3_TARGET], +) + +rust_proto_library( + name = "unittest_proto3_optional_rust_proto", + testonly = True, + visibility = [ + "//rust/test/shared:__subpackages__", + ], + deps = [ + UNITTEST_PROTO3_OPTIONAL_TARGET, + ], +) + +rust_cc_proto_library( + name = "unittest_proto3_optional_cc_rust_proto", + testonly = True, + visibility = [ + "//rust/test/cpp:__subpackages__", + "//rust/test/shared:__subpackages__", + ], + deps = [UNITTEST_PROTO3_OPTIONAL_CC_TARGET], +) + +rust_upb_proto_library( + name = "unittest_proto3_optional_upb_rust_proto", + testonly = True, + visibility = [ + "//rust/test/cpp:__subpackages__", + "//rust/test/shared:__subpackages__", + ], + deps = [UNITTEST_PROTO3_OPTIONAL_TARGET], +) + proto_library( name = "parent_proto", srcs = ["parent.proto"], diff --git a/rust/test/cpp/interop/main.rs b/rust/test/cpp/interop/main.rs index 31584303fbf8..1ab03f5060c3 100644 --- a/rust/test/cpp/interop/main.rs +++ b/rust/test/cpp/interop/main.rs @@ -64,7 +64,7 @@ fn mutate_message_in_cpp() { let mut msg2 = TestAllTypes::new(); msg2.optional_int64_set(Some(42)); - msg2.optional_bytes_set(Some(b"something mysterious")); + msg2.optional_bytes_mut().set(b"something mysterious"); msg2.optional_bool_set(Some(false)); proto_assert_eq!(msg1, msg2); @@ -74,7 +74,7 @@ fn mutate_message_in_cpp() { fn deserialize_in_rust() { let mut msg1 = TestAllTypes::new(); msg1.optional_int64_set(Some(-1)); - msg1.optional_bytes_set(Some(b"some cool data I guess")); + msg1.optional_bytes_mut().set(b"some cool data I guess"); let serialized = unsafe { SerializeTestAllTypes(msg1.__unstable_cpp_repr_grant_permission_to_break()) }; @@ -87,7 +87,7 @@ fn deserialize_in_rust() { fn deserialize_in_cpp() { let mut msg1 = TestAllTypes::new(); msg1.optional_int64_set(Some(-1)); - msg1.optional_bytes_set(Some(b"some cool data I guess")); + msg1.optional_bytes_mut().set(b"some cool data I guess"); let data = msg1.serialize(); let msg2 = unsafe { diff --git a/rust/test/shared/BUILD b/rust/test/shared/BUILD index ac1486e27ddd..1ade4ce36466 100644 --- a/rust/test/shared/BUILD +++ b/rust/test/shared/BUILD @@ -87,25 +87,51 @@ rust_test( rust_test( name = "accessors_cpp_test", srcs = ["accessors_test.rs"], - deps = ["//rust/test:unittest_cc_rust_proto"], + aliases = { + "//rust:protobuf_cpp": "protobuf", + }, + deps = [ + "//rust:protobuf_cpp", + "//rust/test:unittest_cc_rust_proto", + ], ) rust_test( name = "accessors_upb_test", srcs = ["accessors_test.rs"], - deps = ["//rust/test:unittest_upb_rust_proto"], + aliases = { + "//rust:protobuf_upb": "protobuf", + }, + deps = [ + "//rust:protobuf_upb", + "//rust/test:unittest_upb_rust_proto", + ], ) rust_test( name = "accessors_proto3_cpp_test", srcs = ["accessors_proto3_test.rs"], - deps = ["//rust/test:unittest_proto3_cc_rust_proto"], + aliases = { + "//rust:protobuf_cpp": "protobuf", + }, + deps = [ + "//rust:protobuf_cpp", + "//rust/test:unittest_proto3_cc_rust_proto", + "//rust/test:unittest_proto3_optional_cc_rust_proto", + ], ) rust_test( name = "accessors_proto3_upb_test", srcs = ["accessors_proto3_test.rs"], - deps = ["//rust/test:unittest_proto3_upb_rust_proto"], + aliases = { + "//rust:protobuf_upb": "protobuf", + }, + deps = [ + "//rust:protobuf_upb", + "//rust/test:unittest_proto3_optional_upb_rust_proto", + "//rust/test:unittest_proto3_upb_rust_proto", + ], ) rust_test( diff --git a/rust/test/shared/accessors_proto3_test.rs b/rust/test/shared/accessors_proto3_test.rs index 9a002b68ebd2..2789f11aef49 100644 --- a/rust/test/shared/accessors_proto3_test.rs +++ b/rust/test/shared/accessors_proto3_test.rs @@ -30,7 +30,9 @@ /// Tests covering accessors for singular bool, int32, int64, and bytes fields /// on proto3. +use protobuf::Optional; use unittest_proto3::proto3_unittest::TestAllTypes; +use unittest_proto3_optional::proto2_unittest::TestProto3Optional; #[test] fn test_fixed32_accessors() { @@ -45,18 +47,75 @@ fn test_fixed32_accessors() { } #[test] -fn test_optional_bytes_accessors() { +fn test_bytes_accessors() { let mut msg = TestAllTypes::new(); // Note: even though its named 'optional_bytes' the field is actually not proto3 // optional, so it does not support presence. assert_eq!(msg.optional_bytes(), b""); + assert_eq!(msg.optional_bytes_mut().get(), b""); - msg.optional_bytes_set(Some(b"accessors_test")); + msg.optional_bytes_mut().set(b"accessors_test"); assert_eq!(msg.optional_bytes(), b"accessors_test"); + assert_eq!(msg.optional_bytes_mut().get(), b"accessors_test"); + + { + let s = Vec::from(&b"hello world"[..]); + msg.optional_bytes_mut().set(&s[..]); + } + assert_eq!(msg.optional_bytes(), b"hello world"); + assert_eq!(msg.optional_bytes_mut().get(), b"hello world"); - msg.optional_bytes_set(None); + msg.optional_bytes_mut().clear(); assert_eq!(msg.optional_bytes(), b""); + assert_eq!(msg.optional_bytes_mut().get(), b""); - msg.optional_bytes_set(Some(b"")); + msg.optional_bytes_mut().set(b""); assert_eq!(msg.optional_bytes(), b""); + assert_eq!(msg.optional_bytes_mut().get(), b""); +} + +#[test] +fn test_optional_bytes_accessors() { + let mut msg = TestProto3Optional::new(); + assert_eq!(msg.optional_bytes(), b""); + assert_eq!(msg.optional_bytes_opt(), Optional::Unset(&b""[..])); + assert_eq!(msg.optional_bytes_mut().get(), b""); + assert!(msg.optional_bytes_mut().is_unset()); + + { + let s = Vec::from(&b"hello world"[..]); + msg.optional_bytes_mut().set(&s[..]); + } + assert_eq!(msg.optional_bytes(), b"hello world"); + assert_eq!(msg.optional_bytes_opt(), Optional::Set(&b"hello world"[..])); + assert!(msg.optional_bytes_mut().is_set()); + assert_eq!(msg.optional_bytes_mut().get(), b"hello world"); + + msg.optional_bytes_mut().or_default().set(b"accessors_test"); + assert_eq!(msg.optional_bytes(), b"accessors_test"); + assert_eq!(msg.optional_bytes_opt(), Optional::Set(&b"accessors_test"[..])); + assert!(msg.optional_bytes_mut().is_set()); + assert_eq!(msg.optional_bytes_mut().get(), b"accessors_test"); + assert_eq!(msg.optional_bytes_mut().or_default().get(), b"accessors_test"); + + msg.optional_bytes_mut().clear(); + assert_eq!(msg.optional_bytes(), b""); + assert_eq!(msg.optional_bytes_opt(), Optional::Unset(&b""[..])); + assert!(msg.optional_bytes_mut().is_unset()); + + msg.optional_bytes_mut().set(b""); + assert_eq!(msg.optional_bytes(), b""); + assert_eq!(msg.optional_bytes_opt(), Optional::Set(&b""[..])); + + msg.optional_bytes_mut().clear(); + msg.optional_bytes_mut().or_default(); + assert_eq!(msg.optional_bytes(), b""); + assert_eq!(msg.optional_bytes_opt(), Optional::Set(&b""[..])); + + msg.optional_bytes_mut().or_default().set(b"\xffbinary\x85non-utf8"); + assert_eq!(msg.optional_bytes(), b"\xffbinary\x85non-utf8"); + assert_eq!(msg.optional_bytes_opt(), Optional::Set(&b"\xffbinary\x85non-utf8"[..])); + assert!(msg.optional_bytes_mut().is_set()); + assert_eq!(msg.optional_bytes_mut().get(), b"\xffbinary\x85non-utf8"); + assert_eq!(msg.optional_bytes_mut().or_default().get(), b"\xffbinary\x85non-utf8"); } diff --git a/rust/test/shared/accessors_test.rs b/rust/test/shared/accessors_test.rs index 55e700d35ee7..32679ea808f4 100644 --- a/rust/test/shared/accessors_test.rs +++ b/rust/test/shared/accessors_test.rs @@ -28,7 +28,9 @@ // (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE // OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -/// Tests covering accessors for singular bool, int32, int64, and bytes fields. +//! Tests covering accessors for singular bool, int32, int64, and bytes fields. + +use protobuf::Optional; use unittest_proto::proto2_unittest::TestAllTypes; #[test] @@ -216,16 +218,93 @@ fn test_optional_bool_accessors() { #[test] fn test_optional_bytes_accessors() { let mut msg = TestAllTypes::new(); - assert_eq!(msg.optional_bytes_opt(), None); - - msg.optional_bytes_set(Some(b"accessors_test")); - assert_eq!(msg.optional_bytes_opt().unwrap(), b"accessors_test"); - - msg.optional_bytes_set(None); - assert_eq!(msg.optional_bytes_opt(), None); + assert_eq!(msg.optional_bytes(), b""); + assert_eq!(msg.optional_bytes_opt(), Optional::Unset(&b""[..])); + assert_eq!(msg.optional_bytes_mut().get(), b""); + assert!(msg.optional_bytes_mut().is_unset()); + + { + let s = Vec::from(&b"hello world"[..]); + msg.optional_bytes_mut().set(&s[..]); + } + assert_eq!(msg.optional_bytes(), b"hello world"); + assert_eq!(msg.optional_bytes_opt(), Optional::Set(&b"hello world"[..])); + assert!(msg.optional_bytes_mut().is_set()); + assert_eq!(msg.optional_bytes_mut().get(), b"hello world"); + + msg.optional_bytes_mut().or_default().set(b"accessors_test"); + assert_eq!(msg.optional_bytes(), b"accessors_test"); + assert_eq!(msg.optional_bytes_opt(), Optional::Set(&b"accessors_test"[..])); + assert!(msg.optional_bytes_mut().is_set()); + assert_eq!(msg.optional_bytes_mut().get(), b"accessors_test"); + assert_eq!(msg.optional_bytes_mut().or_default().get(), b"accessors_test"); + + msg.optional_bytes_mut().clear(); + assert_eq!(msg.optional_bytes(), b""); + assert_eq!(msg.optional_bytes_opt(), Optional::Unset(&b""[..])); + assert!(msg.optional_bytes_mut().is_unset()); + + msg.optional_bytes_mut().set(b""); + assert_eq!(msg.optional_bytes(), b""); + assert_eq!(msg.optional_bytes_opt(), Optional::Set(&b""[..])); + + msg.optional_bytes_mut().clear(); + msg.optional_bytes_mut().or_default(); + assert_eq!(msg.optional_bytes(), b""); + assert_eq!(msg.optional_bytes_opt(), Optional::Set(&b""[..])); + + msg.optional_bytes_mut().or_default().set(b"\xffbinary\x85non-utf8"); + assert_eq!(msg.optional_bytes(), b"\xffbinary\x85non-utf8"); + assert_eq!(msg.optional_bytes_opt(), Optional::Set(&b"\xffbinary\x85non-utf8"[..])); + assert!(msg.optional_bytes_mut().is_set()); + assert_eq!(msg.optional_bytes_mut().get(), b"\xffbinary\x85non-utf8"); + assert_eq!(msg.optional_bytes_mut().or_default().get(), b"\xffbinary\x85non-utf8"); +} - msg.optional_bytes_set(Some(b"")); - assert_eq!(msg.optional_bytes_opt().unwrap(), b""); +#[test] +fn test_nonempty_default_bytes_accessors() { + let mut msg = TestAllTypes::new(); + assert_eq!(msg.default_bytes(), b"world"); + assert_eq!(msg.default_bytes_opt(), Optional::Unset(&b"world"[..])); + assert_eq!(msg.default_bytes_mut().get(), b"world"); + assert!(msg.default_bytes_mut().is_unset()); + + { + let s = String::from("hello world"); + msg.default_bytes_mut().set(s.as_bytes()); + } + assert_eq!(msg.default_bytes(), b"hello world"); + assert_eq!(msg.default_bytes_opt(), Optional::Set(&b"hello world"[..])); + assert!(msg.default_bytes_mut().is_set()); + assert_eq!(msg.default_bytes_mut().get(), b"hello world"); + + msg.default_bytes_mut().or_default().set(b"accessors_test"); + assert_eq!(msg.default_bytes(), b"accessors_test"); + assert_eq!(msg.default_bytes_opt(), Optional::Set(&b"accessors_test"[..])); + assert!(msg.default_bytes_mut().is_set()); + assert_eq!(msg.default_bytes_mut().get(), b"accessors_test"); + assert_eq!(msg.default_bytes_mut().or_default().get(), b"accessors_test"); + + msg.default_bytes_mut().clear(); + assert_eq!(msg.default_bytes(), b"world"); + assert_eq!(msg.default_bytes_opt(), Optional::Unset(&b"world"[..])); + assert!(msg.default_bytes_mut().is_unset()); + + msg.default_bytes_mut().set(b""); + assert_eq!(msg.default_bytes(), b""); + assert_eq!(msg.default_bytes_opt(), Optional::Set(&b""[..])); + + msg.default_bytes_mut().clear(); + msg.default_bytes_mut().or_default(); + assert_eq!(msg.default_bytes(), b"world"); + assert_eq!(msg.default_bytes_opt(), Optional::Set(&b"world"[..])); + + msg.default_bytes_mut().or_default().set(b"\xffbinary\x85non-utf8"); + assert_eq!(msg.default_bytes(), b"\xffbinary\x85non-utf8"); + assert_eq!(msg.default_bytes_opt(), Optional::Set(&b"\xffbinary\x85non-utf8"[..])); + assert!(msg.default_bytes_mut().is_set()); + assert_eq!(msg.default_bytes_mut().get(), b"\xffbinary\x85non-utf8"); + assert_eq!(msg.default_bytes_mut().or_default().get(), b"\xffbinary\x85non-utf8"); } #[test] diff --git a/rust/test/shared/serialization_test.rs b/rust/test/shared/serialization_test.rs index ddfce520ce2f..c265c166aaa3 100644 --- a/rust/test/shared/serialization_test.rs +++ b/rust/test/shared/serialization_test.rs @@ -35,7 +35,7 @@ fn serialize_deserialize_message() { let mut msg = TestAllTypes::new(); msg.optional_int64_set(Some(42)); msg.optional_bool_set(Some(true)); - msg.optional_bytes_set(Some(b"serialize deserialize test")); + msg.optional_bytes_mut().set(b"serialize deserialize test"); let serialized = msg.serialize(); diff --git a/rust/upb.rs b/rust/upb.rs index a887052a4e9f..957f4268d52c 100644 --- a/rust/upb.rs +++ b/rust/upb.rs @@ -30,7 +30,7 @@ //! UPB FFI wrapper code for use by Rust Protobuf. -use crate::__internal::RawArena; +use crate::__internal::{Private, RawArena, RawMessage}; use std::alloc; use std::alloc::Layout; use std::cell::UnsafeCell; @@ -200,6 +200,72 @@ impl fmt::Debug for SerializedData { } } +// TODO(b/293919363): Investigate replacing this with direct access to UPB bits. +pub type BytesPresentMutData<'msg> = crate::vtable::RawVTableOptionalMutatorData<'msg, [u8]>; +pub type BytesAbsentMutData<'msg> = crate::vtable::RawVTableOptionalMutatorData<'msg, [u8]>; +pub type InnerBytesMut<'msg> = crate::vtable::RawVTableMutator<'msg, [u8]>; + +/// The raw contents of every generated message. +#[derive(Debug)] +pub struct MessageInner { + pub msg: RawMessage, + pub arena: Arena, +} + +/// Mutators that point to their original message use this to do so. +/// +/// Since UPB expects runtimes to manage their own arenas, this needs to have +/// access to an `Arena`. +/// +/// This has two possible designs: +/// - Store two pointers here, `RawMessage` and `&'msg Arena`. This doesn't +/// place any restriction on the layout of generated messages and their +/// mutators. This makes a vtable-based mutator three pointers, which can no +/// longer be returned in registers on most platforms. +/// - Store one pointer here, `&'msg MessageInner`, where `MessageInner` stores +/// a `RawMessage` and an `Arena`. This would require all generated messages +/// to store `MessageInner`, and since their mutators need to be able to +/// generate `BytesMut`, would also require `BytesMut` to store a `&'msg +/// MessageInner` since they can't store an owned `Arena`. +/// +/// Note: even though this type is `Copy`, it should only be copied by +/// protobuf internals that can maintain mutation invariants. +#[derive(Clone, Copy, Debug)] +pub struct MutatorMessageRef<'msg> { + msg: RawMessage, + arena: &'msg Arena, +} + +impl<'msg> MutatorMessageRef<'msg> { + #[doc(hidden)] + #[allow(clippy::needless_pass_by_ref_mut)] // Sound construction requires mutable access. + pub fn new(_private: Private, msg: &'msg mut MessageInner) -> Self { + MutatorMessageRef { msg: msg.msg, arena: &msg.arena } + } + + pub fn msg(&self) -> RawMessage { + self.msg + } +} + +pub fn copy_bytes_in_arena_if_needed_by_runtime<'a>( + msg_ref: MutatorMessageRef<'a>, + val: &'a [u8], +) -> &'a [u8] { + // SAFETY: the alignment of `[u8]` is less than `UPB_MALLOC_ALIGN`. + let new_alloc = unsafe { msg_ref.arena.alloc(Layout::for_value(val)) }; + debug_assert_eq!(new_alloc.len(), val.len()); + + let start: *mut u8 = new_alloc.as_mut_ptr().cast(); + // SAFETY: + // - `new_alloc` is writeable for `val.len()` bytes. + // - After the copy, `new_alloc` is initialized for `val.len()` bytes. + unsafe { + val.as_ptr().copy_to_nonoverlapping(start, val.len()); + &*(new_alloc as *mut _ as *mut [u8]) + } +} + #[cfg(test)] mod tests { use super::*; diff --git a/rust/vtable.rs b/rust/vtable.rs new file mode 100644 index 000000000000..293710bcde93 --- /dev/null +++ b/rust/vtable.rs @@ -0,0 +1,378 @@ +// Protocol Buffers - Google's data interchange format +// Copyright 2023 Google LLC. All rights reserved. +// https://developers.google.com/protocol-buffers/ +// +// Redistribution and use in source and binary forms, with or without +// modification, are permitted provided that the following conditions are +// met: +// +// * Redistributions of source code must retain the above copyright +// notice, this list of conditions and the following disclaimer. +// * Redistributions in binary form must reproduce the above +// copyright notice, this list of conditions and the following disclaimer +// in the documentation and/or other materials provided with the +// distribution. +// * Neither the name of Google LLC. nor the names of its +// contributors may be used to endorse or promote products derived from +// this software without specific prior written permission. +// +// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS +// "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT +// LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR +// A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT +// OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, +// SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT +// LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, +// DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY +// THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +use crate::__internal::{Private, PtrAndLen, RawMessage}; +use crate::__runtime::{copy_bytes_in_arena_if_needed_by_runtime, MutatorMessageRef}; +use crate::{ + AbsentField, FieldEntry, Mut, MutProxy, Optional, PresentField, Proxied, ProxiedWithPresence, + View, ViewProxy, +}; +use std::fmt::{self, Debug}; + +/// A proxied type that can use a vtable to provide get/set access for a +/// present field. +/// +/// This vtable should consist of `unsafe fn`s that call thunks that operate on +/// `RawMessage`. The structure of this vtable is different per proxied type. +pub trait ProxiedWithRawVTable: Proxied { + /// The vtable for get/set access, stored in static memory. + type VTable: Debug + 'static; + + fn make_view(_private: Private, mut_inner: RawVTableMutator<'_, Self>) -> View<'_, Self>; + fn make_mut(_private: Private, inner: RawVTableMutator<'_, Self>) -> Mut<'_, Self>; +} + +/// A proxied type that can use a vtable to provide get/set/clear access for +/// an optional field. +/// +/// This vtable should consist of `unsafe fn`s that call thunks that operate on +/// `RawMessage`. The structure of this vtable is different per-proxied type. +pub trait ProxiedWithRawOptionalVTable: ProxiedWithRawVTable + ProxiedWithPresence { + /// The vtable for get/set/clear, must contain `Self::VTable`. + type OptionalVTable: Debug + 'static; + + /// Cast from a static reference of `OptionalVTable` to `VTable`. + /// This should mean `OptionalVTable` contains a `VTable`. + fn upcast_vtable( + _private: Private, + optional_vtable: &'static Self::OptionalVTable, + ) -> &'static Self::VTable; +} + +/// Constructs a new field entry from a raw message, a vtable for manipulation, +/// and an eager check for whether the value is present or not. +/// +/// # Safety +/// - `msg_ref` must be valid to provide as an argument for `vtable`'s methods +/// for `'msg`. +/// - If given `msg_ref` as an argument, any values returned by `vtable` methods +/// must be valid for `'msg`. +/// - Operations on the vtable must be thread-compatible. +#[doc(hidden)] +pub unsafe fn new_vtable_field_entry<'msg, T: ProxiedWithRawOptionalVTable + ?Sized>( + _private: Private, + msg_ref: MutatorMessageRef<'msg>, + optional_vtable: &'static T::OptionalVTable, + is_set: bool, +) -> FieldEntry<'msg, T> +where + T: ProxiedWithPresence< + PresentMutData<'msg> = RawVTableOptionalMutatorData<'msg, T>, + AbsentMutData<'msg> = RawVTableOptionalMutatorData<'msg, T>, + >, +{ + let data = RawVTableOptionalMutatorData { msg_ref, vtable: optional_vtable }; + if is_set { + Optional::Set(PresentField::from_inner(Private, data)) + } else { + Optional::Unset(AbsentField::from_inner(Private, data)) + } +} + +/// The internal implementation type for a vtable-based `protobuf::Mut`. +/// +/// This stores the two components necessary to mutate the field: +/// borrowed message data and a vtable reference. +/// +/// The borrowed message data varies per runtime: C++ needs a message pointer, +/// while UPB needs a message pointer and an `&Arena`. +/// +/// Implementations of `ProxiedWithRawVTable` implement get/set +/// on top of `RawVTableMutator`, and the top-level mutator (e.g. +/// `BytesMut`) calls these methods. +/// +/// [`RawVTableOptionalMutatorData`] is similar, but also includes the +/// capability to has/clear. +pub struct RawVTableMutator<'msg, T: ProxiedWithRawVTable + ?Sized> { + msg_ref: MutatorMessageRef<'msg>, + vtable: &'static T::VTable, +} + +// These use manual impls instead of derives to avoid unnecessary bounds on `T`. +// This problem is referred to as "perfect derive". +// https://smallcultfollowing.com/babysteps/blog/2022/04/12/implied-bounds-and-perfect-derive/ +impl<'msg, T: ProxiedWithRawVTable + ?Sized> Clone for RawVTableMutator<'msg, T> { + fn clone(&self) -> Self { + *self + } +} +impl<'msg, T: ProxiedWithRawVTable + ?Sized> Copy for RawVTableMutator<'msg, T> {} + +impl<'msg, T: ProxiedWithRawVTable + ?Sized> Debug for RawVTableMutator<'msg, T> { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.debug_struct("RawVTableMutator") + .field("msg_ref", &self.msg_ref) + .field("vtable", &self.vtable) + .finish() + } +} + +impl<'msg, T: ProxiedWithRawVTable + ?Sized> RawVTableMutator<'msg, T> { + /// # Safety + /// - `msg_ref` must be valid to provide as an argument for `vtable`'s + /// methods for `'msg`. + /// - If given `msg_ref` as an argument, any values returned by `vtable` + /// methods must be valid for `'msg`. + #[doc(hidden)] + pub unsafe fn new( + _private: Private, + msg_ref: MutatorMessageRef<'msg>, + vtable: &'static T::VTable, + ) -> Self { + RawVTableMutator { msg_ref, vtable } + } +} + +/// [`RawVTableMutator`], but also includes has/clear. +/// +/// This is used as the `PresentData` and `AbsentData` for `impl +/// ProxiedWithPresence for T`. In that implementation, `clear_present_field` +/// and `set_absent_to_default` will use methods implemented on +/// `RawVTableOptionalMutatorData` to do the setting and clearing. +/// +/// This has the same representation for "present" and "absent" data; +/// differences like default values are obviated by the vtable. +pub struct RawVTableOptionalMutatorData<'msg, T: ProxiedWithRawOptionalVTable + ?Sized> { + msg_ref: MutatorMessageRef<'msg>, + vtable: &'static T::OptionalVTable, +} + +unsafe impl<'msg, T: ProxiedWithRawOptionalVTable + ?Sized> Sync + for RawVTableOptionalMutatorData<'msg, T> +{ +} + +// These use manual impls instead of derives to avoid unnecessary bounds on `T`. +// This problem is referred to as "perfect derive". +// https://smallcultfollowing.com/babysteps/blog/2022/04/12/implied-bounds-and-perfect-derive/ +impl<'msg, T: ProxiedWithRawOptionalVTable + ?Sized> Clone + for RawVTableOptionalMutatorData<'msg, T> +{ + fn clone(&self) -> Self { + *self + } +} +impl<'msg, T: ProxiedWithRawOptionalVTable + ?Sized> Copy + for RawVTableOptionalMutatorData<'msg, T> +{ +} + +impl<'msg, T: ProxiedWithRawOptionalVTable + ?Sized> Debug + for RawVTableOptionalMutatorData<'msg, T> +{ + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.debug_struct("RawVTableOptionalMutatorData") + .field("msg_ref", &self.msg_ref) + .field("vtable", &self.vtable) + .finish() + } +} + +impl<'msg, T: ProxiedWithRawOptionalVTable + ?Sized> RawVTableOptionalMutatorData<'msg, T> { + /// # Safety + /// - `msg_ref` must be valid to provide as an argument for `vtable`'s + /// methods for `'msg`. + /// - If given `msg_ref` as an argument, any values returned by `vtable` + /// methods must be valid for `'msg`. + #[doc(hidden)] + pub unsafe fn new( + _private: Private, + msg_ref: MutatorMessageRef<'msg>, + vtable: &'static T::OptionalVTable, + ) -> Self { + Self { msg_ref, vtable } + } + + fn into_raw_mut(self) -> RawVTableMutator<'msg, T> { + RawVTableMutator { msg_ref: self.msg_ref, vtable: T::upcast_vtable(Private, self.vtable) } + } +} + +impl<'msg, T: ProxiedWithRawOptionalVTable + ?Sized + 'msg> ViewProxy<'msg> + for RawVTableOptionalMutatorData<'msg, T> +{ + type Proxied = T; + + fn as_view(&self) -> View<'_, T> { + T::make_view(Private, self.into_raw_mut()) + } + + fn into_view<'shorter>(self) -> View<'shorter, T> + where + 'msg: 'shorter, + { + T::make_view(Private, self.into_raw_mut()) + } +} + +// Note: though this raw value implements `MutProxy`, the `as_mut` is only valid +// when the field is known to be present. `FieldEntry` enforces this in its +// design: `AbsentField { inner: RawVTableOptionalMutatorData }` does not +// implement `MutProxy`. +impl<'msg, T: ProxiedWithRawOptionalVTable + ?Sized + 'msg> MutProxy<'msg> + for RawVTableOptionalMutatorData<'msg, T> +{ + fn as_mut(&mut self) -> Mut<'_, T> { + T::make_mut(Private, self.into_raw_mut()) + } + + fn into_mut<'shorter>(self) -> Mut<'shorter, T> + where + 'msg: 'shorter, + { + T::make_mut(Private, self.into_raw_mut()) + } +} + +impl ProxiedWithRawVTable for [u8] { + type VTable = BytesMutVTable; + + fn make_view(_private: Private, mut_inner: RawVTableMutator<'_, Self>) -> View<'_, Self> { + mut_inner.get() + } + + fn make_mut(_private: Private, inner: RawVTableMutator<'_, Self>) -> Mut<'_, Self> { + crate::string::BytesMut::from_inner(Private, inner) + } +} + +impl ProxiedWithRawOptionalVTable for [u8] { + type OptionalVTable = BytesOptionalMutVTable; + fn upcast_vtable( + _private: Private, + optional_vtable: &'static Self::OptionalVTable, + ) -> &'static Self::VTable { + &optional_vtable.base + } +} + +/// A generic thunk vtable for mutating a present `bytes` or `string` field. +#[doc(hidden)] +#[derive(Debug)] +pub struct BytesMutVTable { + pub(crate) setter: unsafe extern "C" fn(msg: RawMessage, val: *const u8, len: usize), + pub(crate) getter: unsafe extern "C" fn(msg: RawMessage) -> PtrAndLen, +} + +/// A generic thunk vtable for mutating an `optional` `bytes` or `string` field. +#[derive(Debug)] +pub struct BytesOptionalMutVTable { + pub(crate) base: BytesMutVTable, + pub(crate) clearer: unsafe extern "C" fn(msg: RawMessage), + pub(crate) default: &'static [u8], +} + +impl BytesMutVTable { + #[doc(hidden)] + pub const fn new( + _private: Private, + getter: unsafe extern "C" fn(msg: RawMessage) -> PtrAndLen, + setter: unsafe extern "C" fn(msg: RawMessage, val: *const u8, len: usize), + ) -> Self { + Self { getter, setter } + } +} + +impl BytesOptionalMutVTable { + /// # Safety + /// The `default` value must be UTF-8 if required by + /// the runtime and this is for a `string` field. + #[doc(hidden)] + pub const unsafe fn new( + _private: Private, + getter: unsafe extern "C" fn(msg: RawMessage) -> PtrAndLen, + setter: unsafe extern "C" fn(msg: RawMessage, val: *const u8, len: usize), + clearer: unsafe extern "C" fn(msg: RawMessage), + default: &'static [u8], + ) -> Self { + Self { base: BytesMutVTable { getter, setter }, clearer, default } + } +} + +impl<'msg> RawVTableMutator<'msg, [u8]> { + pub(crate) fn get(self) -> &'msg [u8] { + // SAFETY: + // - `msg_ref` is valid for `'msg` as promised by the caller of `new`. + // - The caller of `BytesMutVTable` promised that the returned `PtrAndLen` is + // valid for `'msg`. + unsafe { (self.vtable.getter)(self.msg_ref.msg()).as_ref() } + } + + /// # Safety + /// - `msg_ref` must be valid for `'msg` + /// - If this is for a `string` field, `val` must be valid UTF-8 if the + /// runtime requires it. + pub(crate) unsafe fn set(self, val: &[u8]) { + let val = copy_bytes_in_arena_if_needed_by_runtime(self.msg_ref, val); + // SAFETY: + // - `msg_ref` is valid for `'msg` as promised by the caller of `new`. + unsafe { (self.vtable.setter)(self.msg_ref.msg(), val.as_ptr(), val.len()) } + } + + pub(crate) fn truncate(&self, len: usize) { + if len == 0 { + // SAFETY: The empty string is valid UTF-8. + unsafe { + self.set(b""); + } + return; + } + todo!("b/294252563") + } +} + +impl<'msg> RawVTableOptionalMutatorData<'msg, [u8]> { + /// Sets an absent `bytes`/`string` field to its default value. + pub(crate) fn set_absent_to_default(self) -> Self { + // SAFETY: The default value is UTF-8 if required by the + // runtime as promised by the caller of `BytesOptionalMutVTable::new`. + unsafe { self.set(self.vtable.default) } + } + + /// # Safety + /// - If this is a `string` field, `val` must be valid UTF-8 if required by + /// the runtime. + pub(crate) unsafe fn set(self, val: &[u8]) -> Self { + let val = copy_bytes_in_arena_if_needed_by_runtime(self.msg_ref, val); + // SAFETY: + // - `msg_ref` is valid for `'msg` as promised by the caller. + unsafe { (self.vtable.base.setter)(self.msg_ref.msg(), val.as_ptr(), val.len()) } + self + } + + pub(crate) fn clear(self) -> Self { + // SAFETY: + // - `msg_ref` is valid for `'msg` as promised by the caller. + // - The caller of `new` promised that the returned `PtrAndLen` is valid for + // `'msg`. + unsafe { (self.vtable.clearer)(self.msg_ref.msg()) } + self + } +} diff --git a/src/google/protobuf/compiler/rust/accessors/singular_bytes.cc b/src/google/protobuf/compiler/rust/accessors/singular_bytes.cc index 3deb5a61e4b0..089ff63467cc 100644 --- a/src/google/protobuf/compiler/rust/accessors/singular_bytes.cc +++ b/src/google/protobuf/compiler/rust/accessors/singular_bytes.cc @@ -28,6 +28,9 @@ // (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE // OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +#include + +#include "absl/strings/escaping.h" #include "absl/strings/string_view.h" #include "google/protobuf/compiler/cpp/helpers.h" #include "google/protobuf/compiler/rust/accessors/accessor_generator.h" @@ -41,45 +44,107 @@ namespace compiler { namespace rust { void SingularBytes::InMsgImpl(Context field) const { + std::string hazzer_thunk = Thunk(field, "has"); + std::string getter_thunk = Thunk(field, "get"); + std::string setter_thunk = Thunk(field, "set"); field.Emit( { {"field", field.desc().name()}, - {"hazzer_thunk", Thunk(field, "has")}, - {"getter_thunk", Thunk(field, "get")}, - {"setter_thunk", Thunk(field, "set")}, - {"clearer_thunk", Thunk(field, "clear")}, - {"getter_opt", + {"hazzer_thunk", hazzer_thunk}, + {"getter_thunk", getter_thunk}, + {"setter_thunk", setter_thunk}, + {"field_optional_getter", [&] { if (!field.desc().is_optional()) return; if (!field.desc().has_presence()) return; - field.Emit({}, R"rs( - pub fn $field$_opt(&self) -> Option<&[u8]> { - if !unsafe { $hazzer_thunk$(self.msg) } { - return None; - } - unsafe { - Some($getter_thunk$(self.msg).as_ref()) - } - })rs"); + field.Emit({{"hazzer_thunk", hazzer_thunk}, + {"getter_thunk", getter_thunk}}, + R"rs( + pub fn $field$_opt(&self) -> $pb$::Optional<&[u8]> { + unsafe { + $pb$::Optional::new( + $getter_thunk$(self.inner.msg).as_ref(), + $hazzer_thunk$(self.inner.msg) + ) + } + } + )rs"); + }}, + {"field_mutator_getter", + [&] { + if (field.desc().has_presence()) { + field.Emit( + { + {"field", field.desc().name()}, + {"default_val", + absl::CHexEscape(field.desc().default_value_string())}, + {"hazzer_thunk", hazzer_thunk}, + {"getter_thunk", getter_thunk}, + {"setter_thunk", setter_thunk}, + {"clearer_thunk", Thunk(field, "clear")}, + }, + R"rs( + pub fn $field$_mut(&mut self) -> $pb$::FieldEntry<'_, [u8]> { + static VTABLE: $pbi$::BytesOptionalMutVTable = unsafe { + $pbi$::BytesOptionalMutVTable::new( + $pbi$::Private, + $getter_thunk$, + $setter_thunk$, + $clearer_thunk$, + b"$default_val$", + ) + }; + unsafe { + let has = $hazzer_thunk$(self.inner.msg); + $pbi$::new_vtable_field_entry( + $pbi$::Private, + $pbr$::MutatorMessageRef::new( + $pbi$::Private, &mut self.inner), + &VTABLE, + has, + ) + } + } + )rs"); + } else { + field.Emit({{"field", field.desc().name()}, + {"getter_thunk", getter_thunk}, + {"setter_thunk", setter_thunk}}, + R"rs( + pub fn $field$_mut(&mut self) -> $pb$::BytesMut<'_> { + static VTABLE: $pbi$::BytesMutVTable = unsafe { + $pbi$::BytesMutVTable::new( + $pbi$::Private, + $getter_thunk$, + $setter_thunk$, + ) + }; + unsafe { + $pb$::BytesMut::from_inner( + $pbi$::Private, + $pbi$::RawVTableMutator::new( + $pbi$::Private, + $pbr$::MutatorMessageRef::new( + $pbi$::Private, &mut self.inner), + &VTABLE, + ) + ) + } + } + )rs"); + } }}, }, R"rs( - pub fn r#$field$(&self) -> &[u8] { - unsafe { $getter_thunk$(self.msg).as_ref() } + pub fn r#$field$(&self) -> &[u8] { + unsafe { + $getter_thunk$(self.inner.msg).as_ref() } - $getter_opt$ - pub fn $field$_set(&mut self, val: Option<&[u8]>) { - match val { - Some(val) => - if val.len() == 0 { - unsafe { $setter_thunk$(self.msg, $std$::ptr::null(), 0) } - } else { - unsafe { $setter_thunk$(self.msg, val.as_ptr(), val.len()) } - }, - None => unsafe { $clearer_thunk$(self.msg) }, - } - } - )rs"); + } + + $field_optional_getter$ + $field_mutator_getter$ + )rs"); } void SingularBytes::InExternC(Context field) const { diff --git a/src/google/protobuf/compiler/rust/accessors/singular_message.cc b/src/google/protobuf/compiler/rust/accessors/singular_message.cc index c939714639c8..038abe78d00d 100644 --- a/src/google/protobuf/compiler/rust/accessors/singular_message.cc +++ b/src/google/protobuf/compiler/rust/accessors/singular_message.cc @@ -46,7 +46,7 @@ void SingularMessage::InMsgImpl(Context field) const { R"rs( // inMsgImpl pub fn r#$field$(&self) -> $Msg$View { - $Msg$View { msg: self.msg, _phantom: std::marker::PhantomData } + $Msg$View { msg: self.inner.msg, _phantom: std::marker::PhantomData } } )rs"); } diff --git a/src/google/protobuf/compiler/rust/accessors/singular_scalar.cc b/src/google/protobuf/compiler/rust/accessors/singular_scalar.cc index 6be8f44eefc2..93a5d41b67b3 100644 --- a/src/google/protobuf/compiler/rust/accessors/singular_scalar.cc +++ b/src/google/protobuf/compiler/rust/accessors/singular_scalar.cc @@ -50,7 +50,7 @@ void SingularScalar::InMsgImpl(Context field) const { [&] { field.Emit({}, R"rs( pub fn r#$field$(&self) -> $Scalar$ { - unsafe { $getter_thunk$(self.msg) } + unsafe { $getter_thunk$(self.inner.msg) } } )rs"); }}, @@ -60,10 +60,10 @@ void SingularScalar::InMsgImpl(Context field) const { if (!field.desc().has_presence()) return; field.Emit({}, R"rs( pub fn r#$field$_opt(&self) -> Option<$Scalar$> { - if !unsafe { $hazzer_thunk$(self.msg) } { + if !unsafe { $hazzer_thunk$(self.inner.msg) } { return None; } - Some(unsafe { $getter_thunk$(self.msg) }) + Some(unsafe { $getter_thunk$(self.inner.msg) }) } )rs"); }}, @@ -77,8 +77,8 @@ void SingularScalar::InMsgImpl(Context field) const { pub fn $field$_set(&mut self, val: Option<$Scalar$>) { match val { - Some(val) => unsafe { $setter_thunk$(self.msg, val) }, - None => unsafe { $clearer_thunk$(self.msg) }, + Some(val) => unsafe { $setter_thunk$(self.inner.msg, val) }, + None => unsafe { $clearer_thunk$(self.inner.msg) }, } } )rs"); diff --git a/src/google/protobuf/compiler/rust/message.cc b/src/google/protobuf/compiler/rust/message.cc index 64027f97ab1f..6f9ce5a43e16 100644 --- a/src/google/protobuf/compiler/rust/message.cc +++ b/src/google/protobuf/compiler/rust/message.cc @@ -45,33 +45,12 @@ namespace protobuf { namespace compiler { namespace rust { namespace { -void MessageStructFields(Context msg) { - switch (msg.opts().kernel) { - case Kernel::kCpp: - msg.Emit(R"rs( - msg: $pbi$::RawMessage, - )rs"); - return; - - case Kernel::kUpb: - msg.Emit(R"rs( - msg: $pbi$::RawMessage, - //~ rustc incorrectly thinks this field is never read, even though - //~ it has a destructor! - #[allow(dead_code)] - arena: $pbr$::Arena, - )rs"); - return; - } - - ABSL_LOG(FATAL) << "unreachable"; -} void MessageNew(Context msg) { switch (msg.opts().kernel) { case Kernel::kCpp: msg.Emit({{"new_thunk", Thunk(msg, "new")}}, R"rs( - Self { msg: unsafe { $new_thunk$() } } + Self { inner: $pbr$::MessageInner { msg: unsafe { $new_thunk$() } } } )rs"); return; @@ -79,8 +58,10 @@ void MessageNew(Context msg) { msg.Emit({{"new_thunk", Thunk(msg, "new")}}, R"rs( let arena = $pbr$::Arena::new(); Self { - msg: unsafe { $new_thunk$(arena.raw()) }, - arena, + inner: $pbr$::MessageInner { + msg: unsafe { $new_thunk$(arena.raw()) }, + arena, + } } )rs"); return; @@ -93,7 +74,7 @@ void MessageSerialize(Context msg) { switch (msg.opts().kernel) { case Kernel::kCpp: msg.Emit({{"serialize_thunk", Thunk(msg, "serialize")}}, R"rs( - unsafe { $serialize_thunk$(self.msg) } + unsafe { $serialize_thunk$(self.inner.msg) } )rs"); return; @@ -102,7 +83,7 @@ void MessageSerialize(Context msg) { let arena = $pbr$::Arena::new(); let mut len = 0; unsafe { - let data = $serialize_thunk$(self.msg, arena.raw(), &mut len); + let data = $serialize_thunk$(self.inner.msg, arena.raw(), &mut len); $pbr$::SerializedData::from_raw_parts(arena, data, len) } )rs"); @@ -126,7 +107,7 @@ void MessageDeserialize(Context msg) { data.len(), ); - $deserialize_thunk$(self.msg, data) + $deserialize_thunk$(self.inner.msg, data) }; success.then_some(()).ok_or($pb$::ParseError) )rs"); @@ -143,9 +124,9 @@ void MessageDeserialize(Context msg) { None => Err($pb$::ParseError), Some(msg) => { // This assignment causes self.arena to be dropped and to deallocate - // any previous message pointed/owned to by self.msg. - self.arena = arena; - self.msg = msg; + // any previous message pointed/owned to by self.inner.msg. + self.inner.arena = arena; + self.inner.msg = msg; Ok(()) } } @@ -200,7 +181,7 @@ void MessageDrop(Context msg) { } msg.Emit({{"delete_thunk", Thunk(msg, "delete")}}, R"rs( - unsafe { $delete_thunk$(self.msg); } + unsafe { $delete_thunk$(self.inner.msg); } )rs"); } } // namespace @@ -213,7 +194,6 @@ void GenerateRs(Context msg) { msg.Emit( { {"Msg", msg.desc().name()}, - {"Msg.fields", [&] { MessageStructFields(msg); }}, {"Msg::new", [&] { MessageNew(msg); }}, {"Msg::serialize", [&] { MessageSerialize(msg); }}, {"Msg::deserialize", [&] { MessageDeserialize(msg); }}, @@ -264,12 +244,15 @@ void GenerateRs(Context msg) { #[allow(non_camel_case_types)] #[derive(Debug)] pub struct $Msg$ { - $Msg.fields$ + inner: $pbr$::MessageInner } + // SAFETY: + // - `$Msg$` does not provide shared mutation with its arena. + // - `$Msg$Mut` is not `Send`, and so even in the presence of mutator + // splitting, synchronous access of an arena that would conflict with + // field access is impossible. unsafe impl Sync for $Msg$ {} - unsafe impl Sync for $Msg$View<'_> {} - unsafe impl Send for $Msg$View<'_> {} impl $pb$::Proxied for $Msg$ { type View<'a> = $Msg$View<'a>; @@ -283,6 +266,15 @@ void GenerateRs(Context msg) { _phantom: $Phantom$<&'a ()>, } + // SAFETY: + // - `$Msg$View` does not perform any mutation. + // - While a `$Msg$View` exists, a `$Msg$Mut` can't exist to mutate + // the arena that would conflict with field access. + // - `$Msg$Mut` is not `Send`, and so even in the presence of mutator + // splitting, synchronous access of an arena is impossible. + unsafe impl Sync for $Msg$View<'_> {} + unsafe impl Send for $Msg$View<'_> {} + impl<'a> $pb$::ViewProxy<'a> for $Msg$View<'a> { type Proxied = $Msg$; @@ -303,15 +295,18 @@ void GenerateRs(Context msg) { #[derive(Debug, Copy, Clone)] #[allow(dead_code)] pub struct $Msg$Mut<'a> { - msg: $pbi$::RawMessage, - _phantom: $Phantom$<&'a mut ()>, + inner: $pbr$::MutatorMessageRef<'a>, } + // SAFETY: + // - `$Msg$Mut` does not perform any shared mutation. + // - `$Msg$Mut` is not `Send`, and so even in the presence of mutator + // splitting, synchronous access of an arena is impossible. unsafe impl Sync for $Msg$Mut<'_> {} impl<'a> $pb$::MutProxy<'a> for $Msg$Mut<'a> { fn as_mut(&mut self) -> $pb$::Mut<'_, $Msg$> { - $Msg$Mut { msg: self.msg, _phantom: self._phantom } + $Msg$Mut { inner: self.inner } } fn into_mut<'shorter>(self) -> $pb$::Mut<'shorter, $Msg$> where 'a : 'shorter { self } } @@ -319,10 +314,10 @@ void GenerateRs(Context msg) { impl<'a> $pb$::ViewProxy<'a> for $Msg$Mut<'a> { type Proxied = $Msg$; fn as_view(&self) -> $pb$::View<'_, $Msg$> { - $Msg$View { msg: self.msg, _phantom: std::marker::PhantomData } + $Msg$View { msg: self.inner.msg(), _phantom: std::marker::PhantomData } } fn into_view<'shorter>(self) -> $pb$::View<'shorter, $Msg$> where 'a: 'shorter { - $Msg$View { msg: self.msg, _phantom: std::marker::PhantomData } + $Msg$View { msg: self.inner.msg(), _phantom: std::marker::PhantomData } } } @@ -363,10 +358,10 @@ void GenerateRs(Context msg) { msg.Emit({{"Msg", msg.desc().name()}}, R"rs( impl $Msg$ { pub fn __unstable_wrap_cpp_grant_permission_to_break(msg: $pbi$::RawMessage) -> Self { - Self { msg } + Self { inner: $pbr$::MessageInner { msg } } } pub fn __unstable_cpp_repr_grant_permission_to_break(&mut self) -> $pbi$::RawMessage { - self.msg + self.inner.msg } } )rs");